pytorch
633 строки · 22.3 Кб
1#!/usr/bin/env python3
2
3from __future__ import annotations
4
5import argparse
6import json
7import sys
8from typing import Any
9
10import yaml
11from gen_op_registration_allowlist import (
12canonical_name,
13gen_transitive_closure,
14load_op_dep_graph,
15)
16
17from torchgen.selective_build.operator import (
18merge_operator_dicts,
19SelectiveBuildOperator,
20)
21from torchgen.selective_build.selector import merge_kernel_metadata
22
23
24# Generate YAML file containing the operators used for a specific PyTorch model.
25# ------------------------------------------------------------------------------
26#
27# This binary is responsible for generating the model_operators.yaml file for
28# each model from a pt_operator_library() BUCK macro invocation.
29#
30# Output YAML file format:
31# ------------------------
32#
33# <BEGIN FILE CONTENTS>
34# include_all_non_op_selectives: False
35# include_all_operators: False
36# debug_info:
37# - model1@v100
38# - model2@v50
39# operators:
40# aten::add:
41# is_root_operator: Yes
42# is_used_for_training: Yes
43# include_all_overloads: No
44# debug_info:
45# - model1@v100
46# - model2@v50
47# aten::add.int:
48# is_root_operator: No
49# is_used_for_training: No
50# include_all_overloads: Yes
51# kernel_metadata:
52# add_kernel:
53# - Int8
54# - UInt32
55# sub_kernel:
56# - Int16
57# - Float
58# <END FILE CONTENTS>
59#
60# There are a few main inputs to this application
61# -----------------------------------------------
62#
63# 1. Inference Root Operators (--root-ops): Root operators (called directly
64# from TorchScript) used by inference use-cases.
65#
66# 2. Training Root Operators (--training-root-ops): Root operators used
67# by training use-cases. Currently, this list is the list of all operators
68# used by training, and not just the root operators. All Training ops are
69# also considered for inference, so these are merged into inference ops.
70#
71# 3. Operator Depencency Graph (--dep-graph-yaml-path): A path to the
72# operator dependency graph used to determine which operators depend on
73# which other operators for correct functioning. This is used for
74# generating the transitive closure of all the operators used by the
75# model based on the root operators when static selective build is used.
76# For tracing based selective build, we don't need to perform this
77# transitive cloure.
78#
79# 4. Model Metadata (--model-name, --model-versions, --model-assets,
80# --model-backends): Self-descriptive. These are used to tell this
81# script which model operator lists to fetch from the Model
82# Build Metadata YAML files.
83#
84# 5. Model YAML files (--models-yaml-path): These yaml files contains
85# (for each model/version/asset/backend) the set of used root and traced
86# operators. This is used to extract the actual set of operators
87# needed to be included in the build.
88#
89
90
91def canonical_opnames(opnames: list[str]) -> list[str]:
92return [canonical_name(opname) for opname in opnames]
93
94
95def make_filter_from_options(
96model_name: str,
97model_versions: list[str],
98model_assets: list[str] | None,
99model_backends: list[str] | None,
100):
101def is_model_included(model_info) -> bool:
102model = model_info["model"]
103if model["name"] != model_name:
104return False
105if str(model["version"]) not in model_versions:
106return False
107if model_assets is not None and model["asset"] not in model_assets:
108return False
109# TODO: Handle backend later
110return True
111
112return is_model_included
113
114
115# Returns if a the specified rule is a new or old style pt_operator_library
116def is_new_style_rule(model_name: str, model_versions: list[str] | None):
117return model_name is not None and model_versions is not None
118
119
120# Verifies that specified model_name, and all specified versions and assets
121# appear in at least one model yaml. Throws if verification is failed,
122# returns None on success
123def verify_all_specified_present(
124model_assets: list[str] | None,
125model_versions: list[str],
126selected_models_yaml: list[dict[str, Any]],
127rule_name: str,
128model_name: str,
129new_style_rule: bool,
130) -> None:
131def find_missing_items(model_items, key, selected_models_yaml):
132missing_items = []
133if not new_style_rule or not model_items:
134return missing_items
135for item in model_items:
136found = False
137for model in selected_models_yaml:
138if str(model["model"][key]) == item:
139found = True
140if not found:
141missing_items.append(item)
142return missing_items
143
144missing_assets = find_missing_items(model_assets, "asset", selected_models_yaml)
145missing_versions = find_missing_items(
146model_versions, "version", selected_models_yaml
147)
148
149if len(missing_versions) > 0 or len(missing_assets) > 0: # at least one is missing
150name_warning = ""
151if len(selected_models_yaml) == 0:
152name_warning = (
153"WARNING: 0 yaml's were found for target rule. This could be because the "
154+ "provided model name: {name} is incorrect. Please check that field as well as "
155+ "the assets and versions."
156).format(name=model_name)
157raise RuntimeError(
158(
159"Error: From the pt_operator_library rule for Rule: {name}, at least one entry for the "
160+ "following fields was expected -- Model: {model_name} Expected Assets: {expected_assets}, Expected Versions: "
161+ "{expected_versions}. {name_warning} In all_mobile_models.yaml either no assets were on one of the "
162+ "specified versions, one of the specified assets was not present on any of the specified "
163+ "versions, or both. Assets not found: {missing_assets}, Versions not found: {missing_versions} "
164+ "For questions please ask in https://fb.workplace.com/groups/2148543255442743/"
165).format(
166name=rule_name,
167model_name=model_name,
168expected_versions=model_versions,
169expected_assets=model_assets
170if model_assets
171else "<All model assets present on specified versions>",
172name_warning=name_warning,
173missing_versions=missing_versions
174if len(missing_versions) > 0
175else "<All specified versions had at least one asset>",
176missing_assets=missing_assets
177if len(missing_assets) > 0
178else "<All specified assets are present on at least 1 version>",
179)
180)
181
182
183# Uses the selected models configs and then combines them into one dictionary,
184# formats them as a string, and places the string into output as a top level debug_info
185def create_debug_info_from_selected_models(
186output: dict[str, object],
187selected_models: list[dict],
188new_style_rule: bool,
189) -> None:
190model_dict = {
191"asset_info": {}, # maps asset name -> dict of asset metadata like hashes
192"is_new_style_rule": new_style_rule,
193}
194
195for model in selected_models:
196model_info = model["model"]
197asset = model_info["asset"]
198hash = model_info["md5_hash"]
199
200asset_info = model_dict["asset_info"].setdefault(asset, {})
201
202asset_info.setdefault("md5_hash", []).append(hash)
203
204# Will later be used in gen_oplist to generate the model/version/asset checking
205output["debug_info"] = [json.dumps(model_dict)]
206
207
208def fill_output(output: dict[str, object], options: object) -> None:
209"""Populate the output dict with the information required to serialize
210the YAML file used for selective build.
211"""
212dept_graph = load_op_dep_graph(options.dep_graph_yaml_path)
213
214model_versions = (
215options.model_versions.split(",") if options.model_versions is not None else []
216)
217model_assets = (
218options.model_assets.split(",") if options.model_assets is not None else None
219)
220
221all_models_yaml = []
222if options.models_yaml_path:
223for yaml_path in options.models_yaml_path:
224with open(yaml_path, "rb") as f:
225all_models_yaml.append(yaml.safe_load(f))
226
227model_filter_func = make_filter_from_options(
228options.model_name, model_versions, model_assets, options.model_backends
229)
230
231selected_models_yaml = list(filter(model_filter_func, all_models_yaml))
232
233verify_all_specified_present(
234model_assets=model_assets,
235model_versions=model_versions,
236selected_models_yaml=selected_models_yaml,
237rule_name=options.rule_name,
238model_name=options.model_name,
239new_style_rule=is_new_style_rule(options.model_name, options.model_versions),
240)
241
242create_debug_info_from_selected_models(
243output,
244selected_models_yaml,
245is_new_style_rule(options.model_name, options.model_versions),
246)
247
248# initialize variables for static build from the pt_operator_library rule
249if options.root_ops is not None:
250static_root_ops = set(filter(lambda x: len(x) > 0, options.root_ops.split(",")))
251else:
252static_root_ops = set()
253
254static_training_root_ops = set(
255filter(
256lambda x: len(x) > 0,
257(options.training_root_ops or "").split(","),
258)
259)
260if len(static_training_root_ops) > 0:
261static_root_ops = static_root_ops | static_training_root_ops
262# end if
263
264root_ops_unexpand = set()
265traced_ops = set()
266training_root_ops_unexpand = set()
267traced_training_ops = set()
268all_kernel_metadata = []
269all_custom_classes = set()
270all_build_features = set()
271
272# Go through each yaml file and retrieve operator information.
273for model_info in selected_models_yaml:
274if "traced_operators" not in model_info:
275# If this YAML file doesn't specify any traced operators, then it is using
276# the static analysis selective build approach of finding transitively
277# used operators, and we should update root_ops with the set of root
278# operators, all of whose overloads must be included. In addition, these
279# root_ops will be further expanded using the transitive closure of
280# operator dependencies.
281static_root_ops = static_root_ops | set(model_info["root_operators"])
282else:
283# If this YAML file specifies traced operators, then it is using
284# the tracing based selective build approach of finding used
285# operators, and we should update root_ops_unexpand with the set of root
286# operators whose overloads don't need to be included. In addition, these
287# root_ops_unexpand will NOT be further expanded. If the train flag is
288# set then the ops will be used for training, so we put them in a separate
289# set
290if model_info["train"]:
291training_root_ops_unexpand = training_root_ops_unexpand | set(
292model_info["root_operators"]
293)
294traced_training_ops = traced_training_ops | set(
295model_info["traced_operators"]
296)
297else:
298root_ops_unexpand = root_ops_unexpand | set(
299model_info["root_operators"]
300)
301traced_ops = traced_ops | set(model_info["traced_operators"])
302
303if "kernel_metadata" in model_info:
304all_kernel_metadata.append(model_info["kernel_metadata"])
305
306if "custom_classes" in model_info:
307all_custom_classes = all_custom_classes | set(model_info["custom_classes"])
308
309if "build_features" in model_info:
310all_build_features = all_build_features | set(model_info["build_features"])
311
312# This following section on transitive closure is relevant to static build only
313canonical_root_ops = canonical_opnames(static_root_ops)
314# If no canonical_root_ops exist, don't compute the transitive closure
315# otherwise, we will include __BASE__ and __ROOT__ ops and mark them as required
316# for inference.
317if len(canonical_root_ops) > 0:
318closure_op_list = gen_transitive_closure(dept_graph, canonical_root_ops)
319else:
320closure_op_list = set()
321
322canonical_training_root_ops = canonical_opnames(static_training_root_ops)
323# If no canonical_training_root_ops exist, don't compute the transitive closure
324# otherwise, we will include __BASE__ and __ROOT__ ops and mark them as required
325# for training.
326if len(canonical_training_root_ops) > 0:
327closure_training_op_list = gen_transitive_closure(
328dept_graph, canonical_training_root_ops, train=True
329)
330else:
331closure_training_op_list = set()
332
333# bucketed_ops holds sets of operators that correspond to specific semantic buckets. For
334# example:
335#
336# 1. Root Operators not used for training w/o full overload inclusion
337# 2. Root Operators not used for training w/ full overload inclusion
338# 3. Root Operators used for training w/o full overload inclusion
339# 4. Root Operators used for training w/ full overload inclusion
340# 5. Non-root Operators not used for training w/o full overload inclusion
341# etc...
342#
343# Basically for each of the 3 boolean conditional, there are 2
344# options (True/False).
345#
346bucketed_ops = []
347
348# START STATIC BUILD OPS
349static_root_ops_bucket = {}
350for op_name in static_root_ops:
351op = SelectiveBuildOperator.from_yaml_dict(
352op_name,
353{
354"is_root_operator": True,
355"is_used_for_training": False,
356"include_all_overloads": not options.not_include_all_overloads_static_root_ops,
357"debug_info": [options.model_name],
358},
359)
360static_root_ops_bucket[op_name] = op
361bucketed_ops.append(static_root_ops_bucket)
362
363closure_ops_bucket = {}
364for op_name in closure_op_list:
365op = SelectiveBuildOperator.from_yaml_dict(
366op_name,
367{
368"is_root_operator": False,
369"is_used_for_training": False,
370"include_all_overloads": not options.not_include_all_overloads_closure_ops,
371"debug_info": [options.model_name],
372},
373)
374closure_ops_bucket[op_name] = op
375bucketed_ops.append(closure_ops_bucket)
376
377static_training_root_ops_bucket = {}
378for op_name in static_training_root_ops:
379op = SelectiveBuildOperator.from_yaml_dict(
380op_name,
381{
382"is_root_operator": True,
383"is_used_for_training": True,
384"include_all_overloads": True,
385"debug_info": [options.model_name],
386},
387)
388static_training_root_ops_bucket[op_name] = op
389bucketed_ops.append(static_training_root_ops_bucket)
390
391closure_training_ops_bucket = {}
392for op_name in closure_training_op_list:
393op = SelectiveBuildOperator.from_yaml_dict(
394op_name,
395{
396"is_root_operator": False,
397"is_used_for_training": True,
398"include_all_overloads": True,
399"debug_info": [options.model_name],
400},
401)
402closure_training_ops_bucket[op_name] = op
403bucketed_ops.append(closure_training_ops_bucket)
404# END STATIC BUILD OPS
405
406# START TRACING BASED BUILD OPS
407root_ops_unexpand_bucket = {}
408for op_name in root_ops_unexpand:
409op = SelectiveBuildOperator.from_yaml_dict(
410op_name,
411{
412"is_root_operator": True,
413"is_used_for_training": False,
414"include_all_overloads": False,
415"debug_info": [options.model_name],
416},
417)
418root_ops_unexpand_bucket[op_name] = op
419bucketed_ops.append(root_ops_unexpand_bucket)
420
421traced_ops_bucket = {}
422for op_name in traced_ops:
423op = SelectiveBuildOperator.from_yaml_dict(
424op_name,
425{
426"is_root_operator": False,
427"is_used_for_training": False,
428"include_all_overloads": False,
429"debug_info": [options.model_name],
430},
431)
432traced_ops_bucket[op_name] = op
433bucketed_ops.append(traced_ops_bucket)
434
435training_root_ops_unexpand_bucket = {}
436for op_name in training_root_ops_unexpand:
437op = SelectiveBuildOperator.from_yaml_dict(
438op_name,
439{
440"is_root_operator": True,
441"is_used_for_training": True,
442"include_all_overloads": False,
443"debug_info": [options.model_name],
444},
445)
446training_root_ops_unexpand_bucket[op_name] = op
447bucketed_ops.append(training_root_ops_unexpand_bucket)
448
449traced_training_ops_bucket = {}
450for op_name in traced_training_ops:
451op = SelectiveBuildOperator.from_yaml_dict(
452op_name,
453{
454"is_root_operator": False,
455"is_used_for_training": True,
456"include_all_overloads": False,
457"debug_info": [options.model_name],
458},
459)
460traced_training_ops_bucket[op_name] = op
461bucketed_ops.append(traced_training_ops_bucket)
462# END TRACING BASED BUILD OPS
463
464# Merge dictionaries together to remove op duplication
465operators: dict[str, SelectiveBuildOperator] = {}
466for ops_dict in bucketed_ops:
467operators = merge_operator_dicts(operators, ops_dict)
468
469# Loop over all operators, and if any of the them specifies that
470# all overloads need to be included, then set include_all_non_op_selectives
471# to True, since it indicates that this operator list came from something
472# other than a traced operator list.
473include_all_non_op_selectives = False
474for op_name, op_info in operators.items():
475include_all_non_op_selectives = (
476include_all_non_op_selectives or op_info.include_all_overloads
477)
478
479operators_as_dict = {}
480for k, v in operators.items():
481operators_as_dict[k] = v.to_dict()
482
483output["operators"] = operators_as_dict
484
485output["custom_classes"] = all_custom_classes
486
487output["build_features"] = all_build_features
488
489output["include_all_non_op_selectives"] = include_all_non_op_selectives
490if len(all_kernel_metadata) > 0:
491kernel_metadata = {}
492for kt in all_kernel_metadata:
493kernel_metadata = merge_kernel_metadata(kernel_metadata, kt)
494output["kernel_metadata"] = kernel_metadata
495
496
497def add_arguments_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
498parser.add_argument(
499"--root-ops",
500"--root_ops",
501help="A comma separated list of root operators used by the model",
502required=False,
503)
504parser.add_argument(
505"--training-root-ops",
506"--training_root_ops",
507help="A comma separated list of root operators used for training",
508required=False,
509)
510parser.add_argument(
511"--output-path",
512"--output_path",
513help="The location of the output yaml file.",
514required=True,
515)
516parser.add_argument(
517"--dep-graph-yaml-path",
518"--dep_graph_yaml_path",
519type=str,
520help="A path to the Operator Dependency Graph YAML file.",
521required=True,
522)
523parser.add_argument(
524"--model-name",
525"--model_name",
526type=str,
527help="The name of the model that uses the specified root operators.",
528required=True,
529)
530parser.add_argument(
531"--model-versions",
532"--model_versions",
533type=str,
534help="A comma separated list of model versions.",
535required=False,
536)
537parser.add_argument(
538"--model-assets",
539"--model_assets",
540type=str,
541help="A comma separate list of model asset names (if absent, defaults to all assets for this model).",
542required=False,
543)
544parser.add_argument(
545"--model-backends",
546"--model_backends",
547type=str,
548default="CPU",
549help="A comma separated list of model backends.",
550required=False,
551)
552parser.add_argument(
553"--models-yaml-path",
554"--models_yaml_path",
555type=str,
556help="The paths to the mobile model config YAML files.",
557required=False,
558nargs="+",
559)
560parser.add_argument(
561"--include-all-operators",
562"--include_all_operators",
563action="store_true",
564default=False,
565help="Set this flag to request inclusion of all operators (i.e. build is not selective).",
566required=False,
567)
568parser.add_argument(
569"--rule-name",
570"--rule_name",
571type=str,
572help="The name of pt_operator_library rule resulting in this generation",
573required=True,
574)
575parser.add_argument(
576"--not-include-all-overloads-static-root-ops",
577"--not_include_all_overloads_static_root_ops",
578action="store_true",
579default=False,
580help="Set this flag to not include all overloaded operators for static root ops bucket in fill_output() subroutine",
581required=False,
582)
583parser.add_argument(
584"--not-include-all-overloads-closure-ops",
585"--not_include_all_overloads_closure_ops",
586action="store_true",
587default=False,
588help="Set this flag to not include all overloaded operators for closure ops bucket in fill_output() subroutine",
589required=False,
590)
591return parser
592
593
594def parse_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
595return parser.parse_args()
596
597
598def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
599parser = add_arguments_parser(parser)
600return parse_options(parser)
601
602
603def main(argv) -> None:
604parser = argparse.ArgumentParser(description="Generate used operators YAML")
605options = get_parser_options(parser)
606
607model_dict = {
608"model_name": options.model_name,
609"asset_info": {},
610"is_new_style_rule": False,
611}
612output = {
613"debug_info": [json.dumps(model_dict)],
614}
615
616if options.include_all_operators:
617output["include_all_operators"] = True
618output["operators"] = {}
619output["kernel_metadata"] = {}
620else:
621fill_output(output, options)
622
623with open(options.output_path, "wb") as out_file:
624out_file.write(
625yaml.safe_dump(
626output,
627default_flow_style=False,
628).encode("utf-8")
629)
630
631
632if __name__ == "__main__":
633sys.exit(main(sys.argv))
634