pytorch

Форк
0
/
pt_template_srcs.bzl 
246 строк · 11.1 Кб
1
# This file keeps a list of PyTorch source files that are used for templated selective build.
2
# NB: as this is PyTorch Edge selective build, we assume only CPU targets are
3
# being built
4

5
load("@bazel_skylib//lib:paths.bzl", "paths")
6
load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
7
load(":build_variables.bzl", "aten_native_source_list")
8
load(
9
    ":ufunc_defs.bzl",
10
    "aten_ufunc_generated_cpu_kernel_sources",
11
    "aten_ufunc_generated_cpu_sources",
12
)
13

14
# Files in this list are supposed to be built separately for each app,
15
# for different operator allow lists.
16
TEMPLATE_SOURCE_LIST = [
17
    "torch/csrc/jit/runtime/register_prim_ops.cpp",
18
    "torch/csrc/jit/runtime/register_special_ops.cpp",
19
] + aten_native_source_list
20

21
# For selective build, we can lump the CPU and CPU kernel sources altogether
22
# because there is only ever one vectorization variant that is compiled
23
def aten_ufunc_generated_all_cpu_sources(gencode_pattern = "{}"):
24
    return (
25
        aten_ufunc_generated_cpu_sources(gencode_pattern) +
26
        aten_ufunc_generated_cpu_kernel_sources(gencode_pattern)
27
    )
28

29
TEMPLATE_MASKRCNN_SOURCE_LIST = [
30
    "register_maskrcnn_ops.cpp",
31
]
32

33
TEMPLATE_BATCH_BOX_COX_SOURCE_LIST = [
34
    "register_batch_box_cox_ops.cpp",
35
]
36

37
METAL_SOURCE_LIST = [
38
    "aten/src/ATen/native/metal/MetalAten.mm",
39
    "aten/src/ATen/native/metal/MetalGuardImpl.cpp",
40
    "aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp",
41
    "aten/src/ATen/native/metal/MetalCommandBuffer.mm",
42
    "aten/src/ATen/native/metal/MetalContext.mm",
43
    "aten/src/ATen/native/metal/MetalConvParams.mm",
44
    "aten/src/ATen/native/metal/MetalTensorImplStorage.mm",
45
    "aten/src/ATen/native/metal/MetalTensorUtils.mm",
46
    "aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm",
47
    "aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm",
48
    "aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm",
49
    "aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm",
50
    "aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm",
51
    "aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm",
52
    "aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm",
53
    "aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm",
54
    "aten/src/ATen/native/metal/ops/MetalAddmm.mm",
55
    "aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm",
56
    "aten/src/ATen/native/metal/ops/MetalChunk.mm",
57
    "aten/src/ATen/native/metal/ops/MetalClamp.mm",
58
    "aten/src/ATen/native/metal/ops/MetalConcat.mm",
59
    "aten/src/ATen/native/metal/ops/MetalConvolution.mm",
60
    "aten/src/ATen/native/metal/ops/MetalCopy.mm",
61
    "aten/src/ATen/native/metal/ops/MetalHardswish.mm",
62
    "aten/src/ATen/native/metal/ops/MetalHardshrink.mm",
63
    "aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm",
64
    "aten/src/ATen/native/metal/ops/MetalNeurons.mm",
65
    "aten/src/ATen/native/metal/ops/MetalPadding.mm",
66
    "aten/src/ATen/native/metal/ops/MetalPooling.mm",
67
    "aten/src/ATen/native/metal/ops/MetalReduce.mm",
68
    "aten/src/ATen/native/metal/ops/MetalReshape.mm",
69
    "aten/src/ATen/native/metal/ops/MetalSoftmax.mm",
70
    "aten/src/ATen/native/metal/ops/MetalTranspose.mm",
71
    "aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm",
72
]
73

74
UNET_METAL_PREPACK_SOURCE_LIST = [
75
    "unet_metal_prepack.cpp",
76
    "unet_metal_prepack.mm",
77
]
78

79
METAL_MASKRCNN_SOURCE_LIST = [
80
    "maskrcnn/srcs/GenerateProposals.mm",
81
    "maskrcnn/srcs/RoIAlign.mm",
82
]
83

84
# The get_template_source_dict() returns a dict containing a path prefix
85
# and a list of .cpp source files containing operator definitions and
86
# registrations that should get selected via templated selective build.
87
# The file selected_mobile_ops.h has the list of selected top level
88
# operators.
89
# NB: doesn't include generated files; copy_template_registration_files
90
# handles those specially
91
def get_template_source_dict():
92
    ret = {}
93
    for file_path in TEMPLATE_SOURCE_LIST:
94
        path_prefix = paths.dirname(file_path)
95
        if path_prefix not in ret:
96
            ret[path_prefix] = []
97
        ret[path_prefix].append(file_path)
98
    return ret
99

100
def get_gen_oplist_outs():
101
    return {
102
        "SupportedMobileModelsRegistration.cpp": [
103
            "SupportedMobileModelsRegistration.cpp",
104
        ],
105
        "selected_mobile_ops.h": [
106
            "selected_mobile_ops.h",
107
        ],
108
        "selected_operators.yaml": [
109
            "selected_operators.yaml",
110
        ],
111
    }
112

113
def get_generate_code_bin_outs():
114
    outs = {
115
        "autograd/generated/ADInplaceOrViewTypeEverything.cpp": ["autograd/generated/ADInplaceOrViewTypeEverything.cpp"],
116
        "autograd/generated/ADInplaceOrViewType_0.cpp": ["autograd/generated/ADInplaceOrViewType_0.cpp"],
117
        "autograd/generated/ADInplaceOrViewType_1.cpp": ["autograd/generated/ADInplaceOrViewType_1.cpp"],
118
        "autograd/generated/Functions.cpp": ["autograd/generated/Functions.cpp"],
119
        "autograd/generated/Functions.h": ["autograd/generated/Functions.h"],
120
        "autograd/generated/TraceTypeEverything.cpp": ["autograd/generated/TraceTypeEverything.cpp"],
121
        "autograd/generated/TraceType_0.cpp": ["autograd/generated/TraceType_0.cpp"],
122
        "autograd/generated/TraceType_1.cpp": ["autograd/generated/TraceType_1.cpp"],
123
        "autograd/generated/TraceType_2.cpp": ["autograd/generated/TraceType_2.cpp"],
124
        "autograd/generated/TraceType_3.cpp": ["autograd/generated/TraceType_3.cpp"],
125
        "autograd/generated/TraceType_4.cpp": ["autograd/generated/TraceType_4.cpp"],
126
        "autograd/generated/VariableType.h": ["autograd/generated/VariableType.h"],
127
        "autograd/generated/VariableTypeEverything.cpp": ["autograd/generated/VariableTypeEverything.cpp"],
128
        "autograd/generated/VariableType_0.cpp": ["autograd/generated/VariableType_0.cpp"],
129
        "autograd/generated/VariableType_1.cpp": ["autograd/generated/VariableType_1.cpp"],
130
        "autograd/generated/VariableType_2.cpp": ["autograd/generated/VariableType_2.cpp"],
131
        "autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"],
132
        "autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"],
133
        "autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"],
134
        "autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"],
135
        "autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"],
136
    }
137

138
    if is_arvr_mode():
139
        outs.update({
140
            "autograd/generated/python_enum_tag.cpp": ["autograd/generated/python_enum_tag.cpp"],
141
            "autograd/generated/python_fft_functions.cpp": ["autograd/generated/python_fft_functions.cpp"],
142
            "autograd/generated/python_functions.h": ["autograd/generated/python_functions.h"],
143
            "autograd/generated/python_functions_0.cpp": ["autograd/generated/python_functions_0.cpp"],
144
            "autograd/generated/python_functions_1.cpp": ["autograd/generated/python_functions_1.cpp"],
145
            "autograd/generated/python_functions_2.cpp": ["autograd/generated/python_functions_2.cpp"],
146
            "autograd/generated/python_functions_3.cpp": ["autograd/generated/python_functions_3.cpp"],
147
            "autograd/generated/python_functions_4.cpp": ["autograd/generated/python_functions_4.cpp"],
148
            "autograd/generated/python_linalg_functions.cpp": ["autograd/generated/python_linalg_functions.cpp"],
149
            "autograd/generated/python_nested_functions.cpp": ["autograd/generated/python_nested_functions.cpp"],
150
            "autograd/generated/python_nn_functions.cpp": ["autograd/generated/python_nn_functions.cpp"],
151
            "autograd/generated/python_return_types.h": ["autograd/generated/python_return_types.h"],
152
            "autograd/generated/python_return_types.cpp": ["autograd/generated/python_return_types.cpp"],
153
            "autograd/generated/python_sparse_functions.cpp": ["autograd/generated/python_sparse_functions.cpp"],
154
            "autograd/generated/python_special_functions.cpp": ["autograd/generated/python_special_functions.cpp"],
155
            "autograd/generated/python_torch_functions_0.cpp": ["autograd/generated/python_torch_functions_0.cpp"],
156
            "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
157
            "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
158
            "autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"],
159
        })
160
    return outs
161

162
def get_template_registration_files_outs(is_oss = False):
163
    outs = {}
164
    if not is_oss:
165
        for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST:
166
            outs[file_path] = [file_path]
167

168
        for file_path in TEMPLATE_BATCH_BOX_COX_SOURCE_LIST:
169
            outs[file_path] = [file_path]
170

171
    for file_path in TEMPLATE_SOURCE_LIST:
172
        outs[file_path] = [file_path]
173

174
    for base_name in aten_ufunc_generated_all_cpu_sources():
175
        file_path = "aten/src/ATen/{}".format(base_name)
176
        outs[file_path] = [file_path]
177

178
    return outs
179

180
def get_template_registration_file_rules(rule_name, is_oss = False):
181
    rules = []
182
    for file_path in TEMPLATE_SOURCE_LIST if is_oss else (TEMPLATE_SOURCE_LIST + TEMPLATE_MASKRCNN_SOURCE_LIST + TEMPLATE_BATCH_BOX_COX_SOURCE_LIST):
183
        rules.append(":{}[{}]".format(rule_name, file_path))
184
    for file_path in aten_ufunc_generated_all_cpu_sources():
185
        rules.append(":{}[aten/src/ATen/{}]".format(rule_name, file_path))
186

187
    return rules
188

189
# ---------------------METAL RULES---------------------
190
def get_metal_source_dict():
191
    ret = {}
192
    for file_path in METAL_SOURCE_LIST:
193
        path_prefix = paths.dirname(file_path)
194
        if path_prefix not in ret:
195
            ret[path_prefix] = []
196
        ret[path_prefix].append(file_path)
197
    return ret
198

199
def get_metal_registration_files_outs():
200
    outs = {}
201
    for file_path in METAL_SOURCE_LIST:
202
        outs[file_path] = [file_path]
203

204
    for file_path in UNET_METAL_PREPACK_SOURCE_LIST:
205
        outs[file_path] = [file_path]
206

207
    for file_path in METAL_MASKRCNN_SOURCE_LIST:
208
        outs[file_path] = [file_path]
209
    return outs
210

211
# There is a really weird issue with the arvr windows builds where
212
# the custom op files are breaking them. See https://fburl.com/za87443c
213
# The hack is just to not build them for that platform and pray they arent needed.
214
def get_metal_registration_files_outs_windows():
215
    outs = {}
216
    for file_path in METAL_SOURCE_LIST:
217
        outs[file_path] = [file_path]
218
    return outs
219

220
def get_metal_registration_files_rules(rule_name):
221
    ret = {}
222
    objc_rules = []
223
    cxx_rules = []
224

225
    for file_path in METAL_SOURCE_LIST + METAL_MASKRCNN_SOURCE_LIST + UNET_METAL_PREPACK_SOURCE_LIST:
226
        if ".cpp" not in file_path:
227
            objc_rules.append(":{}[{}]".format(rule_name, file_path))
228
        else:
229
            cxx_rules.append(":{}[{}]".format(rule_name, file_path))
230
    ret["objc"] = objc_rules
231
    ret["cxx"] = cxx_rules
232
    return ret
233

234
def get_metal_registration_files_rules_windows(rule_name):
235
    ret = {}
236
    objc_rules = []
237
    cxx_rules = []
238

239
    for file_path in METAL_SOURCE_LIST:
240
        if ".cpp" not in file_path:
241
            objc_rules.append(":{}[{}]".format(rule_name, file_path))
242
        else:
243
            cxx_rules.append(":{}[{}]".format(rule_name, file_path))
244
    ret["objc"] = objc_rules
245
    ret["cxx"] = cxx_rules
246
    return ret
247

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.