Pytorch中的checkPoint: torch.utils.checkpoint.checkpoint
作者:互联网
torch.utils.checkpoint.checkpoint笔记,内容来源于官方手册
仅作笔记只用,不完整之处请查阅官方手册
https://pytorch.org/docs/stable/checkpoint.html
checkpoint是通过在backward期间为每个checkpoint段重新运行forward-pass segment来实现的。
这可能会导致像 RNG 状态这样的持久状态比没有checkpoint的情况更先进。默认情况下,checkpoint包括处理 RNG 状态的逻辑,以便与非checkpoint传递相比,使用 RNG 的checkpoint传递(例如通过 dropout)具有确定性输出。
根据checkpoint操作的运行时间,存储和恢复 RNG 状态的逻辑可能会导致一定的性能下降。
注:RNG状态指随机数状态
如果不需要确定性输出,则为checkpoint或 checkpoint_sequential 设置preserve_rng_state=False ,以在每个checkpoint期间省略存储和恢复RNG 状态。
换言之,1.9版本中,checkpoint能处理随机数状态了!
存储逻辑将当前设备和所有 cuda Tensor 参数的设备的 RNG 状态保存并恢复到 run_fn。但是,逻辑无法预测用户是否会将张量移动到 run_fn 本身内的新设备。因此,如果您在 run_fn 中将张量移动到一个新设备确定性输出永远无法保证。
尽量使用原生的 torch.utils.checkpoint.checkpoint
checkpoint的工作原理是用计算换取内存。与存储整个计算图的所有中间激活用于反向计算不同,checkpoint部分不保存中间激活,而是在反向传递中重新计算它们.它可以应用于模型的任何部分。
具体来说,在前向传递中,函数将以 torch.no_grad() 方式运行,即不存储中间激活。相反,前向传递保存输入元组和函数参数。在向后传递中,检索保存的输入和函数,并再次对函数计算前向传递,现在跟踪中间激活,然后使用这些激活值计算梯度。
函数的输出可以包含非 Tensor 值,并且仅对 Tensor 值执行梯度记录。请注意,如果输出包含由张量组成的嵌套结构(例如:自定义对象、列表、字典等),则这些嵌套在自定义结构中的张量将不会被视为 autograd 的一部分。
标签:checkPoint,状态,torch,RNG,张量,checkpoint,传递,Pytorch 来源: https://blog.csdn.net/ftimes/article/details/120678872