其他分享
首页 > 其他分享> > Pytorch register_forward_hook()简单用法

Pytorch register_forward_hook()简单用法

作者:互联网

简单来说就是在不改动网络结构的情况下获取网络中间层输出
比如有个LeNet:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)     
        out = F.max_pool2d(out, 2)      
        
        out = self.conv2(out)
        out = F.relu(out)  
        out = F.max_pool2d(out, 2)
        
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

如果我们要获取conv2的输出,一种最直观的思路是这样:

def forward(self, x):
    out = self.conv1(x)
    out = F.relu(out)     
    out = F.max_pool2d(out, 2)      
    
    out = self.conv2(out)
    out_conv2 = out
    out = F.relu(out)
    out = F.max_pool2d(out, 2)
    
    out = out.view(out.size(0), -1)
    out = F.relu(self.fc1(out))
    out = F.relu(self.fc2(out))
    out = self.fc3(out)
    return out, out_conv2

直接修改forward部分的代码,将conv2的中间结果return即可。

但很多时候,我们并没有办法去直接修改网络的源代码,比如在pytorch中已经封装好的网络,那么这个时候就可以利用hook从外部获取Module的中间输出结果了。即:

features = []
def hook(module, input, output): 
    features.append(output.clone().detach())

net = LeNet() 
x = torch.randn(2, 3, 32, 32)  
handle = net.conv2.register_forward_hook(hook)
y = net(x)
print(features[0])
handle.remove()

取出网络的相应层后,对该层调用register_forward_hook方法。这个方法需要传入一个hook方法:

hook(module, input, output) -> None or modified output

从这里可以发现hook甚至可以更改输入输出(不过并不会影响网络forward的实际结果),不过在这里我们只是简单地将output给保存下来。
需要注意的是hook函数在使用后应及时删除,以避免每次都运行增加运行负载。

参考:
https://blog.csdn.net/winycg/article/details/100695373
https://blog.csdn.net/foneone/article/details/107099060

标签:nn,conv2,self,register,relu,hook,Pytorch,out
来源: https://blog.csdn.net/qq_40714949/article/details/114702690