PyTorch学习笔记之CBOW模型实践
作者:互联网
1 import torch 2 from torch import nn, optim 3 from torch.autograd import Variable 4 import torch.nn.functional as F 5 6 CONTEXT_SIZE = 2 # 2 words to the left, 2 to the right 7 raw_text = "We are about to study the idea of a computational process. Computational processes are abstract beings that inhabit computers. As they evolve, processes manipulate other abstract things called data. The evolution of a process is directed by a pattern of rules called a program. People create programs to direct processes. In effect, we conjure the spirits of the computer with our spells.".split(' ') 8 9 vocab = set(raw_text) 10 word_to_idx = {word: i for i, word in enumerate(vocab)} 11 12 data = [] 13 for i in range(CONTEXT_SIZE, len(raw_text)-CONTEXT_SIZE): 14 context = [raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]] 15 target = raw_text[i] 16 data.append((context, target)) 17 18 19 class CBOW(nn.Module): 20 def __init__(self, n_word, n_dim, context_size): 21 super(CBOW, self).__init__() 22 self.embedding = nn.Embedding(n_word, n_dim) 23 self.linear1 = nn.Linear(2*context_size*n_dim, 128) 24 self.linear2 = nn.Linear(128, n_word) 25 26 def forward(self, x): 27 x = self.embedding(x) 28 x = x.view(1, -1) 29 x = self.linear1(x) 30 x = F.relu(x, inplace=True) 31 x = self.linear2(x) 32 x = F.log_softmax(x) 33 return x 34 35 36 model = CBOW(len(word_to_idx), 100, CONTEXT_SIZE) 37 if torch.cuda.is_available(): 38 model = model.cuda() 39 40 criterion = nn.CrossEntropyLoss() 41 optimizer = optim.SGD(model.parameters(), lr=1e-3) 42 43 for epoch in range(100): 44 print('epoch {}'.format(epoch)) 45 print('*'*10) 46 running_loss = 0 47 for word in data: 48 context, target = word 49 context = Variable(torch.LongTensor([word_to_idx[i] for i in context])) 50 target = Variable(torch.LongTensor([word_to_idx[target]])) 51 if torch.cuda.is_available(): 52 context = context.cuda() 53 target = target.cuda() 54 # forward 55 out = model(context) 56 loss = criterion(out, target) 57 running_loss += loss.data[0] 58 # backward 59 optimizer.zero_grad() 60 loss.backward() 61 optimizer.step() 62 print('loss: {:.6f}'.format(running_loss / len(data)))
标签:word,target,text,self,torch,笔记,PyTorch,context,CBOW 来源: https://www.cnblogs.com/jfdwd/p/11076977.html