pytorch
239 строк · 8.1 Кб
1from __future__ import annotations2
3import argparse4import os5import sys6from pathlib import Path7from typing import Any, cast8
9import yaml10
11
12try:13# use faster C loader if available14from yaml import CSafeLoader as YamlLoader15except ImportError:16from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]17
18NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"19TAGS_PATH = "aten/src/ATen/native/tags.yaml"20
21
22def generate_code(23gen_dir: Path,24native_functions_path: str | None = None,25tags_path: str | None = None,26install_dir: str | None = None,27subset: str | None = None,28disable_autograd: bool = False,29force_schema_registration: bool = False,30operator_selector: Any = None,31) -> None:32from tools.autograd.gen_annotated_fn_args import gen_annotated33from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python34
35from torchgen.selective_build.selector import SelectiveBuilder36
37# Build ATen based Variable classes38if install_dir is None:39install_dir = os.fspath(gen_dir / "torch/csrc")40python_install_dir = os.fspath(gen_dir / "torch/testing/_internal/generated")41else:42python_install_dir = install_dir43autograd_gen_dir = os.path.join(install_dir, "autograd", "generated")44for d in (autograd_gen_dir, python_install_dir):45os.makedirs(d, exist_ok=True)46autograd_dir = os.fspath(Path(__file__).parent.parent / "autograd")47
48if subset == "pybindings" or not subset:49gen_autograd_python(50native_functions_path or NATIVE_FUNCTIONS_PATH,51tags_path or TAGS_PATH,52autograd_gen_dir,53autograd_dir,54)55
56if operator_selector is None:57operator_selector = SelectiveBuilder.get_nop_selector()58
59if subset == "libtorch" or not subset:60gen_autograd(61native_functions_path or NATIVE_FUNCTIONS_PATH,62tags_path or TAGS_PATH,63autograd_gen_dir,64autograd_dir,65disable_autograd=disable_autograd,66operator_selector=operator_selector,67)68
69if subset == "python" or not subset:70gen_annotated(71native_functions_path or NATIVE_FUNCTIONS_PATH,72tags_path or TAGS_PATH,73python_install_dir,74autograd_dir,75)76
77
78def get_selector_from_legacy_operator_selection_list(79selected_op_list_path: str,80) -> Any:81with open(selected_op_list_path) as f:82# strip out the overload part83# It's only for legacy config - do NOT copy this code!84selected_op_list = {85opname.split(".", 1)[0] for opname in yaml.load(f, Loader=YamlLoader)86}87
88# Internal build doesn't use this flag any more. Only used by OSS89# build now. Every operator should be considered a root operator90# (hence generating unboxing code for it, which is consistent with91# the current behavior), and also be considered as used for92# training, since OSS doesn't support training on mobile for now.93#94is_root_operator = True95is_used_for_training = True96
97from torchgen.selective_build.selector import SelectiveBuilder98
99selector = SelectiveBuilder.from_legacy_op_registration_allow_list(100selected_op_list,101is_root_operator,102is_used_for_training,103)104
105return selector106
107
108def get_selector(109selected_op_list_path: str | None,110operators_yaml_path: str | None,111) -> Any:112# cwrap depends on pyyaml, so we can't import it earlier113root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))114sys.path.insert(0, root)115from torchgen.selective_build.selector import SelectiveBuilder116
117assert not (118selected_op_list_path is not None and operators_yaml_path is not None119), (120"Expected at most one of selected_op_list_path and "121+ "operators_yaml_path to be set."122)123
124if selected_op_list_path is None and operators_yaml_path is None:125return SelectiveBuilder.get_nop_selector()126elif selected_op_list_path is not None:127return get_selector_from_legacy_operator_selection_list(selected_op_list_path)128else:129return SelectiveBuilder.from_yaml_path(cast(str, operators_yaml_path))130
131
132def main() -> None:133parser = argparse.ArgumentParser(description="Autogenerate code")134parser.add_argument("--native-functions-path")135parser.add_argument("--tags-path")136parser.add_argument(137"--gen-dir",138type=Path,139default=Path("."),140help="Root directory where to install files. Defaults to the current working directory.",141)142parser.add_argument(143"--install-dir",144"--install_dir",145help=(146"Deprecated. Use --gen-dir instead. The semantics are different, do not change "147"blindly."148),149)150parser.add_argument(151"--subset",152help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.',153)154parser.add_argument(155"--disable-autograd",156default=False,157action="store_true",158help="It can skip generating autograd related code when the flag is set",159)160parser.add_argument(161"--selected-op-list-path",162help="Path to the YAML file that contains the list of operators to include for custom build.",163)164parser.add_argument(165"--operators-yaml-path",166"--operators_yaml_path",167help="Path to the model YAML file that contains the list of operators to include for custom build.",168)169parser.add_argument(170"--force-schema-registration",171"--force_schema_registration",172action="store_true",173help="force it to generate schema-only registrations for ops that are not"174"listed on --selected-op-list",175)176parser.add_argument(177"--gen-lazy-ts-backend",178"--gen_lazy_ts_backend",179action="store_true",180help="Enable generation of the torch::lazy TorchScript backend",181)182parser.add_argument(183"--per-operator-headers",184"--per_operator_headers",185action="store_true",186help="Build lazy tensor ts backend with per-operator ATen headers, must match how ATen was built",187)188options = parser.parse_args()189
190generate_code(191options.gen_dir,192options.native_functions_path,193options.tags_path,194options.install_dir,195options.subset,196options.disable_autograd,197options.force_schema_registration,198# options.selected_op_list199operator_selector=get_selector(200options.selected_op_list_path, options.operators_yaml_path201),202)203
204if options.gen_lazy_ts_backend:205aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))206ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")207ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"208ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"209install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc")210lazy_install_dir = os.path.join(install_dir, "lazy/generated")211os.makedirs(lazy_install_dir, exist_ok=True)212
213assert os.path.isfile(214ts_backend_yaml
215), f"Unable to access ts_backend_yaml: {ts_backend_yaml}"216assert os.path.isfile(217ts_native_functions
218), f"Unable to access {ts_native_functions}"219from torchgen.dest.lazy_ir import GenTSLazyIR220from torchgen.gen_lazy_tensor import run_gen_lazy_tensor221
222run_gen_lazy_tensor(223aten_path=aten_path,224source_yaml=ts_backend_yaml,225backend_name="TorchScript",226output_dir=lazy_install_dir,227dry_run=False,228impl_path=ts_native_functions,229node_base="TsNode",230node_base_hdr=ts_node_base,231build_in_tree=True,232lazy_ir_generator=GenTSLazyIR,233per_operator_headers=options.per_operator_headers,234gen_forced_fallback_code=True,235)236
237
238if __name__ == "__main__":239main()240