其他分享
首页 > 其他分享> > 2021-10-19

2021-10-19

作者:互联网

江大白共学计划课程–Pytorch课程作业(1)

import torch
import torchvision.models as models
from thop import profile
# 加载模型结构
model = models.resnet18()
# 读取权重和载入权重
pretrained_state_dict = torch.load('Lesson2/weights/resnet18-5c106cde.pth')
model_state_dict = model.load_state_dict(pretrained_state_dict, strict=False)


# 让模型变为推理状态
model.eval()
# 将模型放置到cuda上
model.to(torch.device('cuda'))

# 构建一个项目推理时需要的输入大小的单精度Tensor,并且放置模型所在的设备(cpu或cuda)
inputs = torch.ones([1, 3, 224, 224]).type(torch.float32).to(torch.device('cuda'))
flops, params = profile(model=model, inputs=(inputs))
print('Model:{:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))

# 生成onnx:将训练好的模型(包括结构和权重)保存成onnx。
torch.onnx.export(model, inputs, 'Lesson2/weights/resnet18.onnx', verbose = False)
输入:[1, 3, 224, 224]
输出:resnet18模型的参数量和计算量
Model:0.31 GFLOPs and 3.50M parameters

输入:[1, 3, 448, 448]
输出:resnet18模型的参数量和计算量
Model:1.25 GFLOPs and 3.50M parameters

Netron可视化ResNet18的网络结构

在这里插入图片描述

标签:resnet18,10,19,onnx,torch,state,dict,2021,model
来源: https://blog.csdn.net/weixin_42538848/article/details/120855207