其他分享
首页 > 其他分享> > 梯度截断代码

梯度截断代码

作者:互联网

梯度截断代码

需要添加在loss反向传播后,optimizer.step()前

将梯度裁剪到-grad_clip和grad_clip之间

def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.

    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group["params"]:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)


或者

            nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)

标签:optimizer,clip,梯度,代码,param,截断,gradients,grad,norm
来源: https://www.cnblogs.com/yuzhoutaiyang/p/16215614.html