2021-10-19
作者:互联网
江大白共学计划课程–Pytorch课程作业(1)
import torch
import torchvision.models as models
from thop import profile
# 加载模型结构
model = models.resnet18()
# 读取权重和载入权重
pretrained_state_dict = torch.load('Lesson2/weights/resnet18-5c106cde.pth')
model_state_dict = model.load_state_dict(pretrained_state_dict, strict=False)
# 让模型变为推理状态
model.eval()
# 将模型放置到cuda上
model.to(torch.device('cuda'))
# 构建一个项目推理时需要的输入大小的单精度Tensor,并且放置模型所在的设备(cpu或cuda)
inputs = torch.ones([1, 3, 224, 224]).type(torch.float32).to(torch.device('cuda'))
flops, params = profile(model=model, inputs=(inputs))
print('Model:{:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))
# 生成onnx:将训练好的模型(包括结构和权重)保存成onnx。
torch.onnx.export(model, inputs, 'Lesson2/weights/resnet18.onnx', verbose = False)
输入:[1, 3, 224, 224]
输出:resnet18模型的参数量和计算量
Model:0.31 GFLOPs and 3.50M parameters
输入:[1, 3, 448, 448]
输出:resnet18模型的参数量和计算量
Model:1.25 GFLOPs and 3.50M parameters
Netron可视化ResNet18的网络结构
标签:resnet18,10,19,onnx,torch,state,dict,2021,model 来源: https://blog.csdn.net/weixin_42538848/article/details/120855207