其他分享
首页 > 其他分享> > Pytorch模型转onnx,变得很大

Pytorch模型转onnx,变得很大

作者:互联网

引言

模型名称大小
raw.pth6.58M
convert_raw.onnx242M

解决方案

from onnxruntime.transformers.onnx_model import OnnxModel
import onnx

def has_same_value(val_one,val_two):
    if val_one.raw_data == val_two.raw_data:
        return True
    else:
        return False

path = f"convert_raw.onnx"  # 242M
output_path = f"slim_convert.onnx"  # 7.50M
model = onnx.load(path)
onnx_model = OnnxModel(model)

count = len(model.graph.initializer)
same = [-1] * count
for i in tqdm(range(count - 1)):
  if same[i] >= 0:
        continue
  for j in range(i+1, count):
      if has_same_value(model.graph.initializer[i], 
                        model.graph.initializer[j]):
          same[j] = i

for i in tqdm(range(count)):
   if same[i] >= 0:
       onnx_model.replace_input_of_all_nodes(model.graph.initializer[i].name,
                                             model.graph.initializer[same[i]].name)
onnx_model.update_graph()
onnx_model.save_model_to_file(output_path)

相关资料

标签:onnx,模型,same,initializer,raw,Pytorch,graph,model
来源: https://blog.csdn.net/shiwanghualuo/article/details/120250602