pytorch

Форк
0
/
gen_operators_yaml.py 
633 строки · 22.3 Кб
1
#!/usr/bin/env python3
2

3
from __future__ import annotations
4

5
import argparse
6
import json
7
import sys
8
from typing import Any
9

10
import yaml
11
from gen_op_registration_allowlist import (
12
    canonical_name,
13
    gen_transitive_closure,
14
    load_op_dep_graph,
15
)
16

17
from torchgen.selective_build.operator import (
18
    merge_operator_dicts,
19
    SelectiveBuildOperator,
20
)
21
from torchgen.selective_build.selector import merge_kernel_metadata
22

23

24
# Generate YAML file containing the operators used for a specific PyTorch model.
25
# ------------------------------------------------------------------------------
26
#
27
# This binary is responsible for generating the model_operators.yaml file for
28
# each model from a pt_operator_library() BUCK macro invocation.
29
#
30
# Output YAML file format:
31
# ------------------------
32
#
33
# <BEGIN FILE CONTENTS>
34
# include_all_non_op_selectives: False
35
# include_all_operators: False
36
# debug_info:
37
#   - model1@v100
38
#   - model2@v50
39
# operators:
40
#   aten::add:
41
#     is_root_operator: Yes
42
#     is_used_for_training: Yes
43
#     include_all_overloads: No
44
#     debug_info:
45
#       - model1@v100
46
#       - model2@v50
47
#   aten::add.int:
48
#     is_root_operator: No
49
#     is_used_for_training: No
50
#     include_all_overloads: Yes
51
# kernel_metadata:
52
#   add_kernel:
53
#     - Int8
54
#     - UInt32
55
#   sub_kernel:
56
#     - Int16
57
#     - Float
58
# <END FILE CONTENTS>
59
#
60
# There are a few main inputs to this application
61
# -----------------------------------------------
62
#
63
# 1. Inference Root Operators (--root-ops): Root operators (called directly
64
#    from TorchScript) used by inference use-cases.
65
#
66
# 2. Training Root Operators (--training-root-ops): Root operators used
67
#    by training use-cases. Currently, this list is the list of all operators
68
#    used by training, and not just the root operators. All Training ops are
69
#    also considered for inference, so these are merged into inference ops.
70
#
71
# 3. Operator Depencency Graph (--dep-graph-yaml-path): A path to the
72
#    operator dependency graph used to determine which operators depend on
73
#    which other operators for correct functioning. This is used for
74
#    generating the transitive closure of all the operators used by the
75
#    model based on the root operators when static selective build is used.
76
#    For tracing based selective build, we don't need to perform this
77
#    transitive cloure.
78
#
79
# 4. Model Metadata (--model-name, --model-versions, --model-assets,
80
#    --model-backends): Self-descriptive. These are used to tell this
81
#    script which model operator lists to fetch from the Model
82
#    Build Metadata YAML files.
83
#
84
# 5. Model YAML files (--models-yaml-path): These yaml files contains
85
#    (for each model/version/asset/backend) the set of used root and traced
86
#    operators. This is used to extract the actual set of operators
87
#    needed to be included in the build.
88
#
89

90

91
def canonical_opnames(opnames: list[str]) -> list[str]:
92
    return [canonical_name(opname) for opname in opnames]
93

94

95
def make_filter_from_options(
96
    model_name: str,
97
    model_versions: list[str],
98
    model_assets: list[str] | None,
99
    model_backends: list[str] | None,
100
):
101
    def is_model_included(model_info) -> bool:
102
        model = model_info["model"]
103
        if model["name"] != model_name:
104
            return False
105
        if str(model["version"]) not in model_versions:
106
            return False
107
        if model_assets is not None and model["asset"] not in model_assets:
108
            return False
109
        # TODO: Handle backend later
110
        return True
111

112
    return is_model_included
113

114

115
# Returns if a the specified rule is a new or old style pt_operator_library
116
def is_new_style_rule(model_name: str, model_versions: list[str] | None):
117
    return model_name is not None and model_versions is not None
118

119

120
# Verifies that specified model_name, and all specified versions and assets
121
# appear in at least one model yaml. Throws if verification is failed,
122
# returns None on success
123
def verify_all_specified_present(
124
    model_assets: list[str] | None,
125
    model_versions: list[str],
126
    selected_models_yaml: list[dict[str, Any]],
127
    rule_name: str,
128
    model_name: str,
129
    new_style_rule: bool,
130
) -> None:
131
    def find_missing_items(model_items, key, selected_models_yaml):
132
        missing_items = []
133
        if not new_style_rule or not model_items:
134
            return missing_items
135
        for item in model_items:
136
            found = False
137
            for model in selected_models_yaml:
138
                if str(model["model"][key]) == item:
139
                    found = True
140
            if not found:
141
                missing_items.append(item)
142
        return missing_items
143

144
    missing_assets = find_missing_items(model_assets, "asset", selected_models_yaml)
145
    missing_versions = find_missing_items(
146
        model_versions, "version", selected_models_yaml
147
    )
148

149
    if len(missing_versions) > 0 or len(missing_assets) > 0:  # at least one is missing
150
        name_warning = ""
151
        if len(selected_models_yaml) == 0:
152
            name_warning = (
153
                "WARNING: 0 yaml's were found for target rule. This could be because the "
154
                + "provided model name: {name} is incorrect. Please check that field as well as "
155
                + "the assets and versions."
156
            ).format(name=model_name)
157
        raise RuntimeError(
158
            (
159
                "Error: From the pt_operator_library rule for Rule: {name}, at least one entry for the "
160
                + "following fields was expected -- Model: {model_name} Expected Assets: {expected_assets}, Expected Versions: "
161
                + "{expected_versions}. {name_warning} In all_mobile_models.yaml either no assets were on one of the "
162
                + "specified versions, one of the specified assets was not present on any of the specified "
163
                + "versions, or both. Assets not found: {missing_assets}, Versions not found: {missing_versions} "
164
                + "For questions please ask in https://fb.workplace.com/groups/2148543255442743/"
165
            ).format(
166
                name=rule_name,
167
                model_name=model_name,
168
                expected_versions=model_versions,
169
                expected_assets=model_assets
170
                if model_assets
171
                else "<All model assets present on specified versions>",
172
                name_warning=name_warning,
173
                missing_versions=missing_versions
174
                if len(missing_versions) > 0
175
                else "<All specified versions had at least one asset>",
176
                missing_assets=missing_assets
177
                if len(missing_assets) > 0
178
                else "<All specified assets are present on at least 1 version>",
179
            )
180
        )
181

182

183
# Uses the selected models configs and then combines them into one dictionary,
184
# formats them as a string, and places the string into output as a top level debug_info
185
def create_debug_info_from_selected_models(
186
    output: dict[str, object],
187
    selected_models: list[dict],
188
    new_style_rule: bool,
189
) -> None:
190
    model_dict = {
191
        "asset_info": {},  # maps asset name -> dict of asset metadata like hashes
192
        "is_new_style_rule": new_style_rule,
193
    }
194

195
    for model in selected_models:
196
        model_info = model["model"]
197
        asset = model_info["asset"]
198
        hash = model_info["md5_hash"]
199

200
        asset_info = model_dict["asset_info"].setdefault(asset, {})
201

202
        asset_info.setdefault("md5_hash", []).append(hash)
203

204
    # Will later be used in gen_oplist to generate the model/version/asset checking
205
    output["debug_info"] = [json.dumps(model_dict)]
206

207

208
def fill_output(output: dict[str, object], options: object) -> None:
209
    """Populate the output dict with the information required to serialize
210
    the YAML file used for selective build.
211
    """
212
    dept_graph = load_op_dep_graph(options.dep_graph_yaml_path)
213

214
    model_versions = (
215
        options.model_versions.split(",") if options.model_versions is not None else []
216
    )
217
    model_assets = (
218
        options.model_assets.split(",") if options.model_assets is not None else None
219
    )
220

221
    all_models_yaml = []
222
    if options.models_yaml_path:
223
        for yaml_path in options.models_yaml_path:
224
            with open(yaml_path, "rb") as f:
225
                all_models_yaml.append(yaml.safe_load(f))
226

227
    model_filter_func = make_filter_from_options(
228
        options.model_name, model_versions, model_assets, options.model_backends
229
    )
230

231
    selected_models_yaml = list(filter(model_filter_func, all_models_yaml))
232

233
    verify_all_specified_present(
234
        model_assets=model_assets,
235
        model_versions=model_versions,
236
        selected_models_yaml=selected_models_yaml,
237
        rule_name=options.rule_name,
238
        model_name=options.model_name,
239
        new_style_rule=is_new_style_rule(options.model_name, options.model_versions),
240
    )
241

242
    create_debug_info_from_selected_models(
243
        output,
244
        selected_models_yaml,
245
        is_new_style_rule(options.model_name, options.model_versions),
246
    )
247

248
    # initialize variables for static build from the pt_operator_library rule
249
    if options.root_ops is not None:
250
        static_root_ops = set(filter(lambda x: len(x) > 0, options.root_ops.split(",")))
251
    else:
252
        static_root_ops = set()
253

254
    static_training_root_ops = set(
255
        filter(
256
            lambda x: len(x) > 0,
257
            (options.training_root_ops or "").split(","),
258
        )
259
    )
260
    if len(static_training_root_ops) > 0:
261
        static_root_ops = static_root_ops | static_training_root_ops
262
    # end if
263

264
    root_ops_unexpand = set()
265
    traced_ops = set()
266
    training_root_ops_unexpand = set()
267
    traced_training_ops = set()
268
    all_kernel_metadata = []
269
    all_custom_classes = set()
270
    all_build_features = set()
271

272
    # Go through each yaml file and retrieve operator information.
273
    for model_info in selected_models_yaml:
274
        if "traced_operators" not in model_info:
275
            # If this YAML file doesn't specify any traced operators, then it is using
276
            # the static analysis selective build approach of finding transitively
277
            # used operators, and we should update root_ops with the set of root
278
            # operators, all of whose overloads must be included. In addition, these
279
            # root_ops will be further expanded using the transitive closure of
280
            # operator dependencies.
281
            static_root_ops = static_root_ops | set(model_info["root_operators"])
282
        else:
283
            # If this YAML file specifies traced operators, then it is using
284
            # the tracing based selective build approach of finding used
285
            # operators, and we should update root_ops_unexpand with the set of root
286
            # operators whose overloads don't need to be included. In addition, these
287
            # root_ops_unexpand will NOT be further expanded. If the train flag is
288
            # set then the ops will be used for training, so we put them in a separate
289
            # set
290
            if model_info["train"]:
291
                training_root_ops_unexpand = training_root_ops_unexpand | set(
292
                    model_info["root_operators"]
293
                )
294
                traced_training_ops = traced_training_ops | set(
295
                    model_info["traced_operators"]
296
                )
297
            else:
298
                root_ops_unexpand = root_ops_unexpand | set(
299
                    model_info["root_operators"]
300
                )
301
                traced_ops = traced_ops | set(model_info["traced_operators"])
302

303
        if "kernel_metadata" in model_info:
304
            all_kernel_metadata.append(model_info["kernel_metadata"])
305

306
        if "custom_classes" in model_info:
307
            all_custom_classes = all_custom_classes | set(model_info["custom_classes"])
308

309
        if "build_features" in model_info:
310
            all_build_features = all_build_features | set(model_info["build_features"])
311

312
    # This following section on transitive closure is relevant to static build only
313
    canonical_root_ops = canonical_opnames(static_root_ops)
314
    # If no canonical_root_ops exist, don't compute the transitive closure
315
    # otherwise, we will include __BASE__ and __ROOT__ ops and mark them as required
316
    # for inference.
317
    if len(canonical_root_ops) > 0:
318
        closure_op_list = gen_transitive_closure(dept_graph, canonical_root_ops)
319
    else:
320
        closure_op_list = set()
321

322
    canonical_training_root_ops = canonical_opnames(static_training_root_ops)
323
    # If no canonical_training_root_ops exist, don't compute the transitive closure
324
    # otherwise, we will include __BASE__ and __ROOT__ ops and mark them as required
325
    # for training.
326
    if len(canonical_training_root_ops) > 0:
327
        closure_training_op_list = gen_transitive_closure(
328
            dept_graph, canonical_training_root_ops, train=True
329
        )
330
    else:
331
        closure_training_op_list = set()
332

333
    # bucketed_ops holds sets of operators that correspond to specific semantic buckets. For
334
    # example:
335
    #
336
    # 1. Root Operators not used for training w/o full overload inclusion
337
    # 2. Root Operators not used for training w/ full overload inclusion
338
    # 3. Root Operators used for training w/o full overload inclusion
339
    # 4. Root Operators used for training w/ full overload inclusion
340
    # 5. Non-root Operators not used for training w/o full overload inclusion
341
    # etc...
342
    #
343
    # Basically for each of the 3 boolean conditional, there are 2
344
    # options (True/False).
345
    #
346
    bucketed_ops = []
347

348
    # START STATIC BUILD OPS
349
    static_root_ops_bucket = {}
350
    for op_name in static_root_ops:
351
        op = SelectiveBuildOperator.from_yaml_dict(
352
            op_name,
353
            {
354
                "is_root_operator": True,
355
                "is_used_for_training": False,
356
                "include_all_overloads": not options.not_include_all_overloads_static_root_ops,
357
                "debug_info": [options.model_name],
358
            },
359
        )
360
        static_root_ops_bucket[op_name] = op
361
    bucketed_ops.append(static_root_ops_bucket)
362

363
    closure_ops_bucket = {}
364
    for op_name in closure_op_list:
365
        op = SelectiveBuildOperator.from_yaml_dict(
366
            op_name,
367
            {
368
                "is_root_operator": False,
369
                "is_used_for_training": False,
370
                "include_all_overloads": not options.not_include_all_overloads_closure_ops,
371
                "debug_info": [options.model_name],
372
            },
373
        )
374
        closure_ops_bucket[op_name] = op
375
    bucketed_ops.append(closure_ops_bucket)
376

377
    static_training_root_ops_bucket = {}
378
    for op_name in static_training_root_ops:
379
        op = SelectiveBuildOperator.from_yaml_dict(
380
            op_name,
381
            {
382
                "is_root_operator": True,
383
                "is_used_for_training": True,
384
                "include_all_overloads": True,
385
                "debug_info": [options.model_name],
386
            },
387
        )
388
        static_training_root_ops_bucket[op_name] = op
389
    bucketed_ops.append(static_training_root_ops_bucket)
390

391
    closure_training_ops_bucket = {}
392
    for op_name in closure_training_op_list:
393
        op = SelectiveBuildOperator.from_yaml_dict(
394
            op_name,
395
            {
396
                "is_root_operator": False,
397
                "is_used_for_training": True,
398
                "include_all_overloads": True,
399
                "debug_info": [options.model_name],
400
            },
401
        )
402
        closure_training_ops_bucket[op_name] = op
403
    bucketed_ops.append(closure_training_ops_bucket)
404
    # END STATIC BUILD OPS
405

406
    # START TRACING BASED BUILD OPS
407
    root_ops_unexpand_bucket = {}
408
    for op_name in root_ops_unexpand:
409
        op = SelectiveBuildOperator.from_yaml_dict(
410
            op_name,
411
            {
412
                "is_root_operator": True,
413
                "is_used_for_training": False,
414
                "include_all_overloads": False,
415
                "debug_info": [options.model_name],
416
            },
417
        )
418
        root_ops_unexpand_bucket[op_name] = op
419
    bucketed_ops.append(root_ops_unexpand_bucket)
420

421
    traced_ops_bucket = {}
422
    for op_name in traced_ops:
423
        op = SelectiveBuildOperator.from_yaml_dict(
424
            op_name,
425
            {
426
                "is_root_operator": False,
427
                "is_used_for_training": False,
428
                "include_all_overloads": False,
429
                "debug_info": [options.model_name],
430
            },
431
        )
432
        traced_ops_bucket[op_name] = op
433
    bucketed_ops.append(traced_ops_bucket)
434

435
    training_root_ops_unexpand_bucket = {}
436
    for op_name in training_root_ops_unexpand:
437
        op = SelectiveBuildOperator.from_yaml_dict(
438
            op_name,
439
            {
440
                "is_root_operator": True,
441
                "is_used_for_training": True,
442
                "include_all_overloads": False,
443
                "debug_info": [options.model_name],
444
            },
445
        )
446
        training_root_ops_unexpand_bucket[op_name] = op
447
    bucketed_ops.append(training_root_ops_unexpand_bucket)
448

449
    traced_training_ops_bucket = {}
450
    for op_name in traced_training_ops:
451
        op = SelectiveBuildOperator.from_yaml_dict(
452
            op_name,
453
            {
454
                "is_root_operator": False,
455
                "is_used_for_training": True,
456
                "include_all_overloads": False,
457
                "debug_info": [options.model_name],
458
            },
459
        )
460
        traced_training_ops_bucket[op_name] = op
461
    bucketed_ops.append(traced_training_ops_bucket)
462
    # END TRACING BASED BUILD OPS
463

464
    # Merge dictionaries together to remove op duplication
465
    operators: dict[str, SelectiveBuildOperator] = {}
466
    for ops_dict in bucketed_ops:
467
        operators = merge_operator_dicts(operators, ops_dict)
468

469
    # Loop over all operators, and if any of the them specifies that
470
    # all overloads need to be included, then set include_all_non_op_selectives
471
    # to True, since it indicates that this operator list came from something
472
    # other than a traced operator list.
473
    include_all_non_op_selectives = False
474
    for op_name, op_info in operators.items():
475
        include_all_non_op_selectives = (
476
            include_all_non_op_selectives or op_info.include_all_overloads
477
        )
478

479
    operators_as_dict = {}
480
    for k, v in operators.items():
481
        operators_as_dict[k] = v.to_dict()
482

483
    output["operators"] = operators_as_dict
484

485
    output["custom_classes"] = all_custom_classes
486

487
    output["build_features"] = all_build_features
488

489
    output["include_all_non_op_selectives"] = include_all_non_op_selectives
490
    if len(all_kernel_metadata) > 0:
491
        kernel_metadata = {}
492
        for kt in all_kernel_metadata:
493
            kernel_metadata = merge_kernel_metadata(kernel_metadata, kt)
494
        output["kernel_metadata"] = kernel_metadata
495

496

497
def add_arguments_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
498
    parser.add_argument(
499
        "--root-ops",
500
        "--root_ops",
501
        help="A comma separated list of root operators used by the model",
502
        required=False,
503
    )
504
    parser.add_argument(
505
        "--training-root-ops",
506
        "--training_root_ops",
507
        help="A comma separated list of root operators used for training",
508
        required=False,
509
    )
510
    parser.add_argument(
511
        "--output-path",
512
        "--output_path",
513
        help="The location of the output yaml file.",
514
        required=True,
515
    )
516
    parser.add_argument(
517
        "--dep-graph-yaml-path",
518
        "--dep_graph_yaml_path",
519
        type=str,
520
        help="A path to the Operator Dependency Graph YAML file.",
521
        required=True,
522
    )
523
    parser.add_argument(
524
        "--model-name",
525
        "--model_name",
526
        type=str,
527
        help="The name of the model that uses the specified root operators.",
528
        required=True,
529
    )
530
    parser.add_argument(
531
        "--model-versions",
532
        "--model_versions",
533
        type=str,
534
        help="A comma separated list of model versions.",
535
        required=False,
536
    )
537
    parser.add_argument(
538
        "--model-assets",
539
        "--model_assets",
540
        type=str,
541
        help="A comma separate list of model asset names (if absent, defaults to all assets for this model).",
542
        required=False,
543
    )
544
    parser.add_argument(
545
        "--model-backends",
546
        "--model_backends",
547
        type=str,
548
        default="CPU",
549
        help="A comma separated list of model backends.",
550
        required=False,
551
    )
552
    parser.add_argument(
553
        "--models-yaml-path",
554
        "--models_yaml_path",
555
        type=str,
556
        help="The paths to the mobile model config YAML files.",
557
        required=False,
558
        nargs="+",
559
    )
560
    parser.add_argument(
561
        "--include-all-operators",
562
        "--include_all_operators",
563
        action="store_true",
564
        default=False,
565
        help="Set this flag to request inclusion of all operators (i.e. build is not selective).",
566
        required=False,
567
    )
568
    parser.add_argument(
569
        "--rule-name",
570
        "--rule_name",
571
        type=str,
572
        help="The name of pt_operator_library rule resulting in this generation",
573
        required=True,
574
    )
575
    parser.add_argument(
576
        "--not-include-all-overloads-static-root-ops",
577
        "--not_include_all_overloads_static_root_ops",
578
        action="store_true",
579
        default=False,
580
        help="Set this flag to not include all overloaded operators for static root ops bucket in fill_output() subroutine",
581
        required=False,
582
    )
583
    parser.add_argument(
584
        "--not-include-all-overloads-closure-ops",
585
        "--not_include_all_overloads_closure_ops",
586
        action="store_true",
587
        default=False,
588
        help="Set this flag to not include all overloaded operators for closure ops bucket in fill_output() subroutine",
589
        required=False,
590
    )
591
    return parser
592

593

594
def parse_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
595
    return parser.parse_args()
596

597

598
def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
599
    parser = add_arguments_parser(parser)
600
    return parse_options(parser)
601

602

603
def main(argv) -> None:
604
    parser = argparse.ArgumentParser(description="Generate used operators YAML")
605
    options = get_parser_options(parser)
606

607
    model_dict = {
608
        "model_name": options.model_name,
609
        "asset_info": {},
610
        "is_new_style_rule": False,
611
    }
612
    output = {
613
        "debug_info": [json.dumps(model_dict)],
614
    }
615

616
    if options.include_all_operators:
617
        output["include_all_operators"] = True
618
        output["operators"] = {}
619
        output["kernel_metadata"] = {}
620
    else:
621
        fill_output(output, options)
622

623
    with open(options.output_path, "wb") as out_file:
624
        out_file.write(
625
            yaml.safe_dump(
626
                output,
627
                default_flow_style=False,
628
            ).encode("utf-8")
629
        )
630

631

632
if __name__ == "__main__":
633
    sys.exit(main(sys.argv))
634

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.