pytorch(6)–深度置信网络

pytorch(6)–深度置信网络一、前言本文主要使用pytorch实现的DBN网络,用于对数据做回归,单个数据维度为(N,21),其中N为不定长,输出则为(N,1),对应N个值DBN网络结构:首层神经元数量输入为变量长度21,中间为RBM网络,如本篇使用的网络结构诶[128,64,32,16],为一个4层的RBM网络结构,训练时RBM需要逐层做训练;在RBM训练后,再接上BP神经网络,再对BP网络做微调,回归损失函数使用MSEloss。二、深度置信网络实现代码#DBN.pyimpor…

一、前言

    本文主要使用pytorch 实现的DBN网络,用于对数据做回归,单个数据维度为(N,21),其中N为不定长,输出则为(N,1),对应N个值

DBN网络结构:

    pytorch(6)--深度置信网络

    首层神经元数量输入为变量长度21,中间为RBM网络,如本篇使用的网络结构诶[128,64,32,16],为一个4层的RBM网络结构,训练时RBM需要逐层做训练;在RBM训练后,再接上BP神经网络,再对BP网络做微调,回归损失函数使用MSE loss。

二、深度置信网络实现代码

#DBN.py
import torch
import warnings
import torch.nn as nn
import numpy as np

from RBM import RBM
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.optim import Adam, SGD

from genCsvData import indefDataSet,DataLoader


class DBN(nn.Module

今天的文章pytorch(6)–深度置信网络分享到此就结束了,感谢您的阅读,如果确实帮到您,您可以动动手指转发给其他人。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/24720.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注