pytorch修改resnet18 输入通道
作者:互联网
方法一:扩张1通道为3通道,利用torch.expand()方法
model = resnet18(pretrained=False) # 主干提取网络
model.load_state_dict(torch.load('./resnet18-5c106cde.pth'), strict=False)
print(model)
par = summary(model, (3, 224, 224), device='cpu')
print(par)
net = RFNet( model, 1, use_bn=True) # 输出类别 num_classes
# print(model)
input1 = torch.rand((1,3,256,256)) # 输入通道为1
input1 = input1.expand(1,3,256,256) # 扩展为3通道
print(input1.shape)
input2 = torch.rand((1,1,256, 256))
output = net(input1,input2)
print(output.shape)
方法二:修改字典参数
import torchvision.models as models
import torch
import torch.nn as nn
from torchsummary import summary
resnet18 = models.resnet18(pretrained=False)
resnet18.conv1= nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)
# print(resnet18)
pretrained_dict = torch.load('./resnet/resnet18-5c106cde.pth')
# for k, v in pretrained_dict.items():
# print(k)
x = torch.rand(64, 1, 7, 7)
pretrained_dict["conv1.weight"] = x
conv1 = pretrained_dict["conv1.weight"]
print(conv1.shape)
resnet18.load_state_dict(pretrained_dict)
# print(resnet18)
par = summary(resnet18, (1, 224, 224),device='cpu')
print(par)
标签:resnet18,torch,pytorch,dict,pretrained,print,256,输入 来源: https://blog.csdn.net/fanlily913/article/details/121194863