pytorch

Форк
0
/
gen_oplist.py 
187 строк · 6.3 Кб
1
#!/usr/bin/env python3
2

3
from __future__ import annotations
4

5
import argparse
6
import json
7
import os
8
import sys
9
from functools import reduce
10
from typing import Any
11

12
import yaml
13
from tools.lite_interpreter.gen_selected_mobile_ops_header import (
14
    write_selected_mobile_ops,
15
)
16

17
from torchgen.selective_build.selector import (
18
    combine_selective_builders,
19
    SelectiveBuilder,
20
)
21

22

23
def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]:
24
    return set(selective_builder.operators.keys())
25

26

27
def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]:
28
    ops = []
29
    for op_name, op in selective_builder.operators.items():
30
        if op.is_used_for_training:
31
            ops.append(op_name)
32
    return set(ops)
33

34

35
def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:
36
    ops = []
37
    for op_name, op in selective_builder.operators.items():
38
        if op.include_all_overloads:
39
            ops.append(op_name)
40
    if ops:
41
        raise Exception(  # noqa: TRY002
42
            (
43
                "Operators that include all overloads are "
44
                + "not allowed since --allow-include-all-overloads "
45
                + "was specified: {}"
46
            ).format(", ".join(ops))
47
        )
48

49

50
def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None:
51
    supported_mobile_models_source = """/*
52
 * Generated by gen_oplist.py
53
 */
54
#include "fb/supported_mobile_models/SupportedMobileModels.h"
55

56

57
struct SupportedMobileModelCheckerRegistry {{
58
  SupportedMobileModelCheckerRegistry() {{
59
    auto& ref = facebook::pytorch::supported_model::SupportedMobileModelChecker::singleton();
60
    ref.set_supported_md5_hashes(std::unordered_set<std::string>{{
61
      {supported_hashes_template}
62
    }});
63
  }}
64
}};
65

66
// This is a global object, initializing which causes the registration to happen.
67
SupportedMobileModelCheckerRegistry register_model_versions;
68

69

70
"""
71

72
    # Generate SupportedMobileModelsRegistration.cpp
73
    md5_hashes = set()
74
    for model_dict in model_dicts:
75
        if "debug_info" in model_dict:
76
            debug_info = json.loads(model_dict["debug_info"][0])
77
            if debug_info["is_new_style_rule"]:
78
                for asset_info in debug_info["asset_info"].values():
79
                    md5_hashes.update(asset_info["md5_hash"])
80

81
    supported_hashes = ""
82
    for md5 in md5_hashes:
83
        supported_hashes += f'"{md5}",\n'
84
    with open(
85
        os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb"
86
    ) as out_file:
87
        source = supported_mobile_models_source.format(
88
            supported_hashes_template=supported_hashes
89
        )
90
        out_file.write(source.encode("utf-8"))
91

92

93
def main(argv: list[Any]) -> None:
94
    """This binary generates 3 files:
95

96
    1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
97
       dtypes captured by tracing
98
    2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
99
    """
100
    parser = argparse.ArgumentParser(description="Generate operator lists")
101
    parser.add_argument(
102
        "--output-dir",
103
        "--output_dir",
104
        help=(
105
            "The directory to store the output yaml files (selected_mobile_ops.h, "
106
            + "selected_kernel_dtypes.h, selected_operators.yaml)"
107
        ),
108
        required=True,
109
    )
110
    parser.add_argument(
111
        "--model-file-list-path",
112
        "--model_file_list_path",
113
        help=(
114
            "Path to a file that contains the locations of individual "
115
            + "model YAML files that contain the set of used operators. This "
116
            + "file path must have a leading @-symbol, which will be stripped "
117
            + "out before processing."
118
        ),
119
        required=True,
120
    )
121
    parser.add_argument(
122
        "--allow-include-all-overloads",
123
        "--allow_include_all_overloads",
124
        help=(
125
            "Flag to allow operators that include all overloads. "
126
            + "If not set, operators registered without using the traced style will"
127
            + "break the build."
128
        ),
129
        action="store_true",
130
        default=False,
131
        required=False,
132
    )
133
    options = parser.parse_args(argv)
134

135
    if os.path.isfile(options.model_file_list_path):
136
        print("Processing model file: ", options.model_file_list_path)
137
        model_dicts = []
138
        model_dict = yaml.safe_load(open(options.model_file_list_path))
139
        model_dicts.append(model_dict)
140
    else:
141
        print("Processing model directory: ", options.model_file_list_path)
142
        assert options.model_file_list_path[0] == "@"
143
        model_file_list_path = options.model_file_list_path[1:]
144

145
        model_dicts = []
146
        with open(model_file_list_path) as model_list_file:
147
            model_file_names = model_list_file.read().split()
148
            for model_file_name in model_file_names:
149
                with open(model_file_name, "rb") as model_file:
150
                    model_dict = yaml.safe_load(model_file)
151
                    model_dicts.append(model_dict)
152

153
    selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts]
154

155
    # While we have the model_dicts generate the supported mobile models api
156
    gen_supported_mobile_models(model_dicts, options.output_dir)
157

158
    # We may have 0 selective builders since there may not be any viable
159
    # pt_operator_library rule marked as a dep for the pt_operator_registry rule.
160
    # This is potentially an error, and we should probably raise an assertion
161
    # failure here. However, this needs to be investigated further.
162
    selective_builder = SelectiveBuilder.from_yaml_dict({})
163
    if len(selective_builders) > 0:
164
        selective_builder = reduce(
165
            combine_selective_builders,
166
            selective_builders,
167
        )
168

169
    if not options.allow_include_all_overloads:
170
        throw_if_any_op_includes_overloads(selective_builder)
171
    with open(
172
        os.path.join(options.output_dir, "selected_operators.yaml"), "wb"
173
    ) as out_file:
174
        out_file.write(
175
            yaml.safe_dump(
176
                selective_builder.to_dict(), default_flow_style=False
177
            ).encode("utf-8"),
178
        )
179

180
    write_selected_mobile_ops(
181
        os.path.join(options.output_dir, "selected_mobile_ops.h"),
182
        selective_builder,
183
    )
184

185

186
if __name__ == "__main__":
187
    main(sys.argv[1:])
188

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.