# 1.ONNX 模型在底层是用什么格式存储的?

ONNX 在底层是用 Protobuf 定义的。Protobuf,全称 Protocol Buffer,是 Google 提出的一套表示和序列化数据的机制。使用 Protobuf 时,用户需要先写一份数据定义文件,再根据这份定义文件把数据存储进一份二进制文件。可以说,数据定义文件就是数据类,二进制文件就是数据类的实例。一个 Protobuf 数据定义文件的例子:

message Person {
  required string name = 1;
  required int32 id = 2;
  optional string email = 3;
}

这段定义表示在 Person 这种数据类型中,必须包含 name、id 这两个字段,选择性包含 email 字段。根据这份定义文件,用户就可以选择一种编程语言,定义一个含有成员变量 name、id、email 的 Person 类,把这个类的某个实例用 Protobuf 存储成二进制文件;反之,用户也可以用二进制文件和对应的数据定义文件,读取出一个 Person 类的实例。

对于 ONNX ,它的 Protobuf 数据定义文件在其开源库 (opens new window)中,这些文件定义了神经网络中模型、节点、张量的数据类型规范;而数据定义文件对应的二进制文件就是我们熟悉的“.onnx”文件,每一个 “.onnx” 文件按照数据定义规范,存储了一个神经网络的所有相关数据。直接用 Protobuf 生成 ONNX 模型还是比较麻烦的。幸运的是,ONNX 提供了很多实用 API,我们可以在完全不了解 Protobuf 的前提下,构造和读取 ONNX 模型。

神经网络本质上是一个计算图。计算图的节点是算子,边是参与运算的张量。而通过可视化 ONNX 模型,我们知道 ONNX 记录了所有算子节点的属性信息,并把参与运算的张量信息存储在算子节点的输入输出信息中。ONNX 模型的结构可以用类图大致表示如下:

如图所示,一个 ONNX 模型可以用 ModelProto 类表示。ModelProto 包含了版本、创建者等日志信息,还包含了存储计算图结构的 graph。GraphProto 类则由输入张量信息、输出张量信息、节点信息组成。张量信息 ValueInfoProto 类包括张量名、基本数据类型、形状。节点信息 NodeProto 类包含了算子名、算子输入张量名、算子输出张量名。 让我们来看一个具体的例子。假如我们有一个描述 output=a*x+b 的 ONNX 模型 model,用 print(model) 可以输出以下内容:

ir_version: 8
graph {
  node {
    input: "a"
    input: "x"
    output: "c"
    op_type: "Mul"
  }
  node {
    input: "c"
    input: "b"
    output: "output"
    op_type: "Add"
  }
  name: "linear_func"
  input {
    name: "a"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {dim_value: 10}
          dim {dim_value: 10}
        }
      }
    }
  }
  input {
    name: "x"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {dim_value: 10}
          dim {dim_value: 10}
        }
      }
    }
  }
  input {
    name: "b"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {dim_value: 10}
          dim {dim_value: 10}
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim { dim_value: 10}
          dim { dim_value: 10}
        }
      }
    }
  }
}
opset_import {version: 15}

对应上文中的类图,这个模型的信息由 ir_version,opset_import 等全局信息和 graph 图信息组成。而 graph 包含一个乘法节点、一个加法节点、三个输入张量 a, x, b 以及一个输出张量 output。在下一节里,我们会用 API 构造出这个模型,并输出这段结果。根据这里定义的模型结构信息,我们就可以按这个数据结构进行模型的编辑了。

# 2.使用graphsurgeon进行模型的编辑

onnx本身提供了很多API能够加载解析onnx模型,但修改模型时不够方便,譬如要修改某个边(也就是tensor)的shape信息时,使用onnx时操作为:

model = onnx.load(onnx_file)
model.inputs[0].type.tensor_type.shape.dim.add().dim_value = 100

可以看到比较麻烦,nvidia的开发者开发了一套专门用来编辑onnx模型的工具graphsurgeon

使用graphsurgeon修改tensor的shape的示例

model = onnx.load(onnx_file)
graph = graphsurgeon.import_onnx(model)
graph.inputs[0].shape[1] = 100

可以看到比手动修改onnx model要方便很多,此外还提供了诸如cleanup清除孤立的node,toposort对node点进行拓扑排序,tensors获取所有tensor等API,可以方便的进行node的增删修改。

# ---------------- 1. 构建模型 ----------------
import torch
import torch.nn as nn
import numpy as np
import onnx
import onnxruntime as ort
import onnx_graphsurgeon as gs
import os

onnx_file = "convnet.onnx"
simp_onnx_file = "convnet_simplified.onnx"
slice_onnx_file  = "convnet_simplified_sliced.onnx"
sub_slice_onnx_file = "sub_convnet_simplified_sliced.onnx"

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),  # 输入 3 通道
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))                 # 输出 32×1×1
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),  # 输入 3 通道
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))                 # 输出 32×1×1
        )
        self.classifier = nn.Linear(64, num_classes)

    def forward(self, x):
        x1 = self.features1(x)      # -> (N,32,1,1)
        x2 = self.features2(x)      # -> (N,32,1,1)
        x = torch.concat([x1, x2], dim=1) # -> (N,64,1,1)
        x = x.view(x.size(0), -1) # -> (N,64)
        return self.classifier(x) # -> (N,10)


def covert_to_onnx():
    model = ConvNet().eval()

    # ---------------- 2. 构造假输入 ----------------
    dummy_input = torch.randn(1, 3, 224, 224)   # NCHW
    with torch.no_grad():
        torch_out = model(dummy_input)
    print("PyTorch output shape:", torch_out.shape)  # 应为 [1,10]

    # ---------------- 3. 导出 ONNX ----------------
    torch.onnx.export(
        model,
        dummy_input,
        onnx_file,
        export_params=True,        # 同时保存权重
        opset_version=13,          # 推荐 ≥11,支持更多算子
        do_constant_folding=True,  # 折叠常量,减小体积
        input_names=['input'],     # 名称可自定义
        output_names=['output'],
        dynamic_axes={             # 支持动态 batch / 分辨率
            'input': {0: 'batch', 2: 'height', 3: 'width'},
            'output': {0: 'batch'}
        }
    )
    print(f"ONNX 已导出到 {onnx_file}")

    # ---------------- 4. 用 onnxruntime 验证 ----------------

    sess = ort.InferenceSession(onnx_file)
    ort_out = sess.run(None, {'input': dummy_input.numpy()})[0]

    np.testing.assert_allclose(torch_out.numpy(), ort_out, rtol=1e-5, atol=1e-5)
    print("ONNXRuntime 与 PyTorch 输出一致,导出成功!")

def simplify_onnx():
    import onnx
    from onnxsim import simplify

    model = onnx.load(onnx_file)

    model_simp, check = simplify(model)
    assert check, "简化后的模型校验失败!"

    onnx.save(model_simp, simp_onnx_file)
    print(f"简化后的 ONNX 模型已保存到 {simp_onnx_file}")

def add_slice_node():

    # 1. 载入模型 → 转 graphsurgeon
    graph = gs.import_onnx(onnx.load(onnx_file))

    print("Original graph inputs:", graph.inputs)
    print("Original graph outputs:", graph.outputs)
    print("Original graph nodes:", graph.nodes)

    # 2. 找到原输入 tensor
    old_in = graph.inputs[0]          # 名字是 "input"  shape=[1,3,224,224]
    old_in.shape = [1, 3, 224, 224]
    # 3. 构造 Slice 常量节点
    starts = gs.Constant("starts", np.array([56, 56], dtype=np.int64))
    ends   = gs.Constant("ends",   np.array([56+112, 56+112], dtype=np.int64))
    axes   = gs.Constant("axes",   np.array([2, 3], dtype=np.int64))
    # (batch/channel 维度不裁,所以 start=0, end=0 表示“到最大”)

    # 4. 插入 Slice 节点
    sliced = gs.Variable("sliced_112", dtype=old_in.dtype, shape=[1, 3, 112, 112])
    slice_node = gs.Node(
        op="Slice",
        inputs=[old_in, starts, ends, axes],
        outputs=[sliced],
        name="center_crop_112"
    )
    graph.nodes.append(slice_node)

    # 5. 把原网络所有对 old_in 的引用,改成 sliced
    for node in graph.nodes:
        if node is slice_node:
            continue
        for i, inp in enumerate(node.inputs):
            if inp == old_in:
                node.inputs[i] = sliced

    # 6. 重新指定 graph 输入输出(输入仍是 old_in,但数据流先经过 Slice)
    graph.inputs = [old_in]
    graph.outputs = graph.outputs   # 不变
    graph.cleanup()                 # 删除悬空节点,重新拓扑
    graph.toposort()
    model = gs.export_onnx(graph)
    model = onnx.shape_inference.infer_shapes(model)   # ← 关键一步  确保 shape 信息正确传播

    # 7. 导出
    onnx.save(model, slice_onnx_file)
    print(f"Done! 裁剪后模型已保存为 {slice_onnx_file}")

    # 8. 验证新模型
    for t in model.graph.value_info:
        if t.name == "sliced_112":
            print("sliced_112 shape:", [d.dim_value for d in t.type.tensor_type.shape.dim])

def get_subgraph_by_in_out_names():
    """根据输入输出名字,提取子图"""
    input_names = ["input"]
    output_names = ["/Concat_output_0"]
    model = onnx.load(slice_onnx_file)
    graph = gs.import_onnx(model)
    all_tensor_maps = graph.tensors()
    print("All tensor names in the graph:", all_tensor_maps)
    in_tensors = [all_tensor_maps[name] for name in input_names]
    out_tensors = [all_tensor_maps[name] for name in output_names]
    graph.inputs = in_tensors
    graph.outputs = out_tensors
    graph.outputs[0].name = "subgraph_output"
    print("Subgraph inputs:", graph.__doc__)
    graph.cleanup().toposort()
    model = gs.export_onnx(graph)
    model = onnx.shape_inference.infer_shapes(model)   # ← 关键一步  确保 shape 信息正确传播
    onnx.save(model, sub_slice_onnx_file)
    print(f"提取的子图ONNX 模型已保存到 {sub_slice_onnx_file}")
    
def main():
    # 导出模型
    if not os.path.exists(onnx_file):
        covert_to_onnx()
    # 简化模型
    if not os.path.exists(simp_onnx_file):
        simplify_onnx()
    # 添加 Slice 节点
    if not os.path.exists(slice_onnx_file):
        add_slice_node()
    # 提取子图
    if not os.path.exists(sub_slice_onnx_file):
        get_subgraph_by_in_out_names()

if __name__ == "__main__":
    main()

# reference

1.https://mmdeploy.readthedocs.io/zh-cn/stable/tutorial/05_onnx_model_editing.html (opens new window)