其他分享
首页 > 其他分享> > 基于pytorch从模型中取出指定特征层的输出

基于pytorch从模型中取出指定特征层的输出

作者:互联网

重构网络模型

问题:若网络模型为ResNet-50,想取出layer4和layer3作为网络的最终输出。在语义分割中是常需要用到的操作。将layer4作为主输出,layer3用于辅助输出。分类网络如GoogLeNet。
解决:首先需要重构原本ResNet-50,将layer3、layer4与其输出值定义为一个字典:
{‘layer4’: ‘main_out’, ‘layer3’: ‘aux_out’}, 使用IntermediateLayerGetter将网络最终的输出调整为一个有序字典的形式,layer3、layer4的输出分别对应’main_out’和’aux_out’。
return_layer中以字典的方式传入指定层的名称以及输出后的名称。如:{‘layer4’: ‘main_out’, ‘layer3’: ‘aux_out’}。最终网络输出main_out和aux_out,即layer3、layer4的特征值。

class IntermediateLayerGetter(nn.ModuleDict):
    def __init__(self, model: nn.Module, return_layer: Dict[str, str]):
        # 首先判断 return_layer中的key 是否在model中
        if not set(return_layer).issubset([name for name, _ in model.named_children()]):
            raise ValueError('return_layers are not present in model')
        orig_return_layers = return_layer
        return_layer = {str(k): str(v) for k, v in return_layer.items()}

        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layer:
                del return_layer[name]
            if not return_layer:
                break
        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layer = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layer:
                out_name = self.return_layer[name]
                out[out_name] = x
        return out

整体程序

import torch
from ResNet_dilation import ResNet50
import torch.nn as nn
from typing import Dict
from collections import OrderedDict
import torch.nn.functional as F


class IntermediateLayerGetter(nn.ModuleDict):
    def __init__(self, model: nn.Module, return_layer: Dict[str, str]):
        # 首先判断 return_layer中的key 是否在model中
        if not set(return_layer).issubset([name for name, _ in model.named_children()]):
            raise ValueError('return_layers are not present in model')
        orig_return_layers = return_layer
        return_layer = {str(k): str(v) for k, v in return_layer.items()}

        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layer:
                del return_layer[name]
            if not return_layer:
                break
        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layer = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layer:
                out_name = self.return_layer[name]
                out[out_name] = x
        return out


class FCNHead(nn.Sequential):
    def __init__(self, in_channels, channels):
        # 降维 1024 -> 256
        inter_channels = in_channels // 4

        layers = [
            nn.Conv2d(in_channels, inter_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, channels, kernel_size=1)
        ]

        super(FCNHead, self).__init__(*layers)


class FCN(nn.Module):
    def __init__(self, backbone, main_classifier, aux_classifier):
        super(FCN, self).__init__()

        self.backbone = backbone
        self.main_classifier = main_classifier
        self.aux_classifier = aux_classifier

    def forward(self, x):
        input_shape = x.shape[2:]
        features = self.backbone(x)

        result = OrderedDict()
        x = features['main_out']
        x = self.main_classifier(x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result['main_out'] = x

        if self.aux_classifier is not None:
            x = features['aux_out']
            x = self.aux_classifier(x)
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result['aux_out'] = x

        return result


def fcn_resnet50(aux, num_classes=21):

    backbone = ResNet50(num_classes=num_classes, dilation_replace_stride=[False, True, True])

    out_layer = 'layer4'
    out_inplanes = 2048
    aux_layer3 = 'layer3'
    aux_inplanes = 1024

    return_layers = {out_layer: 'main_out'}
    if aux:
        return_layers[aux_layer3] = 'aux_out'
        # 返回main_out and aux_out的特征图 OrderDict
    backbone = IntermediateLayerGetter(backbone, return_layer=return_layers) 
    
    aux_classifier = None
    if aux:
        aux_classifier = FCNHead(aux_inplanes, num_classes)

    main_classifier = FCNHead(out_inplanes, num_classes)

    model = FCN(backbone, main_classifier, aux_classifier)
    # OrderDict main_out and aux_out
    return model

标签:输出,layer,return,name,self,指定,pytorch,aux,out
来源: https://blog.csdn.net/weixin_44422920/article/details/123627903