pytorch 计算flops和params
作者:互联网
pytorch 计算 params
from thop import profile
import torch
from resnet_18 import Resnet_18, resnet18
model = Resnet_18()
input = torch.randn(1, 3, 256, 256)
flops, params = profile(model, inputs = (input))
print(flops)
print(params)
**FPS 计算过程 **
res = []
for id, (data, depth, img_name, img_size) in enumerate(test_loader):
torch.cuda.synchronize()
start = time.time()
predict= model_rgb(inputs, depth) # 有待修改
torch.cuda.synchronize()
end = time.time()
res.append(end-start)
time_sum = 0
for i in res:
time_sum += i
print("FPS: %f"%(1.0/(time_sum/len(res))))
标签:res,sum,torch,pytorch,params,time,print,flops 来源: https://blog.csdn.net/fanlily913/article/details/122005710