TorchSummary无法载入Dict数据类型解决方法
作者:互联网
前言
torchsummary是一个比较不错的评价网络数据结构的包,但是目前torchsummary只支持网络输入和输出为torch.Tensor类型的网络,在对一些较为复杂的网络模型中,载入的可能并不一定为tensor类型,也有可能是list或者dict类型的数据。
为了可以支持自己的网络结构,为此简单修改了一下torchsummary的文件,使其支持dict结构。
示例
from utils.summary import summary
# 传入需要两个数据 一个是 正常传入网络的结构,一个是网络
# 如果网络加载到cuda中,传入的数据也需要.cuda()
fake_data = {
'img': torch.rand(8, 3, data_cfg.img_height, data_cfg.img_width).cuda()*255.0,
'gt_mask': torch.rand(8, data_cfg.img_height, data_cfg.img_width).cuda()*3
}
net_summary = summary(net, fake_data)
代码
直接复制粘贴,保存该summary代码到本地文件
import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
import numpy as np
def summary(model, data, batch_size=-1, device="cuda"):
"""
from torchsummary import summary, change it for dict input
"""
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
if isinstance(input, (list, tuple)):
# this is a sequential module for hook
summary[m_key]["input_shape"] = list()
# record input shape
if isinstance(input[0], torch.Tensor):
input = input[0]
else:
for l_i in input[0]:
summary[m_key]["input_shape"].append(l_i.size())
if isinstance(input, torch.Tensor):
summary[m_key]["input_shape"] = list(input.size())
# the dict input wasn't a issues for me
# if have some bugs, try fixed it.
# if isinstance(input, dict):
# summary[m_key]["input_shape"] = input[0].size()
summary[m_key]["input_shape"][0] = batch_size
if isinstance(output, (list, tuple)):
summary[m_key]["output_shape"] = [
[batch_size] + list(o.size())[1:] for o in output
]
elif isinstance(output, dict):
summary[m_key]["output_shape"] = [k for k in output.keys()]
else:
summary[m_key]["output_shape"] = list(output.size())
summary[m_key]["output_shape"][0] = batch_size
params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]["nb_params"] = params
if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
and not (module == model)
):
hooks.append(module.register_forward_hook(hook))
device = device.lower()
assert device in [
"cuda",
"cpu",
], "Input device is not valid, please specify 'cuda' or 'cpu'"
if device == "cuda" and torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
# your need create your self input data before you call this function
x = data
input_size = []
# get input shape
if isinstance(x, torch.Tensor):
input_size = data.size()
if isinstance(x, (list, dict)):
input_size = list(data.values())[0].size()
if batch_size == -1:
batch_size = input_size[0]
input_size = input_size[1:]
# print(type(x[0]))
# create properties
summary = OrderedDict()
hooks = []
# make a forward pass
# my some net block need get the input shape then
# to create the linear layer, so i need inject data before hook
# print(x.shape)
model(x)
# some model need initialization after first forward
# register hook
model.apply(register_hook)
model(x)
# remove these hooks
for h in hooks:
h.remove()
print("--------------------------------------------------------------------------")
line_new = "{:>25} {:>30} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("==========================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# input_shape, output_shape, trainable, nb_params
total_params += summary[layer]["nb_params"]
# total_output += np.prod(summary[layer]["output_shape"])
output_shape = summary[layer]["output_shape"]
if isinstance(summary[layer]["output_shape"][0], list):
output_shape = ""
for out_shape_list in summary[layer]["output_shape"]:
output_shape = f"{output_shape} {out_shape_list}"
if isinstance(summary[layer]['output_shape'][-1], int):
total_output = summary[layer]['output_shape']
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
line_new = "{:>25} {:>30} {:>15}".format(
layer,
str(output_shape),
"{0:,}".format(summary[layer]["nb_params"]),
)
print(line_new)
# assume 4 bytes/number (float on cuda).
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * np.prod(total_output) * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
print("==========================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("--------------------------------------------------------------------------")
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("--------------------------------------------------------------------------")
# return summary
输出示例
--------------------------------------------------------------------------
Layer (type) Output Shape Param #
==========================================================================
Conv2d-1 [8, 64, 176, 320] 9,408
BatchNorm2d-2 [8, 64, 176, 320] 128
ReLU-3 [8, 64, 176, 320] 0
MaxPool2d-4 [8, 64, 88, 160] 0
Conv2d-5 [8, 64, 88, 160] 36,864
BatchNorm2d-6 [8, 64, 88, 160] 128
ReLU-7 [8, 64, 88, 160] 0
Conv2d-8 [8, 64, 88, 160] 36,864
BatchNorm2d-9 [8, 64, 88, 160] 128
ReLU-10 [8, 64, 88, 160] 0
BasicBlock-11 [8, 64, 88, 160] 0
Conv2d-12 [8, 64, 88, 160] 36,864
BatchNorm2d-13 [8, 64, 88, 160] 128
ReLU-14 [8, 64, 88, 160] 0
Conv2d-15 [8, 64, 88, 160] 36,864
BatchNorm2d-16 [8, 64, 88, 160] 128
ReLU-17 [8, 64, 88, 160] 0
BasicBlock-18 [8, 64, 88, 160] 0
Conv2d-19 [8, 128, 44, 80] 73,728
BatchNorm2d-20 [8, 128, 44, 80] 256
ReLU-21 [8, 128, 44, 80] 0
Conv2d-22 [8, 128, 44, 80] 147,456
BatchNorm2d-23 [8, 128, 44, 80] 256
Conv2d-24 [8, 128, 44, 80] 8,192
BatchNorm2d-25 [8, 128, 44, 80] 256
ReLU-26 [8, 128, 44, 80] 0
BasicBlock-27 [8, 128, 44, 80] 0
Conv2d-28 [8, 128, 44, 80] 147,456
BatchNorm2d-29 [8, 128, 44, 80] 256
ReLU-30 [8, 128, 44, 80] 0
Conv2d-31 [8, 128, 44, 80] 147,456
BatchNorm2d-32 [8, 128, 44, 80] 256
ReLU-33 [8, 128, 44, 80] 0
BasicBlock-34 [8, 128, 44, 80] 0
Conv2d-35 [8, 256, 22, 40] 294,912
BatchNorm2d-36 [8, 256, 22, 40] 512
ReLU-37 [8, 256, 22, 40] 0
Conv2d-38 [8, 256, 22, 40] 589,824
BatchNorm2d-39 [8, 256, 22, 40] 512
Conv2d-40 [8, 256, 22, 40] 32,768
BatchNorm2d-41 [8, 256, 22, 40] 512
ReLU-42 [8, 256, 22, 40] 0
BasicBlock-43 [8, 256, 22, 40] 0
Conv2d-44 [8, 256, 22, 40] 589,824
BatchNorm2d-45 [8, 256, 22, 40] 512
ReLU-46 [8, 256, 22, 40] 0
Conv2d-47 [8, 256, 22, 40] 589,824
BatchNorm2d-48 [8, 256, 22, 40] 512
ReLU-49 [8, 256, 22, 40] 0
BasicBlock-50 [8, 256, 22, 40] 0
Conv2d-51 [8, 512, 11, 20] 1,179,648
BatchNorm2d-52 [8, 512, 11, 20] 1,024
ReLU-53 [8, 512, 11, 20] 0
Conv2d-54 [8, 512, 11, 20] 2,359,296
BatchNorm2d-55 [8, 512, 11, 20] 1,024
Conv2d-56 [8, 512, 11, 20] 131,072
BatchNorm2d-57 [8, 512, 11, 20] 1,024
ReLU-58 [8, 512, 11, 20] 0
BasicBlock-59 [8, 512, 11, 20] 0
Conv2d-60 [8, 512, 11, 20] 2,359,296
BatchNorm2d-61 [8, 512, 11, 20] 1,024
ReLU-62 [8, 512, 11, 20] 0
Conv2d-63 [8, 512, 11, 20] 2,359,296
BatchNorm2d-64 [8, 512, 11, 20] 1,024
ReLU-65 [8, 512, 11, 20] 0
BasicBlock-66 [8, 512, 11, 20] 0
ResNet-67 [8, 64, 88, 160] [8, 128, 44, 80] [8, 256, 22, 40] [8, 512, 11, 20] 0
ResNetWrapper-68 [8, 64, 88, 160] [8, 128, 44, 80] [8, 256, 22, 40] [8, 512, 11, 20] 0
Conv2d-69 [8, 128, 88, 160] 8,320
Conv2d-70 [8, 128, 44, 80] 16,512
Conv2d-71 [8, 128, 22, 40] 32,896
Conv2d-72 [8, 128, 11, 20] 65,664
Conv2d-73 [8, 128, 88, 160] 147,584
Conv2d-74 [8, 128, 44, 80] 147,584
Conv2d-75 [8, 128, 22, 40] 147,584
Conv2d-76 [8, 128, 11, 20] 147,584
FPN-77 [8, 128, 88, 160] [8, 128, 44, 80] [8, 128, 22, 40] [8, 128, 11, 20] 0
Core-78 [8, 128, 44, 80] 0
Conv2d-79 [8, 84, 44, 80] 96,852
Conv2d-80 [8, 42, 44, 80] 31,794
Conv2d-81 [8, 6, 44, 80] 2,274
BatchNorm2d-82 [8, 6, 44, 80] 12
UpsamplingBilinear2d-83 [8, 6, 88, 160] 0
BatchNorm2d-84 [8, 6, 88, 160] 12
Conv2d-85 [8, 6, 88, 160] 330
UpsamplingBilinear2d-86 [8, 6, 176, 320] 0
Conv2d-87 [8, 6, 176, 320] 330
UpsamplingBilinear2d-88 [8, 6, 352, 640] 0
ReLU-89 [8, 6, 352, 640] 0
Conv2d-90 [8, 6, 352, 640] 330
Sigmoid-91 [8, 6, 352, 640] 0
Decoder-92 [8, 6, 352, 640] 0
VLaneLoss-93 ['total_loss', 'ce_loss'] 0
==========================================================================
Total params: 12,022,174
Trainable params: 12,022,174
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 20.62
Forward/backward pass size (MB): 82.50
Params size (MB): 45.86
Estimated Total Size (MB): 148.99
--------------------------------------------------------------------------
标签:数据类型,summary,Dict,TorchSummary,output,128,input,Conv2d,size 来源: https://www.cnblogs.com/vase/p/15704968.html