发布时间:2023-11-13 13:30
在PyTorch中的调用也非常简单,使用 nn.RNN()即可调用,下面依次介绍其中的参数。
RNN() 里面的参数有
input_size 表示输入 xt 的特征维度
hidden_size 表示输出的特征维度
num_layers 表示网络的层数
nonlinearity 表示选用的非线性激活函数,默认是 ‘tanh’
bias 表示是否使用偏置,默认使用
batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
dropout 表示是否在输出层应用 dropout
bidirectional 表示是否使用双向的 rnn,默认是 False。
接着再介绍网络接收的输入和输出。网络会接收一个序列输入xt和记忆输入h0,xt的维度是(seq,batch,feature),分别表示序列长度、批量和输入的特征维度,h0也叫隐藏状态,它的维度是(layersdirection,batch,hidden),分别表示层数乘方向(如果是单向,就是1,如果是双向就是2)、批量和输出的维度。网络会输出output和hn,output表示网络实际的输出,维度是(seq,batch, hiddendirection),分别表示序列长度、批量和输出维度乘上方向,hn表示记忆单元,维度是(layer*direction,batch,hidden),分别表示层数乘方向、批量和输出维度。
对于定义好的RNN,可以通过 weight_ih_l0来访问第一层中的 w i h w_{ih} wih,另外要访问第二层网络可以使用 weight_ih_l1。对于 w h h w_{hh} whh,可以用weight_hh_l0来访问,而 b i h b_{ih} bih 则可以用bias_ih_l0来访问.当然可以对它进行自定义的初始化,只需要记得这些参数都是Variable,取出它们的data,对它进行自定的初始化即可。
import torch
from torch.autograd import Variable
from torch import nn
# 构造一个序列,长为 6,batch 是 5, 特征是 100
x = Variable(torch.randn(6, 5, 100)) # 这是 rnn 的输入格式
rnn_seq = nn.RNN(100, 200)
# 访问其中的参数
print(rnn_seq.weight_hh_l0) #与h相乘的权重
print(rnn_seq.weight_ih_l0) #与x相乘的权重
out, h_t = rnn_seq(x) # 使用默认的全 0 隐藏状态
# 自己定义初始的隐藏状态
h_0 = Variable(torch.randn(1, 5, 200))
out, h_t = rnn_seq(x, h_0)
print(out.shape,h_t.shape)
#输出:torch.Size([6, 5, 200]) torch.Size([1, 5, 200])
如果在传入网络的时候不特别注明隐藏状态h0 ,那么初始的隐藏状态默认参数全是0,当然也可以用上面的方式来自定义隐藏状态的初始化。
SAP ABAP 处理 Excel 的标准函数 TEXT_CONVERT_XLS_TO_SAP 介绍试读版
Identity Server 4使用OpenID Connect添加用户身份验证(三)
vue使用element上传视频功能(上传前进行判断,自定义上传)
墨天轮访谈 | 百度云邱学达:GaiaDB如何解决云上场景的业务需求?
Android 中TextureView和SurfaceView的属性方法及示例说明
呕心沥血,一整套完整的基于SpringSecurity的自定义认证、动态授权、权限控制以及注销功能的实现