pytorch

Форк
0
/
gen_mobile_upgraders.py 
392 строки · 12.4 Кб
1
#!/usr/bin/env python3
2
import os
3
from enum import Enum
4
from pathlib import Path
5
from typing import Any, Dict, List
6

7
import torch
8
from torch.jit.generate_bytecode import generate_upgraders_bytecode
9

10
from torchgen.code_template import CodeTemplate
11
from torchgen.operator_versions.gen_mobile_upgraders_constant import (
12
    MOBILE_UPGRADERS_HEADER_DESCRIPTION,
13
)
14

15

16
class ByteCode(Enum):
17
    instructions = 1
18
    constants = 2
19
    types = 3
20
    operators = 4
21
    register_size = 5
22

23

24
EXCLUDED_OP_SET = [
25
    "aten::full.names",
26
    "aten::full.out",
27
    "aten::full",
28
]
29

30
EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"]
31

32
ONE_INSTRUCTION = CodeTemplate(
33
    """
34
    Instruction{OpCode::${operator_name}, ${X}, ${N}},"""
35
)
36

37
INSTRUCTION_LIST = CodeTemplate(
38
    """std::vector<Instruction>({
39
        ${instruction_list}
40
    }), // instructions list"""
41
)
42

43
ONE_CONSTANT = CodeTemplate(
44
    """
45
    c10::IValue(${constant}),"""
46
)
47

48
CONSTANT_LIST = CodeTemplate(
49
    """std::vector<c10::IValue>({
50
        ${constant_list}
51
    }), // constants list"""
52
)
53

54
CONSTANTS_LIST_EMPTY = """std::vector<c10::IValue>(), // constants list"""
55

56
ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""")
57

58
TYPE_LIST = CodeTemplate(
59
    """std::vector<c10::TypePtr>({
60
        ${type_list}
61
    }), // types list"""
62
)
63

64
TYPE_LIST_EMPTY = """std::vector<c10::TypePtr>(), // types list"""
65

66
ONE_OPERATOTR_STRING = CodeTemplate(
67
    """
68
    OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),"""
69
)
70

71
OPERATOR_STRING_LIST = CodeTemplate(
72
    """
73
    std::vector<OperatorString>({
74
        ${operator_string_list}
75
    }), // operators list"""
76
)
77

78
ONE_UPGRADER_FUNCTION = CodeTemplate(
79
    """
80
    mobile::Function::registerFunc(
81
        "${upgrader_name}",
82
        ${instruction_list},
83
        ${constant_list},
84
        ${type_list},
85
        ${register_size}
86
    )"""
87
)
88

89
ONE_UPGRADER_SRC = CodeTemplate(
90
    """
91
    ByteCodeFunctionWithOperator({
92
        ${bytecode_function},
93
        ${operator_string_list}
94
    }),"""
95
)
96

97

98
ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate(
99
    """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})"""
100
)  # noqa: E501
101

102
ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate(
103
    """
104
    {std::string("${operator_name}"),
105
        std::vector<Upgrader>({
106
            ${upgrader_list_in_version_map}
107
        })},"""
108
)
109

110

111
OPERATOR_VERSION_MAP = CodeTemplate(
112
    """
113
const std::unordered_map<std::string, std::vector<Upgrader>>
114
getOperatorVersionMapForMobile() {
115
  static std::unordered_map<std::string, std::vector<Upgrader>>
116
        operatorVersionMapForMobile({
117
            ${operator_list_in_version_map}
118
      });
119
  return operatorVersionMapForMobile;
120
}
121
"""
122
)
123

124

125
UPGRADER_CPP_SRC = CodeTemplate(
126
    MOBILE_UPGRADERS_HEADER_DESCRIPTION
127
    + """
128
#include <caffe2/serialize/versions.h>
129
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
130

131
namespace c10 {
132
TypePtr parseType(const std::string& pythonStr);
133
} // namespace c10
134

135
namespace torch {
136
namespace jit {
137

138
// clang-format off
139

140
// From operator_versions_map
141
${operator_version_map}
142

143
const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
144
  auto generate_upgrader_bytecode_list = []() {
145
    std::vector<ByteCodeFunctionWithOperator> upgrader_function_list({
146
               ${upgrader_bytecode}
147
            });
148
    for (const auto& upgrader_function : upgrader_function_list) {
149
      for (const auto& op : upgrader_function.operators) {
150
        upgrader_function.function.append_operator(
151
            op.name,
152
            op.overload_name,
153
            op.num_specified_args);
154
      }
155
    }
156
    return upgrader_function_list;
157
  };
158
  static std::vector<ByteCodeFunctionWithOperator> upgraderBytecodeList =
159
      generate_upgrader_bytecode_list();
160
  return upgraderBytecodeList;
161
}
162

163
// clang-format on
164

165
} // namespace jit
166
} // namespace torch
167
"""
168
)
169

170
UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp"
171

172
UPGRADER_ELEMENT = CodeTemplate(
173
    """\
174
Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}),
175
"""
176
)
177

178
PER_OPERATOR_UPGRADER_LIST = CodeTemplate(
179
    """\
180
{
181
  std::string(${operator_name}),
182
  std::vector<Upgrader>({${upgrader_list}});
183
}
184
"""
185
)
186

187

188
def construct_instruction(instruction_list_from_yaml: List[Any]) -> str:
189
    instruction_list_part = []
190
    for instruction in instruction_list_from_yaml:
191
        instruction_list_part.append(
192
            ONE_INSTRUCTION.substitute(
193
                operator_name=instruction[0],
194
                X=instruction[1],
195
                N=instruction[2],
196
            )
197
        )
198
    return INSTRUCTION_LIST.substitute(
199
        instruction_list="".join(instruction_list_part).lstrip("\n")
200
    )
201

202

203
def construct_constants(constants_list_from_yaml: List[Any]) -> str:
204
    constants_list_part = []
205
    for constant_from_yaml in constants_list_from_yaml:
206
        convert_constant = None
207
        if isinstance(constant_from_yaml, str):
208
            # Add quotes if it's string
209
            convert_constant = f'"{constant_from_yaml}"'
210
        elif isinstance(constant_from_yaml, bool):
211
            convert_constant = "true" if constant_from_yaml else "false"
212
        elif constant_from_yaml is None:
213
            convert_constant = ""
214
        elif isinstance(constant_from_yaml, int):
215
            convert_constant = str(constant_from_yaml)
216
        else:
217
            raise ValueError(
218
                f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. "
219
                "Please add change in construct_constants function in gen_mobile_upgraders.py."
220
            )
221
        constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant))
222
    if len(constants_list_part) == 0:
223
        return CONSTANTS_LIST_EMPTY
224
    return CONSTANT_LIST.substitute(
225
        constant_list="".join(constants_list_part).lstrip("\n")
226
    )
227

228

229
def construct_operators(operator_list_from_yaml: List[Any]) -> str:
230
    operator_list_part = []
231
    for operator in operator_list_from_yaml:
232
        operator_list_part.append(
233
            ONE_OPERATOTR_STRING.substitute(
234
                operator_name=operator[0],
235
                overload_name=operator[1],
236
                num_of_args=operator[2],
237
            )
238
        )
239
    return OPERATOR_STRING_LIST.substitute(
240
        operator_string_list="".join(operator_list_part).lstrip("\n")
241
    )
242

243

244
def construct_types(types_tr_list_from_yaml: List[Any]) -> str:
245
    types_tr_list_part = []
246
    for types_tr in types_tr_list_from_yaml:
247
        types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))
248
    if len(types_tr_list_part) == 0:
249
        return TYPE_LIST_EMPTY
250
    return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n"))
251

252

253
def construct_register_size(register_size_from_yaml: int) -> str:
254
    if not isinstance(register_size_from_yaml, int):
255
        raise ValueError(
256
            f"Input register size is {register_size_from_yaml} and"
257
            "it's type is {type(register_size_from_yaml)}. An int type is expected."
258
        )
259
    return str(register_size_from_yaml)
260

261

262
def construct_version_maps(
263
    upgrader_bytecode_function_to_index_map: Dict[str, Any]
264
) -> str:
265
    version_map = torch._C._get_operator_version_map()
266
    sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0])  # type: ignore[no-any-return]
267
    sorted_version_map = dict(sorted_version_map_)
268

269
    operator_list_in_version_map_part = []
270
    for op_name in sorted_version_map:
271
        upgraders_in_version_map_part = []
272
        # TODO: remove the skip after these two operators schemas are fixed
273
        if op_name in EXCLUDED_OP_SET:
274
            continue
275
        upgrader_ranges = torch._C._get_upgrader_ranges(op_name)
276
        upgrader_entries = sorted_version_map[op_name]
277
        assert len(upgrader_ranges) == len(upgrader_entries)
278
        for idx, upgrader_entry in enumerate(upgrader_entries):
279
            upgrader_name = upgrader_entry.upgrader_name
280
            bytecode_function_index = upgrader_bytecode_function_to_index_map[
281
                upgrader_name
282
            ]
283
            upgraders_in_version_map_part.append(
284
                ONE_UPGRADER_IN_VERSION_MAP.substitute(
285
                    upgrader_min_version=upgrader_ranges[idx].min_version,
286
                    upgrader_max_version=upgrader_ranges[idx].max_version,
287
                    upgrader_name=upgrader_name,
288
                    bytecode_func_index=bytecode_function_index,
289
                )
290
            )
291
        operator_list_in_version_map_part.append(
292
            ONE_OPERATOR_IN_VERSION_MAP.substitute(
293
                operator_name=op_name,
294
                upgrader_list_in_version_map="".join(upgraders_in_version_map_part),
295
            )
296
        )
297
    return OPERATOR_VERSION_MAP.substitute(
298
        operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip(
299
            "\n"
300
        )
301
    )
302

303

304
def get_upgrader_bytecode_function_to_index_map(
305
    upgrader_dict: List[Dict[str, Any]]
306
) -> Dict[str, Any]:
307
    upgrader_bytecode_function_to_index_map = {}
308
    index = 0
309
    for upgrader_bytecode in upgrader_dict:
310
        for upgrader_name in upgrader_bytecode.keys():
311
            if upgrader_name in EXCLUE_UPGRADER_SET:
312
                continue
313
            upgrader_bytecode_function_to_index_map[upgrader_name] = index
314
            index += 1
315
    return upgrader_bytecode_function_to_index_map
316

317

318
def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None:
319
    body_parts = []
320
    upgrader_bytecode_function_to_index_map = (
321
        get_upgrader_bytecode_function_to_index_map(upgrader_dict)
322
    )
323
    version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map)
324
    all_upgrader_src_string = []
325
    for upgrader_bytecode in upgrader_dict:
326
        for upgrader_name, bytecode in upgrader_bytecode.items():
327
            # TODO: remove the skip after these two operators schemas are fixed
328
            if upgrader_name in EXCLUE_UPGRADER_SET:
329
                continue
330
            instruction_list_str = ""
331
            constant_list_str = ""
332
            type_list_str = ""
333
            register_size_str = ""
334
            operator_list_str = ""
335
            for table_name, contents in bytecode.items():
336
                element = ByteCode[table_name]
337
                body_string = ""
338
                if element is ByteCode.instructions:
339
                    instruction_list_str = construct_instruction(contents)
340
                elif element is ByteCode.constants:
341
                    constant_list_str = construct_constants(contents)
342
                elif element is ByteCode.operators:
343
                    operator_list_str = construct_operators(contents)
344
                elif element is ByteCode.types:
345
                    type_list_str = construct_types(contents)
346
                elif element is ByteCode.register_size:
347
                    register_size_str = construct_register_size(contents)
348

349
            one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute(
350
                upgrader_name=upgrader_name,
351
                instruction_list=instruction_list_str,
352
                constant_list=constant_list_str,
353
                type_list=type_list_str,
354
                register_size=register_size_str,
355
            )
356
            one_upgrader_src_string = ONE_UPGRADER_SRC.substitute(
357
                bytecode_function=one_upgrader_function_string.lstrip("\n"),
358
                operator_string_list=operator_list_str.lstrip("\n"),
359
            )
360
            all_upgrader_src_string.append(one_upgrader_src_string)
361

362
    upgrader_file_content = UPGRADER_CPP_SRC.substitute(
363
        operator_version_map=version_map_src,
364
        upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"),
365
    )
366
    body_parts.append(upgrader_file_content)
367
    print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME)
368
    with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file:
369
        final_output = "".join(body_parts)
370
        out_file.write(upgrader_file_content.encode("utf-8"))
371

372

373
def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
374
    sorted_upgrader_list = sorted(
375
        upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
376
    )
377
    return sorted_upgrader_list
378

379

380
def main() -> None:
381
    upgrader_list = generate_upgraders_bytecode()
382
    sorted_upgrader_list = sort_upgrader(upgrader_list)
383
    for up in sorted_upgrader_list:
384
        print("after sort upgrader : ", next(iter(up)))
385

386
    pytorch_dir = Path(__file__).resolve().parents[2]
387
    upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile"
388
    write_cpp(str(upgrader_path), sorted_upgrader_list)
389

390

391
if __name__ == "__main__":
392
    main()
393

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.