pytorch

Форк
0
/
gen_jit_shape_functions.py 
182 строки · 5.1 Кб
1
#!/usr/bin/env python3
2
import os
3
import sys
4
from importlib.util import module_from_spec, spec_from_file_location
5
from itertools import chain
6
from pathlib import Path
7

8

9
# Manually importing the shape function module based on current directory
10
# instead of torch imports to avoid needing to recompile Pytorch before
11
# running the script
12

13
file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py"
14
module_name = "torch.jit._shape_functions"
15

16
err_msg = """Could not find shape functions file, please make sure
17
you are in the root directory of the Pytorch git repo"""
18
if not file_path.exists():
19
    raise Exception(err_msg)  # noqa: TRY002
20

21
spec = spec_from_file_location(module_name, file_path)
22
assert spec is not None
23
module = module_from_spec(spec)
24
sys.modules[module_name] = module
25
assert spec.loader is not None
26
assert module is not None
27
spec.loader.exec_module(module)
28

29
bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
30
shape_compute_graph_mapping = module.shape_compute_graph_mapping
31

32

33
SHAPE_HEADER = r"""
34
/**
35
 * @generated
36
 * This is an auto-generated file. Please do not modify it by hand.
37
 * To re-generate, please run:
38
 * cd ~/pytorch && python
39
 * torchgen/shape_functions/gen_jit_shape_functions.py
40
 */
41
#include <torch/csrc/jit/jit_log.h>
42
#include <torch/csrc/jit/passes/inliner.h>
43
#include <torch/csrc/jit/runtime/operator.h>
44
#include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
45

46
// clang-format off
47

48
namespace torch {
49
namespace jit {
50

51

52
std::string shape_funcs = ""
53
"""
54

55

56
DECOMP_CENTER = r"""
57

58

59
const std::string& GetSerializedShapeFunctions() {
60
  return shape_funcs;
61
}
62

63
"""
64

65
DECOMP_END = r"""
66
// clang-format on
67

68
} // namespace jit
69
} // namespace torch
70
"""
71

72

73
SERIALIZED_SHAPE_UTIL_FILE_NAME = "serialized_shape_function_registry.cpp"
74

75

76
def gen_serialized_decompisitions() -> str:
77
    already_serialized_names = set()
78
    unique_funcs = []
79
    all_funcs = chain(
80
        shape_compute_graph_mapping.values(), *bounded_compute_graph_mapping.values()
81
    )
82
    for scripted_func in all_funcs:
83
        if scripted_func.name in already_serialized_names:
84
            continue
85
        already_serialized_names.add(scripted_func.name)
86
        unique_funcs.append(scripted_func)
87

88
    output_strs = []
89
    curr_str = ""
90
    for scripted_func in unique_funcs:
91
        serialized_code = scripted_func.code
92
        # technically its higher but give a buffer bc there are weird rules
93
        # around some characters
94
        # TODO: this was the limit I found by googling but it seems way
95
        # too short ?
96
        MAX_MSFT_STR_LEN = 2000
97
        if len(curr_str) + len(serialized_code) <= MAX_MSFT_STR_LEN:
98
            curr_str += "\n" + serialized_code
99
        else:
100
            output_strs.append(curr_str)
101
            curr_str = scripted_func.code
102
    output_strs.append(curr_str)
103

104
    final_output = ""
105
    # Windows compiler doesnt correctly handle adjacent
106
    # string literals
107
    for output_str in output_strs:
108
        start = '+ std::string(R"=====('
109
        end = '\n)=====")\n'
110
        final_output += start + output_str + end
111
    final_output += ";"
112
    return final_output
113

114

115
SHAPE_SCHEMA_START = r"""
116
const OperatorMap<std::string>& GetShapeFunctionMappings() {
117
 static const OperatorMap<std::string> shape_mappings {
118
"""
119

120
SHAPE_SCHEMA_END = r"""
121
  };
122

123
  return shape_mappings;
124
}
125
"""
126

127

128
def gen_shape_mappings() -> str:
129
    shape_mappings = []
130
    for schema, scripted_func in shape_compute_graph_mapping.items():
131
        shape_mappings.append('    {"' + schema + '", "' + scripted_func.name + '"},')
132
    return SHAPE_SCHEMA_START + "\n".join(shape_mappings) + SHAPE_SCHEMA_END
133

134

135
BOUNDED_SCHEMA_START = r"""
136
const OperatorMap<std::pair<std::string, std::string>>& GetBoundedShapeMappings() {
137
 static const OperatorMap<std::pair<std::string, std::string>> shape_mappings {
138
"""
139

140

141
def gen_bounded_mappings() -> str:
142
    bounded_mappings = []
143
    for schema, (lower_func, upper_func) in bounded_compute_graph_mapping.items():
144
        map_str = (
145
            '    {"'
146
            + schema
147
            + '", {"'
148
            + lower_func.name
149
            + '", "'
150
            + upper_func.name
151
            + '"}},'
152
        )
153
        bounded_mappings.append(map_str)
154
    return BOUNDED_SCHEMA_START + "\n".join(bounded_mappings) + SHAPE_SCHEMA_END
155

156

157
def write_decomposition_util_file(path: str) -> None:
158
    decomposition_str = gen_serialized_decompisitions()
159
    shape_mappings = gen_shape_mappings()
160
    bounded_mappings = gen_bounded_mappings()
161
    file_components = [
162
        SHAPE_HEADER,
163
        decomposition_str,
164
        DECOMP_CENTER,
165
        shape_mappings,
166
        bounded_mappings,
167
        DECOMP_END,
168
    ]
169
    print("writing file to : ", path + "/" + SERIALIZED_SHAPE_UTIL_FILE_NAME)
170
    with open(os.path.join(path, SERIALIZED_SHAPE_UTIL_FILE_NAME), "wb") as out_file:
171
        final_output = "".join(file_components)
172
        out_file.write(final_output.encode("utf-8"))
173

174

175
def main() -> None:
176
    pytorch_dir = Path(__file__).resolve().parents[2]
177
    upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
178
    write_decomposition_util_file(str(upgrader_path))
179

180

181
if __name__ == "__main__":
182
    main()
183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.