其他分享
首页 > 其他分享> > 深度学习学习——将LSTM,GRU等模型加入 nn.Sequential中

深度学习学习——将LSTM,GRU等模型加入 nn.Sequential中

作者:互联网

因为nn.GRU还有nn.LSTM的输出是两个元素,直接加到nn.Sequential中会报错,因此需要借助一个元素选择的小组件 SelectItem 来挑选

class SelectItem(nn.Module):
    def __init__(self, item_index):
        super(SelectItem, self).__init__()
        self._name = 'selectitem'
        self.item_index = item_index

    def forward(self, inputs):
        return inputs[self.item_index]

SelectItem 可以用于到Sequential中选择隐含状态:

    net = nn.Sequential(
        nn.GRU(dim_in, dim_out, batch_first=True),
        SelectItem(1),
        nn.Dropout(0.2),
        )

标签:index,GRU,nn,SelectItem,self,item,Sequential
来源: https://blog.csdn.net/m0_37876745/article/details/123032207