pytorch
/
buckbuild.bzl
2241 строка · 82.9 Кб
1# NOTE: This file is shared by internal and OSS BUCK build.
2# These load paths point to different files in internal and OSS environment
3
4load("@bazel_skylib//lib:paths.bzl", "paths")
5load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
6load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
7load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
8load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
9load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
10load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX")
11load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
12load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build")
13load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build")
14load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags")
15load(
16":build_variables.bzl",
17"aten_cpu_source_list",
18"aten_native_source_list",
19"core_sources_common",
20"core_sources_full_mobile_no_backend_interface_xplat",
21"core_trainer_sources",
22"jit_core_headers",
23"jit_core_sources",
24"libtorch_profiler_sources",
25"torch_cpp_srcs",
26"torch_mobile_tracer_sources",
27)
28load(
29":pt_ops.bzl",
30"USED_PT_BACKENDS",
31)
32load(
33":pt_template_srcs.bzl",
34"METAL_MASKRCNN_SOURCE_LIST",
35"METAL_SOURCE_LIST",
36"TEMPLATE_MASKRCNN_SOURCE_LIST",
37"TEMPLATE_SOURCE_LIST",
38"aten_ufunc_generated_all_cpu_sources",
39"get_gen_oplist_outs",
40"get_generate_code_bin_outs",
41"get_metal_registration_files_outs",
42"get_metal_registration_files_outs_windows",
43"get_metal_source_dict",
44"get_template_registration_file_rules",
45"get_template_registration_files_outs",
46"get_template_source_dict",
47)
48load(
49":ufunc_defs.bzl",
50"aten_ufunc_generated_cpu_kernel_sources",
51"aten_ufunc_generated_cpu_sources",
52"aten_ufunc_generated_cuda_sources",
53)
54
55def read_bool(section, field, default, required = True):
56val = read_config(section, field)
57if val != None:
58if val in ["true", "True", "1"]:
59return True
60elif val in ["false", "False", "0"]:
61return False
62else:
63fail(
64"`{}:{}`: must be one of (0, 1, true, false, True, False), but was {}".format(section, field, val),
65)
66elif default != None:
67return default
68elif not required:
69return None
70else:
71fail("`{}:{}`: no value set".format(section, field))
72
73def _is_build_mode_dev():
74if is_production_build_android():
75# Android Prod builds
76return False
77if is_production_build_ios():
78# iOS Prod builds
79return False
80
81return True
82
83def _get_enable_lightweight_dispatch():
84return read_bool("pt", "enable_lightweight_dispatch", False)
85
86def _get_enable_record_kernel_dtype():
87return read_bool("pt", "enable_record_kernel_dtype", False)
88
89def get_enable_mobile_dispatch_keys_trimming():
90return read_bool("pt", "enable_mobile_dispatch_keys_trimming", False)
91
92def get_disable_per_op_profiling():
93return read_bool("pt", "disable_per_op_profiling", True)
94
95def get_strip_error_messages():
96if IS_OSS:
97return True # always strip in OSS CI to expose potential issues
98return read_bool("pt", "strip_error_messages", not _is_build_mode_dev())
99
100def get_disable_warn():
101return read_bool("pt", "disable_warn", False)
102
103def get_enable_eager_symbolication():
104return read_bool("pt", "enable_eager_symbolication", default = False, required = False)
105
106def get_static_dispatch_backend():
107static_dispatch_backend = native.read_config("pt", "static_dispatch_backend", None)
108if static_dispatch_backend == None:
109return []
110return static_dispatch_backend.split(";")
111
112def get_glsl_image_format():
113if read_config("pt", "vulkan_full_precision", "0") == "0":
114return "rgba16f"
115return "rgba32f"
116
117def get_glsl_paths():
118paths = [
119"//xplat/caffe2:aten_vulkan_glsl_src_path",
120"aten/src/ATen/native/vulkan/glsl",
121] + [
122p
123for p in read_config("gen_vulkan_spv", "additional_glsl_paths", "").split(" ")
124if p
125]
126
127if len(paths) % 2 != 0:
128fail(
129"gen_vulkan_spv.additional_glsl_paths must contain an even number of elements",
130)
131
132return " ".join(
133[
134"$(location {})/{}".format(
135paths[i],
136paths[i + 1],
137)
138for i in range(
1390,
140len(paths),
1412,
142)
143],
144)
145
146def spv_shader_library():
147pass
148
149IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
150
151NOT_OSS = not IS_OSS
152
153# for targets in caffe2 root path
154ROOT = "//" if IS_OSS else "//xplat/caffe2"
155
156# for targets in subfolders
157ROOT_PATH = "//" if IS_OSS else "//xplat/caffe2/"
158
159C10 = "//c10:c10" if IS_OSS else "//xplat/caffe2/c10:c10"
160
161# a dictionary maps third party library name to fbsource and oss target
162THIRD_PARTY_LIBS = {
163"FP16": ["//xplat/third-party/FP16:FP16", "//third_party:FP16"],
164"FXdiv": ["//xplat/third-party/FXdiv:FXdiv", "//third_party:FXdiv"],
165"XNNPACK": ["//xplat/third-party/XNNPACK:XNNPACK", "//third_party:XNNPACK"],
166"clog": ["//xplat/third-party/clog:clog", "//third_party:clog"],
167"cpuinfo": ["//third-party/cpuinfo:cpuinfo", "//third_party:cpuinfo"],
168"flatbuffers-api": ["//third-party/flatbuffers/fbsource_namespace:flatbuffers-api", "//third_party:flatbuffers-api"],
169"flatc": ["//third-party/flatbuffers/fbsource_namespace:flatc", "//third_party:flatc"],
170"fmt": ["//third-party/fmt:fmt", "//third_party:fmt"],
171"glog": ["//third-party/glog:glog", "//third_party:glog"],
172"gmock": ["//third-party/googletest:gmock_main", "//third_party:gmock"],
173"gtest": ["//third-party/googletest:gtest_main", "//third_party:gtest"],
174"kineto": ["//xplat/kineto/libkineto:libkineto", "//third_party:libkineto"],
175"libkineto_headers": ["//xplat/kineto/libkineto:libkineto_headers", "//third_party:libkineto_headers"],
176"omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
177"pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
178"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
179"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
180"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
181"pyyaml": ["//third-party/pyyaml:pyyaml", "//third_party:pyyaml"],
182"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
183"ruy": ["//third-party/ruy:ruy_xplat_lib", "//third_party:ruy_lib"],
184"sleef_arm": ["//third-party/sleef:sleef_arm", "//third_party:sleef_arm"],
185"typing-extensions": ["//third-party/typing-extensions:typing-extensions", "//third_party:typing-extensions"],
186}
187
188def third_party(name):
189if name not in THIRD_PARTY_LIBS:
190fail("Cannot find third party library " + name + ", please register it in THIRD_PARTY_LIBS first!")
191return THIRD_PARTY_LIBS[name][1] if IS_OSS else THIRD_PARTY_LIBS[name][0]
192
193def get_pt_compiler_flags():
194return select({
195"DEFAULT": _PT_COMPILER_FLAGS,
196"ovr_config//compiler:cl": windows_convert_gcc_clang_flags(_PT_COMPILER_FLAGS),
197})
198
199_PT_COMPILER_FLAGS = [
200"-fexceptions",
201"-frtti",
202"-Os",
203"-Wno-unknown-pragmas",
204"-Wno-write-strings",
205"-Wno-unused-variable",
206"-Wno-unused-function",
207"-Wno-deprecated-declarations",
208"-Wno-shadow",
209"-Wno-global-constructors",
210"-Wno-missing-prototypes",
211]
212
213ATEN_COMPILER_FLAGS = [
214"-fexceptions",
215"-frtti",
216"-fPIC",
217"-Os",
218"-Wno-absolute-value",
219"-Wno-deprecated-declarations",
220"-Wno-macro-redefined",
221"-Wno-tautological-constant-out-of-range-compare",
222"-Wno-unknown-pragmas",
223"-Wno-unknown-warning-option",
224"-Wno-unused-function",
225"-Wno-unused-variable",
226"-Wno-pass-failed",
227"-Wno-shadow",
228]
229
230def get_aten_compiler_flags():
231return ATEN_COMPILER_FLAGS
232
233_COMMON_PREPROCESSOR_FLAGS = [
234"-DC10_MOBILE",
235"-DNO_EXPORT",
236] + (
237["-DC10_MOBILE_TRIM_DISPATCH_KEYS"] if get_enable_mobile_dispatch_keys_trimming() else []
238) + (
239["-DSTRIP_ERROR_MESSAGES"] if get_strip_error_messages() else []
240) + (
241["-DDISABLE_WARN"] if get_disable_warn() else []
242)
243
244def get_aten_preprocessor_flags():
245# read_config is not allowed outside of function in Starlark
246ATEN_PREPROCESSOR_FLAGS = _COMMON_PREPROCESSOR_FLAGS + [
247"-DCPU_CAPABILITY_DEFAULT",
248"-DCPU_CAPABILITY=DEFAULT",
249"-DCAFFE2_USE_LITE_PROTO",
250"-DATEN_CUDNN_ENABLED_FBXPLAT=0",
251"-DATEN_MKLDNN_ENABLED_FBXPLAT=0",
252"-DATEN_MKLDNN_ACL_ENABLED_FBXPLAT=0",
253"-DATEN_NNPACK_ENABLED_FBXPLAT=0",
254"-DATEN_MKL_ENABLED_FBXPLAT=0",
255"-DATEN_MKL_SEQUENTIAL_FBXPLAT=0",
256"-DUSE_PYTORCH_METAL",
257"-DUSE_PYTORCH_QNNPACK",
258"-DUSE_XNNPACK",
259"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
260"-DAT_PARALLEL_OPENMP_FBXPLAT=0",
261"-DAT_PARALLEL_NATIVE_FBXPLAT=1",
262"-DUSE_LAPACK_FBXPLAT=0",
263"-DAT_BLAS_F2C_FBXPLAT=0",
264"-DAT_BLAS_USE_CBLAS_DOT_FBXPLAT=0",
265"-DUSE_RUY_QMATMUL",
266]
267if get_disable_per_op_profiling():
268ATEN_PREPROCESSOR_FLAGS.append("-DPYTORCH_DISABLE_PER_OP_PROFILING")
269if _get_enable_record_kernel_dtype():
270ATEN_PREPROCESSOR_FLAGS.append("-DENABLE_RECORD_KERNEL_FUNCTION_DTYPE")
271return ATEN_PREPROCESSOR_FLAGS
272
273def get_pt_preprocessor_flags():
274# read_config is not allowed outside of function in Starlark
275PT_PREPROCESSOR_FLAGS = _COMMON_PREPROCESSOR_FLAGS + [
276"-D_THP_CORE",
277"-DUSE_SCALARS",
278"-DNO_CUDNN_DESTROY_HANDLE",
279]
280
281if _is_build_mode_dev():
282PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS")
283return PT_PREPROCESSOR_FLAGS
284
285# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892
286PT_BACKEND_HEADERS = [
287"CPU",
288"CUDA",
289"CompositeExplicitAutograd",
290"CompositeExplicitAutogradNonFunctional",
291"CompositeImplicitAutograd",
292"CompositeImplicitAutogradNestedTensor",
293"Meta",
294]
295
296def get_aten_static_dispatch_backend_headers(existing_headers):
297static_backends = get_static_dispatch_backend()
298for backend in static_backends:
299if backend != "CPU":
300existing_headers["{}Functions.h".format(backend)] = ":gen_aten[{}Functions.h]".format(backend)
301existing_headers["{}Functions_inl.h".format(backend)] = ":gen_aten[{}Functions_inl.h]".format(backend)
302return existing_headers
303
304def get_aten_codegen_extra_params(backends):
305extra_params = {
306"force_schema_registration": True,
307}
308static_backends = get_static_dispatch_backend()
309if static_backends:
310extra_params["static_dispatch_backend"] = static_backends
311extra_params["enabled_backends"] = static_backends
312else:
313extra_params["enabled_backends"] = backends
314return extra_params
315
316def get_jit_codegen_params():
317return []
318
319def get_unboxing_generated_files():
320srcs = []
321if _get_enable_lightweight_dispatch():
322srcs = [
323"UnboxingFunctions.h",
324"UnboxingFunctions_0.cpp",
325"UnboxingFunctions_1.cpp",
326"UnboxingFunctions_2.cpp",
327"UnboxingFunctions_3.cpp",
328"UnboxingFunctions_4.cpp",
329"RegisterCodegenUnboxedKernels_0.cpp",
330"RegisterCodegenUnboxedKernels_1.cpp",
331"RegisterCodegenUnboxedKernels_2.cpp",
332"RegisterCodegenUnboxedKernels_3.cpp",
333"RegisterCodegenUnboxedKernels_4.cpp",
334"RegisterCodegenUnboxedKernels_5.cpp",
335"RegisterCodegenUnboxedKernels_6.cpp",
336"RegisterCodegenUnboxedKernels_7.cpp",
337"RegisterCodegenUnboxedKernels_8.cpp",
338"RegisterCodegenUnboxedKernels_9.cpp",
339]
340res = {}
341for file_name in srcs:
342res[file_name] = [file_name]
343return res
344
345def get_aten_generated_files(enabled_backends):
346# NB: RegisterMeta counts as an optionally enabled backend,
347# and is intentionally omitted from here
348src_files = [
349"RegisterBackendSelect.cpp",
350"RegisterCompositeImplicitAutograd.cpp",
351"RegisterCompositeImplicitAutogradNestedTensor.cpp",
352"RegisterCompositeExplicitAutograd.cpp",
353"RegisterCompositeExplicitAutogradNonFunctional.cpp",
354"CompositeViewCopyKernels.cpp",
355"RegisterSchema.cpp",
356"Declarations.yaml",
357"Functions.cpp",
358"Functions.h",
359"RedispatchFunctions.h",
360"NativeFunctions.h",
361"NativeMetaFunctions.h",
362"MethodOperators.h",
363"FunctionalInverses.h",
364"Operators.h",
365"Operators_0.cpp",
366"Operators_1.cpp",
367"Operators_2.cpp",
368"Operators_3.cpp",
369"Operators_4.cpp",
370"CompositeImplicitAutogradFunctions.h",
371"CompositeImplicitAutogradFunctions_inl.h",
372"CompositeImplicitAutogradNestedTensorFunctions.h",
373"CompositeImplicitAutogradNestedTensorFunctions_inl.h",
374"CompositeExplicitAutogradFunctions.h",
375"CompositeExplicitAutogradFunctions_inl.h",
376"CompositeExplicitAutogradNonFunctionalFunctions.h",
377"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
378"VmapGeneratedPlumbing.h",
379"core/ATenOpList.cpp",
380"core/TensorBody.h",
381"core/TensorMethods.cpp",
382"core/aten_interned_strings.h",
383"core/enum_tag.h",
384"torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp",
385] + get_aten_derived_type_srcs(enabled_backends)
386
387# This is tiresome. A better strategy would be to unconditionally
388# generate these files, and then only actually COMPILE them depended
389# on the generated set. C'est la vie...
390if "CPU" in enabled_backends:
391src_files.extend(aten_ufunc_generated_cpu_sources())
392src_files.extend(aten_ufunc_generated_cpu_kernel_sources())
393if "CUDA" in enabled_backends:
394# Cannot unconditionally include this, because in the Edge selective
395# build CUDA is not enabled and thus the ufunc codegen for CUDA gets
396# skipped
397src_files.extend(aten_ufunc_generated_cuda_sources())
398
399res = {}
400for file_name in src_files:
401res[file_name] = [file_name]
402return res
403
404def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends):
405return [
406":{}[{}]".format(aten_rule_name, "Register" + backend + ".cpp")
407for backend in enabled_backends
408]
409
410def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
411return [
412":{}[{}]".format(aten_rule_name, f)
413for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
414] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
415
416def get_aten_derived_type_srcs(enabled_backends):
417return [
418"Register" + derived_type + ".cpp"
419for derived_type in enabled_backends
420] + [
421derived_type + "Functions.h"
422for derived_type in enabled_backends
423if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend()
424] + [
425derived_type + "Functions_inl.h"
426for derived_type in enabled_backends
427if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend()
428]
429
430def gen_aten_files(
431name,
432extra_flags = {},
433visibility = [],
434compatible_with = [],
435apple_sdks = None):
436extra_params = []
437force_schema_registration = extra_flags.get("force_schema_registration", False)
438op_registration_allowlist = extra_flags.get("op_registration_allowlist", None)
439op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None)
440enabled_backends = extra_flags.get("enabled_backends", None)
441static_dispatch_backend = extra_flags.get("static_dispatch_backend", None)
442
443if force_schema_registration:
444extra_params.append("--force_schema_registration")
445if op_registration_allowlist != None and is_string(op_registration_allowlist):
446extra_params.append("--op_registration_whitelist")
447extra_params.append(op_registration_allowlist)
448if op_selection_yaml_path != None and is_string(op_selection_yaml_path):
449extra_params.append("--op_selection_yaml_path")
450extra_params.append(op_selection_yaml_path)
451if enabled_backends != None and is_list(enabled_backends):
452extra_params.append("--backend_whitelist")
453extra_params.extend(enabled_backends)
454if _get_enable_lightweight_dispatch():
455extra_params.append("--skip_dispatcher_op_registration")
456if static_dispatch_backend:
457extra_params.append("--static_dispatch_backend")
458extra_params.extend(static_dispatch_backend)
459backends = static_dispatch_backend
460else:
461backends = enabled_backends
462fb_xplat_genrule(
463name = name,
464default_outs = ["."],
465outs = get_aten_generated_files(backends),
466cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([
467"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
468"--install_dir $OUT",
469"--aoti_install_dir $OUT/torch/csrc/inductor/aoti_torch/generated"
470] + extra_params),
471visibility = visibility,
472compatible_with = compatible_with,
473apple_sdks = apple_sdks,
474)
475
476def gen_aten_unboxing_files(
477genrule_name,
478extra_flags = {}):
479extra_params = []
480op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None)
481op_registration_allowlist = extra_flags.get("op_registration_allowlist", None)
482if op_selection_yaml_path != None and is_string(op_selection_yaml_path):
483extra_params.append("--op_selection_yaml_path")
484extra_params.append(op_selection_yaml_path)
485if op_registration_allowlist != None and is_string(op_registration_allowlist):
486extra_params.append("--op_registration_allowlist")
487extra_params.append(op_registration_allowlist)
488
489fb_xplat_genrule(
490name = genrule_name,
491default_outs = ["."],
492outs = get_unboxing_generated_files(),
493cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([
494"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
495"--install_dir $OUT",
496] + extra_params),
497visibility = ["PUBLIC"],
498)
499
500def copy_template_registration_files(name, apple_sdks = None):
501cmd = []
502cmd_exe = []
503
504template_source_dict = get_template_source_dict()
505
506# Ideally, we would run one copy command for a single source directory along
507# with all its child directories, but it's somewhat hard to know if a directory
508# is a child of another just bu looking at the metadata (directory relative
509# path) that we currently have since 1 directory could look like a parent of
510# another and yet come from a different filegroup() rule.
511#
512for (path_prefix, file_paths) in template_source_dict.items():
513cmd.append("mkdir -p $OUT/{}".format(path_prefix))
514cmd_exe.append("md $OUT/{}".format(path_prefix))
515
516# Adding *.cpp is a workaround to prevent cp from thrown an error when it
517# encounters a directory (since -r was not specified). If files with an
518# extension other than .cpp need to be copied, then the command below
519# will not work and will need to be updated.
520#
521cmd.append("cp -f $(location {0}:templated_selective_build_srcs)/{1}/*.cpp $OUT/{1}/".format(ROOT, path_prefix))
522cmd_exe.append("robocopy /E $(location {0}:templated_selective_build_srcs)/{1} $OUT/{1}".format(ROOT, path_prefix))
523
524if NOT_OSS:
525for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST:
526maskrcnn_file = "$(location //xplat/caffe2/fb/custom_ops/maskrcnn:templated_selective_build_srcs)/" + file_path
527cmd.append("cp -f " + maskrcnn_file + " $OUT")
528cmd_exe.append("copy " + maskrcnn_file + " $OUT")
529
530cmd.append("mkdir -p $OUT/aten/src/ATen")
531cmd_exe.append("md $OUT/aten/src/ATen")
532
533# NB: CUDA is skipped here because this is selective build and CUDA is not
534# supported for selective build
535for ufunc_file in aten_ufunc_generated_all_cpu_sources("$(location " + ROOT + ":gen_aten[{}])"):
536cmd.append("cp -f " + ufunc_file + " $OUT/aten/src/ATen")
537cmd_exe.append("copy " + ufunc_file + " $OUT/aten/src/ATen")
538
539if NOT_OSS:
540pvd_batch_box_cox_file = "$(location //xplat/caffe2/fb/custom_ops/batch_box_cox:templated_selective_build_srcs)/register_batch_box_cox_ops.cpp"
541cmd.append("cp -f " + pvd_batch_box_cox_file + " $OUT")
542cmd_exe.append("copy " + pvd_batch_box_cox_file + " $OUT")
543
544fb_xplat_genrule(
545name = name,
546cmd = " && ".join(cmd),
547cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)),
548outs = get_template_registration_files_outs(IS_OSS),
549default_outs = ["."],
550apple_sdks = apple_sdks,
551)
552
553def get_feature_tracer_source_list():
554"""
555Return just the Feature specific handlers used in the model tracer.
556"""
557sources = []
558for s in torch_mobile_tracer_sources:
559if s.endswith("Tracer.cpp"):
560sources.append(s)
561return sources
562
563def pt_operator_query_codegen(
564name,
565deps = [],
566train = False,
567enforce_traced_op_list = False,
568pt_allow_forced_schema_registration = True,
569compatible_with = [],
570apple_sdks = None):
571oplist_dir_name = name + "_pt_oplist"
572
573# @lint-ignore BUCKLINT
574fb_native.genrule(
575name = oplist_dir_name,
576cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) +
577"--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " +
578("" if enforce_traced_op_list else "--allow_include_all_overloads ") +
579"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
580outs = get_gen_oplist_outs(),
581default_outs = ["."],
582compatible_with = compatible_with,
583)
584
585# Aten files
586aten_genrule = name + "_aten"
587extra_flags = {
588"enabled_backends": USED_PT_BACKENDS,
589"op_selection_yaml_path": "$(location :{}[selected_operators.yaml])".format(oplist_dir_name),
590}
591
592if train and pt_allow_forced_schema_registration:
593extra_flags["force_schema_registration"] = True
594
595unboxing_genrule = name + "_unboxing"
596if _get_enable_lightweight_dispatch():
597gen_aten_unboxing_files(
598unboxing_genrule,
599extra_flags = extra_flags,
600)
601
602static_dispatch_backend = get_static_dispatch_backend()
603if static_dispatch_backend:
604extra_flags["static_dispatch_backend"] = static_dispatch_backend
605
606gen_aten_files(
607aten_genrule,
608extra_flags = extra_flags,
609compatible_with = compatible_with,
610apple_sdks = apple_sdks,
611)
612
613# unboxing_wrappers files
614extra_params = [
615"--operators_yaml_path",
616"$(location :" + oplist_dir_name + "[selected_operators.yaml])",
617]
618unboxing_and_autograd_genrule = name + "_unboxing_and_autograd"
619gen_aten_libtorch_files(
620unboxing_and_autograd_genrule,
621extra_params,
622compatible_with,
623apple_sdks = apple_sdks,
624)
625
626# Template runtime files (prim ops, etc)
627template_registration_genrule = name + "_template_registration"
628copy_template_registration_files(template_registration_genrule, apple_sdks = apple_sdks)
629
630# Files needed for metal
631if NOT_OSS:
632metal_genrule = name + "_metal"
633copy_metal(metal_genrule, apple_sdks = apple_sdks)
634
635srcs = get_aten_selective_cpp_rules(
636aten_genrule,
637static_dispatch_backend if static_dispatch_backend else USED_PT_BACKENDS,
638) + get_template_registration_file_rules(template_registration_genrule, IS_OSS) + ([
639":{}[autograd/generated/VariableType_0.cpp]".format(unboxing_and_autograd_genrule),
640":{}[autograd/generated/VariableType_1.cpp]".format(unboxing_and_autograd_genrule),
641":{}[autograd/generated/VariableType_2.cpp]".format(unboxing_and_autograd_genrule),
642":{}[autograd/generated/VariableType_3.cpp]".format(unboxing_and_autograd_genrule),
643":{}[autograd/generated/VariableType_4.cpp]".format(unboxing_and_autograd_genrule),
644":{}[autograd/generated/ADInplaceOrViewType_0.cpp]".format(unboxing_and_autograd_genrule),
645":{}[autograd/generated/ADInplaceOrViewType_1.cpp]".format(unboxing_and_autograd_genrule),
646] if train else []) + ([
647":{}[SupportedMobileModelsRegistration.cpp]".format(oplist_dir_name),
648] if NOT_OSS else [])
649
650headers = {
651"selected_mobile_ops.h": ":{}[selected_mobile_ops.h]".format(oplist_dir_name),
652}
653
654if _get_enable_lightweight_dispatch():
655srcs.extend([
656":{}[UnboxingFunctions_0.cpp]".format(unboxing_genrule),
657":{}[UnboxingFunctions_1.cpp]".format(unboxing_genrule),
658":{}[UnboxingFunctions_2.cpp]".format(unboxing_genrule),
659":{}[UnboxingFunctions_3.cpp]".format(unboxing_genrule),
660":{}[UnboxingFunctions_4.cpp]".format(unboxing_genrule),
661":{}[RegisterCodegenUnboxedKernels_0.cpp]".format(unboxing_genrule),
662":{}[RegisterCodegenUnboxedKernels_1.cpp]".format(unboxing_genrule),
663":{}[RegisterCodegenUnboxedKernels_2.cpp]".format(unboxing_genrule),
664":{}[RegisterCodegenUnboxedKernels_3.cpp]".format(unboxing_genrule),
665":{}[RegisterCodegenUnboxedKernels_4.cpp]".format(unboxing_genrule),
666":{}[RegisterCodegenUnboxedKernels_5.cpp]".format(unboxing_genrule),
667":{}[RegisterCodegenUnboxedKernels_6.cpp]".format(unboxing_genrule),
668":{}[RegisterCodegenUnboxedKernels_7.cpp]".format(unboxing_genrule),
669":{}[RegisterCodegenUnboxedKernels_8.cpp]".format(unboxing_genrule),
670":{}[RegisterCodegenUnboxedKernels_9.cpp]".format(unboxing_genrule),
671])
672headers["UnboxingFunctions.h"] = ":{}[UnboxingFunctions.h]".format(unboxing_genrule)
673return {"headers": headers, "srcs": srcs}
674
675def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple_sdks = None):
676fb_xplat_genrule(
677name = name,
678outs = get_generate_code_bin_outs(),
679default_outs = ["."],
680bash = "mkdir -p tools && " +
681"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
682# Mobile build only needs libtorch - skip python bindings for now, except
683# for ovrsource, which needs Python bindings.
684(["--subset libtorch"] if not is_arvr_mode() else []) + [
685"--native-functions-path $(location {}:aten_src_path)/aten/src/ATen/native/native_functions.yaml".format(ROOT),
686"--tags-path $(location {}:aten_src_path)/aten/src/ATen/native/tags.yaml".format(ROOT),
687"--install_dir $OUT",
688] + extra_params,
689),
690cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " +
691"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
692# Mobile build only needs libtorch - skip python bindings for now, except
693# for ovrsource, which needs Python bindings.
694(["--subset libtorch"] if not is_arvr_mode() else []) + [
695"--native-functions-path $(location {}:aten_src_path)/aten/src/ATen/native/native_functions.yaml".format(ROOT),
696"--tags-path $(location {}:aten_src_path)/aten/src/ATen/native/tags.yaml".format(ROOT),
697"--install_dir $OUT",
698] + extra_params,
699),
700compatible_with = compatible_with,
701apple_sdks = apple_sdks,
702)
703
704def vulkan_spv_shader_library(name, spv_filegroup):
705genrule_cmd = [
706"$(exe //xplat/caffe2/tools:gen_aten_vulkan_spv_bin)",
707"--glsl-paths $(location {})".format(spv_filegroup),
708"--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()),
709"--glslc-path=$(exe //xplat/caffe2/fb/vulkan/dotslash:glslc)",
710"--tmp-dir-path=$TMP",
711]
712
713genrule_name = "gen_{}_cpp".format(name)
714fb_xplat_genrule(
715name = "gen_{}_cpp".format(name),
716outs = {
717"{}.cpp".format(name): ["spv.cpp"],
718},
719cmd = " ".join(genrule_cmd),
720default_outs = ["."],
721labels = ["uses_dotslash"],
722)
723
724fb_xplat_cxx_library(
725name = name,
726srcs = [
727":{}[{}.cpp]".format(genrule_name, name),
728],
729# Static initialization is used to register shaders to the global shader registry,
730# therefore link_whole must be True to make sure unused symbols are not discarded.
731# @lint-ignore BUCKLINT: Avoid `link_whole=True`
732link_whole = True,
733# Define a soname that can be used for dynamic loading in Java, Python, etc.
734soname = "lib{}.$(ext)".format(name),
735visibility = ["PUBLIC"],
736exported_deps = [
737"//xplat/caffe2:torch_vulkan_api",
738],
739)
740
741def copy_metal(name, apple_sdks = None):
742cmd = []
743cmd_exe = []
744metal_source_dict = get_metal_source_dict()
745
746# Copy all source files over to bring them into the per app build
747for path_prefix in sorted(metal_source_dict.keys()):
748cmd.append("mkdir -p $OUT/{}".format(path_prefix))
749cmd_exe.append("mkdir -Force $OUT/{0}".format(path_prefix))
750
751# Not every directory has a mm or cpp file so '2>/dev/null || :' are tricks to suppress the error messages and codes.
752cmd.append("cp -f {0}/{1}/*.mm $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
753cmd.append("cp -f {0}/{1}/*.cpp $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
754
755# Robocopy has a default success code of 1 which buck treats as failure so the echo masks that problem
756cmd_exe.append("(robocopy /E /NFL /NDL /NJH /NJS {0}/{1} $OUT/{1}) || ECHO robocopy failed".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
757
758# Metal custom ops currently have to be brought into selective build because they directly reference metal ops instead of
759# going through the dispatcher. There is some weird issues with the genrule and these files locations on windows though, so
760# for now we simply skip building them for windows where they very likely arent needed anyway.
761# Metal MaskRCNN custom op
762for full_path in METAL_MASKRCNN_SOURCE_LIST:
763path_prefix = paths.dirname(full_path)
764cmd.append("mkdir -p $OUT/{}".format(path_prefix))
765cmd.append("cp -f {0}/{1}/*.mm $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2/fb/metal:metal_maskrcnn_sources)", path_prefix))
766
767# Unet Metal Prepack Custom op
768unet_metal_prepack_file = "$(location //xplat/caffe2/fb/custom_ops/unet_metal_prepack:unet_metal_prepack_sources)"
769cmd.append("cp -f " + unet_metal_prepack_file + "/unet_metal_prepack.cpp" + " $OUT")
770cmd.append("cp -f " + unet_metal_prepack_file + "/unet_metal_prepack.mm" + " $OUT")
771
772fb_xplat_genrule(
773name = name,
774cmd = " && ".join(cmd),
775cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)),
776# due to an obscure bug certain custom ops werent being copied correctly on windows. ARVR also sometimes builds android targets on windows,
777# so we just exclude those targets from being copied for those platforms (They end up uncompiled anyway).
778outs = select({
779"DEFAULT": get_metal_registration_files_outs(),
780"ovr_config//os:android": get_metal_registration_files_outs_windows(),
781"ovr_config//os:windows": get_metal_registration_files_outs_windows(),
782}),
783default_outs = ["."],
784apple_sdks = apple_sdks,
785)
786
787def get_pt_operator_registry_dict(
788name,
789deps = [],
790train = False,
791labels = [],
792env = [],
793template_select = True,
794enforce_traced_op_list = False,
795pt_allow_forced_schema_registration = True,
796enable_flatbuffer = False,
797**kwargs):
798code_gen_files = pt_operator_query_codegen(
799name,
800deps = deps,
801train = train,
802enforce_traced_op_list = enforce_traced_op_list,
803pt_allow_forced_schema_registration = pt_allow_forced_schema_registration,
804compatible_with = kwargs.get("compatible_with", []),
805apple_sdks = kwargs.get("apple_sdks"),
806)
807
808return dict(
809srcs = code_gen_files["srcs"],
810linker_flags = [
811"-Wl,--no-as-needed",
812],
813# @lint-ignore BUCKLINT link_whole
814link_whole = True,
815soname = "libtorch-code-gen.$(ext)",
816header_namespace = "ATen",
817compiler_flags = get_aten_compiler_flags(),
818exported_headers = code_gen_files["headers"],
819exported_preprocessor_flags = get_aten_preprocessor_flags() + (["-DTEMPLATE_SELECTIVE_BUILD"] if template_select else []),
820headers = kwargs.pop("headers", []),
821labels = kwargs.pop("labels", []) + [
822# This library has multiple sources with the same file name
823# and does not work with Buck filegroup used in bad practices.
824# Opt out of the bad practices check with the below label.
825"bad_practices_ignore_override",
826"pt_operator_registry",
827],
828deps = [
829# need absolute path here
830ROOT + ":torch_mobile_core",
831ROOT + ":aten_cpu",
832ROOT + ":aten_metal_prepack_header",
833third_party("glog"),
834C10,
835] + ([ROOT + ":torch_mobile_train"] if train else []),
836**kwargs
837)
838
839# these targets are shared by internal and OSS BUCK
840def define_buck_targets(
841aten_default_args = dict(),
842pt_xplat_cxx_library = fb_xplat_cxx_library,
843c2_fbandroid_xplat_compiler_flags = [],
844labels = []):
845# @lint-ignore BUCKLINT
846fb_native.filegroup(
847name = "metal_build_srcs",
848srcs = glob(METAL_SOURCE_LIST),
849visibility = [
850"PUBLIC",
851],
852)
853
854# @lint-ignore BUCKLINT
855fb_native.filegroup(
856name = "templated_selective_build_srcs",
857# NB: no glob here, there are generated targets in this list!
858srcs = glob(TEMPLATE_SOURCE_LIST) + aten_ufunc_generated_all_cpu_sources(":gen_aten[{}]"),
859visibility = [
860"PUBLIC",
861],
862)
863
864fb_xplat_cxx_library(
865name = "aten_header",
866header_namespace = "",
867exported_headers = subdir_glob([
868# ATen Core
869("aten/src", "ATen/core/**/*.h"),
870("aten/src", "ATen/ops/*.h"),
871# ATen Base
872("aten/src", "ATen/*.h"),
873("aten/src", "ATen/cpu/**/*.h"),
874("aten/src", "ATen/detail/*.h"),
875("aten/src", "ATen/functorch/**/*.h"),
876("aten/src", "ATen/quantized/*.h"),
877("aten/src", "ATen/vulkan/*.h"),
878("aten/src", "ATen/metal/*.h"),
879("aten/src", "ATen/nnapi/*.h"),
880# ATen Native
881("aten/src", "ATen/native/*.h"),
882("aten/src", "ATen/native/ao_sparse/quantized/cpu/*.h"),
883("aten/src", "ATen/native/cpu/**/*.h"),
884("aten/src", "ATen/native/sparse/*.h"),
885("aten/src", "ATen/native/nested/*.h"),
886("aten/src", "ATen/native/quantized/*.h"),
887("aten/src", "ATen/native/quantized/cpu/*.h"),
888("aten/src", "ATen/native/transformers/*.h"),
889("aten/src", "ATen/native/ufunc/*.h"),
890("aten/src", "ATen/native/utils/*.h"),
891("aten/src", "ATen/native/vulkan/ops/*.h"),
892("aten/src", "ATen/native/xnnpack/*.h"),
893("aten/src", "ATen/mps/*.h"),
894("aten/src", "ATen/native/mps/*.h"),
895# Remove the following after modifying codegen for mobile.
896("aten/src", "ATen/mkl/*.h"),
897("aten/src", "ATen/native/mkl/*.h"),
898("aten/src", "ATen/native/mkldnn/*.h"),
899]),
900visibility = ["PUBLIC"],
901labels = labels,
902)
903
904fb_xplat_cxx_library(
905name = "aten_vulkan_header",
906header_namespace = "",
907exported_headers = subdir_glob([
908("aten/src", "ATen/native/vulkan/*.h"),
909("aten/src", "ATen/native/vulkan/ops/*.h"),
910("aten/src", "ATen/vulkan/*.h"),
911]),
912labels = labels,
913visibility = ["PUBLIC"],
914)
915
916fb_xplat_cxx_library(
917name = "jit_core_headers",
918header_namespace = "",
919exported_headers = subdir_glob([("", x) for x in jit_core_headers]),
920labels = labels,
921)
922
923fb_xplat_cxx_library(
924name = "torch_headers",
925header_namespace = "",
926exported_headers = subdir_glob(
927[
928("torch/csrc/api/include", "torch/**/*.h"),
929("", "torch/csrc/**/*.h"),
930("", "torch/script.h"),
931("", "torch/library.h"),
932("", "torch/custom_class.h"),
933("", "torch/custom_class_detail.h"),
934# Add again due to namespace difference from aten_header.
935("", "aten/src/ATen/*.h"),
936("", "aten/src/ATen/functorch/**/*.h"),
937("", "aten/src/ATen/quantized/*.h"),
938],
939exclude = [
940# Don't need on mobile.
941"torch/csrc/Exceptions.h",
942"torch/csrc/python_headers.h",
943"torch/csrc/jit/serialization/mobile_bytecode_generated.h",
944],
945),
946labels = labels,
947visibility = ["PUBLIC"],
948deps = [
949":generated-version-header",
950],
951)
952
953fb_xplat_cxx_library(
954name = "aten_test_header",
955header_namespace = "",
956exported_headers = subdir_glob([
957("aten/src", "ATen/test/*.h"),
958]),
959)
960
961fb_xplat_cxx_library(
962name = "aten_metal_prepack_header",
963header_namespace = "",
964exported_headers = subdir_glob([
965("aten/src", "ATen/native/metal/MetalPrepackOpContext.h"),
966]),
967labels = labels,
968visibility = ["PUBLIC"],
969)
970
971fb_xplat_cxx_library(
972name = "torch_mobile_headers",
973header_namespace = "",
974exported_headers = subdir_glob(
975[
976("", "torch/csrc/jit/mobile/*.h"),
977],
978),
979labels = labels,
980visibility = ["PUBLIC"],
981)
982
983fb_xplat_cxx_library(
984name = "generated_aten_config_header",
985header_namespace = "ATen",
986exported_headers = {
987"Config.h": ":generate_aten_config[Config.h]",
988},
989labels = labels,
990)
991
992fb_xplat_cxx_library(
993name = "generated-autograd-headers",
994header_namespace = "torch/csrc/autograd/generated",
995exported_headers = {
996"Functions.h": ":gen_aten_libtorch[autograd/generated/Functions.h]",
997"VariableType.h": ":gen_aten_libtorch[autograd/generated/VariableType.h]",
998"variable_factories.h": ":gen_aten_libtorch[autograd/generated/variable_factories.h]",
999"ViewFuncs.h": ":gen_aten_libtorch[autograd/generated/ViewFuncs.h]",
1000# Don't build python bindings on mobile.
1001#"python_functions.h",
1002},
1003labels = labels,
1004visibility = ["PUBLIC"],
1005)
1006
1007fb_xplat_cxx_library(
1008name = "generated-version-header",
1009header_namespace = "torch",
1010exported_headers = {
1011"version.h": ":generate-version-header[version.h]",
1012},
1013labels = labels,
1014)
1015
1016# @lint-ignore BUCKLINT
1017fb_native.genrule(
1018name = "generate-version-header",
1019srcs = [
1020"torch/csrc/api/include/torch/version.h.in",
1021"version.txt",
1022],
1023cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([
1024"--template-path",
1025"torch/csrc/api/include/torch/version.h.in",
1026"--version-path",
1027"version.txt",
1028"--output-path",
1029"$OUT/version.h",
1030]),
1031outs = {
1032"version.h": ["version.h"],
1033},
1034default_outs = ["."],
1035)
1036
1037# @lint-ignore BUCKLINT
1038fb_native.filegroup(
1039name = "aten_src_path",
1040srcs = [
1041"aten/src/ATen/native/native_functions.yaml",
1042"aten/src/ATen/native/tags.yaml",
1043] + glob(["aten/src/ATen/templates/*"]),
1044visibility = [
1045"PUBLIC",
1046],
1047)
1048
1049fb_xplat_cxx_library(
1050name = "common_core",
1051srcs = [
1052"caffe2/core/common.cc",
1053],
1054apple_sdks = (IOS, MACOSX, APPLETVOS),
1055compiler_flags = get_pt_compiler_flags(),
1056labels = labels,
1057# @lint-ignore BUCKLINT link_whole
1058link_whole = True,
1059visibility = ["PUBLIC"],
1060windows_preferred_linkage = "static" if is_arvr_mode() else None,
1061deps = [
1062":caffe2_headers",
1063C10,
1064],
1065)
1066
1067# @lint-ignore BUCKLINT
1068fb_native.genrule(
1069name = "generate_aten_config",
1070srcs = [
1071"aten/src/ATen/Config.h.in",
1072],
1073cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([
1074"--install_dir",
1075"$OUT",
1076"--input-file",
1077"aten/src/ATen/Config.h.in",
1078"--output-file",
1079"Config.h",
1080"--replace",
1081"@AT_MKLDNN_ENABLED@",
1082"ATEN_MKLDNN_ENABLED_FBXPLAT",
1083"--replace",
1084"@AT_MKLDNN_ACL_ENABLED@",
1085"ATEN_MKLDNN_ACL_ENABLED_FBXPLAT",
1086"--replace",
1087"@AT_MKL_ENABLED@",
1088"ATEN_MKL_ENABLED_FBXPLAT",
1089"--replace",
1090"@AT_MKL_SEQUENTIAL@",
1091"ATEN_MKL_SEQUENTIAL_FBXPLAT",
1092"--replace",
1093"@AT_POCKETFFT_ENABLED@",
1094"1",
1095"--replace",
1096"@AT_NNPACK_ENABLED@",
1097"ATEN_NNPACK_ENABLED_FBXPLAT",
1098"--replace",
1099"@CAFFE2_STATIC_LINK_CUDA_INT@",
1100"CAFFE2_STATIC_LINK_CUDA_FBXPLAT",
1101"--replace",
1102"@AT_BUILD_WITH_BLAS@",
1103"USE_BLAS_FBXPLAT",
1104"--replace",
1105"@AT_PARALLEL_OPENMP@",
1106"AT_PARALLEL_OPENMP_FBXPLAT",
1107"--replace",
1108"@AT_PARALLEL_NATIVE@",
1109"AT_PARALLEL_NATIVE_FBXPLAT",
1110"--replace",
1111"@AT_BUILD_WITH_LAPACK@",
1112"USE_LAPACK_FBXPLAT",
1113"--replace",
1114"@AT_BLAS_F2C@",
1115"AT_BLAS_F2C_FBXPLAT",
1116"--replace",
1117"@AT_BLAS_USE_CBLAS_DOT@",
1118"AT_BLAS_USE_CBLAS_DOT_FBXPLAT",
1119]),
1120outs = {
1121"Config.h": ["Config.h"],
1122},
1123default_outs = ["."],
1124)
1125
1126gen_aten_files(
1127name = "gen_aten",
1128extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS),
1129visibility = ["PUBLIC"],
1130)
1131
1132gen_aten_libtorch_files(name = "gen_aten_libtorch")
1133
1134gen_aten_libtorch_files(
1135name = "gen_aten_libtorch_lite",
1136extra_params = get_jit_codegen_params(),
1137)
1138
1139fb_xplat_cxx_library(
1140name = "generated_aten_headers_cpu",
1141header_namespace = "ATen",
1142exported_headers = get_aten_static_dispatch_backend_headers({
1143"CPUFunctions.h": ":gen_aten[CPUFunctions.h]",
1144"CPUFunctions_inl.h": ":gen_aten[CPUFunctions_inl.h]",
1145"CompositeExplicitAutogradFunctions.h": ":gen_aten[CompositeExplicitAutogradFunctions.h]",
1146"CompositeExplicitAutogradFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradFunctions_inl.h]",
1147"CompositeExplicitAutogradNonFunctionalFunctions.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions.h]",
1148"CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]",
1149"CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]",
1150"CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]",
1151"CompositeImplicitAutogradNestedTensorFunctions.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions.h]",
1152"CompositeImplicitAutogradNestedTensorFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions_inl.h]",
1153"FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]",
1154"Functions.h": ":gen_aten[Functions.h]",
1155"MethodOperators.h": ":gen_aten[MethodOperators.h]",
1156"NativeFunctions.h": ":gen_aten[NativeFunctions.h]",
1157"NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]",
1158"Operators.h": ":gen_aten[Operators.h]",
1159"RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
1160"core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
1161"core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
1162"core/enum_tag.h": ":gen_aten[core/enum_tag.h]",
1163}),
1164labels = labels,
1165)
1166
1167fb_xplat_cxx_library(
1168name = "torch_mobile_observer",
1169srcs = [
1170"torch/csrc/jit/mobile/observer.cpp",
1171] + ([] if IS_OSS else ["torch/fb/observers/MobileObserverUtil.cpp"]),
1172compiler_flags = ["-fexceptions"],
1173header_namespace = "",
1174exported_headers = subdir_glob(
1175[
1176("", "torch/csrc/jit/mobile/observer.h"),
1177] + ([] if IS_OSS else [
1178("", "torch/fb/observers/ObserverUtil.h"),
1179("", "torch/fb/observers/MobileObserverUtil.h"),
1180]),
1181),
1182fbobjc_compiler_flags = [
1183"-Wno-missing-prototypes",
1184],
1185labels = labels,
1186visibility = ["PUBLIC"],
1187deps = [
1188C10,
1189],
1190)
1191
1192# Base library shared by lite-interpreter and full-jit.
1193pt_xplat_cxx_library(
1194name = "torch_common",
1195srcs = core_sources_common,
1196compiler_flags = get_pt_compiler_flags(),
1197exported_preprocessor_flags = get_pt_preprocessor_flags(),
1198# @lint-ignore BUCKLINT link_whole
1199link_whole = True,
1200visibility = ["PUBLIC"],
1201deps = [
1202":aten_cpu",
1203":generated-autograd-headers",
1204":torch_headers",
1205C10,
1206third_party("libkineto_headers"),
1207],
1208)
1209
1210pt_xplat_cxx_library(
1211name = "torch_mobile_deserialize_common",
1212srcs = [
1213"torch/csrc/jit/mobile/parse_bytecode.cpp",
1214"torch/csrc/jit/mobile/parse_operators.cpp",
1215"torch/csrc/jit/mobile/upgrader_mobile.cpp",
1216"torch/csrc/jit/serialization/import_read.cpp",
1217"torch/csrc/jit/serialization/unpickler.cpp",
1218],
1219header_namespace = "",
1220exported_headers = [
1221"torch/csrc/jit/serialization/import_read.h",
1222"torch/csrc/jit/serialization/unpickler.h",
1223],
1224compiler_flags = get_pt_compiler_flags(),
1225exported_preprocessor_flags = get_pt_preprocessor_flags(),
1226extra_flags = {
1227"fbandroid_compiler_flags": ["-frtti"],
1228},
1229# torch_mobile_deserialize brings in sources neccessary to read a module
1230# which depends on mobile module definition
1231# link_whole is enable so that all symbols neccessary for mobile module are compiled
1232# instead of only symbols used while loading; this prevents symbol
1233# found definied in runtime
1234# @lint-ignore BUCKLINT link_whole
1235link_whole = True,
1236linker_flags = ["-Wl,--no-as-needed"],
1237visibility = ["PUBLIC"],
1238exported_deps = [
1239":aten_cpu",
1240":caffe2_headers",
1241":caffe2_serialize",
1242":torch_common",
1243":torch_headers",
1244":torch_mobile_headers",
1245":torch_mobile_module",
1246":torch_mobile_observer",
1247C10,
1248],
1249)
1250
1251pt_xplat_cxx_library(
1252name = "torch_mobile_module",
1253srcs = [
1254"torch/csrc/jit/mobile/function.cpp",
1255"torch/csrc/jit/mobile/interpreter.cpp",
1256"torch/csrc/jit/mobile/module.cpp",
1257],
1258header_namespace = "",
1259exported_headers = [
1260],
1261compiler_flags = get_pt_compiler_flags(),
1262exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1263extra_flags = {
1264"fbandroid_compiler_flags": ["-frtti"],
1265},
1266# @lint-ignore BUCKLINT link_whole
1267link_whole = True,
1268linker_flags = [
1269"-Wl,--no-as-needed",
1270],
1271visibility = ["PUBLIC"],
1272exported_deps = [
1273":aten_cpu",
1274":caffe2_headers",
1275":torch_common",
1276":torch_headers",
1277":torch_mobile_headers",
1278":torch_mobile_observer",
1279C10,
1280],
1281)
1282
1283pt_xplat_cxx_library(
1284name = "torch_mobile_debug_symbolication",
1285srcs = [
1286# included in aten_cpu "torch/csrc/jit/frontend/source_range.cpp",
1287"torch/csrc/jit/ir/scope.cpp",
1288"torch/csrc/jit/mobile/debug_info.cpp",
1289"torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
1290"torch/csrc/jit/serialization/source_range_serialization.cpp",
1291"torch/csrc/jit/serialization/pickle.cpp",
1292# pickler.cpp doesn't seem to be needed.
1293# "torch/csrc/jit/serialization/pickler.cpp",
1294# included in core_sources_common "torch/csrc/jit/serialization/unpickler.cpp",
1295],
1296compiler_flags = get_pt_compiler_flags(),
1297exported_preprocessor_flags = get_pt_preprocessor_flags(),
1298header_namespace = "",
1299# @lint-ignore BUCKLINT link_whole
1300link_whole = True,
1301linker_flags = [
1302"-Wl,--no-as-needed",
1303],
1304visibility = ["PUBLIC"],
1305deps = [
1306":torch_mobile_deserialize",
1307],
1308exported_deps = [
1309":torch_common",
1310],
1311)
1312
1313pt_xplat_cxx_library(
1314name = "torch_model_tracer",
1315srcs = [
1316"torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp",
1317] + get_feature_tracer_source_list(),
1318header_namespace = "",
1319compiler_flags = get_pt_compiler_flags(),
1320exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1321# @lint-ignore BUCKLINT link_whole
1322link_whole = True,
1323linker_flags = [
1324"-Wl,--no-as-needed",
1325],
1326visibility = ["PUBLIC"],
1327deps = [
1328":generated-autograd-headers",
1329":torch_mobile_deserialize",
1330":torch_mobile_headers",
1331":torch_mobile_observer",
1332] + ([] if IS_OSS else ["//xplat/folly:molly"]),
1333exported_deps = [
1334":aten_cpu",
1335":torch_common",
1336] + ([] if IS_OSS else [
1337"//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1338"//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1339]),
1340)
1341
1342pt_xplat_cxx_library(
1343name = "torch_mobile_deserialize",
1344srcs = [
1345"torch/csrc/jit/mobile/import.cpp",
1346"torch/csrc/jit/mobile/flatbuffer_loader.cpp",
1347],
1348compiler_flags = get_pt_compiler_flags(),
1349exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1350header_namespace = "",
1351exported_headers = [
1352"torch/csrc/jit/mobile/import.h",
1353"torch/csrc/jit/mobile/flatbuffer_loader.h",
1354],
1355# torch_mobile_deserialize brings in sources neccessary to read a module
1356# which depends on mobile module definition
1357# link_whole is enable so that all symbols neccessary for mobile module are compiled
1358# instead of only symbols used while loading; this prevents symbol
1359# found definied in runtime
1360# @lint-ignore BUCKLINT link_whole
1361link_whole = True,
1362linker_flags = [
1363"-Wl,--no-as-needed",
1364],
1365visibility = ["PUBLIC"],
1366exported_deps = [
1367":aten_cpu",
1368":caffe2_headers",
1369":caffe2_serialize",
1370":torch_common",
1371":torch_headers",
1372":torch_mobile_headers",
1373":torch_mobile_module",
1374":torch_mobile_observer",
1375":torch_mobile_deserialize_common",
1376":mobile_bytecode",
1377C10,
1378],
1379)
1380
1381pt_xplat_cxx_library(
1382name = "torch_mobile_core",
1383srcs = [],
1384header_namespace = "",
1385exported_headers = [],
1386compiler_flags = get_pt_compiler_flags(),
1387exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1388# torch_mobile_core brings in sources neccessary to read and run a module
1389# link_whole is enabled so that all symbols linked
1390# operators, registerations and other few symbols are need in runtime
1391# @lint-ignore BUCKLINT link_whole
1392link_whole = True,
1393linker_flags = [
1394"-Wl,--no-as-needed",
1395],
1396visibility = ["PUBLIC"],
1397deps = [
1398":generated-autograd-headers",
1399":torch_mobile_headers",
1400":torch_mobile_observer",
1401],
1402exported_deps = [
1403":aten_cpu",
1404":torch_common",
1405":torch_mobile_deserialize",
1406":torch_supported_mobile_models",
1407],
1408)
1409
1410pt_xplat_cxx_library(
1411name = "torch_mobile_core_pickle_and_flatbuffer",
1412compiler_flags = get_pt_compiler_flags(),
1413exported_preprocessor_flags = get_pt_preprocessor_flags(),
1414visibility = ["PUBLIC"],
1415exported_deps = [
1416":flatbuffers_mobile",
1417":torch_mobile_core",
1418],
1419)
1420
1421pt_xplat_cxx_library(
1422name = "torch_cpp_cpu",
1423srcs = torch_cpp_srcs,
1424headers = native.glob(["torch/csrc/api/include/**/*.h"]) + ["torch/script.h"],
1425compiler_flags = get_pt_compiler_flags(),
1426exported_preprocessor_flags = get_pt_preprocessor_flags(),
1427visibility = ["PUBLIC"],
1428exported_deps = [
1429":torch",
1430":torch_mobile_deserialize_common", # for torch/csrc/api/src/serialize/input-archive.cpp
1431],
1432)
1433
1434pt_xplat_cxx_library(
1435name = "torch_core",
1436srcs = core_sources_full_mobile_no_backend_interface_xplat,
1437compiler_flags = get_pt_compiler_flags(),
1438exported_preprocessor_flags = get_pt_preprocessor_flags(),
1439visibility = [
1440"//xplat/caffe2/android/...",
1441"//xplat/caffe2/fb/...",
1442"//xplat/caffe2/fb/model_tracer/...",
1443],
1444deps = [
1445":aten_cpu",
1446":backend_interface_lib",
1447":generated-autograd-headers",
1448":torch_headers",
1449":torch_mobile_deserialize",
1450third_party("glog"),
1451third_party("rt"),
1452C10,
1453] + ([] if IS_OSS else [
1454"//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1455"//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1456]),
1457exported_deps = [
1458":torch_common",
1459":torch_mobile_train",
1460],
1461)
1462
1463pt_xplat_cxx_library(
1464name = "torch_train",
1465srcs = [
1466"torch/csrc/api/src/data/samplers/random.cpp",
1467"torch/csrc/api/src/data/samplers/sequential.cpp",
1468"torch/csrc/api/src/optim/optimizer.cpp",
1469"torch/csrc/api/src/optim/serialize.cpp",
1470"torch/csrc/api/src/optim/sgd.cpp",
1471"torch/csrc/api/src/serialize/input-archive.cpp",
1472"torch/csrc/api/src/serialize/output-archive.cpp",
1473"torch/csrc/jit/api/module_save.cpp",
1474],
1475compiler_flags = get_pt_compiler_flags(),
1476exported_preprocessor_flags = get_pt_preprocessor_flags(),
1477visibility = ["PUBLIC"],
1478deps = [
1479":aten_cpu",
1480":torch_headers",
1481":torch",
1482":torch_core",
1483":torch_mobile_deserialize",
1484":torch_mobile_train",
1485":jit_module_saving",
1486C10,
1487],
1488)
1489
1490pt_xplat_cxx_library(
1491name = "torch_mobile_train",
1492srcs = core_trainer_sources + [
1493"torch/csrc/autograd/VariableTypeManual.cpp",
1494"torch/csrc/autograd/FunctionsManual.cpp",
1495"torch/csrc/api/src/data/datasets/mnist.cpp",
1496"torch/csrc/jit/mobile/quantization.cpp",
1497"torch/csrc/jit/mobile/train/export_data.cpp",
1498"torch/csrc/jit/mobile/train/optim/sgd.cpp",
1499"torch/csrc/jit/mobile/train/random.cpp",
1500"torch/csrc/jit/mobile/train/sequential.cpp",
1501":gen_aten_libtorch[autograd/generated/Functions.cpp]",
1502":gen_aten_libtorch[autograd/generated/ViewFuncs.cpp]",
1503],
1504compiler_flags = get_pt_compiler_flags(),
1505exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
1506# torch_mobile_train brings in sources neccessary to read and run a mobile
1507# and save and load mobile params along with autograd
1508# link_whole is enabled so that all symbols linked
1509# operators, registerations and autograd related symbols are need in runtime
1510# @lint-ignore BUCKLINT link_whole
1511link_whole = True,
1512visibility = ["PUBLIC"],
1513deps = [
1514":aten_cpu",
1515":generated-autograd-headers",
1516":torch_headers",
1517":torch_mobile_deserialize",
1518":flatbuffers_serializer_mobile",
1519C10,
1520],
1521)
1522
1523pt_xplat_cxx_library(
1524name = "torch",
1525srcs = [
1526"torch/csrc/jit/runtime/register_c10_ops.cpp",
1527"torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp",
1528],
1529compiler_flags = get_pt_compiler_flags(),
1530exported_preprocessor_flags = get_pt_preprocessor_flags(),
1531# torch brings in all sources neccessary to read and run a mobile module/jit module
1532# link_whole is enabled so that all symbols linked
1533# operators, registerations and other few symbols are need in runtime
1534# @lint-ignore BUCKLINT link_whole
1535link_whole = True,
1536visibility = ["PUBLIC"],
1537deps = [
1538# This is to have autograd profiler available
1539# in xplat/caffe2:torch which some builds are using
1540# notable xplate/facegen:testsAndroid
1541":torch_headers",
1542":torch_kineto_profiling",
1543],
1544exported_deps = [
1545":aten_cpu",
1546":torch_core",
1547C10,
1548],
1549)
1550
1551pt_xplat_cxx_library(
1552name = "torch_mobile_train_import_data",
1553srcs = [
1554"torch/csrc/jit/mobile/import_data.cpp",
1555],
1556compiler_flags = get_pt_compiler_flags(),
1557exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
1558# torch_mobile_train_import_data brings in sources neccessary to read a mobile module
1559# link_whole is enabled so that all symbols linked
1560# operators other few symbols are need in runtime
1561# @lint-ignore BUCKLINT link_whole
1562link_whole = True,
1563visibility = ["PUBLIC"],
1564deps = [
1565":torch_headers",
1566":torch_mobile_observer",
1567":torch_mobile_core",
1568":torch_mobile_train",
1569],
1570)
1571
1572fb_xplat_cxx_library(
1573name = "torch_mobile_compatibility",
1574srcs = [
1575# These .cpp brought in through core_sources_common
1576# "torch/csrc/jit/mobile/compatibility/runtime_compatibility.cpp",
1577# "torch/csrc/jit/serialization/unpickler.cpp",
1578"torch/csrc/jit/mobile/compatibility/model_compatibility.cpp",
1579],
1580header_namespace = "",
1581exported_headers = [
1582"torch/csrc/jit/mobile/compatibility/backport.h",
1583"torch/csrc/jit/mobile/compatibility/backport_manager.h",
1584"torch/csrc/jit/mobile/compatibility/model_compatibility.h",
1585"torch/csrc/jit/mobile/compatibility/runtime_compatibility.h",
1586],
1587compiler_flags = [
1588"-fexceptions",
1589"-frtti",
1590"-Wno-deprecated-declarations",
1591"-Wno-global-constructors",
1592],
1593labels = labels,
1594visibility = ["PUBLIC"],
1595deps = [
1596":torch_mobile_deserialize",
1597],
1598)
1599
1600pt_xplat_cxx_library(
1601name = "jit_module_saving",
1602srcs = [
1603"torch/csrc/jit/api/module_save.cpp",
1604"torch/csrc/jit/serialization/export_bytecode.cpp",
1605"torch/csrc/jit/serialization/export_module.cpp",
1606],
1607compiler_flags = get_pt_compiler_flags(),
1608exported_preprocessor_flags = get_pt_preprocessor_flags() +
1609(["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1610exported_headers = [
1611"torch/csrc/jit/serialization/export.h",
1612],
1613visibility = ["PUBLIC"],
1614deps = [
1615":torch",
1616":torch_mobile_core",
1617":flatbuffers_serializer_mobile",
1618],
1619)
1620
1621pt_xplat_cxx_library(
1622name = "torch_mobile_model_tracer",
1623srcs = [
1624"torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp",
1625"torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp",
1626],
1627headers = [
1628"torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h",
1629"torch/csrc/jit/mobile/model_tracer/TensorUtils.h",
1630],
1631header_namespace = "",
1632exported_headers = [
1633"torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h",
1634],
1635compiler_flags = get_pt_compiler_flags(),
1636exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1637# torch_mobile_model_tracer brings in sources neccessary to read and run a jit module
1638# and trace the ops
1639# link_whole is enabled so that all symbols linked
1640# operators, registerations and other few symbols are need in runtime
1641# @lint-ignore BUCKLINT link_whole
1642link_whole = True,
1643linker_flags = [
1644"-Wl,--no-as-needed",
1645],
1646visibility = ["PUBLIC"],
1647deps = [
1648":caffe2_serialize",
1649":generated-autograd-headers",
1650":torch_mobile_headers",
1651":torch_mobile_observer",
1652":torch_mobile_core",
1653] + ([] if IS_OSS else ["//xplat/folly:molly"]),
1654exported_deps = [
1655":aten_cpu",
1656":torch_common",
1657] + ([] if IS_OSS else [
1658"//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1659"//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1660"//xplat/caffe2/fb/custom_ops/sparsenn:sparsenn-all",
1661]),
1662)
1663
1664#TODO(qihan) delete
1665pt_xplat_cxx_library(
1666name = "torch_mobile_core_flatbuffer",
1667srcs = [],
1668header_namespace = "",
1669exported_headers = [],
1670compiler_flags = get_pt_compiler_flags(),
1671exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1672# @lint-ignore BUCKLINT link_whole
1673link_whole = True,
1674linker_flags = [
1675"-Wl,--no-as-needed",
1676],
1677visibility = ["PUBLIC"],
1678deps = [
1679":generated-autograd-headers",
1680":torch_mobile_headers",
1681":torch_mobile_observer",
1682],
1683exported_deps = [
1684":aten_cpu",
1685":torch_common",
1686],
1687)
1688
1689fb_xplat_cxx_library(
1690name = "backend_interface_lib",
1691srcs = [
1692"torch/csrc/jit/backends/backend_debug_info.cpp",
1693"torch/csrc/jit/backends/backend_interface.cpp",
1694],
1695compiler_flags = get_pt_compiler_flags(),
1696fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
1697# @lint-ignore BUCKLINT link_whole
1698link_whole = True,
1699linker_flags = [
1700"-Wl,--no-as-needed",
1701],
1702visibility = ["PUBLIC"],
1703exported_deps = [
1704":aten_cpu",
1705":torch_common",
1706],
1707)
1708
1709pt_xplat_cxx_library(
1710name = "torch_kineto_profiling",
1711srcs = libtorch_profiler_sources,
1712compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1713exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1714"-DUSE_KINETO",
1715# Need this otherwise USE_KINETO is undefed
1716# for mobile
1717"-DEDGE_PROFILER_USE_KINETO",
1718],
1719# @lint-ignore BUCKLINT link_whole
1720link_whole = True,
1721linker_flags = [
1722"-Wl,--no-as-needed",
1723],
1724visibility = ["PUBLIC"],
1725deps = [
1726third_party("glog"),
1727third_party("kineto"),
1728],
1729exported_deps = [
1730":aten_cpu",
1731":torch_common",
1732],
1733)
1734
1735pt_xplat_cxx_library(
1736name = "torch_edge_profiling",
1737srcs = ["torch/csrc/jit/mobile/profiler_edge.cpp"],
1738compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1739exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1740"-DUSE_KINETO",
1741"-DEDGE_PROFILER_USE_KINETO",
1742],
1743# @lint-ignore BUCKLINT link_whole
1744link_whole = True,
1745linker_flags = [
1746"-Wl,--no-as-needed",
1747],
1748visibility = ["PUBLIC"],
1749exported_deps = [
1750":torch_common",
1751":torch_kineto_profiling",
1752":torch_mobile_core",
1753],
1754)
1755
1756fb_xplat_genrule(
1757name = "mobile_bytecode_header",
1758srcs = [
1759"torch/csrc/jit/serialization/mobile_bytecode.fbs",
1760],
1761outs = {
1762"mobile_bytecode_generated_fbsource.h": ["mobile_bytecode_generated.h"],
1763},
1764cmd = "$(exe {})".format(third_party("flatc")) +
1765" --cpp --gen-mutable --scoped-enums -o ${OUT} ${SRCS}",
1766default_outs = ["."],
1767visibility = [
1768"{}:mobile_bytecode".format(ROOT),
1769],
1770)
1771
1772# Users of this target will need to add third_party("flatbuffers-api") as a
1773# dep.
1774fb_xplat_cxx_library(
1775name = "mobile_bytecode",
1776header_namespace = "",
1777exported_headers = {
1778("torch/csrc/jit/serialization/mobile_bytecode_generated.h" if IS_OSS else "torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h"): ":mobile_bytecode_header[mobile_bytecode_generated_fbsource.h]",
1779},
1780# Avoid leaking implementation details by only exposing this header to
1781# the internals of the loader/serializer layer.
1782visibility = [
1783"{}:flatbuffer_loader".format(ROOT),
1784"{}:flatbuffers_serializer_mobile".format(ROOT),
1785],
1786exported_deps = [
1787third_party("flatbuffers-api"),
1788],
1789)
1790
1791fb_xplat_cxx_library(
1792name = "flatbuffers_serializer_mobile",
1793srcs = ["torch/csrc/jit/serialization/flatbuffer_serializer.cpp"],
1794exported_headers = [
1795"torch/csrc/jit/serialization/flatbuffer_serializer.h",
1796],
1797compiler_flags = [
1798"-g0",
1799"-O3",
1800"-fexceptions",
1801"-frtti",
1802"-Wno-deprecated-declarations",
1803] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1804visibility = ["PUBLIC"],
1805deps = [
1806":mobile_bytecode",
1807":torch_mobile_module",
1808C10,
1809],
1810exported_deps = [
1811":torch_mobile_deserialize",
1812":mobile_bytecode",
1813],
1814)
1815
1816# TODO (qihan) delete
1817pt_xplat_cxx_library(
1818name = "flatbuffer_loader",
1819srcs = [
1820],
1821exported_headers = [
1822"torch/csrc/jit/mobile/flatbuffer_loader.h",
1823],
1824compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1825exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1826"-DUSE_KINETO",
1827# Need this otherwise USE_KINETO is undefed
1828# for mobile
1829"-DEDGE_PROFILER_USE_KINETO",
1830] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1831extra_flags = {
1832"fbandroid_compiler_flags": ["-frtti"],
1833},
1834# torch_mobile_deserialize brings in sources neccessary to read a module
1835# which depends on mobile module definition
1836# link_whole is enable so that all symbols neccessary for mobile module are compiled
1837# instead of only symbols used while loading; this prevents symbol
1838# found definied in runtime
1839# @lint-ignore BUCKLINT link_whole
1840link_whole = True,
1841linker_flags = [
1842"-Wl,--no-as-needed",
1843],
1844visibility = ["PUBLIC"],
1845deps = [
1846":mobile_bytecode",
1847],
1848exported_deps = [
1849C10,
1850],
1851)
1852
1853# TODO(qihan) delete
1854fb_xplat_cxx_library(
1855name = "flatbuffers_serializer_jit",
1856compiler_flags = [
1857"-g0",
1858"-O3",
1859"-fexceptions",
1860"-frtti",
1861"-Wno-deprecated-declarations",
1862],
1863headers = [
1864"torch/csrc/jit/serialization/flatbuffer_serializer_jit.h",
1865],
1866srcs = [
1867"torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp",
1868],
1869linker_flags = [
1870"-Wl,--no-as-needed",
1871],
1872visibility = ["PUBLIC"],
1873deps = [
1874":flatbuffer_loader",
1875":flatbuffers_serializer_mobile",
1876":torch_core",
1877":torch_mobile_module",
1878C10,
1879],
1880)
1881
1882fb_xplat_cxx_library(
1883name = "flatbuffers_jit",
1884visibility = ["PUBLIC"],
1885exported_deps = [
1886":flatbuffer_loader",
1887":flatbuffers_serializer_mobile",
1888":flatbuffers_serializer_jit",
1889],
1890)
1891
1892fb_xplat_cxx_library(
1893name = "flatbuffers_mobile",
1894visibility = ["PUBLIC"],
1895exported_deps = [
1896":flatbuffer_loader",
1897":flatbuffers_serializer_mobile",
1898":torch_mobile_train",
1899],
1900)
1901
1902pt_xplat_cxx_library(
1903name = "torch_supported_mobile_models",
1904srcs = [
1905"fb/supported_mobile_models/SupportedMobileModels.cpp",
1906] if NOT_OSS else [],
1907header_namespace = "",
1908exported_headers = ["fb/supported_mobile_models/SupportedMobileModels.h"] if NOT_OSS else [],
1909compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1910exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1911# @lint-ignore BUCKLINT link_whole
1912link_whole = True,
1913linker_flags = [
1914"-Wl,--no-as-needed",
1915],
1916visibility = ["PUBLIC"],
1917deps = [],
1918exported_deps = [
1919"//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1920"//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1921] if NOT_OSS else [],
1922)
1923
1924fb_xplat_cxx_library(
1925name = "static_runtime",
1926srcs = [
1927"torch/csrc/jit/runtime/static/fusion.cpp",
1928"torch/csrc/jit/runtime/static/generated_ops.cpp",
1929"torch/csrc/jit/runtime/static/impl.cpp",
1930"torch/csrc/jit/runtime/static/memory_planner.cpp",
1931"torch/csrc/jit/runtime/static/native_ops.cpp",
1932"torch/csrc/jit/runtime/static/ops.cpp",
1933"torch/csrc/jit/runtime/static/passes.cpp",
1934"torch/csrc/jit/runtime/static/te_wrapper.cpp",
1935],
1936compiler_flags = ["-fexceptions"],
1937labels = labels,
1938# @lint-ignore BUCKLINT link_whole
1939link_whole = True,
1940visibility = ["PUBLIC"],
1941windows_preferred_linkage = "static" if is_arvr_mode() else None,
1942deps = [
1943":aten_cpu",
1944":caffe2_headers",
1945":torch_core",
1946C10,
1947],
1948)
1949
1950# aten_cpu and aten_native_cpu
1951for name, srcs in [
1952("aten_cpu", jit_core_sources + aten_cpu_source_list + [
1953# Generated
1954":gen_aten[Functions.cpp]",
1955":gen_aten[Operators_0.cpp]",
1956":gen_aten[Operators_1.cpp]",
1957":gen_aten[Operators_2.cpp]",
1958":gen_aten[Operators_3.cpp]",
1959":gen_aten[Operators_4.cpp]",
1960":gen_aten[core/ATenOpList.cpp]",
1961":gen_aten[core/TensorMethods.cpp]",
1962# Needed by ATen/native/EmbeddingBag.cpp
1963"caffe2/perfkernels/embedding_lookup_idx.cc",
1964]),
1965("aten_native_cpu", aten_native_source_list),
1966]:
1967fb_xplat_cxx_library(
1968name = name,
1969srcs = srcs,
1970header_namespace = "",
1971# @lint-ignore BUCKLINT
1972link_whole = True,
1973visibility = ["PUBLIC"],
1974deps = [
1975third_party("omp"),
1976third_party("cpuinfo"),
1977third_party("glog"),
1978third_party("XNNPACK"),
1979third_party("pocketfft"),
1980] + select({
1981"DEFAULT": [],
1982"ovr_config//runtime:fbcode-arm64": [
1983third_party("sleef_arm"),
1984],
1985}),
1986compiler_flags = get_aten_compiler_flags(),
1987exported_preprocessor_flags = get_aten_preprocessor_flags(),
1988exported_deps = [
1989":aten_header",
1990":caffe2_headers",
1991":common_core",
1992":generated_aten_config_header",
1993":generated_aten_headers_cpu",
1994":jit_core_headers",
1995":pthreadpool",
1996third_party("fmt"),
1997third_party("ruy"),
1998C10,
1999ROOT_PATH + "aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack",
2000],
2001labels = labels,
2002**aten_default_args
2003)
2004
2005fb_xplat_cxx_library(
2006name = "lean_runtime_with_flatbuffer",
2007srcs = [
2008"aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp",
2009"torch/csrc/jit/mobile/import.cpp",
2010"torch/csrc/jit/mobile/module.cpp",
2011"torch/csrc/jit/mobile/observer.cpp",
2012"torch/csrc/jit/serialization/import_read.cpp",
2013],
2014header_namespace = "",
2015exported_headers = subdir_glob(
2016[
2017("", "torch/csrc/jit/ir/*.h"),
2018("", "caffe2/serialize/*.h"),
2019("", "caffe2/utils/*.h"),
2020("", "caffe2/core/*.h"),
2021("", "torch/csrc/*.h"),
2022("", "torch/csrc/api/include/torch/*.h"),
2023("", "torch/csrc/autograd/*.h"),
2024("", "torch/csrc/autograd/*/*.h"),
2025("", "torch/csrc/jit/api/*.h"),
2026("", "torch/csrc/jit/backends/*.h"),
2027("", "torch/csrc/jit/mobile/*.h"),
2028("", "torch/csrc/jit/runtime/*.h"),
2029("", "torch/csrc/jit/passes/*.h"),
2030("", "torch/csrc/jit/python/*.h"),
2031("", "torch/csrc/jit/frontend/*.h"),
2032("", "torch/csrc/jit/serialization/*.h"),
2033("", "torch/csrc/profiler/**/*.h"),
2034("", "torch/csrc/utils/*.h"),
2035("", "aten/src/ATen/quantized/*.h"),
2036] + ([
2037("third_party/miniz-2.1.0", "*.h"),
2038] if NOT_OSS else []),
2039exclude = [
2040"torch/csrc/jit/serialization/mobile_bytecode_generated.h",
2041],
2042),
2043compiler_flags = get_pt_compiler_flags() + select({
2044"DEFAULT": [],
2045"ovr_config//os:xtensa-xos": [
2046"-fdata-sections",
2047"-ffunction-sections",
2048],
2049}),
2050exported_preprocessor_flags = get_pt_preprocessor_flags() + [
2051"-DMIN_EDGE_RUNTIME",
2052],
2053linker_flags = [
2054"-Wl,--no-as-needed",
2055] + select({
2056"DEFAULT": [],
2057"ovr_config//os:macos": [
2058"-dead_strip",
2059],
2060"ovr_config//os:xtensa-xos": [
2061"-Wl,--gc-sections",
2062],
2063}),
2064visibility = ["PUBLIC"],
2065exported_deps = [
2066":lean_runtime_with_tensor",
2067],
2068)
2069
2070pt_xplat_cxx_library(
2071name = "lean_runtime_with_tensor",
2072srcs = [
2073"aten/src/ATen/Context.cpp",
2074"aten/src/ATen/EmptyTensor.cpp",
2075"aten/src/ATen/Utils.cpp",
2076"aten/src/ATen/detail/CUDAHooksInterface.cpp",
2077"aten/src/ATen/detail/PrivateUse1HooksInterface.cpp",
2078":gen_aten[Operators_0.cpp]",
2079":gen_aten[Operators_1.cpp]",
2080":gen_aten[Operators_2.cpp]",
2081":gen_aten[Operators_3.cpp]",
2082":gen_aten[Operators_4.cpp]",
2083":gen_aten[core/TensorMethods.cpp]",
2084],
2085header_namespace = "",
2086exported_headers = [
2087"torch/csrc/jit/runtime/custom_operator.h",
2088":gen_aten[core/TensorBody.h]",
2089],
2090compiler_flags = get_pt_compiler_flags() + select({
2091"DEFAULT": [],
2092"ovr_config//os:xtensa-xos": [
2093"-fdata-sections",
2094"-ffunction-sections",
2095],
2096}),
2097exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2098"DEFAULT": [],
2099"ovr_config//os:xtensa-xos": [
2100"-Dthread_local=",
2101],
2102}),
2103# @lint-ignore BUCKLINT link_whole
2104link_whole = True,
2105linker_flags = [
2106"-Wl,--no-as-needed",
2107],
2108visibility = ["PUBLIC"],
2109exported_deps = [
2110":generated_aten_config_header",
2111":lean_runtime_with_op",
2112":aten_header",
2113C10,
2114] + (["//xplat/caffe2/fb/embedded:experimental"] if NOT_OSS else []),
2115)
2116
2117pt_xplat_cxx_library(
2118name = "lean_runtime_with_op",
2119srcs = [
2120"aten/src/ATen/SequenceNumber.cpp",
2121"aten/src/ATen/core/boxing/KernelFunction.cpp",
2122"aten/src/ATen/core/custom_class.cpp",
2123"aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp",
2124"aten/src/ATen/core/dispatch/Dispatcher.cpp",
2125"aten/src/ATen/core/dispatch/ObservedOperators.cpp",
2126"aten/src/ATen/core/dispatch/OperatorEntry.cpp",
2127"aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp",
2128"aten/src/ATen/core/interned_strings.cpp",
2129"aten/src/ATen/core/library.cpp",
2130"aten/src/ATen/core/op_registration/infer_schema.cpp",
2131"aten/src/ATen/core/function_schema.cpp",
2132"aten/src/ATen/core/operator_name.cpp",
2133"aten/src/ATen/core/register_symbols.cpp",
2134"aten/src/ATen/core/tensor_type.cpp",
2135"aten/src/ATen/core/union_type.cpp",
2136"aten/src/ATen/record_function.cpp",
2137"torch/csrc/jit/frontend/edit_distance.cpp",
2138"torch/csrc/jit/frontend/error_report.cpp",
2139"torch/csrc/jit/frontend/function_schema_parser.cpp",
2140"torch/csrc/jit/frontend/lexer.cpp",
2141"torch/csrc/jit/frontend/schema_type_parser.cpp",
2142"torch/csrc/jit/frontend/source_range.cpp",
2143"torch/csrc/jit/frontend/strtod.cpp",
2144"torch/csrc/jit/mobile/parse_operators.cpp",
2145"torch/csrc/jit/mobile/prim_ops_registery.cpp",
2146"torch/csrc/jit/runtime/operator.cpp",
2147"torch/csrc/jit/runtime/slice_indices_adjust.cpp",
2148],
2149header_namespace = "",
2150exported_headers = [
2151"torch/csrc/jit/frontend/edit_distance.h",
2152"torch/csrc/jit/runtime/slice_indices_adjust.h",
2153],
2154compiler_flags = get_pt_compiler_flags() + select({
2155"DEFAULT": [],
2156"ovr_config//os:xtensa-xos": [
2157"-fdata-sections",
2158"-ffunction-sections",
2159],
2160}),
2161exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2162"DEFAULT": [],
2163"ovr_config//os:xtensa-xos": [
2164"-Dthread_local=",
2165],
2166}),
2167# @lint-ignore BUCKLINT link_whole
2168link_whole = True,
2169linker_flags = [
2170"-Wl,--no-as-needed",
2171],
2172visibility = ["PUBLIC"],
2173exported_deps = [
2174":min_runtime_lib",
2175C10,
2176],
2177)
2178
2179pt_xplat_cxx_library(
2180name = "min_runtime_lib",
2181srcs = [
2182"aten/src/ATen/ScalarOps.cpp",
2183"aten/src/ATen/core/Dict.cpp",
2184"aten/src/ATen/core/List.cpp",
2185"aten/src/ATen/core/class_type.cpp",
2186"aten/src/ATen/core/dynamic_type.cpp",
2187"aten/src/ATen/core/ivalue.cpp",
2188"aten/src/ATen/core/type.cpp",
2189"aten/src/ATen/core/type_factory.cpp",
2190"aten/src/ATen/native/prim_native_functions.cpp",
2191"torch/csrc/jit/mobile/function.cpp",
2192"torch/csrc/jit/mobile/interpreter.cpp",
2193"torch/csrc/jit/mobile/parse_bytecode.cpp",
2194"torch/csrc/jit/mobile/promoted_prim_ops.cpp",
2195"torch/csrc/jit/mobile/register_ops_common_utils.cpp",
2196"torch/csrc/jit/mobile/type_parser.cpp",
2197"torch/csrc/jit/runtime/instruction.cpp",
2198"torch/csrc/jit/runtime/jit_exception.cpp",
2199"torch/csrc/jit/runtime/vararg_functions.cpp",
2200],
2201header_namespace = "",
2202exported_headers = [
2203"caffe2/serialize/versions.h",
2204"torch/csrc/jit/backends/backend_exception.h",
2205"torch/csrc/jit/mobile/register_ops_common_utils.h",
2206"torch/csrc/jit/runtime/instruction.h",
2207"torch/csrc/jit/runtime/jit_exception.h",
2208"torch/csrc/jit/runtime/operator.h",
2209"torch/csrc/jit/runtime/operator_options.h",
2210"torch/csrc/jit/runtime/vararg_functions.h",
2211"torch/csrc/jit/serialization/import_export_constants.h",
2212"torch/csrc/jit/serialization/import_export_functions.h",
2213],
2214compiler_flags = get_pt_compiler_flags() + select({
2215"DEFAULT": [],
2216"ovr_config//os:xtensa-xos": [
2217"-fexceptions",
2218"-fdata-sections",
2219"-ffunction-sections",
2220],
2221}),
2222exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2223"DEFAULT": [],
2224"ovr_config//os:xtensa-xos": [
2225"-Dthread_local=",
2226],
2227}),
2228# @lint-ignore BUCKLINT link_whole
2229link_whole = True,
2230linker_flags = [
2231"-Wl,--no-as-needed",
2232],
2233visibility = ["PUBLIC"],
2234exported_deps = [
2235":aten_header",
2236":generated_aten_headers_cpu",
2237":jit_core_headers",
2238":torch_mobile_headers",
2239C10,
2240],
2241)
2242