Pytorch register_forward_hook()简单用法
作者:互联网
简单来说就是在不改动网络结构的情况下获取网络中间层输出。
比如有个LeNet:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
如果我们要获取conv2的输出,一种最直观的思路是这样:
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out_conv2 = out
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out, out_conv2
直接修改forward部分的代码,将conv2的中间结果return即可。
但很多时候,我们并没有办法去直接修改网络的源代码,比如在pytorch中已经封装好的网络,那么这个时候就可以利用hook从外部获取Module的中间输出结果了。即:
features = []
def hook(module, input, output):
features.append(output.clone().detach())
net = LeNet()
x = torch.randn(2, 3, 32, 32)
handle = net.conv2.register_forward_hook(hook)
y = net(x)
print(features[0])
handle.remove()
取出网络的相应层后,对该层调用register_forward_hook方法。这个方法需要传入一个hook方法:
hook(module, input, output) -> None or modified output
- module:表示该层网络
- input:该层网络的输入
- output:该层网络的输出
从这里可以发现hook甚至可以更改输入输出(不过并不会影响网络forward的实际结果),不过在这里我们只是简单地将output给保存下来。
需要注意的是hook函数在使用后应及时删除,以避免每次都运行增加运行负载。
参考:
https://blog.csdn.net/winycg/article/details/100695373
https://blog.csdn.net/foneone/article/details/107099060
标签:nn,conv2,self,register,relu,hook,Pytorch,out 来源: https://blog.csdn.net/qq_40714949/article/details/114702690