onnxruntime

Форк
0
/
dump_subgraphs.py 
54 строки · 1.8 Кб
1
import argparse
2
import os
3

4
import onnx
5

6

7
def export_and_recurse(node, attribute, output_dir, level):
8
    name = node.name
9
    name = name.replace("/", "_")
10
    sub_model = onnx.ModelProto()
11
    sub_model.graph.MergeFrom(attribute.g)
12
    filename = "L" + str(level) + "_" + node.op_type + "_" + attribute.name + "_" + name + ".onnx"
13
    onnx.save_model(sub_model, os.path.join(output_dir, filename))
14
    dump_subgraph(sub_model, output_dir, level + 1)
15

16

17
def dump_subgraph(model, output_dir, level=0):
18
    graph = model.graph
19

20
    for node in graph.node:
21
        if node.op_type == "Scan" or node.op_type == "Loop":
22
            body_attribute = next(iter(filter(lambda attr: attr.name == "body", node.attribute)))
23
            export_and_recurse(node, body_attribute, output_dir, level)
24
        if node.op_type == "If":
25
            then_attribute = next(iter(filter(lambda attr: attr.name == "then_branch", node.attribute)))
26
            else_attribute = next(iter(filter(lambda attr: attr.name == "else_branch", node.attribute)))
27
            export_and_recurse(node, then_attribute, output_dir, level)
28
            export_and_recurse(node, else_attribute, output_dir, level)
29

30

31
def parse_args():
32
    parser = argparse.ArgumentParser(
33
        os.path.basename(__file__), description="Dump all subgraphs from an ONNX model into separate onnx files."
34
    )
35
    parser.add_argument("-m", "--model", required=True, help="model file")
36
    parser.add_argument("-o", "--out", required=True, help="output directory")
37
    return parser.parse_args()
38

39

40
def main():
41
    args = parse_args()
42

43
    model_path = args.model
44
    out = os.path.abspath(args.out)
45

46
    if not os.path.exists(out):
47
        os.makedirs(out)
48

49
    model = onnx.load_model(model_path)
50
    dump_subgraph(model, out)
51

52

53
if __name__ == "__main__":
54
    main()
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.