GRU(Gated Recurrent Unit)门控循环单元结构
作者:互联网
GRU(Gated Recurrent Unit)也称门控循环单元结构
nn.GRU类初始化主要参数解释:
input_size: 输入张量x中特征维度的大小.
hidden_size: 隐层张量h中特征维度的大小.
num_layers: 隐含层的数量.
nonlinearity: 激活函数的选择, 默认是tanh.
bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.
nn.GRU类实例化对象主要参数解释:
input: 输入张量x.
h0: 初始化的隐层张量h.
代码示例:
import torch
import torch.nn as nn
rnn = nn.GRU(5,6,2) #数据向量维数5, 隐藏元维度6, 2个LSTM层串联(如果是1,可以省略,默认为1)
input = torch.randn(1,3,5) # 序列长度seq_len=1, batch_size=3, 数据向量维数=5
h0 = torch.randn(2,3,6) # 2个LSTM层,batch_size=3,隐藏元维度6
output, hn = rnn(input,h0)
print(output)
print(output.type())
print(output.shape)
print(hn)
代码运行结果
tensor([[[-0.0307, 0.0718, -0.2517, 0.0565, 0.0613, 0.5001],
[-0.6239, 1.0618, 0.7506, 0.3475, -0.8536, -0.8410],
[-0.0949, 0.5698, 0.4491, -0.0122, 0.5413, -0.2383]]],
grad_fn=<StackBackward0>)
torch.FloatTensor
torch.Size([1, 3, 6])
tensor([[[-0.5540, 0.3067, -1.2936, -0.3727, -0.4141, 0.2967],
[-1.2364, 0.7779, 0.4355, -1.2783, -0.0382, 0.5875],
[ 1.4438, 1.2898, -0.3959, -0.5599, -1.1615, 0.3538]],
[[-0.0307, 0.0718, -0.2517, 0.0565, 0.0613, 0.5001],
[-0.6239, 1.0618, 0.7506, 0.3475, -0.8536, -0.8410],
[-0.0949, 0.5698, 0.4491, -0.0122, 0.5413, -0.2383]]],
grad_fn=<StackBackward0>)
GRU的优势:
GRU和LSTM作用相同, 在捕捉长序列语义关联时, 能有效抑制梯度消失或爆炸, 效果都优于传统RNN且计算复杂度相比LSTM要小.
GRU的缺点:
GRU仍然不能完全解决梯度消失问题, 同时其作用RNN的变体, 有着RNN结构本身的一大弊端, 即不可并行计算, 这在数据量和模型体量逐步增大的未来, 是RNN发展的关键瓶颈.
标签:主要参数,GRU,torch,nn,Recurrent,Gated,input,LSTM 来源: https://blog.csdn.net/weixin_41862755/article/details/123140739