其他分享
首页 > 其他分享> > Pytorch深度学习实践第十二讲 循环神经网络(基础篇)

Pytorch深度学习实践第十二讲 循环神经网络(基础篇)

作者:互联网

import torch

input_size=4
hidden_size=4
batch_size=1
# 准备数据
idx2char=['e','h','l','o']
x_data=[1,0,2,2,3] # hello
y_data=[3,1,2,3,2] # ohlol

one_hot_lookup=[[1,0,0,0],
                [0,1,0,0],
                [0,0,1,0],
                [0,0,0,1]] #分别对应0,1,2,3项
x_one_hot=[one_hot_lookup[x] for x in x_data] # 组成序列张量
print('x_one_hot:',x_one_hot)

# 构造输入序列和标签
inputs=torch.Tensor(x_one_hot).view(-1,batch_size,input_size)
labels=torch.LongTensor(y_data).view(-1,1)

# design model
class Model(torch.nn.Module):
    def __init__(self,input_size,hidden_size,batch_size):
        super(Model, self).__init__()
        self.batch_size=batch_size
        self.input_size=input_size
        self.hidden_size=hidden_size
        self.rnncell=torch.nn.RNNCell(input_size=self.input_size,
                                      hidden_size=self.hidden_size)

    def forward(self,input,hidden):
        hidden=self.rnncell(input,hidden)
        return hidden

    def init_hidden(self):
        return torch.zeros(self.batch_size,self.hidden_size)
net=Model(input_size,hidden_size,batch_size)

# loss and optimizer
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.1)

# train cycle
for epoch in range(15):
    loss=0
    optimizer.zero_grad()
    hidden=net.init_hidden()
    print('Predicted String:',end='')
    for input ,lable in zip(inputs,labels):
        hidden=net(input,hidden)
        loss+=criterion(hidden,lable)
        _, idx=hidden.max(dim=1)
        print(idx2char[idx.item()],end='')
    loss.backward()
    optimizer.step()
    print(',Epoch [%d/15] loss=%.4f' % (epoch+1,loss.item()))

输出结果:

x_one_hot: [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
Predicted String:eeeee,Epoch [1/15] loss=9.1647
Predicted String:eehoe,Epoch [2/15] loss=7.5995
Predicted String:elhol,Epoch [3/15] loss=6.7802
Predicted String:olool,Epoch [4/15] loss=6.1220
Predicted String:olool,Epoch [5/15] loss=5.5625
Predicted String:ololl,Epoch [6/15] loss=5.1270
Predicted String:ololl,Epoch [7/15] loss=4.8060
Predicted String:ololl,Epoch [8/15] loss=4.5607
Predicted String:oholl,Epoch [9/15] loss=4.3423
Predicted String:oholl,Epoch [10/15] loss=4.1480
Predicted String:oholl,Epoch [11/15] loss=3.9697
Predicted String:oholl,Epoch [12/15] loss=3.8007
Predicted String:oholl,Epoch [13/15] loss=3.6583
Predicted String:oholl,Epoch [14/15] loss=3.5437
Predicted String:oholl,Epoch [15/15] loss=3.4412

标签:第十二,loss,15,String,Epoch,神经网络,Pytorch,hidden,size
来源: https://blog.csdn.net/lcnana/article/details/121237826