其他分享
首页 > 其他分享> > TorchSummary无法载入Dict数据类型解决方法

TorchSummary无法载入Dict数据类型解决方法

作者:互联网

前言

torchsummary是一个比较不错的评价网络数据结构的包,但是目前torchsummary只支持网络输入和输出为torch.Tensor类型的网络,在对一些较为复杂的网络模型中,载入的可能并不一定为tensor类型,也有可能是list或者dict类型的数据。
为了可以支持自己的网络结构,为此简单修改了一下torchsummary的文件,使其支持dict结构。

示例

from utils.summary import summary
# 传入需要两个数据 一个是 正常传入网络的结构,一个是网络
# 如果网络加载到cuda中,传入的数据也需要.cuda()
fake_data = {
    'img': torch.rand(8, 3, data_cfg.img_height, data_cfg.img_width).cuda()*255.0,
    'gt_mask': torch.rand(8, data_cfg.img_height, data_cfg.img_width).cuda()*3
}
net_summary = summary(net, fake_data)

代码

直接复制粘贴,保存该summary代码到本地文件

import torch
import torch.nn as nn
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np


def summary(model, data, batch_size=-1, device="cuda"):
    """
    from torchsummary import summary, change it for dict input
    """
    def register_hook(module):

        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            if isinstance(input, (list, tuple)):
                # this is a sequential module for hook
                summary[m_key]["input_shape"] = list()
                # record input shape
                if isinstance(input[0], torch.Tensor):
                    input = input[0]
                else:
                    for l_i in input[0]:
                        summary[m_key]["input_shape"].append(l_i.size())
            if isinstance(input, torch.Tensor):
                summary[m_key]["input_shape"] = list(input.size())
            # the dict input wasn't a issues for me
            # if have some bugs, try fixed it.
            # if isinstance(input, dict):
            #     summary[m_key]["input_shape"] = input[0].size()

            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [batch_size] + list(o.size())[1:] for o in output
                ]
            elif isinstance(output, dict):
                summary[m_key]["output_shape"] = [k for k in output.keys()]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)
        ):
            hooks.append(module.register_forward_hook(hook))

    device = device.lower()
    assert device in [
        "cuda",
        "cpu",
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    if device == "cuda" and torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor


    # your need create your self input data before you call this function
    x = data
    input_size = []
    # get input shape
    if isinstance(x, torch.Tensor):
        input_size = data.size()
    if isinstance(x, (list, dict)):
        input_size = list(data.values())[0].size()
    if batch_size == -1:
        batch_size = input_size[0]
    input_size = input_size[1:]
    # print(type(x[0]))

    # create properties
    summary = OrderedDict()
    hooks = []

    # make a forward pass
    # my some net block need get the input shape then
    # to create the linear layer, so i need inject data before hook
    # print(x.shape)
    model(x)

    # some model need initialization after first forward
    # register hook
    model.apply(register_hook)

    model(x)
    # remove these hooks
    for h in hooks:
        h.remove()

    print("--------------------------------------------------------------------------")
    line_new = "{:>25}  {:>30} {:>15}".format("Layer (type)", "Output Shape", "Param #")
    print(line_new)
    print("==========================================================================")
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params

        total_params += summary[layer]["nb_params"]
        # total_output += np.prod(summary[layer]["output_shape"])
        output_shape = summary[layer]["output_shape"]
        if isinstance(summary[layer]["output_shape"][0], list):
            output_shape = ""
            for out_shape_list in summary[layer]["output_shape"]:
                output_shape = f"{output_shape}  {out_shape_list}"
        if isinstance(summary[layer]['output_shape'][-1], int):
            total_output = summary[layer]['output_shape']
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]

        line_new = "{:>25}  {:>30} {:>15}".format(
            layer,
            str(output_shape),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        print(line_new)

    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * np.prod(total_output) * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    print("==========================================================================")
    print("Total params: {0:,}".format(total_params))
    print("Trainable params: {0:,}".format(trainable_params))
    print("Non-trainable params: {0:,}".format(total_params - trainable_params))
    print("--------------------------------------------------------------------------")
    print("Input size (MB): %0.2f" % total_input_size)
    print("Forward/backward pass size (MB): %0.2f" % total_output_size)
    print("Params size (MB): %0.2f" % total_params_size)
    print("Estimated Total Size (MB): %0.2f" % total_size)
    print("--------------------------------------------------------------------------")
    # return summary

输出示例

--------------------------------------------------------------------------
             Layer (type)                    Output Shape         Param #
==========================================================================
                 Conv2d-1               [8, 64, 176, 320]           9,408
            BatchNorm2d-2               [8, 64, 176, 320]             128
                   ReLU-3               [8, 64, 176, 320]               0
              MaxPool2d-4                [8, 64, 88, 160]               0
                 Conv2d-5                [8, 64, 88, 160]          36,864
            BatchNorm2d-6                [8, 64, 88, 160]             128
                   ReLU-7                [8, 64, 88, 160]               0
                 Conv2d-8                [8, 64, 88, 160]          36,864
            BatchNorm2d-9                [8, 64, 88, 160]             128
                  ReLU-10                [8, 64, 88, 160]               0
            BasicBlock-11                [8, 64, 88, 160]               0
                Conv2d-12                [8, 64, 88, 160]          36,864
           BatchNorm2d-13                [8, 64, 88, 160]             128
                  ReLU-14                [8, 64, 88, 160]               0
                Conv2d-15                [8, 64, 88, 160]          36,864
           BatchNorm2d-16                [8, 64, 88, 160]             128
                  ReLU-17                [8, 64, 88, 160]               0
            BasicBlock-18                [8, 64, 88, 160]               0
                Conv2d-19                [8, 128, 44, 80]          73,728
           BatchNorm2d-20                [8, 128, 44, 80]             256
                  ReLU-21                [8, 128, 44, 80]               0
                Conv2d-22                [8, 128, 44, 80]         147,456
           BatchNorm2d-23                [8, 128, 44, 80]             256
                Conv2d-24                [8, 128, 44, 80]           8,192
           BatchNorm2d-25                [8, 128, 44, 80]             256
                  ReLU-26                [8, 128, 44, 80]               0
            BasicBlock-27                [8, 128, 44, 80]               0
                Conv2d-28                [8, 128, 44, 80]         147,456
           BatchNorm2d-29                [8, 128, 44, 80]             256
                  ReLU-30                [8, 128, 44, 80]               0
                Conv2d-31                [8, 128, 44, 80]         147,456
           BatchNorm2d-32                [8, 128, 44, 80]             256
                  ReLU-33                [8, 128, 44, 80]               0
            BasicBlock-34                [8, 128, 44, 80]               0
                Conv2d-35                [8, 256, 22, 40]         294,912
           BatchNorm2d-36                [8, 256, 22, 40]             512
                  ReLU-37                [8, 256, 22, 40]               0
                Conv2d-38                [8, 256, 22, 40]         589,824
           BatchNorm2d-39                [8, 256, 22, 40]             512
                Conv2d-40                [8, 256, 22, 40]          32,768
           BatchNorm2d-41                [8, 256, 22, 40]             512
                  ReLU-42                [8, 256, 22, 40]               0
            BasicBlock-43                [8, 256, 22, 40]               0
                Conv2d-44                [8, 256, 22, 40]         589,824
           BatchNorm2d-45                [8, 256, 22, 40]             512
                  ReLU-46                [8, 256, 22, 40]               0
                Conv2d-47                [8, 256, 22, 40]         589,824
           BatchNorm2d-48                [8, 256, 22, 40]             512
                  ReLU-49                [8, 256, 22, 40]               0
            BasicBlock-50                [8, 256, 22, 40]               0
                Conv2d-51                [8, 512, 11, 20]       1,179,648
           BatchNorm2d-52                [8, 512, 11, 20]           1,024
                  ReLU-53                [8, 512, 11, 20]               0
                Conv2d-54                [8, 512, 11, 20]       2,359,296
           BatchNorm2d-55                [8, 512, 11, 20]           1,024
                Conv2d-56                [8, 512, 11, 20]         131,072
           BatchNorm2d-57                [8, 512, 11, 20]           1,024
                  ReLU-58                [8, 512, 11, 20]               0
            BasicBlock-59                [8, 512, 11, 20]               0
                Conv2d-60                [8, 512, 11, 20]       2,359,296
           BatchNorm2d-61                [8, 512, 11, 20]           1,024
                  ReLU-62                [8, 512, 11, 20]               0
                Conv2d-63                [8, 512, 11, 20]       2,359,296
           BatchNorm2d-64                [8, 512, 11, 20]           1,024
                  ReLU-65                [8, 512, 11, 20]               0
            BasicBlock-66                [8, 512, 11, 20]               0
                ResNet-67    [8, 64, 88, 160]  [8, 128, 44, 80]  [8, 256, 22, 40]  [8, 512, 11, 20]               0
         ResNetWrapper-68    [8, 64, 88, 160]  [8, 128, 44, 80]  [8, 256, 22, 40]  [8, 512, 11, 20]               0
                Conv2d-69               [8, 128, 88, 160]           8,320
                Conv2d-70                [8, 128, 44, 80]          16,512
                Conv2d-71                [8, 128, 22, 40]          32,896
                Conv2d-72                [8, 128, 11, 20]          65,664
                Conv2d-73               [8, 128, 88, 160]         147,584
                Conv2d-74                [8, 128, 44, 80]         147,584
                Conv2d-75                [8, 128, 22, 40]         147,584
                Conv2d-76                [8, 128, 11, 20]         147,584
                   FPN-77    [8, 128, 88, 160]  [8, 128, 44, 80]  [8, 128, 22, 40]  [8, 128, 11, 20]               0
                  Core-78                [8, 128, 44, 80]               0
                Conv2d-79                 [8, 84, 44, 80]          96,852
                Conv2d-80                 [8, 42, 44, 80]          31,794
                Conv2d-81                  [8, 6, 44, 80]           2,274
           BatchNorm2d-82                  [8, 6, 44, 80]              12
  UpsamplingBilinear2d-83                 [8, 6, 88, 160]               0
           BatchNorm2d-84                 [8, 6, 88, 160]              12
                Conv2d-85                 [8, 6, 88, 160]             330
  UpsamplingBilinear2d-86                [8, 6, 176, 320]               0
                Conv2d-87                [8, 6, 176, 320]             330
  UpsamplingBilinear2d-88                [8, 6, 352, 640]               0
                  ReLU-89                [8, 6, 352, 640]               0
                Conv2d-90                [8, 6, 352, 640]             330
               Sigmoid-91                [8, 6, 352, 640]               0
               Decoder-92                [8, 6, 352, 640]               0
             VLaneLoss-93       ['total_loss', 'ce_loss']               0
==========================================================================
Total params: 12,022,174
Trainable params: 12,022,174
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 20.62
Forward/backward pass size (MB): 82.50
Params size (MB): 45.86
Estimated Total Size (MB): 148.99
--------------------------------------------------------------------------

标签:数据类型,summary,Dict,TorchSummary,output,128,input,Conv2d,size
来源: https://www.cnblogs.com/vase/p/15704968.html