pytorch

Форк
0
/
_onnx_graph.py 
63 строки · 1.9 Кб
1
# mypy: allow-untyped-defs
2
from tensorboard.compat.proto.graph_pb2 import GraphDef
3
from tensorboard.compat.proto.node_def_pb2 import NodeDef
4
from tensorboard.compat.proto.versions_pb2 import VersionDef
5
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
6
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
7

8

9
def load_onnx_graph(fname):
10
    import onnx
11

12
    m = onnx.load(fname)  # type: ignore[attr-defined]
13
    g = m.graph
14
    return parse(g)
15

16

17
def parse(graph):
18
    nodes = []
19
    import itertools
20

21
    nodes_proto = list(itertools.chain(graph.input, graph.output))
22

23
    for node in nodes_proto:
24
        print(node.name)
25
        shapeproto = TensorShapeProto(
26
            dim=[
27
                TensorShapeProto.Dim(size=d.dim_value)
28
                for d in node.type.tensor_type.shape.dim
29
            ]
30
        )
31
        nodes.append(
32
            NodeDef(
33
                name=node.name.encode(encoding="utf_8"),
34
                op="Variable",
35
                input=[],
36
                attr={
37
                    "dtype": AttrValue(type=node.type.tensor_type.elem_type),
38
                    "shape": AttrValue(shape=shapeproto),
39
                },
40
            )
41
        )
42

43
    for node in graph.node:
44
        _attr = []
45
        for s in node.attribute:
46
            _attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
47
        attr = ", ".join(_attr).encode(encoding="utf_8")
48
        print(node.output[0])
49
        nodes.append(
50
            NodeDef(
51
                name=node.output[0].encode(encoding="utf_8"),
52
                op=node.op_type,
53
                input=node.input,
54
                attr={"parameters": AttrValue(s=attr)},
55
            )
56
        )
57

58
    # two pass token replacement, appends opname to object id
59
    mapping = {}
60
    for node in nodes:
61
        mapping[node.name] = node.op + "_" + node.name
62

63
    return GraphDef(node=nodes, versions=VersionDef(producer=22))
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.