pytorch
/
pt_template_srcs.bzl
246 строк · 11.1 Кб
1# This file keeps a list of PyTorch source files that are used for templated selective build.
2# NB: as this is PyTorch Edge selective build, we assume only CPU targets are
3# being built
4
5load("@bazel_skylib//lib:paths.bzl", "paths")
6load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
7load(":build_variables.bzl", "aten_native_source_list")
8load(
9":ufunc_defs.bzl",
10"aten_ufunc_generated_cpu_kernel_sources",
11"aten_ufunc_generated_cpu_sources",
12)
13
14# Files in this list are supposed to be built separately for each app,
15# for different operator allow lists.
16TEMPLATE_SOURCE_LIST = [
17"torch/csrc/jit/runtime/register_prim_ops.cpp",
18"torch/csrc/jit/runtime/register_special_ops.cpp",
19] + aten_native_source_list
20
21# For selective build, we can lump the CPU and CPU kernel sources altogether
22# because there is only ever one vectorization variant that is compiled
23def aten_ufunc_generated_all_cpu_sources(gencode_pattern = "{}"):
24return (
25aten_ufunc_generated_cpu_sources(gencode_pattern) +
26aten_ufunc_generated_cpu_kernel_sources(gencode_pattern)
27)
28
29TEMPLATE_MASKRCNN_SOURCE_LIST = [
30"register_maskrcnn_ops.cpp",
31]
32
33TEMPLATE_BATCH_BOX_COX_SOURCE_LIST = [
34"register_batch_box_cox_ops.cpp",
35]
36
37METAL_SOURCE_LIST = [
38"aten/src/ATen/native/metal/MetalAten.mm",
39"aten/src/ATen/native/metal/MetalGuardImpl.cpp",
40"aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp",
41"aten/src/ATen/native/metal/MetalCommandBuffer.mm",
42"aten/src/ATen/native/metal/MetalContext.mm",
43"aten/src/ATen/native/metal/MetalConvParams.mm",
44"aten/src/ATen/native/metal/MetalTensorImplStorage.mm",
45"aten/src/ATen/native/metal/MetalTensorUtils.mm",
46"aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm",
47"aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm",
48"aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm",
49"aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm",
50"aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm",
51"aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm",
52"aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm",
53"aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm",
54"aten/src/ATen/native/metal/ops/MetalAddmm.mm",
55"aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm",
56"aten/src/ATen/native/metal/ops/MetalChunk.mm",
57"aten/src/ATen/native/metal/ops/MetalClamp.mm",
58"aten/src/ATen/native/metal/ops/MetalConcat.mm",
59"aten/src/ATen/native/metal/ops/MetalConvolution.mm",
60"aten/src/ATen/native/metal/ops/MetalCopy.mm",
61"aten/src/ATen/native/metal/ops/MetalHardswish.mm",
62"aten/src/ATen/native/metal/ops/MetalHardshrink.mm",
63"aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm",
64"aten/src/ATen/native/metal/ops/MetalNeurons.mm",
65"aten/src/ATen/native/metal/ops/MetalPadding.mm",
66"aten/src/ATen/native/metal/ops/MetalPooling.mm",
67"aten/src/ATen/native/metal/ops/MetalReduce.mm",
68"aten/src/ATen/native/metal/ops/MetalReshape.mm",
69"aten/src/ATen/native/metal/ops/MetalSoftmax.mm",
70"aten/src/ATen/native/metal/ops/MetalTranspose.mm",
71"aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm",
72]
73
74UNET_METAL_PREPACK_SOURCE_LIST = [
75"unet_metal_prepack.cpp",
76"unet_metal_prepack.mm",
77]
78
79METAL_MASKRCNN_SOURCE_LIST = [
80"maskrcnn/srcs/GenerateProposals.mm",
81"maskrcnn/srcs/RoIAlign.mm",
82]
83
84# The get_template_source_dict() returns a dict containing a path prefix
85# and a list of .cpp source files containing operator definitions and
86# registrations that should get selected via templated selective build.
87# The file selected_mobile_ops.h has the list of selected top level
88# operators.
89# NB: doesn't include generated files; copy_template_registration_files
90# handles those specially
91def get_template_source_dict():
92ret = {}
93for file_path in TEMPLATE_SOURCE_LIST:
94path_prefix = paths.dirname(file_path)
95if path_prefix not in ret:
96ret[path_prefix] = []
97ret[path_prefix].append(file_path)
98return ret
99
100def get_gen_oplist_outs():
101return {
102"SupportedMobileModelsRegistration.cpp": [
103"SupportedMobileModelsRegistration.cpp",
104],
105"selected_mobile_ops.h": [
106"selected_mobile_ops.h",
107],
108"selected_operators.yaml": [
109"selected_operators.yaml",
110],
111}
112
113def get_generate_code_bin_outs():
114outs = {
115"autograd/generated/ADInplaceOrViewTypeEverything.cpp": ["autograd/generated/ADInplaceOrViewTypeEverything.cpp"],
116"autograd/generated/ADInplaceOrViewType_0.cpp": ["autograd/generated/ADInplaceOrViewType_0.cpp"],
117"autograd/generated/ADInplaceOrViewType_1.cpp": ["autograd/generated/ADInplaceOrViewType_1.cpp"],
118"autograd/generated/Functions.cpp": ["autograd/generated/Functions.cpp"],
119"autograd/generated/Functions.h": ["autograd/generated/Functions.h"],
120"autograd/generated/TraceTypeEverything.cpp": ["autograd/generated/TraceTypeEverything.cpp"],
121"autograd/generated/TraceType_0.cpp": ["autograd/generated/TraceType_0.cpp"],
122"autograd/generated/TraceType_1.cpp": ["autograd/generated/TraceType_1.cpp"],
123"autograd/generated/TraceType_2.cpp": ["autograd/generated/TraceType_2.cpp"],
124"autograd/generated/TraceType_3.cpp": ["autograd/generated/TraceType_3.cpp"],
125"autograd/generated/TraceType_4.cpp": ["autograd/generated/TraceType_4.cpp"],
126"autograd/generated/VariableType.h": ["autograd/generated/VariableType.h"],
127"autograd/generated/VariableTypeEverything.cpp": ["autograd/generated/VariableTypeEverything.cpp"],
128"autograd/generated/VariableType_0.cpp": ["autograd/generated/VariableType_0.cpp"],
129"autograd/generated/VariableType_1.cpp": ["autograd/generated/VariableType_1.cpp"],
130"autograd/generated/VariableType_2.cpp": ["autograd/generated/VariableType_2.cpp"],
131"autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"],
132"autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"],
133"autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"],
134"autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"],
135"autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"],
136}
137
138if is_arvr_mode():
139outs.update({
140"autograd/generated/python_enum_tag.cpp": ["autograd/generated/python_enum_tag.cpp"],
141"autograd/generated/python_fft_functions.cpp": ["autograd/generated/python_fft_functions.cpp"],
142"autograd/generated/python_functions.h": ["autograd/generated/python_functions.h"],
143"autograd/generated/python_functions_0.cpp": ["autograd/generated/python_functions_0.cpp"],
144"autograd/generated/python_functions_1.cpp": ["autograd/generated/python_functions_1.cpp"],
145"autograd/generated/python_functions_2.cpp": ["autograd/generated/python_functions_2.cpp"],
146"autograd/generated/python_functions_3.cpp": ["autograd/generated/python_functions_3.cpp"],
147"autograd/generated/python_functions_4.cpp": ["autograd/generated/python_functions_4.cpp"],
148"autograd/generated/python_linalg_functions.cpp": ["autograd/generated/python_linalg_functions.cpp"],
149"autograd/generated/python_nested_functions.cpp": ["autograd/generated/python_nested_functions.cpp"],
150"autograd/generated/python_nn_functions.cpp": ["autograd/generated/python_nn_functions.cpp"],
151"autograd/generated/python_return_types.h": ["autograd/generated/python_return_types.h"],
152"autograd/generated/python_return_types.cpp": ["autograd/generated/python_return_types.cpp"],
153"autograd/generated/python_sparse_functions.cpp": ["autograd/generated/python_sparse_functions.cpp"],
154"autograd/generated/python_special_functions.cpp": ["autograd/generated/python_special_functions.cpp"],
155"autograd/generated/python_torch_functions_0.cpp": ["autograd/generated/python_torch_functions_0.cpp"],
156"autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
157"autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
158"autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"],
159})
160return outs
161
162def get_template_registration_files_outs(is_oss = False):
163outs = {}
164if not is_oss:
165for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST:
166outs[file_path] = [file_path]
167
168for file_path in TEMPLATE_BATCH_BOX_COX_SOURCE_LIST:
169outs[file_path] = [file_path]
170
171for file_path in TEMPLATE_SOURCE_LIST:
172outs[file_path] = [file_path]
173
174for base_name in aten_ufunc_generated_all_cpu_sources():
175file_path = "aten/src/ATen/{}".format(base_name)
176outs[file_path] = [file_path]
177
178return outs
179
180def get_template_registration_file_rules(rule_name, is_oss = False):
181rules = []
182for file_path in TEMPLATE_SOURCE_LIST if is_oss else (TEMPLATE_SOURCE_LIST + TEMPLATE_MASKRCNN_SOURCE_LIST + TEMPLATE_BATCH_BOX_COX_SOURCE_LIST):
183rules.append(":{}[{}]".format(rule_name, file_path))
184for file_path in aten_ufunc_generated_all_cpu_sources():
185rules.append(":{}[aten/src/ATen/{}]".format(rule_name, file_path))
186
187return rules
188
189# ---------------------METAL RULES---------------------
190def get_metal_source_dict():
191ret = {}
192for file_path in METAL_SOURCE_LIST:
193path_prefix = paths.dirname(file_path)
194if path_prefix not in ret:
195ret[path_prefix] = []
196ret[path_prefix].append(file_path)
197return ret
198
199def get_metal_registration_files_outs():
200outs = {}
201for file_path in METAL_SOURCE_LIST:
202outs[file_path] = [file_path]
203
204for file_path in UNET_METAL_PREPACK_SOURCE_LIST:
205outs[file_path] = [file_path]
206
207for file_path in METAL_MASKRCNN_SOURCE_LIST:
208outs[file_path] = [file_path]
209return outs
210
211# There is a really weird issue with the arvr windows builds where
212# the custom op files are breaking them. See https://fburl.com/za87443c
213# The hack is just to not build them for that platform and pray they arent needed.
214def get_metal_registration_files_outs_windows():
215outs = {}
216for file_path in METAL_SOURCE_LIST:
217outs[file_path] = [file_path]
218return outs
219
220def get_metal_registration_files_rules(rule_name):
221ret = {}
222objc_rules = []
223cxx_rules = []
224
225for file_path in METAL_SOURCE_LIST + METAL_MASKRCNN_SOURCE_LIST + UNET_METAL_PREPACK_SOURCE_LIST:
226if ".cpp" not in file_path:
227objc_rules.append(":{}[{}]".format(rule_name, file_path))
228else:
229cxx_rules.append(":{}[{}]".format(rule_name, file_path))
230ret["objc"] = objc_rules
231ret["cxx"] = cxx_rules
232return ret
233
234def get_metal_registration_files_rules_windows(rule_name):
235ret = {}
236objc_rules = []
237cxx_rules = []
238
239for file_path in METAL_SOURCE_LIST:
240if ".cpp" not in file_path:
241objc_rules.append(":{}[{}]".format(rule_name, file_path))
242else:
243cxx_rules.append(":{}[{}]".format(rule_name, file_path))
244ret["objc"] = objc_rules
245ret["cxx"] = cxx_rules
246return ret
247