其他分享
首页 > 其他分享> > GRU(Gated Recurrent Unit)门控循环单元结构

GRU(Gated Recurrent Unit)门控循环单元结构

作者:互联网

GRU(Gated Recurrent Unit)也称门控循环单元结构

nn.GRU类初始化主要参数解释:

    input_size: 输入张量x中特征维度的大小.
    hidden_size: 隐层张量h中特征维度的大小.
    num_layers: 隐含层的数量.
    nonlinearity: 激活函数的选择, 默认是tanh.
    bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.

nn.GRU类实例化对象主要参数解释:

    input: 输入张量x.
    h0: 初始化的隐层张量h.

代码示例:

import torch
import torch.nn as nn
rnn = nn.GRU(5,6,2) #数据向量维数5, 隐藏元维度6, 2个LSTM层串联(如果是1,可以省略,默认为1)
input = torch.randn(1,3,5) # 序列长度seq_len=1, batch_size=3, 数据向量维数=5
h0 = torch.randn(2,3,6) # 2个LSTM层,batch_size=3,隐藏元维度6
output, hn = rnn(input,h0)
print(output)
print(output.type())
print(output.shape)
print(hn)

代码运行结果

tensor([[[-0.0307,  0.0718, -0.2517,  0.0565,  0.0613,  0.5001],
         [-0.6239,  1.0618,  0.7506,  0.3475, -0.8536, -0.8410],
         [-0.0949,  0.5698,  0.4491, -0.0122,  0.5413, -0.2383]]],
       grad_fn=<StackBackward0>)
torch.FloatTensor
torch.Size([1, 3, 6])
tensor([[[-0.5540,  0.3067, -1.2936, -0.3727, -0.4141,  0.2967],
         [-1.2364,  0.7779,  0.4355, -1.2783, -0.0382,  0.5875],
         [ 1.4438,  1.2898, -0.3959, -0.5599, -1.1615,  0.3538]],

        [[-0.0307,  0.0718, -0.2517,  0.0565,  0.0613,  0.5001],
         [-0.6239,  1.0618,  0.7506,  0.3475, -0.8536, -0.8410],
         [-0.0949,  0.5698,  0.4491, -0.0122,  0.5413, -0.2383]]],
       grad_fn=<StackBackward0>)

GRU的优势:

    GRU和LSTM作用相同, 在捕捉长序列语义关联时, 能有效抑制梯度消失或爆炸, 效果都优于传统RNN且计算复杂度相比LSTM要小.

GRU的缺点:

    GRU仍然不能完全解决梯度消失问题, 同时其作用RNN的变体, 有着RNN结构本身的一大弊端, 即不可并行计算, 这在数据量和模型体量逐步增大的未来, 是RNN发展的关键瓶颈.

标签:主要参数,GRU,torch,nn,Recurrent,Gated,input,LSTM
来源: https://blog.csdn.net/weixin_41862755/article/details/123140739