其他分享
首页 > 其他分享> > 【推理引擎】ONNX 模型解析

【推理引擎】ONNX 模型解析

作者:互联网

定义模型结构

首先使用 PyTorch 定义一个简单的网络模型:

class ConvBnReluBlock(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.conv1 = nn.Conv2d(3, 64, 3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(3, 1)

        self.conv2 = nn.Conv2d(64, 32, 3)
        self.bn2 = nn.BatchNorm2d(32)
    
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        return out

在导出模型之前,需要提前定义一些变量:

model = ConvBnReluBlock()     # 定义模型对象
x = torch.randn(2, 3, 255, 255)      # 定义输入张量

然后使用 PyTorch 官方 API(torch.onnx.export)导出 ONNX 格式的模型:

# way1:
torch.onnx.export(model, (x), "conv_bn_relu_evalmode.onnx", input_names=["input"], output_names=['output'])

# way2:
import torch._C as _C
TrainingMode = _C._onnx.TrainingMode
torch.onnx.export(model, (x), "conv_bn_relu_trainmode.onnx", input_names=["input"], output_names=['output'],
                opset_version=12,                    # 默认版本为9,但是如果低于12,将不能正确导出 Dropout 和 BatchNorm 节点
                training=TrainingMode.TRAINING,      # 默认模式为 TrainingMode.EVAL
                do_constant_folding=False)           # 常量折叠,默认为True,但是如果使用TrainingMode.TRAINING模式,则需要将其关闭

# way3
torch.onnx.export(model,
                (x),
                "conv_bn_relu_dynamic.onnx",
                input_names=['input'],
                output_names=['output'],
                dynamic_axes={'input': {0: 'batch_size', 2: 'input_width', 3: 'input_height'},
                            'output': {0: 'batch_size', 2: 'output_width', 3: 'output_height'}})

可以看到,这里主要以三种方式导出模型,下面分别介绍区别:

下图分别将这三种导出方式的模型结构使用 Netron 可视化:

分析模型结构

这里参考了BBuf大佬的讲解:【传送门:https://zhuanlan.zhihu.com/p/346511883】
接下来主要针对 way1 方式导出的ONNX模型进行深入分析。

ONNX格式定义:https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
在这个文件中,定义了多个核心对象:ModelProto、GraphProto、NodeProto、ValueInfoProto、TensorProto 和 AttributeProto。

在加载ONNX模型之后,就获得了一个ModelProto,其中包含一些

在 GraphProto 中,有如下几个属性需要注意:

至此,我们已经分析完 GraphProto 的内容,下面根据图中的一个节点可视化说明以上内容:

从图中可以发现,Conv 节点的输入包含三个部分:输入的图像(input)、权重(这里以数字23代表该节点权重W的名字)以及偏置(这里以数字24表示该节点偏置B的名字);输出内容的名字为22;属性信息包括dilations、group、kernel_shape、pads和strides,不同节点会具有不同的属性信息。在initializer数组中,我们可以找到该Conv节点权重(name:23)对应的值(raw_data),并且可以清楚地看到维度信息(64X3X3X3)。

标签:dim,name,output,ONNX,ints,type,input,解析,推理
来源: https://www.cnblogs.com/xxxxxxxxx/p/16061087.html