其他分享
首页 > 其他分享> > Pytorch 链式法则求梯度

Pytorch 链式法则求梯度

作者:互联网

x经过参数w1和b1得到y1,y1再通过w2和b2得到y2,要求y2对w1的导数,可以求y2对y1然后y1对w1的导数。PyTorch可以自动使用链式法则对复杂的导数求解。

import torch

x = torch.tensor(1.2)
w1 = torch.tensor(2.3, requires_grad=True)
b1 = torch.tensor(1.3)
y1 = x * w1 + b1

w2 = torch.tensor(2.2)
b2 = torch.tensor(1.4)
y2 = y1 * w2 + b2

# PyTorch自动实现链式法则的求导
dy2_dw1 = torch.autograd.grad(y2, [w1], retain_graph=True)
print(dy2_dw1[0])

# 手动用链式法则的方式求一下看看
dy2_dy1 = torch.autograd.grad(y2, [y1], retain_graph=True)
dy1_dw1 = torch.autograd.grad(y1, [w1], retain_graph=True)
print(dy2_dy1[0] * dy1_dw1[0])

输出结果:

tensor(2.6400)
tensor(2.6400)

 

标签:链式法则,tensor,梯度,torch,dy1,Pytorch,w1,y1,y2
来源: https://blog.csdn.net/weicao1990/article/details/97754077