onnxruntime
54 строки · 1.8 Кб
1import argparse2import os3
4import onnx5
6
7def export_and_recurse(node, attribute, output_dir, level):8name = node.name9name = name.replace("/", "_")10sub_model = onnx.ModelProto()11sub_model.graph.MergeFrom(attribute.g)12filename = "L" + str(level) + "_" + node.op_type + "_" + attribute.name + "_" + name + ".onnx"13onnx.save_model(sub_model, os.path.join(output_dir, filename))14dump_subgraph(sub_model, output_dir, level + 1)15
16
17def dump_subgraph(model, output_dir, level=0):18graph = model.graph19
20for node in graph.node:21if node.op_type == "Scan" or node.op_type == "Loop":22body_attribute = next(iter(filter(lambda attr: attr.name == "body", node.attribute)))23export_and_recurse(node, body_attribute, output_dir, level)24if node.op_type == "If":25then_attribute = next(iter(filter(lambda attr: attr.name == "then_branch", node.attribute)))26else_attribute = next(iter(filter(lambda attr: attr.name == "else_branch", node.attribute)))27export_and_recurse(node, then_attribute, output_dir, level)28export_and_recurse(node, else_attribute, output_dir, level)29
30
31def parse_args():32parser = argparse.ArgumentParser(33os.path.basename(__file__), description="Dump all subgraphs from an ONNX model into separate onnx files."34)35parser.add_argument("-m", "--model", required=True, help="model file")36parser.add_argument("-o", "--out", required=True, help="output directory")37return parser.parse_args()38
39
40def main():41args = parse_args()42
43model_path = args.model44out = os.path.abspath(args.out)45
46if not os.path.exists(out):47os.makedirs(out)48
49model = onnx.load_model(model_path)50dump_subgraph(model, out)51
52
53if __name__ == "__main__":54main()55