其他分享
首页 > 其他分享> > PyTorch学习笔记——(4)autograd.grad()函数和backward()函数介绍及使用

PyTorch学习笔记——(4)autograd.grad()函数和backward()函数介绍及使用

作者:互联网

目录

1、torch.autograd.grad()

torch.autograd.grad(
		outputs, 
		inputs, 
		grad_outputs=None, 
		retain_graph=None, 
		create_graph=False, 
		only_inputs=True, 
		allow_unused=False)

用一个最简单的例子说明:

现在有一个函数: y = w ∗ x y=w*x y=w∗x其中, x x x为输入数据, w w w为可训练参数, y y y为预测值。我们定义损失函数为: l o s s = ( y − w ∗ x ) 2 loss = \sqrt {{{(y - w*x)}^2}} loss=(y−w∗x)2 ​现在求 l o s s loss loss对 w w w的梯度,应该怎么做呢?

答:

x = torch.ones(1) # 输入数据
w = torch.full([1], 2.) # 初始化训练参数
y = torch.ones(1) # 标签数据
loss = F.mse_loss(y, x*w)  # 计算mse损失
torch.autograd.grad(mse, w) # 错误,发现w不能求导,因为我们们没指定w是能求导的,所以出错了
# 那怎么做呢?接着看:
w.requires_grad_() # 指定w这个参数需要求梯度
# 然后我们再试一试:
torch.autograd.grad(mse, w) 
"""
发现还是出错,这是为啥啊,因为pytorch每次有了计算梯度的参数,那么就会计算一个动态图,
所以我们前面的所有步骤都要重新来一遍。
"""

# 重新来一遍:
w.requires_grad_() # 指定w这个参数需要求梯度,当然在初始化w的时候指定requires_grad=True也是可以的
loss = F.mse_loss(y, x*w) # 求mse误差
torch.autograd.grad(mse, w) # 这个就求出了loss对w的导数了

上面这个函数只能对指定的参数求导,那么我们要是想一次求出所有参数的导数,我们应该怎么做呢?
接着看backward()函数。

2、.torch.autograd.backward()

torch.autograd.backward(
		tensors, 
		grad_tensors=None, 
		retain_graph=None, 
		create_graph=False, 
		grad_variables=None)

第一个方式只是对单独的一个参数进行求导,下面我们使用backward()来对动态图的所有参数一次性求导。

例子:

# 1.准备数据
x = torch.ones(2) # 输入特征数据
w = torch.full([2], 2., requires_grad=True) # 初始化训练参数
y = torch.ones_like(x) # 标签数据

# 2.求mse误差
loss = F.mse_loss(y, x*w) # 这个是个标量

# 3.求梯度
mse.backward() # 动态图所有参数的梯度都计算了,但是不会显示出来。注意:通常在调用一次backward后,
# pytorch会自动把计算图销毁,所以要想对某个变量重复调用backward,则需要将retain_graph参数设置为True

# 4.获取梯度
w.grad
# 输出:
tensor([1., 1.])

上面的损失loss我注释里面说是标量,那要不是标量怎么弄的呢?看一个例子:

a = torch.randn(4) # 初始化
a.requires_grad_() # 指定可求导

p = F.softmax(a, dim=0) # 经过sofrmax函数,输出是[1,4]

p.backward(torch.ones(4)) # 里面的参数必须指定,不然报错

这里的P是多维的向量,里面要加参数grad_tensors=torch.ones(4)

为啥:这个知乎的解释感觉是对的:https://zhuanlan.zhihu.com/p/83172023

标签:loss,函数,autograd,torch,PyTorch,梯度,backward,grad
来源: https://blog.csdn.net/weixin_45901519/article/details/113813934