基于pytorch从模型中取出指定特征层的输出
作者:互联网
重构网络模型
问题:若网络模型为ResNet-50,想取出layer4和layer3作为网络的最终输出。在语义分割中是常需要用到的操作。将layer4作为主输出,layer3用于辅助输出。分类网络如GoogLeNet。
解决:首先需要重构原本ResNet-50,将layer3、layer4与其输出值定义为一个字典:
{‘layer4’: ‘main_out’, ‘layer3’: ‘aux_out’}, 使用IntermediateLayerGetter将网络最终的输出调整为一个有序字典的形式,layer3、layer4的输出分别对应’main_out’和’aux_out’。
return_layer中以字典的方式传入指定层的名称以及输出后的名称。如:{‘layer4’: ‘main_out’, ‘layer3’: ‘aux_out’}。最终网络输出main_out和aux_out,即layer3、layer4的特征值。
class IntermediateLayerGetter(nn.ModuleDict):
def __init__(self, model: nn.Module, return_layer: Dict[str, str]):
# 首先判断 return_layer中的key 是否在model中
if not set(return_layer).issubset([name for name, _ in model.named_children()]):
raise ValueError('return_layers are not present in model')
orig_return_layers = return_layer
return_layer = {str(k): str(v) for k, v in return_layer.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layer:
del return_layer[name]
if not return_layer:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layer = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layer:
out_name = self.return_layer[name]
out[out_name] = x
return out
整体程序
import torch
from ResNet_dilation import ResNet50
import torch.nn as nn
from typing import Dict
from collections import OrderedDict
import torch.nn.functional as F
class IntermediateLayerGetter(nn.ModuleDict):
def __init__(self, model: nn.Module, return_layer: Dict[str, str]):
# 首先判断 return_layer中的key 是否在model中
if not set(return_layer).issubset([name for name, _ in model.named_children()]):
raise ValueError('return_layers are not present in model')
orig_return_layers = return_layer
return_layer = {str(k): str(v) for k, v in return_layer.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layer:
del return_layer[name]
if not return_layer:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layer = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layer:
out_name = self.return_layer[name]
out[out_name] = x
return out
class FCNHead(nn.Sequential):
def __init__(self, in_channels, channels):
# 降维 1024 -> 256
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, kernel_size=1)
]
super(FCNHead, self).__init__(*layers)
class FCN(nn.Module):
def __init__(self, backbone, main_classifier, aux_classifier):
super(FCN, self).__init__()
self.backbone = backbone
self.main_classifier = main_classifier
self.aux_classifier = aux_classifier
def forward(self, x):
input_shape = x.shape[2:]
features = self.backbone(x)
result = OrderedDict()
x = features['main_out']
x = self.main_classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result['main_out'] = x
if self.aux_classifier is not None:
x = features['aux_out']
x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result['aux_out'] = x
return result
def fcn_resnet50(aux, num_classes=21):
backbone = ResNet50(num_classes=num_classes, dilation_replace_stride=[False, True, True])
out_layer = 'layer4'
out_inplanes = 2048
aux_layer3 = 'layer3'
aux_inplanes = 1024
return_layers = {out_layer: 'main_out'}
if aux:
return_layers[aux_layer3] = 'aux_out'
# 返回main_out and aux_out的特征图 OrderDict
backbone = IntermediateLayerGetter(backbone, return_layer=return_layers)
aux_classifier = None
if aux:
aux_classifier = FCNHead(aux_inplanes, num_classes)
main_classifier = FCNHead(out_inplanes, num_classes)
model = FCN(backbone, main_classifier, aux_classifier)
# OrderDict main_out and aux_out
return model
标签:输出,layer,return,name,self,指定,pytorch,aux,out 来源: https://blog.csdn.net/weixin_44422920/article/details/123627903