# 1.ONNX 模型在底层是用什么格式存储的?
ONNX 在底层是用 Protobuf 定义的。Protobuf,全称 Protocol Buffer,是 Google 提出的一套表示和序列化数据的机制。使用 Protobuf 时,用户需要先写一份数据定义文件,再根据这份定义文件把数据存储进一份二进制文件。可以说,数据定义文件就是数据类,二进制文件就是数据类的实例。一个 Protobuf 数据定义文件的例子:
message Person {
required string name = 1;
required int32 id = 2;
optional string email = 3;
}
这段定义表示在 Person 这种数据类型中,必须包含 name、id 这两个字段,选择性包含 email 字段。根据这份定义文件,用户就可以选择一种编程语言,定义一个含有成员变量 name、id、email 的 Person 类,把这个类的某个实例用 Protobuf 存储成二进制文件;反之,用户也可以用二进制文件和对应的数据定义文件,读取出一个 Person 类的实例。
对于 ONNX ,它的 Protobuf 数据定义文件在其开源库 (opens new window)中,这些文件定义了神经网络中模型、节点、张量的数据类型规范;而数据定义文件对应的二进制文件就是我们熟悉的“.onnx”文件,每一个 “.onnx” 文件按照数据定义规范,存储了一个神经网络的所有相关数据。直接用 Protobuf 生成 ONNX 模型还是比较麻烦的。幸运的是,ONNX 提供了很多实用 API,我们可以在完全不了解 Protobuf 的前提下,构造和读取 ONNX 模型。
神经网络本质上是一个计算图。计算图的节点是算子,边是参与运算的张量。而通过可视化 ONNX 模型,我们知道 ONNX 记录了所有算子节点的属性信息,并把参与运算的张量信息存储在算子节点的输入输出信息中。ONNX 模型的结构可以用类图大致表示如下:
如图所示,一个 ONNX 模型可以用 ModelProto 类表示。ModelProto 包含了版本、创建者等日志信息,还包含了存储计算图结构的 graph。GraphProto 类则由输入张量信息、输出张量信息、节点信息组成。张量信息 ValueInfoProto 类包括张量名、基本数据类型、形状。节点信息 NodeProto 类包含了算子名、算子输入张量名、算子输出张量名。 让我们来看一个具体的例子。假如我们有一个描述 output=a*x+b 的 ONNX 模型 model,用 print(model) 可以输出以下内容:
ir_version: 8
graph {
node {
input: "a"
input: "x"
output: "c"
op_type: "Mul"
}
node {
input: "c"
input: "b"
output: "output"
op_type: "Add"
}
name: "linear_func"
input {
name: "a"
type {
tensor_type {
elem_type: 1
shape {
dim {dim_value: 10}
dim {dim_value: 10}
}
}
}
}
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {dim_value: 10}
dim {dim_value: 10}
}
}
}
}
input {
name: "b"
type {
tensor_type {
elem_type: 1
shape {
dim {dim_value: 10}
dim {dim_value: 10}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim { dim_value: 10}
dim { dim_value: 10}
}
}
}
}
}
opset_import {version: 15}
对应上文中的类图,这个模型的信息由 ir_version,opset_import 等全局信息和 graph 图信息组成。而 graph 包含一个乘法节点、一个加法节点、三个输入张量 a, x, b 以及一个输出张量 output。在下一节里,我们会用 API 构造出这个模型,并输出这段结果。根据这里定义的模型结构信息,我们就可以按这个数据结构进行模型的编辑了。
# 2.使用graphsurgeon进行模型的编辑
onnx本身提供了很多API能够加载解析onnx模型,但修改模型时不够方便,譬如要修改某个边(也就是tensor)的shape信息时,使用onnx时操作为:
model = onnx.load(onnx_file)
model.inputs[0].type.tensor_type.shape.dim.add().dim_value = 100
可以看到比较麻烦,nvidia的开发者开发了一套专门用来编辑onnx模型的工具graphsurgeon
使用graphsurgeon修改tensor的shape的示例
model = onnx.load(onnx_file)
graph = graphsurgeon.import_onnx(model)
graph.inputs[0].shape[1] = 100
可以看到比手动修改onnx model要方便很多,此外还提供了诸如cleanup清除孤立的node,toposort对node点进行拓扑排序,tensors获取所有tensor等API,可以方便的进行node的增删修改。
# ---------------- 1. 构建模型 ----------------
import torch
import torch.nn as nn
import numpy as np
import onnx
import onnxruntime as ort
import onnx_graphsurgeon as gs
import os
onnx_file = "convnet.onnx"
simp_onnx_file = "convnet_simplified.onnx"
slice_onnx_file = "convnet_simplified_sliced.onnx"
sub_slice_onnx_file = "sub_convnet_simplified_sliced.onnx"
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.features1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1), # 输入 3 通道
nn.ReLU(inplace=True),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)) # 输出 32×1×1
)
self.features2 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1), # 输入 3 通道
nn.ReLU(inplace=True),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)) # 输出 32×1×1
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x):
x1 = self.features1(x) # -> (N,32,1,1)
x2 = self.features2(x) # -> (N,32,1,1)
x = torch.concat([x1, x2], dim=1) # -> (N,64,1,1)
x = x.view(x.size(0), -1) # -> (N,64)
return self.classifier(x) # -> (N,10)
def covert_to_onnx():
model = ConvNet().eval()
# ---------------- 2. 构造假输入 ----------------
dummy_input = torch.randn(1, 3, 224, 224) # NCHW
with torch.no_grad():
torch_out = model(dummy_input)
print("PyTorch output shape:", torch_out.shape) # 应为 [1,10]
# ---------------- 3. 导出 ONNX ----------------
torch.onnx.export(
model,
dummy_input,
onnx_file,
export_params=True, # 同时保存权重
opset_version=13, # 推荐 ≥11,支持更多算子
do_constant_folding=True, # 折叠常量,减小体积
input_names=['input'], # 名称可自定义
output_names=['output'],
dynamic_axes={ # 支持动态 batch / 分辨率
'input': {0: 'batch', 2: 'height', 3: 'width'},
'output': {0: 'batch'}
}
)
print(f"ONNX 已导出到 {onnx_file}")
# ---------------- 4. 用 onnxruntime 验证 ----------------
sess = ort.InferenceSession(onnx_file)
ort_out = sess.run(None, {'input': dummy_input.numpy()})[0]
np.testing.assert_allclose(torch_out.numpy(), ort_out, rtol=1e-5, atol=1e-5)
print("ONNXRuntime 与 PyTorch 输出一致,导出成功!")
def simplify_onnx():
import onnx
from onnxsim import simplify
model = onnx.load(onnx_file)
model_simp, check = simplify(model)
assert check, "简化后的模型校验失败!"
onnx.save(model_simp, simp_onnx_file)
print(f"简化后的 ONNX 模型已保存到 {simp_onnx_file}")
def add_slice_node():
# 1. 载入模型 → 转 graphsurgeon
graph = gs.import_onnx(onnx.load(onnx_file))
print("Original graph inputs:", graph.inputs)
print("Original graph outputs:", graph.outputs)
print("Original graph nodes:", graph.nodes)
# 2. 找到原输入 tensor
old_in = graph.inputs[0] # 名字是 "input" shape=[1,3,224,224]
old_in.shape = [1, 3, 224, 224]
# 3. 构造 Slice 常量节点
starts = gs.Constant("starts", np.array([56, 56], dtype=np.int64))
ends = gs.Constant("ends", np.array([56+112, 56+112], dtype=np.int64))
axes = gs.Constant("axes", np.array([2, 3], dtype=np.int64))
# (batch/channel 维度不裁,所以 start=0, end=0 表示“到最大”)
# 4. 插入 Slice 节点
sliced = gs.Variable("sliced_112", dtype=old_in.dtype, shape=[1, 3, 112, 112])
slice_node = gs.Node(
op="Slice",
inputs=[old_in, starts, ends, axes],
outputs=[sliced],
name="center_crop_112"
)
graph.nodes.append(slice_node)
# 5. 把原网络所有对 old_in 的引用,改成 sliced
for node in graph.nodes:
if node is slice_node:
continue
for i, inp in enumerate(node.inputs):
if inp == old_in:
node.inputs[i] = sliced
# 6. 重新指定 graph 输入输出(输入仍是 old_in,但数据流先经过 Slice)
graph.inputs = [old_in]
graph.outputs = graph.outputs # 不变
graph.cleanup() # 删除悬空节点,重新拓扑
graph.toposort()
model = gs.export_onnx(graph)
model = onnx.shape_inference.infer_shapes(model) # ← 关键一步 确保 shape 信息正确传播
# 7. 导出
onnx.save(model, slice_onnx_file)
print(f"Done! 裁剪后模型已保存为 {slice_onnx_file}")
# 8. 验证新模型
for t in model.graph.value_info:
if t.name == "sliced_112":
print("sliced_112 shape:", [d.dim_value for d in t.type.tensor_type.shape.dim])
def get_subgraph_by_in_out_names():
"""根据输入输出名字,提取子图"""
input_names = ["input"]
output_names = ["/Concat_output_0"]
model = onnx.load(slice_onnx_file)
graph = gs.import_onnx(model)
all_tensor_maps = graph.tensors()
print("All tensor names in the graph:", all_tensor_maps)
in_tensors = [all_tensor_maps[name] for name in input_names]
out_tensors = [all_tensor_maps[name] for name in output_names]
graph.inputs = in_tensors
graph.outputs = out_tensors
graph.outputs[0].name = "subgraph_output"
print("Subgraph inputs:", graph.__doc__)
graph.cleanup().toposort()
model = gs.export_onnx(graph)
model = onnx.shape_inference.infer_shapes(model) # ← 关键一步 确保 shape 信息正确传播
onnx.save(model, sub_slice_onnx_file)
print(f"提取的子图ONNX 模型已保存到 {sub_slice_onnx_file}")
def main():
# 导出模型
if not os.path.exists(onnx_file):
covert_to_onnx()
# 简化模型
if not os.path.exists(simp_onnx_file):
simplify_onnx()
# 添加 Slice 节点
if not os.path.exists(slice_onnx_file):
add_slice_node()
# 提取子图
if not os.path.exists(sub_slice_onnx_file):
get_subgraph_by_in_out_names()
if __name__ == "__main__":
main()