pytorch
187 строк · 6.3 Кб
1#!/usr/bin/env python3
2
3from __future__ import annotations4
5import argparse6import json7import os8import sys9from functools import reduce10from typing import Any11
12import yaml13from tools.lite_interpreter.gen_selected_mobile_ops_header import (14write_selected_mobile_ops,15)
16
17from torchgen.selective_build.selector import (18combine_selective_builders,19SelectiveBuilder,20)
21
22
23def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]:24return set(selective_builder.operators.keys())25
26
27def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]:28ops = []29for op_name, op in selective_builder.operators.items():30if op.is_used_for_training:31ops.append(op_name)32return set(ops)33
34
35def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:36ops = []37for op_name, op in selective_builder.operators.items():38if op.include_all_overloads:39ops.append(op_name)40if ops:41raise Exception( # noqa: TRY00242(43"Operators that include all overloads are "44+ "not allowed since --allow-include-all-overloads "45+ "was specified: {}"46).format(", ".join(ops))47)48
49
50def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None:51supported_mobile_models_source = """/*52* Generated by gen_oplist.py
53*/
54#include "fb/supported_mobile_models/SupportedMobileModels.h"
55
56
57struct SupportedMobileModelCheckerRegistry {{
58SupportedMobileModelCheckerRegistry() {{
59auto& ref = facebook::pytorch::supported_model::SupportedMobileModelChecker::singleton();
60ref.set_supported_md5_hashes(std::unordered_set<std::string>{{
61{supported_hashes_template}
62}});
63}}
64}};
65
66// This is a global object, initializing which causes the registration to happen.
67SupportedMobileModelCheckerRegistry register_model_versions;
68
69
70"""
71
72# Generate SupportedMobileModelsRegistration.cpp73md5_hashes = set()74for model_dict in model_dicts:75if "debug_info" in model_dict:76debug_info = json.loads(model_dict["debug_info"][0])77if debug_info["is_new_style_rule"]:78for asset_info in debug_info["asset_info"].values():79md5_hashes.update(asset_info["md5_hash"])80
81supported_hashes = ""82for md5 in md5_hashes:83supported_hashes += f'"{md5}",\n'84with open(85os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb"86) as out_file:87source = supported_mobile_models_source.format(88supported_hashes_template=supported_hashes89)90out_file.write(source.encode("utf-8"))91
92
93def main(argv: list[Any]) -> None:94"""This binary generates 3 files:95
961. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
97dtypes captured by tracing
982. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
99"""
100parser = argparse.ArgumentParser(description="Generate operator lists")101parser.add_argument(102"--output-dir",103"--output_dir",104help=(105"The directory to store the output yaml files (selected_mobile_ops.h, "106+ "selected_kernel_dtypes.h, selected_operators.yaml)"107),108required=True,109)110parser.add_argument(111"--model-file-list-path",112"--model_file_list_path",113help=(114"Path to a file that contains the locations of individual "115+ "model YAML files that contain the set of used operators. This "116+ "file path must have a leading @-symbol, which will be stripped "117+ "out before processing."118),119required=True,120)121parser.add_argument(122"--allow-include-all-overloads",123"--allow_include_all_overloads",124help=(125"Flag to allow operators that include all overloads. "126+ "If not set, operators registered without using the traced style will"127+ "break the build."128),129action="store_true",130default=False,131required=False,132)133options = parser.parse_args(argv)134
135if os.path.isfile(options.model_file_list_path):136print("Processing model file: ", options.model_file_list_path)137model_dicts = []138model_dict = yaml.safe_load(open(options.model_file_list_path))139model_dicts.append(model_dict)140else:141print("Processing model directory: ", options.model_file_list_path)142assert options.model_file_list_path[0] == "@"143model_file_list_path = options.model_file_list_path[1:]144
145model_dicts = []146with open(model_file_list_path) as model_list_file:147model_file_names = model_list_file.read().split()148for model_file_name in model_file_names:149with open(model_file_name, "rb") as model_file:150model_dict = yaml.safe_load(model_file)151model_dicts.append(model_dict)152
153selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts]154
155# While we have the model_dicts generate the supported mobile models api156gen_supported_mobile_models(model_dicts, options.output_dir)157
158# We may have 0 selective builders since there may not be any viable159# pt_operator_library rule marked as a dep for the pt_operator_registry rule.160# This is potentially an error, and we should probably raise an assertion161# failure here. However, this needs to be investigated further.162selective_builder = SelectiveBuilder.from_yaml_dict({})163if len(selective_builders) > 0:164selective_builder = reduce(165combine_selective_builders,166selective_builders,167)168
169if not options.allow_include_all_overloads:170throw_if_any_op_includes_overloads(selective_builder)171with open(172os.path.join(options.output_dir, "selected_operators.yaml"), "wb"173) as out_file:174out_file.write(175yaml.safe_dump(176selective_builder.to_dict(), default_flow_style=False177).encode("utf-8"),178)179
180write_selected_mobile_ops(181os.path.join(options.output_dir, "selected_mobile_ops.h"),182selective_builder,183)184
185
186if __name__ == "__main__":187main(sys.argv[1:])188