其他分享
首页 > 其他分享> > pytorch 计算flops和params

pytorch 计算flops和params

作者:互联网

pytorch 计算 params

from thop import profile
import torch
from resnet_18 import Resnet_18, resnet18

model = Resnet_18()
input = torch.randn(1, 3, 256, 256)
flops, params = profile(model, inputs = (input))
print(flops)
print(params)

**FPS 计算过程 **

res = [] 
for id, (data, depth, img_name, img_size) in enumerate(test_loader):
    torch.cuda.synchronize()
    start = time.time()
    predict= model_rgb(inputs, depth)  # 有待修改
    torch.cuda.synchronize()
    end = time.time()
    res.append(end-start)
time_sum = 0
for i in res:
    time_sum += i
print("FPS: %f"%(1.0/(time_sum/len(res))))

标签:res,sum,torch,pytorch,params,time,print,flops
来源: https://blog.csdn.net/fanlily913/article/details/122005710