pytorch
63 строки · 1.9 Кб
1# mypy: allow-untyped-defs
2from tensorboard.compat.proto.graph_pb2 import GraphDef3from tensorboard.compat.proto.node_def_pb2 import NodeDef4from tensorboard.compat.proto.versions_pb2 import VersionDef5from tensorboard.compat.proto.attr_value_pb2 import AttrValue6from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto7
8
9def load_onnx_graph(fname):10import onnx11
12m = onnx.load(fname) # type: ignore[attr-defined]13g = m.graph14return parse(g)15
16
17def parse(graph):18nodes = []19import itertools20
21nodes_proto = list(itertools.chain(graph.input, graph.output))22
23for node in nodes_proto:24print(node.name)25shapeproto = TensorShapeProto(26dim=[27TensorShapeProto.Dim(size=d.dim_value)28for d in node.type.tensor_type.shape.dim29]30)31nodes.append(32NodeDef(33name=node.name.encode(encoding="utf_8"),34op="Variable",35input=[],36attr={37"dtype": AttrValue(type=node.type.tensor_type.elem_type),38"shape": AttrValue(shape=shapeproto),39},40)41)42
43for node in graph.node:44_attr = []45for s in node.attribute:46_attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))47attr = ", ".join(_attr).encode(encoding="utf_8")48print(node.output[0])49nodes.append(50NodeDef(51name=node.output[0].encode(encoding="utf_8"),52op=node.op_type,53input=node.input,54attr={"parameters": AttrValue(s=attr)},55)56)57
58# two pass token replacement, appends opname to object id59mapping = {}60for node in nodes:61mapping[node.name] = node.op + "_" + node.name62
63return GraphDef(node=nodes, versions=VersionDef(producer=22))64