pytorch
392 строки · 12.4 Кб
1#!/usr/bin/env python3
2import os3from enum import Enum4from pathlib import Path5from typing import Any, Dict, List6
7import torch8from torch.jit.generate_bytecode import generate_upgraders_bytecode9
10from torchgen.code_template import CodeTemplate11from torchgen.operator_versions.gen_mobile_upgraders_constant import (12MOBILE_UPGRADERS_HEADER_DESCRIPTION,13)
14
15
16class ByteCode(Enum):17instructions = 118constants = 219types = 320operators = 421register_size = 522
23
24EXCLUDED_OP_SET = [25"aten::full.names",26"aten::full.out",27"aten::full",28]
29
30EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"]31
32ONE_INSTRUCTION = CodeTemplate(33"""34Instruction{OpCode::${operator_name}, ${X}, ${N}},"""
35)
36
37INSTRUCTION_LIST = CodeTemplate(38"""std::vector<Instruction>({39${instruction_list}
40}), // instructions list"""
41)
42
43ONE_CONSTANT = CodeTemplate(44"""45c10::IValue(${constant}),"""
46)
47
48CONSTANT_LIST = CodeTemplate(49"""std::vector<c10::IValue>({50${constant_list}
51}), // constants list"""
52)
53
54CONSTANTS_LIST_EMPTY = """std::vector<c10::IValue>(), // constants list"""55
56ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""")57
58TYPE_LIST = CodeTemplate(59"""std::vector<c10::TypePtr>({60${type_list}
61}), // types list"""
62)
63
64TYPE_LIST_EMPTY = """std::vector<c10::TypePtr>(), // types list"""65
66ONE_OPERATOTR_STRING = CodeTemplate(67"""68OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),"""
69)
70
71OPERATOR_STRING_LIST = CodeTemplate(72"""73std::vector<OperatorString>({
74${operator_string_list}
75}), // operators list"""
76)
77
78ONE_UPGRADER_FUNCTION = CodeTemplate(79"""80mobile::Function::registerFunc(
81"${upgrader_name}",
82${instruction_list},
83${constant_list},
84${type_list},
85${register_size}
86)"""
87)
88
89ONE_UPGRADER_SRC = CodeTemplate(90"""91ByteCodeFunctionWithOperator({
92${bytecode_function},
93${operator_string_list}
94}),"""
95)
96
97
98ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate(99"""Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})"""100) # noqa: E501101
102ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate(103"""104{std::string("${operator_name}"),
105std::vector<Upgrader>({
106${upgrader_list_in_version_map}
107})},"""
108)
109
110
111OPERATOR_VERSION_MAP = CodeTemplate(112"""113const std::unordered_map<std::string, std::vector<Upgrader>>
114getOperatorVersionMapForMobile() {
115static std::unordered_map<std::string, std::vector<Upgrader>>
116operatorVersionMapForMobile({
117${operator_list_in_version_map}
118});
119return operatorVersionMapForMobile;
120}
121"""
122)
123
124
125UPGRADER_CPP_SRC = CodeTemplate(126MOBILE_UPGRADERS_HEADER_DESCRIPTION
127+ """128#include <caffe2/serialize/versions.h>
129#include <torch/csrc/jit/mobile/upgrader_mobile.h>
130
131namespace c10 {
132TypePtr parseType(const std::string& pythonStr);
133} // namespace c10
134
135namespace torch {
136namespace jit {
137
138// clang-format off
139
140// From operator_versions_map
141${operator_version_map}
142
143const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
144auto generate_upgrader_bytecode_list = []() {
145std::vector<ByteCodeFunctionWithOperator> upgrader_function_list({
146${upgrader_bytecode}
147});
148for (const auto& upgrader_function : upgrader_function_list) {
149for (const auto& op : upgrader_function.operators) {
150upgrader_function.function.append_operator(
151op.name,
152op.overload_name,
153op.num_specified_args);
154}
155}
156return upgrader_function_list;
157};
158static std::vector<ByteCodeFunctionWithOperator> upgraderBytecodeList =
159generate_upgrader_bytecode_list();
160return upgraderBytecodeList;
161}
162
163// clang-format on
164
165} // namespace jit
166} // namespace torch
167"""
168)
169
170UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp"171
172UPGRADER_ELEMENT = CodeTemplate(173"""\174Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}),
175"""
176)
177
178PER_OPERATOR_UPGRADER_LIST = CodeTemplate(179"""\180{
181std::string(${operator_name}),
182std::vector<Upgrader>({${upgrader_list}});
183}
184"""
185)
186
187
188def construct_instruction(instruction_list_from_yaml: List[Any]) -> str:189instruction_list_part = []190for instruction in instruction_list_from_yaml:191instruction_list_part.append(192ONE_INSTRUCTION.substitute(193operator_name=instruction[0],194X=instruction[1],195N=instruction[2],196)197)198return INSTRUCTION_LIST.substitute(199instruction_list="".join(instruction_list_part).lstrip("\n")200)201
202
203def construct_constants(constants_list_from_yaml: List[Any]) -> str:204constants_list_part = []205for constant_from_yaml in constants_list_from_yaml:206convert_constant = None207if isinstance(constant_from_yaml, str):208# Add quotes if it's string209convert_constant = f'"{constant_from_yaml}"'210elif isinstance(constant_from_yaml, bool):211convert_constant = "true" if constant_from_yaml else "false"212elif constant_from_yaml is None:213convert_constant = ""214elif isinstance(constant_from_yaml, int):215convert_constant = str(constant_from_yaml)216else:217raise ValueError(218f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. "219"Please add change in construct_constants function in gen_mobile_upgraders.py."220)221constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant))222if len(constants_list_part) == 0:223return CONSTANTS_LIST_EMPTY224return CONSTANT_LIST.substitute(225constant_list="".join(constants_list_part).lstrip("\n")226)227
228
229def construct_operators(operator_list_from_yaml: List[Any]) -> str:230operator_list_part = []231for operator in operator_list_from_yaml:232operator_list_part.append(233ONE_OPERATOTR_STRING.substitute(234operator_name=operator[0],235overload_name=operator[1],236num_of_args=operator[2],237)238)239return OPERATOR_STRING_LIST.substitute(240operator_string_list="".join(operator_list_part).lstrip("\n")241)242
243
244def construct_types(types_tr_list_from_yaml: List[Any]) -> str:245types_tr_list_part = []246for types_tr in types_tr_list_from_yaml:247types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))248if len(types_tr_list_part) == 0:249return TYPE_LIST_EMPTY250return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n"))251
252
253def construct_register_size(register_size_from_yaml: int) -> str:254if not isinstance(register_size_from_yaml, int):255raise ValueError(256f"Input register size is {register_size_from_yaml} and"257"it's type is {type(register_size_from_yaml)}. An int type is expected."258)259return str(register_size_from_yaml)260
261
262def construct_version_maps(263upgrader_bytecode_function_to_index_map: Dict[str, Any]264) -> str:265version_map = torch._C._get_operator_version_map()266sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0]) # type: ignore[no-any-return]267sorted_version_map = dict(sorted_version_map_)268
269operator_list_in_version_map_part = []270for op_name in sorted_version_map:271upgraders_in_version_map_part = []272# TODO: remove the skip after these two operators schemas are fixed273if op_name in EXCLUDED_OP_SET:274continue275upgrader_ranges = torch._C._get_upgrader_ranges(op_name)276upgrader_entries = sorted_version_map[op_name]277assert len(upgrader_ranges) == len(upgrader_entries)278for idx, upgrader_entry in enumerate(upgrader_entries):279upgrader_name = upgrader_entry.upgrader_name280bytecode_function_index = upgrader_bytecode_function_to_index_map[281upgrader_name
282]283upgraders_in_version_map_part.append(284ONE_UPGRADER_IN_VERSION_MAP.substitute(285upgrader_min_version=upgrader_ranges[idx].min_version,286upgrader_max_version=upgrader_ranges[idx].max_version,287upgrader_name=upgrader_name,288bytecode_func_index=bytecode_function_index,289)290)291operator_list_in_version_map_part.append(292ONE_OPERATOR_IN_VERSION_MAP.substitute(293operator_name=op_name,294upgrader_list_in_version_map="".join(upgraders_in_version_map_part),295)296)297return OPERATOR_VERSION_MAP.substitute(298operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip(299"\n"300)301)302
303
304def get_upgrader_bytecode_function_to_index_map(305upgrader_dict: List[Dict[str, Any]]306) -> Dict[str, Any]:307upgrader_bytecode_function_to_index_map = {}308index = 0309for upgrader_bytecode in upgrader_dict:310for upgrader_name in upgrader_bytecode.keys():311if upgrader_name in EXCLUE_UPGRADER_SET:312continue313upgrader_bytecode_function_to_index_map[upgrader_name] = index314index += 1315return upgrader_bytecode_function_to_index_map316
317
318def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None:319body_parts = []320upgrader_bytecode_function_to_index_map = (321get_upgrader_bytecode_function_to_index_map(upgrader_dict)322)323version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map)324all_upgrader_src_string = []325for upgrader_bytecode in upgrader_dict:326for upgrader_name, bytecode in upgrader_bytecode.items():327# TODO: remove the skip after these two operators schemas are fixed328if upgrader_name in EXCLUE_UPGRADER_SET:329continue330instruction_list_str = ""331constant_list_str = ""332type_list_str = ""333register_size_str = ""334operator_list_str = ""335for table_name, contents in bytecode.items():336element = ByteCode[table_name]337body_string = ""338if element is ByteCode.instructions:339instruction_list_str = construct_instruction(contents)340elif element is ByteCode.constants:341constant_list_str = construct_constants(contents)342elif element is ByteCode.operators:343operator_list_str = construct_operators(contents)344elif element is ByteCode.types:345type_list_str = construct_types(contents)346elif element is ByteCode.register_size:347register_size_str = construct_register_size(contents)348
349one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute(350upgrader_name=upgrader_name,351instruction_list=instruction_list_str,352constant_list=constant_list_str,353type_list=type_list_str,354register_size=register_size_str,355)356one_upgrader_src_string = ONE_UPGRADER_SRC.substitute(357bytecode_function=one_upgrader_function_string.lstrip("\n"),358operator_string_list=operator_list_str.lstrip("\n"),359)360all_upgrader_src_string.append(one_upgrader_src_string)361
362upgrader_file_content = UPGRADER_CPP_SRC.substitute(363operator_version_map=version_map_src,364upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"),365)366body_parts.append(upgrader_file_content)367print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME)368with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file:369final_output = "".join(body_parts)370out_file.write(upgrader_file_content.encode("utf-8"))371
372
373def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:374sorted_upgrader_list = sorted(375upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))376)377return sorted_upgrader_list378
379
380def main() -> None:381upgrader_list = generate_upgraders_bytecode()382sorted_upgrader_list = sort_upgrader(upgrader_list)383for up in sorted_upgrader_list:384print("after sort upgrader : ", next(iter(up)))385
386pytorch_dir = Path(__file__).resolve().parents[2]387upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile"388write_cpp(str(upgrader_path), sorted_upgrader_list)389
390
391if __name__ == "__main__":392main()393