其他分享
首页 > 其他分享> > torch工具箱

torch工具箱

作者:互联网

inpt = torch.ones(size=(4, ))
w = torch.tensor(2.0, requires_grad=True)
l = inpt * w
loss = l.mean()

# 钩子函数,在反向传播时将l的梯度打印,随后销毁
l.register_hook(lambda grad: print(f'l.grad:{grad}'))  

# 保存中间变量的梯度
# l.retain_grad()

loss.backward()  # 执行反向传播

print(w.grad)
print(w.grad_fn)  # None,因为是用户自创建
print(l.grad)     # None,应为是非叶子变量,所以默认不保存梯度

# <MulBackward0 object at xxx>,torch内定义了基本操作的反向传播函数
print(l.grad_fn)  

# 以下两命令执行顺序不同,有不同的效果

torch.manual_seed(1)  # 指定seed

torch.seed()  # 随机seed

标签:loss,变量,torch,seed,print,工具箱,grad
来源: https://www.cnblogs.com/wjw-cat/p/16602374.html