MeLU模型复现
作者:互联网
MeLU算是推荐系统冷启动中非常经典的一个模型,在近两年很多冷启动相关的论文都拿它做baseline。以下总结一些个人觉得值得关注的地方。代码参考自MELU_pytorch
class Linear(nn.Linear):
def __init__(self, in_features, out_features):
super(Linear, self).__init__(in_features, out_features)
self.weight.fast = None
self.bias.fast = None
def forward(self, x):
if self.weight.fast is not None and self.bias.fast is not None:
out = F.linear(x, self.weight.fast, self.bias.fast)
else:
out = super(Linear, self).forward(x)
return out
首先是Linear的重写,因为MeLU中涉及到元学习中的MAML,会涉及到两个梯度,普通的Linear无法实现这种操作,因此在原有的Linear上又加入了fast,fast是inner loop更新后的参数。
fast_parameters = []
for k, weight in enumerate(model.final_part.parameters()):
if weight.fast is None:
weight.fast = weight - args.lr_inner * grad[k]
else:
weight.fast = weight.fast - args.lr_inner * grad[k]
fast_parameters.append(weight.fast)
inner loop的更新,这里只更新除了用户与物品属性之外的参数的embedding。
logits_q = model(x_qry[i])
loss_q = F.mse_loss(logits_q, y_qry[i])
loss_after.append(loss_q.item())
task_grad_test = torch.autograd.grad(loss_q, model.parameters())
for g in range(len(task_grad_test)):
meta_grad[g] += task_grad_test[g].detach()
meta_optimizer.zero_grad()
for c, param in enumerate(model.parameters()):
param.grad = meta_grad[c] / float(args.tasks_per_metaupdate)
param.grad.data.clamp_(-10, 10)
meta_optimizer.step()
outer loop的更新,这里包括梯度截断等操作。
MeLU总体的代码还是比较容易看懂的,代码中一半都是用来处理数据,实际的模型代码并不长,核心的部分就是上述内容。
MeLU有着比较明显的优缺点,优点是它使用了元学习中的MAML,可以为冷启动用户或物品生成一个较为通用的表示,使得仅使用少量数据就可以使冷启动用户和物品快速适应推荐系统。缺点是在推荐时,仅使用到了用户和物品的相关属性,没有使用到富有价值的用户历史交互序列这一信息,而且MeLU为每一类属性生成的embedding都是固定的,有可能存在相同属性的用户偏爱不同类型的物品,这种情况下MeLU效果就会很糟糕。
标签:Linear,weight,模型,fast,MeLU,复现,grad,self 来源: https://www.cnblogs.com/ambition-hhn/p/16687762.html