pytorch
/
buckbuild.bzl
2209 строк · 81.8 Кб
1# NOTE: This file is shared by internal and OSS BUCK build.
2# These load paths point to different files in internal and OSS environment
3
4load("@bazel_skylib//lib:paths.bzl", "paths")
5load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
6load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
7load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
8load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
9load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
10load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX")
11load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
12load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build")
13load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build")
14load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags")
15load(
16":build_variables.bzl",
17"aten_cpu_source_list",
18"aten_native_source_list",
19"core_sources_common",
20"core_sources_full_mobile_no_backend_interface_xplat",
21"core_trainer_sources",
22"jit_core_headers",
23"jit_core_sources",
24"libtorch_profiler_sources",
25"torch_cpp_srcs",
26"torch_mobile_tracer_sources",
27)
28load(
29":pt_ops.bzl",
30"USED_PT_BACKENDS",
31)
32load(
33":pt_template_srcs.bzl",
34"METAL_MASKRCNN_SOURCE_LIST",
35"METAL_SOURCE_LIST",
36"TEMPLATE_MASKRCNN_SOURCE_LIST",
37"TEMPLATE_SOURCE_LIST",
38"aten_ufunc_generated_all_cpu_sources",
39"get_gen_oplist_outs",
40"get_generate_code_bin_outs",
41"get_metal_registration_files_outs",
42"get_metal_registration_files_outs_windows",
43"get_metal_source_dict",
44"get_template_registration_file_rules",
45"get_template_registration_files_outs",
46"get_template_source_dict",
47)
48load(
49":ufunc_defs.bzl",
50"aten_ufunc_generated_cpu_kernel_sources",
51"aten_ufunc_generated_cpu_sources",
52"aten_ufunc_generated_cuda_sources",
53)
54
55def read_bool(section, field, default, required = True):
56# @lint-ignore BUCKRESTRICTEDSYNTAX
57val = read_config(section, field)
58if val != None:
59if val in ["true", "True", "1"]:
60return True
61elif val in ["false", "False", "0"]:
62return False
63else:
64fail(
65"`{}:{}`: must be one of (0, 1, true, false, True, False), but was {}".format(section, field, val),
66)
67elif default != None:
68return default
69elif not required:
70return None
71else:
72fail("`{}:{}`: no value set".format(section, field))
73
74def _is_build_mode_dev():
75if is_production_build_android():
76# Android Prod builds
77return False
78if is_production_build_ios():
79# iOS Prod builds
80return False
81
82return True
83
84def _get_enable_lightweight_dispatch():
85return read_bool("pt", "enable_lightweight_dispatch", False)
86
87def _get_enable_record_kernel_dtype():
88return read_bool("pt", "enable_record_kernel_dtype", False)
89
90def get_enable_mobile_dispatch_keys_trimming():
91return read_bool("pt", "enable_mobile_dispatch_keys_trimming", False)
92
93def get_disable_per_op_profiling():
94return read_bool("pt", "disable_per_op_profiling", True)
95
96def get_strip_error_messages():
97if IS_OSS:
98return True # always strip in OSS CI to expose potential issues
99return read_bool("pt", "strip_error_messages", not _is_build_mode_dev())
100
101def get_disable_warn():
102return read_bool("pt", "disable_warn", False)
103
104def get_enable_eager_symbolication():
105return read_bool("pt", "enable_eager_symbolication", default = False, required = False)
106
107def get_static_dispatch_backend():
108static_dispatch_backend = native.read_config("pt", "static_dispatch_backend", None)
109if static_dispatch_backend == None:
110return []
111return static_dispatch_backend.split(";")
112
113def get_glsl_image_format():
114if read_config("pt", "vulkan_full_precision", "0") == "0":
115return "rgba16f"
116return "rgba32f"
117
118def get_glsl_paths():
119paths = [
120"//xplat/caffe2:aten_vulkan_glsl_src_path",
121"aten/src/ATen/native/vulkan/glsl",
122] + [
123p
124for p in read_config("gen_vulkan_spv", "additional_glsl_paths", "").split(" ")
125if p
126]
127
128if len(paths) % 2 != 0:
129fail(
130"gen_vulkan_spv.additional_glsl_paths must contain an even number of elements",
131)
132
133return " ".join(
134[
135"$(location {})/{}".format(
136paths[i],
137paths[i + 1],
138)
139for i in range(
1400,
141len(paths),
1422,
143)
144],
145)
146
147# @lint-ignore BUCKRESTRICTEDSYNTAX
148IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
149
150NOT_OSS = not IS_OSS
151
152# for targets in caffe2 root path
153ROOT = "//" if IS_OSS else "//xplat/caffe2"
154
155# for targets in subfolders
156ROOT_PATH = "//" if IS_OSS else "//xplat/caffe2/"
157
158C10 = "//c10:c10" if IS_OSS else "//xplat/caffe2/c10:c10"
159
160# a dictionary maps third party library name to fbsource and oss target
161THIRD_PARTY_LIBS = {
162"FP16": ["//xplat/third-party/FP16:FP16", "//third_party:FP16"],
163"FXdiv": ["//xplat/third-party/FXdiv:FXdiv", "//third_party:FXdiv"],
164"XNNPACK": ["//xplat/third-party/XNNPACK:XNNPACK", "//third_party:XNNPACK"],
165"clog": ["//xplat/third-party/clog:clog", "//third_party:clog"],
166"cpuinfo": ["//third-party/cpuinfo:cpuinfo", "//third_party:cpuinfo"],
167"flatbuffers-api": ["//third-party/flatbuffers/fbsource_namespace:flatbuffers-api", "//third_party:flatbuffers-api"],
168"flatc": ["//third-party/flatbuffers/fbsource_namespace:flatc", "//third_party:flatc"],
169"fmt": ["//third-party/fmt:fmt", "//third_party:fmt"],
170"glog": ["//third-party/glog:glog", "//third_party:glog"],
171"gmock": ["//third-party/googletest:gmock_main", "//third_party:gmock"],
172"gtest": ["//third-party/googletest:gtest_main", "//third_party:gtest"],
173"kineto": ["//xplat/kineto/libkineto:libkineto", "//third_party:libkineto"],
174"libkineto_headers": ["//xplat/kineto/libkineto:libkineto_headers", "//third_party:libkineto_headers"],
175"omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
176"pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
177"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
178"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
179"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
180"pyyaml": ["//third-party/pyyaml:pyyaml", "//third_party:pyyaml"],
181"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
182"ruy": ["//third-party/ruy:ruy_xplat_lib", "//third_party:ruy_lib"],
183"sleef_arm": ["//third-party/sleef:sleef_arm", "//third_party:sleef_arm"],
184"typing-extensions": ["//third-party/typing-extensions:typing-extensions", "//third_party:typing-extensions"],
185}
186
187def third_party(name):
188if name not in THIRD_PARTY_LIBS:
189fail("Cannot find third party library " + name + ", please register it in THIRD_PARTY_LIBS first!")
190return THIRD_PARTY_LIBS[name][1] if IS_OSS else THIRD_PARTY_LIBS[name][0]
191
192def get_pt_compiler_flags():
193return select({
194"DEFAULT": _PT_COMPILER_FLAGS,
195"ovr_config//compiler:cl": windows_convert_gcc_clang_flags(_PT_COMPILER_FLAGS),
196})
197
198_PT_COMPILER_FLAGS = [
199"-fexceptions",
200"-frtti",
201"-Os",
202"-Wno-unknown-pragmas",
203"-Wno-write-strings",
204"-Wno-unused-variable",
205"-Wno-unused-function",
206"-Wno-deprecated-declarations",
207"-Wno-shadow",
208"-Wno-global-constructors",
209"-Wno-missing-prototypes",
210]
211
212ATEN_COMPILER_FLAGS = [
213"-fexceptions",
214"-frtti",
215"-fPIC",
216"-Os",
217"-Wno-absolute-value",
218"-Wno-deprecated-declarations",
219"-Wno-macro-redefined",
220"-Wno-tautological-constant-out-of-range-compare",
221"-Wno-unknown-pragmas",
222"-Wno-unknown-warning-option",
223"-Wno-unused-function",
224"-Wno-unused-variable",
225"-Wno-pass-failed",
226"-Wno-shadow",
227]
228
229def get_aten_compiler_flags():
230return ATEN_COMPILER_FLAGS
231
232_COMMON_PREPROCESSOR_FLAGS = [
233"-DC10_MOBILE",
234"-DNO_EXPORT",
235] + (
236["-DC10_MOBILE_TRIM_DISPATCH_KEYS"] if get_enable_mobile_dispatch_keys_trimming() else []
237) + (
238["-DSTRIP_ERROR_MESSAGES"] if get_strip_error_messages() else []
239) + (
240["-DDISABLE_WARN"] if get_disable_warn() else []
241)
242
243def get_aten_preprocessor_flags():
244# read_config is not allowed outside of function in Starlark
245ATEN_PREPROCESSOR_FLAGS = _COMMON_PREPROCESSOR_FLAGS + [
246"-DCPU_CAPABILITY_DEFAULT",
247"-DCPU_CAPABILITY=DEFAULT",
248"-DCAFFE2_USE_LITE_PROTO",
249"-DATEN_CUDNN_ENABLED_FBXPLAT=0",
250"-DATEN_MKLDNN_ENABLED_FBXPLAT=0",
251"-DATEN_MKLDNN_ACL_ENABLED_FBXPLAT=0",
252"-DATEN_NNPACK_ENABLED_FBXPLAT=0",
253"-DATEN_MKL_ENABLED_FBXPLAT=0",
254"-DATEN_MKL_SEQUENTIAL_FBXPLAT=0",
255"-DUSE_PYTORCH_METAL",
256"-DUSE_PYTORCH_QNNPACK",
257"-DUSE_XNNPACK",
258"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
259"-DAT_PARALLEL_OPENMP_FBXPLAT=0",
260"-DAT_PARALLEL_NATIVE_FBXPLAT=1",
261"-DAT_PARALLEL_NATIVE_TBB_FBXPLAT=0",
262"-DUSE_LAPACK_FBXPLAT=0",
263"-DAT_BLAS_F2C_FBXPLAT=0",
264"-DAT_BLAS_USE_CBLAS_DOT_FBXPLAT=0",
265"-DUSE_RUY_QMATMUL",
266]
267if get_disable_per_op_profiling():
268ATEN_PREPROCESSOR_FLAGS.append("-DPYTORCH_DISABLE_PER_OP_PROFILING")
269if _get_enable_record_kernel_dtype():
270ATEN_PREPROCESSOR_FLAGS.append("-DENABLE_RECORD_KERNEL_FUNCTION_DTYPE")
271return ATEN_PREPROCESSOR_FLAGS
272
273def get_pt_preprocessor_flags():
274# read_config is not allowed outside of function in Starlark
275PT_PREPROCESSOR_FLAGS = _COMMON_PREPROCESSOR_FLAGS + [
276"-D_THP_CORE",
277"-DUSE_SCALARS",
278"-DNO_CUDNN_DESTROY_HANDLE",
279"-DBUILD_CAFFE2",
280]
281
282if _is_build_mode_dev():
283PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS")
284return PT_PREPROCESSOR_FLAGS
285
286# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892
287PT_BACKEND_HEADERS = [
288"CPU",
289"CUDA",
290"CompositeExplicitAutograd",
291"CompositeExplicitAutogradNonFunctional",
292"CompositeImplicitAutograd",
293"CompositeImplicitAutogradNestedTensor",
294"Meta",
295]
296
297def get_aten_static_dispatch_backend_headers(existing_headers):
298static_backends = get_static_dispatch_backend()
299for backend in static_backends:
300if backend != "CPU":
301existing_headers["{}Functions.h".format(backend)] = ":gen_aten[{}Functions.h]".format(backend)
302existing_headers["{}Functions_inl.h".format(backend)] = ":gen_aten[{}Functions_inl.h]".format(backend)
303return existing_headers
304
305def get_aten_codegen_extra_params(backends):
306extra_params = {
307"force_schema_registration": True,
308}
309static_backends = get_static_dispatch_backend()
310if static_backends:
311extra_params["static_dispatch_backend"] = static_backends
312extra_params["enabled_backends"] = static_backends
313else:
314extra_params["enabled_backends"] = backends
315return extra_params
316
317def get_jit_codegen_params():
318return []
319
320def get_unboxing_generated_files():
321srcs = []
322if _get_enable_lightweight_dispatch():
323srcs = [
324"UnboxingFunctions.h",
325"UnboxingFunctions_0.cpp",
326"UnboxingFunctions_1.cpp",
327"UnboxingFunctions_2.cpp",
328"UnboxingFunctions_3.cpp",
329"UnboxingFunctions_4.cpp",
330"RegisterCodegenUnboxedKernels_0.cpp",
331"RegisterCodegenUnboxedKernels_1.cpp",
332"RegisterCodegenUnboxedKernels_2.cpp",
333"RegisterCodegenUnboxedKernels_3.cpp",
334"RegisterCodegenUnboxedKernels_4.cpp",
335"RegisterCodegenUnboxedKernels_5.cpp",
336"RegisterCodegenUnboxedKernels_6.cpp",
337"RegisterCodegenUnboxedKernels_7.cpp",
338"RegisterCodegenUnboxedKernels_8.cpp",
339"RegisterCodegenUnboxedKernels_9.cpp",
340]
341res = {}
342for file_name in srcs:
343res[file_name] = [file_name]
344return res
345
346def get_aten_generated_files(enabled_backends):
347# NB: RegisterMeta counts as an optionally enabled backend,
348# and is intentionally omitted from here
349src_files = [
350"RegisterBackendSelect.cpp",
351"RegisterCompositeImplicitAutograd.cpp",
352"RegisterCompositeImplicitAutogradNestedTensor.cpp",
353"RegisterCompositeExplicitAutograd.cpp",
354"RegisterCompositeExplicitAutogradNonFunctional.cpp",
355"CompositeViewCopyKernels.cpp",
356"RegisterSchema.cpp",
357"Declarations.yaml",
358"Functions.cpp",
359"Functions.h",
360"RedispatchFunctions.h",
361"NativeFunctions.h",
362"NativeMetaFunctions.h",
363"MethodOperators.h",
364"FunctionalInverses.h",
365"Operators.h",
366"Operators_0.cpp",
367"Operators_1.cpp",
368"Operators_2.cpp",
369"Operators_3.cpp",
370"Operators_4.cpp",
371"CompositeImplicitAutogradFunctions.h",
372"CompositeImplicitAutogradFunctions_inl.h",
373"CompositeImplicitAutogradNestedTensorFunctions.h",
374"CompositeImplicitAutogradNestedTensorFunctions_inl.h",
375"CompositeExplicitAutogradFunctions.h",
376"CompositeExplicitAutogradFunctions_inl.h",
377"CompositeExplicitAutogradNonFunctionalFunctions.h",
378"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
379"VmapGeneratedPlumbing.h",
380"core/ATenOpList.cpp",
381"core/TensorBody.h",
382"core/TensorMethods.cpp",
383"core/aten_interned_strings.h",
384"core/enum_tag.h",
385] + get_aten_derived_type_srcs(enabled_backends)
386
387# This is tiresome. A better strategy would be to unconditionally
388# generate these files, and then only actually COMPILE them depended
389# on the generated set. C'est la vie...
390if "CPU" in enabled_backends:
391src_files.extend(aten_ufunc_generated_cpu_sources())
392src_files.extend(aten_ufunc_generated_cpu_kernel_sources())
393if "CUDA" in enabled_backends:
394# Cannot unconditionally include this, because in the Edge selective
395# build CUDA is not enabled and thus the ufunc codegen for CUDA gets
396# skipped
397src_files.extend(aten_ufunc_generated_cuda_sources())
398
399res = {}
400for file_name in src_files:
401res[file_name] = [file_name]
402return res
403
404def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends):
405return [
406":{}[{}]".format(aten_rule_name, "Register" + backend + ".cpp")
407for backend in enabled_backends
408]
409
410def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
411return [
412":{}[{}]".format(aten_rule_name, f)
413for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
414] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
415
416def get_aten_derived_type_srcs(enabled_backends):
417return [
418"Register" + derived_type + ".cpp"
419for derived_type in enabled_backends
420] + [
421derived_type + "Functions.h"
422for derived_type in enabled_backends
423if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend()
424] + [
425derived_type + "Functions_inl.h"
426for derived_type in enabled_backends
427if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend()
428]
429
430def gen_aten_files(
431name,
432extra_flags = {},
433visibility = [],
434compatible_with = [],
435apple_sdks = None):
436extra_params = []
437force_schema_registration = extra_flags.get("force_schema_registration", False)
438op_registration_allowlist = extra_flags.get("op_registration_allowlist", None)
439op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None)
440enabled_backends = extra_flags.get("enabled_backends", None)
441static_dispatch_backend = extra_flags.get("static_dispatch_backend", None)
442
443if force_schema_registration:
444extra_params.append("--force_schema_registration")
445if op_registration_allowlist != None and is_string(op_registration_allowlist):
446extra_params.append("--op_registration_whitelist")
447extra_params.append(op_registration_allowlist)
448if op_selection_yaml_path != None and is_string(op_selection_yaml_path):
449extra_params.append("--op_selection_yaml_path")
450extra_params.append(op_selection_yaml_path)
451if enabled_backends != None and is_list(enabled_backends):
452extra_params.append("--backend_whitelist")
453extra_params.extend(enabled_backends)
454if _get_enable_lightweight_dispatch():
455extra_params.append("--skip_dispatcher_op_registration")
456if static_dispatch_backend:
457extra_params.append("--static_dispatch_backend")
458extra_params.extend(static_dispatch_backend)
459backends = static_dispatch_backend
460else:
461backends = enabled_backends
462fb_xplat_genrule(
463name = name,
464default_outs = ["."],
465outs = get_aten_generated_files(backends),
466cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([
467"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
468"--install_dir $OUT",
469] + extra_params),
470visibility = visibility,
471compatible_with = compatible_with,
472apple_sdks = apple_sdks,
473)
474
475def gen_aten_unboxing_files(
476genrule_name,
477extra_flags = {}):
478extra_params = []
479op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None)
480op_registration_allowlist = extra_flags.get("op_registration_allowlist", None)
481if op_selection_yaml_path != None and is_string(op_selection_yaml_path):
482extra_params.append("--op_selection_yaml_path")
483extra_params.append(op_selection_yaml_path)
484if op_registration_allowlist != None and is_string(op_registration_allowlist):
485extra_params.append("--op_registration_allowlist")
486extra_params.append(op_registration_allowlist)
487
488fb_xplat_genrule(
489name = genrule_name,
490default_outs = ["."],
491outs = get_unboxing_generated_files(),
492cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([
493"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
494"--install_dir $OUT",
495] + extra_params),
496visibility = ["PUBLIC"],
497)
498
499def copy_template_registration_files(name, apple_sdks = None):
500cmd = []
501cmd_exe = []
502
503template_source_dict = get_template_source_dict()
504
505# Ideally, we would run one copy command for a single source directory along
506# with all its child directories, but it's somewhat hard to know if a directory
507# is a child of another just bu looking at the metadata (directory relative
508# path) that we currently have since 1 directory could look like a parent of
509# another and yet come from a different filegroup() rule.
510#
511for (path_prefix, file_paths) in template_source_dict.items():
512cmd.append("mkdir -p $OUT/{}".format(path_prefix))
513cmd_exe.append("md $OUT/{}".format(path_prefix))
514
515# Adding *.cpp is a workaround to prevent cp from thrown an error when it
516# encounters a directory (since -r was not specified). If files with an
517# extension other than .cpp need to be copied, then the command below
518# will not work and will need to be updated.
519#
520cmd.append("cp -f $(location {0}:templated_selective_build_srcs)/{1}/*.cpp $OUT/{1}/".format(ROOT, path_prefix))
521cmd_exe.append("robocopy /E $(location {0}:templated_selective_build_srcs)/{1} $OUT/{1}".format(ROOT, path_prefix))
522
523if NOT_OSS:
524for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST:
525maskrcnn_file = "$(location //xplat/caffe2/fb/custom_ops/maskrcnn:templated_selective_build_srcs)/" + file_path
526cmd.append("cp -f " + maskrcnn_file + " $OUT")
527cmd_exe.append("copy " + maskrcnn_file + " $OUT")
528
529cmd.append("mkdir -p $OUT/aten/src/ATen")
530cmd_exe.append("md $OUT/aten/src/ATen")
531
532# NB: CUDA is skipped here because this is selective build and CUDA is not
533# supported for selective build
534for ufunc_file in aten_ufunc_generated_all_cpu_sources("$(location " + ROOT + ":gen_aten[{}])"):
535cmd.append("cp -f " + ufunc_file + " $OUT/aten/src/ATen")
536cmd_exe.append("copy " + ufunc_file + " $OUT/aten/src/ATen")
537
538if NOT_OSS:
539pvd_batch_box_cox_file = "$(location //xplat/caffe2/fb/custom_ops/batch_box_cox:templated_selective_build_srcs)/register_batch_box_cox_ops.cpp"
540cmd.append("cp -f " + pvd_batch_box_cox_file + " $OUT")
541cmd_exe.append("copy " + pvd_batch_box_cox_file + " $OUT")
542
543fb_xplat_genrule(
544name = name,
545cmd = " && ".join(cmd),
546cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)),
547outs = get_template_registration_files_outs(IS_OSS),
548default_outs = ["."],
549apple_sdks = apple_sdks,
550)
551
552def get_feature_tracer_source_list():
553"""
554Return just the Feature specific handlers used in the model tracer.
555"""
556sources = []
557for s in torch_mobile_tracer_sources:
558if s.endswith("Tracer.cpp"):
559sources.append(s)
560return sources
561
562def pt_operator_query_codegen(
563name,
564deps = [],
565train = False,
566enforce_traced_op_list = False,
567pt_allow_forced_schema_registration = True,
568compatible_with = [],
569apple_sdks = None):
570oplist_dir_name = name + "_pt_oplist"
571
572# @lint-ignore BUCKLINT
573fb_native.genrule(
574name = oplist_dir_name,
575cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) +
576"--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " +
577("" if enforce_traced_op_list else "--allow_include_all_overloads ") +
578"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
579outs = get_gen_oplist_outs(),
580default_outs = ["."],
581compatible_with = compatible_with,
582)
583
584# Aten files
585aten_genrule = name + "_aten"
586extra_flags = {
587"enabled_backends": USED_PT_BACKENDS,
588"op_selection_yaml_path": "$(location :{}[selected_operators.yaml])".format(oplist_dir_name),
589}
590
591if train and pt_allow_forced_schema_registration:
592extra_flags["force_schema_registration"] = True
593
594unboxing_genrule = name + "_unboxing"
595if _get_enable_lightweight_dispatch():
596gen_aten_unboxing_files(
597unboxing_genrule,
598extra_flags = extra_flags,
599)
600
601static_dispatch_backend = get_static_dispatch_backend()
602if static_dispatch_backend:
603extra_flags["static_dispatch_backend"] = static_dispatch_backend
604
605gen_aten_files(
606aten_genrule,
607extra_flags = extra_flags,
608compatible_with = compatible_with,
609apple_sdks = apple_sdks,
610)
611
612# unboxing_wrappers files
613extra_params = [
614"--operators_yaml_path",
615"$(location :" + oplist_dir_name + "[selected_operators.yaml])",
616]
617unboxing_and_autograd_genrule = name + "_unboxing_and_autograd"
618gen_aten_libtorch_files(
619unboxing_and_autograd_genrule,
620extra_params,
621compatible_with,
622apple_sdks = apple_sdks,
623)
624
625# Template runtime files (prim ops, etc)
626template_registration_genrule = name + "_template_registration"
627copy_template_registration_files(template_registration_genrule, apple_sdks = apple_sdks)
628
629# Files needed for metal
630if NOT_OSS:
631metal_genrule = name + "_metal"
632copy_metal(metal_genrule, apple_sdks = apple_sdks)
633
634srcs = get_aten_selective_cpp_rules(
635aten_genrule,
636static_dispatch_backend if static_dispatch_backend else USED_PT_BACKENDS,
637) + get_template_registration_file_rules(template_registration_genrule, IS_OSS) + ([
638":{}[autograd/generated/VariableType_0.cpp]".format(unboxing_and_autograd_genrule),
639":{}[autograd/generated/VariableType_1.cpp]".format(unboxing_and_autograd_genrule),
640":{}[autograd/generated/VariableType_2.cpp]".format(unboxing_and_autograd_genrule),
641":{}[autograd/generated/VariableType_3.cpp]".format(unboxing_and_autograd_genrule),
642":{}[autograd/generated/VariableType_4.cpp]".format(unboxing_and_autograd_genrule),
643":{}[autograd/generated/ADInplaceOrViewType_0.cpp]".format(unboxing_and_autograd_genrule),
644":{}[autograd/generated/ADInplaceOrViewType_1.cpp]".format(unboxing_and_autograd_genrule),
645] if train else []) + ([
646":{}[SupportedMobileModelsRegistration.cpp]".format(oplist_dir_name),
647] if NOT_OSS else [])
648
649headers = {
650"selected_mobile_ops.h": ":{}[selected_mobile_ops.h]".format(oplist_dir_name),
651}
652
653if _get_enable_lightweight_dispatch():
654srcs.extend([
655":{}[UnboxingFunctions_0.cpp]".format(unboxing_genrule),
656":{}[UnboxingFunctions_1.cpp]".format(unboxing_genrule),
657":{}[UnboxingFunctions_2.cpp]".format(unboxing_genrule),
658":{}[UnboxingFunctions_3.cpp]".format(unboxing_genrule),
659":{}[UnboxingFunctions_4.cpp]".format(unboxing_genrule),
660":{}[RegisterCodegenUnboxedKernels_0.cpp]".format(unboxing_genrule),
661":{}[RegisterCodegenUnboxedKernels_1.cpp]".format(unboxing_genrule),
662":{}[RegisterCodegenUnboxedKernels_2.cpp]".format(unboxing_genrule),
663":{}[RegisterCodegenUnboxedKernels_3.cpp]".format(unboxing_genrule),
664":{}[RegisterCodegenUnboxedKernels_4.cpp]".format(unboxing_genrule),
665":{}[RegisterCodegenUnboxedKernels_5.cpp]".format(unboxing_genrule),
666":{}[RegisterCodegenUnboxedKernels_6.cpp]".format(unboxing_genrule),
667":{}[RegisterCodegenUnboxedKernels_7.cpp]".format(unboxing_genrule),
668":{}[RegisterCodegenUnboxedKernels_8.cpp]".format(unboxing_genrule),
669":{}[RegisterCodegenUnboxedKernels_9.cpp]".format(unboxing_genrule),
670])
671headers["UnboxingFunctions.h"] = ":{}[UnboxingFunctions.h]".format(unboxing_genrule)
672return {"headers": headers, "srcs": srcs}
673
674def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple_sdks = None):
675fb_xplat_genrule(
676name = name,
677outs = get_generate_code_bin_outs(),
678default_outs = ["."],
679bash = "mkdir -p tools && " +
680"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
681# Mobile build only needs libtorch - skip python bindings for now, except
682# for ovrsource, which needs Python bindings.
683(["--subset libtorch"] if not is_arvr_mode() else []) + [
684"--native-functions-path $(location {}:aten_src_path)/aten/src/ATen/native/native_functions.yaml".format(ROOT),
685"--tags-path $(location {}:aten_src_path)/aten/src/ATen/native/tags.yaml".format(ROOT),
686"--install_dir $OUT",
687] + extra_params,
688),
689cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " +
690"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
691# Mobile build only needs libtorch - skip python bindings for now, except
692# for ovrsource, which needs Python bindings.
693(["--subset libtorch"] if not is_arvr_mode() else []) + [
694"--native-functions-path $(location {}:aten_src_path)/aten/src/ATen/native/native_functions.yaml".format(ROOT),
695"--tags-path $(location {}:aten_src_path)/aten/src/ATen/native/tags.yaml".format(ROOT),
696"--install_dir $OUT",
697] + extra_params,
698),
699compatible_with = compatible_with,
700apple_sdks = apple_sdks,
701)
702
703def copy_metal(name, apple_sdks = None):
704cmd = []
705cmd_exe = []
706metal_source_dict = get_metal_source_dict()
707
708# Copy all source files over to bring them into the per app build
709for path_prefix in sorted(metal_source_dict.keys()):
710cmd.append("mkdir -p $OUT/{}".format(path_prefix))
711cmd_exe.append("mkdir -Force $OUT/{0}".format(path_prefix))
712
713# Not every directory has a mm or cpp file so '2>/dev/null || :' are tricks to suppress the error messages and codes.
714cmd.append("cp -f {0}/{1}/*.mm $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
715cmd.append("cp -f {0}/{1}/*.cpp $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
716
717# Robocopy has a default success code of 1 which buck treats as failure so the echo masks that problem
718cmd_exe.append("(robocopy /E /NFL /NDL /NJH /NJS {0}/{1} $OUT/{1}) || ECHO robocopy failed".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
719
720# Metal custom ops currently have to be brought into selective build because they directly reference metal ops instead of
721# going through the dispatcher. There is some weird issues with the genrule and these files locations on windows though, so
722# for now we simply skip building them for windows where they very likely arent needed anyway.
723# Metal MaskRCNN custom op
724for full_path in METAL_MASKRCNN_SOURCE_LIST:
725path_prefix = paths.dirname(full_path)
726cmd.append("mkdir -p $OUT/{}".format(path_prefix))
727cmd.append("cp -f {0}/{1}/*.mm $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2/fb/metal:metal_maskrcnn_sources)", path_prefix))
728
729# Unet Metal Prepack Custom op
730unet_metal_prepack_file = "$(location //xplat/caffe2/fb/custom_ops/unet_metal_prepack:unet_metal_prepack_sources)"
731cmd.append("cp -f " + unet_metal_prepack_file + "/unet_metal_prepack.cpp" + " $OUT")
732cmd.append("cp -f " + unet_metal_prepack_file + "/unet_metal_prepack.mm" + " $OUT")
733
734fb_xplat_genrule(
735name = name,
736cmd = " && ".join(cmd),
737cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)),
738# due to an obscure bug certain custom ops werent being copied correctly on windows. ARVR also sometimes builds android targets on windows,
739# so we just exclude those targets from being copied for those platforms (They end up uncompiled anyway).
740outs = select({
741"DEFAULT": get_metal_registration_files_outs(),
742"ovr_config//os:android": get_metal_registration_files_outs_windows(),
743"ovr_config//os:windows": get_metal_registration_files_outs_windows(),
744}),
745default_outs = ["."],
746apple_sdks = apple_sdks,
747)
748
749def get_pt_operator_registry_dict(
750name,
751deps = [],
752train = False,
753labels = [],
754env = [],
755template_select = True,
756enforce_traced_op_list = False,
757pt_allow_forced_schema_registration = True,
758enable_flatbuffer = False,
759**kwargs):
760code_gen_files = pt_operator_query_codegen(
761name,
762deps = deps,
763train = train,
764enforce_traced_op_list = enforce_traced_op_list,
765pt_allow_forced_schema_registration = pt_allow_forced_schema_registration,
766compatible_with = kwargs.get("compatible_with", []),
767apple_sdks = kwargs.get("apple_sdks"),
768)
769
770return dict(
771srcs = code_gen_files["srcs"],
772linker_flags = [
773"-Wl,--no-as-needed",
774],
775# @lint-ignore BUCKLINT link_whole
776link_whole = True,
777soname = "libtorch-code-gen.$(ext)",
778header_namespace = "ATen",
779compiler_flags = get_aten_compiler_flags(),
780exported_headers = code_gen_files["headers"],
781exported_preprocessor_flags = get_aten_preprocessor_flags() + (["-DTEMPLATE_SELECTIVE_BUILD"] if template_select else []),
782headers = kwargs.pop("headers", []),
783labels = kwargs.pop("labels", []) + [
784# This library has multiple sources with the same file name
785# and does not work with Buck filegroup used in bad practices.
786# Opt out of the bad practices check with the below label.
787"bad_practices_ignore_override",
788"pt_operator_registry",
789],
790deps = [
791# need absolute path here
792ROOT + ":torch_mobile_core",
793ROOT + ":aten_cpu",
794ROOT + ":aten_metal_prepack_header",
795third_party("glog"),
796C10,
797] + ([ROOT + ":torch_mobile_train"] if train else []),
798**kwargs
799)
800
801# these targets are shared by internal and OSS BUCK
802def define_buck_targets(
803aten_default_args = dict(),
804pt_xplat_cxx_library = fb_xplat_cxx_library,
805c2_fbandroid_xplat_compiler_flags = [],
806labels = []):
807# @lint-ignore BUCKLINT
808fb_native.filegroup(
809name = "metal_build_srcs",
810# @lint-ignore BUCKRESTRICTEDSYNTAX
811srcs = glob(METAL_SOURCE_LIST),
812visibility = [
813"PUBLIC",
814],
815)
816
817# @lint-ignore BUCKLINT
818fb_native.filegroup(
819name = "templated_selective_build_srcs",
820# NB: no glob here, there are generated targets in this list!
821# @lint-ignore BUCKRESTRICTEDSYNTAX
822srcs = glob(TEMPLATE_SOURCE_LIST) + aten_ufunc_generated_all_cpu_sources(":gen_aten[{}]"),
823visibility = [
824"PUBLIC",
825],
826)
827
828fb_xplat_cxx_library(
829name = "aten_header",
830header_namespace = "",
831exported_headers = subdir_glob([
832# ATen Core
833("aten/src", "ATen/core/**/*.h"),
834("aten/src", "ATen/ops/*.h"),
835# ATen Base
836("aten/src", "ATen/*.h"),
837("aten/src", "ATen/cpu/**/*.h"),
838("aten/src", "ATen/detail/*.h"),
839("aten/src", "ATen/functorch/**/*.h"),
840("aten/src", "ATen/quantized/*.h"),
841("aten/src", "ATen/vulkan/*.h"),
842("aten/src", "ATen/metal/*.h"),
843("aten/src", "ATen/nnapi/*.h"),
844# ATen Native
845("aten/src", "ATen/native/*.h"),
846("aten/src", "ATen/native/ao_sparse/quantized/cpu/*.h"),
847("aten/src", "ATen/native/cpu/**/*.h"),
848("aten/src", "ATen/native/sparse/*.h"),
849("aten/src", "ATen/native/nested/*.h"),
850("aten/src", "ATen/native/quantized/*.h"),
851("aten/src", "ATen/native/quantized/cpu/*.h"),
852("aten/src", "ATen/native/transformers/*.h"),
853("aten/src", "ATen/native/ufunc/*.h"),
854("aten/src", "ATen/native/utils/*.h"),
855("aten/src", "ATen/native/vulkan/ops/*.h"),
856("aten/src", "ATen/native/xnnpack/*.h"),
857("aten/src", "ATen/mps/*.h"),
858("aten/src", "ATen/native/mps/*.h"),
859# Remove the following after modifying codegen for mobile.
860("aten/src", "ATen/mkl/*.h"),
861("aten/src", "ATen/native/mkl/*.h"),
862("aten/src", "ATen/native/mkldnn/*.h"),
863]),
864visibility = ["PUBLIC"],
865labels = labels,
866)
867
868fb_xplat_cxx_library(
869name = "aten_vulkan_header",
870header_namespace = "",
871exported_headers = subdir_glob([
872("aten/src", "ATen/native/vulkan/*.h"),
873("aten/src", "ATen/native/vulkan/ops/*.h"),
874("aten/src", "ATen/vulkan/*.h"),
875]),
876labels = labels,
877visibility = ["PUBLIC"],
878)
879
880fb_xplat_cxx_library(
881name = "jit_core_headers",
882header_namespace = "",
883exported_headers = subdir_glob([("", x) for x in jit_core_headers]),
884labels = labels,
885)
886
887fb_xplat_cxx_library(
888name = "torch_headers",
889header_namespace = "",
890exported_headers = subdir_glob(
891[
892("torch/csrc/api/include", "torch/**/*.h"),
893("", "torch/csrc/**/*.h"),
894("", "torch/script.h"),
895("", "torch/library.h"),
896("", "torch/custom_class.h"),
897("", "torch/custom_class_detail.h"),
898# Add again due to namespace difference from aten_header.
899("", "aten/src/ATen/*.h"),
900("", "aten/src/ATen/functorch/**/*.h"),
901("", "aten/src/ATen/quantized/*.h"),
902],
903exclude = [
904# Don't need on mobile.
905"torch/csrc/Exceptions.h",
906"torch/csrc/python_headers.h",
907"torch/csrc/jit/serialization/mobile_bytecode_generated.h",
908],
909),
910labels = labels,
911visibility = ["PUBLIC"],
912deps = [
913":generated-version-header",
914],
915)
916
917fb_xplat_cxx_library(
918name = "aten_test_header",
919header_namespace = "",
920exported_headers = subdir_glob([
921("aten/src", "ATen/test/*.h"),
922]),
923)
924
925fb_xplat_cxx_library(
926name = "aten_metal_prepack_header",
927header_namespace = "",
928exported_headers = subdir_glob([
929("aten/src", "ATen/native/metal/MetalPrepackOpContext.h"),
930]),
931labels = labels,
932visibility = ["PUBLIC"],
933)
934
935fb_xplat_cxx_library(
936name = "torch_mobile_headers",
937header_namespace = "",
938exported_headers = subdir_glob(
939[
940("", "torch/csrc/jit/mobile/*.h"),
941],
942),
943labels = labels,
944visibility = ["PUBLIC"],
945)
946
947fb_xplat_cxx_library(
948name = "generated_aten_config_header",
949header_namespace = "ATen",
950exported_headers = {
951"Config.h": ":generate_aten_config[Config.h]",
952},
953labels = labels,
954)
955
956fb_xplat_cxx_library(
957name = "generated-autograd-headers",
958header_namespace = "torch/csrc/autograd/generated",
959exported_headers = {
960"Functions.h": ":gen_aten_libtorch[autograd/generated/Functions.h]",
961"VariableType.h": ":gen_aten_libtorch[autograd/generated/VariableType.h]",
962"variable_factories.h": ":gen_aten_libtorch[autograd/generated/variable_factories.h]",
963"ViewFuncs.h": ":gen_aten_libtorch[autograd/generated/ViewFuncs.h]",
964# Don't build python bindings on mobile.
965#"python_functions.h",
966},
967labels = labels,
968visibility = ["PUBLIC"],
969)
970
971fb_xplat_cxx_library(
972name = "generated-version-header",
973header_namespace = "torch",
974exported_headers = {
975"version.h": ":generate-version-header[version.h]",
976},
977labels = labels,
978)
979
980# @lint-ignore BUCKLINT
981fb_native.genrule(
982name = "generate-version-header",
983srcs = [
984"torch/csrc/api/include/torch/version.h.in",
985"version.txt",
986],
987cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([
988"--template-path",
989"torch/csrc/api/include/torch/version.h.in",
990"--version-path",
991"version.txt",
992"--output-path",
993"$OUT/version.h",
994]),
995outs = {
996"version.h": ["version.h"],
997},
998default_outs = ["."],
999)
1000
1001# @lint-ignore BUCKLINT
1002fb_native.filegroup(
1003name = "aten_src_path",
1004srcs = [
1005"aten/src/ATen/native/native_functions.yaml",
1006"aten/src/ATen/native/tags.yaml",
1007# @lint-ignore BUCKRESTRICTEDSYNTAX
1008] + glob(["aten/src/ATen/templates/*"]),
1009visibility = [
1010"PUBLIC",
1011],
1012)
1013
1014fb_xplat_cxx_library(
1015name = "common_core",
1016srcs = [
1017"caffe2/core/common.cc",
1018],
1019apple_sdks = (IOS, MACOSX, APPLETVOS),
1020compiler_flags = get_pt_compiler_flags(),
1021labels = labels,
1022# @lint-ignore BUCKLINT link_whole
1023link_whole = True,
1024visibility = ["PUBLIC"],
1025windows_preferred_linkage = "static" if is_arvr_mode() else None,
1026deps = [
1027":caffe2_headers",
1028C10,
1029],
1030)
1031
1032# @lint-ignore BUCKLINT
1033fb_native.genrule(
1034name = "generate_aten_config",
1035srcs = [
1036"aten/src/ATen/Config.h.in",
1037],
1038cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([
1039"--install_dir",
1040"$OUT",
1041"--input-file",
1042"aten/src/ATen/Config.h.in",
1043"--output-file",
1044"Config.h",
1045"--replace",
1046"@AT_MKLDNN_ENABLED@",
1047"ATEN_MKLDNN_ENABLED_FBXPLAT",
1048"--replace",
1049"@AT_MKLDNN_ACL_ENABLED@",
1050"ATEN_MKLDNN_ACL_ENABLED_FBXPLAT",
1051"--replace",
1052"@AT_MKL_ENABLED@",
1053"ATEN_MKL_ENABLED_FBXPLAT",
1054"--replace",
1055"@AT_MKL_SEQUENTIAL@",
1056"ATEN_MKL_SEQUENTIAL_FBXPLAT",
1057"--replace",
1058"@AT_POCKETFFT_ENABLED@",
1059"1",
1060"--replace",
1061"@AT_NNPACK_ENABLED@",
1062"ATEN_NNPACK_ENABLED_FBXPLAT",
1063"--replace",
1064"@CAFFE2_STATIC_LINK_CUDA_INT@",
1065"CAFFE2_STATIC_LINK_CUDA_FBXPLAT",
1066"--replace",
1067"@AT_BUILD_WITH_BLAS@",
1068"USE_BLAS_FBXPLAT",
1069"--replace",
1070"@AT_PARALLEL_OPENMP@",
1071"AT_PARALLEL_OPENMP_FBXPLAT",
1072"--replace",
1073"@AT_PARALLEL_NATIVE@",
1074"AT_PARALLEL_NATIVE_FBXPLAT",
1075"--replace",
1076"@AT_PARALLEL_NATIVE_TBB@",
1077"AT_PARALLEL_NATIVE_TBB_FBXPLAT",
1078"--replace",
1079"@AT_BUILD_WITH_LAPACK@",
1080"USE_LAPACK_FBXPLAT",
1081"--replace",
1082"@AT_BLAS_F2C@",
1083"AT_BLAS_F2C_FBXPLAT",
1084"--replace",
1085"@AT_BLAS_USE_CBLAS_DOT@",
1086"AT_BLAS_USE_CBLAS_DOT_FBXPLAT",
1087]),
1088outs = {
1089"Config.h": ["Config.h"],
1090},
1091default_outs = ["."],
1092)
1093
1094gen_aten_files(
1095name = "gen_aten",
1096extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS),
1097visibility = ["PUBLIC"],
1098)
1099
1100gen_aten_libtorch_files(name = "gen_aten_libtorch")
1101
1102gen_aten_libtorch_files(
1103name = "gen_aten_libtorch_lite",
1104extra_params = get_jit_codegen_params(),
1105)
1106
1107fb_xplat_cxx_library(
1108name = "generated_aten_headers_cpu",
1109header_namespace = "ATen",
1110exported_headers = get_aten_static_dispatch_backend_headers({
1111"CPUFunctions.h": ":gen_aten[CPUFunctions.h]",
1112"CPUFunctions_inl.h": ":gen_aten[CPUFunctions_inl.h]",
1113"CompositeExplicitAutogradFunctions.h": ":gen_aten[CompositeExplicitAutogradFunctions.h]",
1114"CompositeExplicitAutogradFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradFunctions_inl.h]",
1115"CompositeExplicitAutogradNonFunctionalFunctions.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions.h]",
1116"CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]",
1117"CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]",
1118"CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]",
1119"CompositeImplicitAutogradNestedTensorFunctions.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions.h]",
1120"CompositeImplicitAutogradNestedTensorFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions_inl.h]",
1121"FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]",
1122"Functions.h": ":gen_aten[Functions.h]",
1123"MethodOperators.h": ":gen_aten[MethodOperators.h]",
1124"NativeFunctions.h": ":gen_aten[NativeFunctions.h]",
1125"NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]",
1126"Operators.h": ":gen_aten[Operators.h]",
1127"RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
1128"core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
1129"core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
1130"core/enum_tag.h": ":gen_aten[core/enum_tag.h]",
1131}),
1132labels = labels,
1133)
1134
1135fb_xplat_cxx_library(
1136name = "torch_mobile_observer",
1137srcs = [
1138"torch/csrc/jit/mobile/observer.cpp",
1139] + ([] if IS_OSS else ["torch/fb/observers/MobileObserverUtil.cpp"]),
1140compiler_flags = ["-fexceptions"],
1141header_namespace = "",
1142exported_headers = subdir_glob(
1143[
1144("", "torch/csrc/jit/mobile/observer.h"),
1145] + ([] if IS_OSS else [
1146("", "torch/fb/observers/ObserverUtil.h"),
1147("", "torch/fb/observers/MobileObserverUtil.h"),
1148]),
1149),
1150fbobjc_compiler_flags = [
1151"-Wno-missing-prototypes",
1152],
1153labels = labels,
1154visibility = ["PUBLIC"],
1155deps = [
1156C10,
1157],
1158)
1159
1160# Base library shared by lite-interpreter and full-jit.
1161pt_xplat_cxx_library(
1162name = "torch_common",
1163srcs = core_sources_common,
1164compiler_flags = get_pt_compiler_flags(),
1165exported_preprocessor_flags = get_pt_preprocessor_flags(),
1166# @lint-ignore BUCKLINT link_whole
1167link_whole = True,
1168visibility = ["PUBLIC"],
1169deps = [
1170":aten_cpu",
1171":generated-autograd-headers",
1172":torch_headers",
1173C10,
1174third_party("libkineto_headers"),
1175],
1176)
1177
1178pt_xplat_cxx_library(
1179name = "torch_mobile_deserialize_common",
1180srcs = [
1181"torch/csrc/jit/mobile/parse_bytecode.cpp",
1182"torch/csrc/jit/mobile/parse_operators.cpp",
1183"torch/csrc/jit/mobile/upgrader_mobile.cpp",
1184"torch/csrc/jit/serialization/import_read.cpp",
1185"torch/csrc/jit/serialization/unpickler.cpp",
1186],
1187header_namespace = "",
1188exported_headers = [
1189"torch/csrc/jit/serialization/import_read.h",
1190"torch/csrc/jit/serialization/unpickler.h",
1191],
1192compiler_flags = get_pt_compiler_flags(),
1193exported_preprocessor_flags = get_pt_preprocessor_flags(),
1194extra_flags = {
1195"fbandroid_compiler_flags": ["-frtti"],
1196},
1197# torch_mobile_deserialize brings in sources neccessary to read a module
1198# which depends on mobile module definition
1199# link_whole is enable so that all symbols neccessary for mobile module are compiled
1200# instead of only symbols used while loading; this prevents symbol
1201# found definied in runtime
1202# @lint-ignore BUCKLINT link_whole
1203link_whole = True,
1204linker_flags = ["-Wl,--no-as-needed"],
1205visibility = ["PUBLIC"],
1206exported_deps = [
1207":aten_cpu",
1208":caffe2_headers",
1209":caffe2_serialize",
1210":torch_common",
1211":torch_headers",
1212":torch_mobile_headers",
1213":torch_mobile_module",
1214":torch_mobile_observer",
1215C10,
1216],
1217)
1218
1219pt_xplat_cxx_library(
1220name = "torch_mobile_module",
1221srcs = [
1222"torch/csrc/jit/mobile/function.cpp",
1223"torch/csrc/jit/mobile/interpreter.cpp",
1224"torch/csrc/jit/mobile/module.cpp",
1225],
1226header_namespace = "",
1227exported_headers = [
1228],
1229compiler_flags = get_pt_compiler_flags(),
1230exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1231extra_flags = {
1232"fbandroid_compiler_flags": ["-frtti"],
1233},
1234# @lint-ignore BUCKLINT link_whole
1235link_whole = True,
1236linker_flags = [
1237"-Wl,--no-as-needed",
1238],
1239visibility = ["PUBLIC"],
1240exported_deps = [
1241":aten_cpu",
1242":caffe2_headers",
1243":torch_common",
1244":torch_headers",
1245":torch_mobile_headers",
1246":torch_mobile_observer",
1247C10,
1248],
1249)
1250
1251pt_xplat_cxx_library(
1252name = "torch_mobile_debug_symbolication",
1253srcs = [
1254# included in aten_cpu "torch/csrc/jit/frontend/source_range.cpp",
1255"torch/csrc/jit/ir/scope.cpp",
1256"torch/csrc/jit/mobile/debug_info.cpp",
1257"torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
1258"torch/csrc/jit/serialization/source_range_serialization.cpp",
1259"torch/csrc/jit/serialization/pickle.cpp",
1260# pickler.cpp doesn't seem to be needed.
1261# "torch/csrc/jit/serialization/pickler.cpp",
1262# included in core_sources_common "torch/csrc/jit/serialization/unpickler.cpp",
1263],
1264compiler_flags = get_pt_compiler_flags(),
1265exported_preprocessor_flags = get_pt_preprocessor_flags(),
1266header_namespace = "",
1267# @lint-ignore BUCKLINT link_whole
1268link_whole = True,
1269linker_flags = [
1270"-Wl,--no-as-needed",
1271],
1272visibility = ["PUBLIC"],
1273deps = [
1274":torch_mobile_deserialize",
1275],
1276exported_deps = [
1277":torch_common",
1278],
1279)
1280
1281pt_xplat_cxx_library(
1282name = "torch_model_tracer",
1283srcs = [
1284"torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp",
1285] + get_feature_tracer_source_list(),
1286header_namespace = "",
1287compiler_flags = get_pt_compiler_flags(),
1288exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1289# @lint-ignore BUCKLINT link_whole
1290link_whole = True,
1291linker_flags = [
1292"-Wl,--no-as-needed",
1293],
1294visibility = ["PUBLIC"],
1295deps = [
1296":generated-autograd-headers",
1297":torch_mobile_deserialize",
1298":torch_mobile_headers",
1299":torch_mobile_observer",
1300] + ([] if IS_OSS else ["//xplat/folly:molly"]),
1301exported_deps = [
1302":aten_cpu",
1303":torch_common",
1304] + ([] if IS_OSS else [
1305"//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1306"//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1307]),
1308)
1309
1310pt_xplat_cxx_library(
1311name = "torch_mobile_deserialize",
1312srcs = [
1313"torch/csrc/jit/mobile/import.cpp",
1314"torch/csrc/jit/mobile/flatbuffer_loader.cpp",
1315],
1316compiler_flags = get_pt_compiler_flags(),
1317exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1318header_namespace = "",
1319exported_headers = [
1320"torch/csrc/jit/mobile/import.h",
1321"torch/csrc/jit/mobile/flatbuffer_loader.h",
1322],
1323# torch_mobile_deserialize brings in sources neccessary to read a module
1324# which depends on mobile module definition
1325# link_whole is enable so that all symbols neccessary for mobile module are compiled
1326# instead of only symbols used while loading; this prevents symbol
1327# found definied in runtime
1328# @lint-ignore BUCKLINT link_whole
1329link_whole = True,
1330linker_flags = [
1331"-Wl,--no-as-needed",
1332],
1333visibility = ["PUBLIC"],
1334exported_deps = [
1335":aten_cpu",
1336":caffe2_headers",
1337":caffe2_serialize",
1338":torch_common",
1339":torch_headers",
1340":torch_mobile_headers",
1341":torch_mobile_module",
1342":torch_mobile_observer",
1343":torch_mobile_deserialize_common",
1344":mobile_bytecode",
1345C10,
1346],
1347)
1348
1349pt_xplat_cxx_library(
1350name = "torch_mobile_core",
1351srcs = [],
1352header_namespace = "",
1353exported_headers = [],
1354compiler_flags = get_pt_compiler_flags(),
1355exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1356# torch_mobile_core brings in sources neccessary to read and run a module
1357# link_whole is enabled so that all symbols linked
1358# operators, registerations and other few symbols are need in runtime
1359# @lint-ignore BUCKLINT link_whole
1360link_whole = True,
1361linker_flags = [
1362"-Wl,--no-as-needed",
1363],
1364visibility = ["PUBLIC"],
1365deps = [
1366":generated-autograd-headers",
1367":torch_mobile_headers",
1368":torch_mobile_observer",
1369],
1370exported_deps = [
1371":aten_cpu",
1372":torch_common",
1373":torch_mobile_deserialize",
1374":torch_supported_mobile_models",
1375],
1376)
1377
1378pt_xplat_cxx_library(
1379name = "torch_mobile_core_pickle_and_flatbuffer",
1380compiler_flags = get_pt_compiler_flags(),
1381exported_preprocessor_flags = get_pt_preprocessor_flags(),
1382visibility = ["PUBLIC"],
1383exported_deps = [
1384":flatbuffers_mobile",
1385":torch_mobile_core",
1386],
1387)
1388
1389pt_xplat_cxx_library(
1390name = "torch_cpp_cpu",
1391srcs = torch_cpp_srcs,
1392headers = native.glob(["torch/csrc/api/include/**/*.h"]) + ["torch/script.h"],
1393compiler_flags = get_pt_compiler_flags(),
1394exported_preprocessor_flags = get_pt_preprocessor_flags(),
1395visibility = ["PUBLIC"],
1396exported_deps = [
1397":torch",
1398":torch_mobile_deserialize_common", # for torch/csrc/api/src/serialize/input-archive.cpp
1399],
1400)
1401
1402pt_xplat_cxx_library(
1403name = "torch_core",
1404srcs = core_sources_full_mobile_no_backend_interface_xplat,
1405compiler_flags = get_pt_compiler_flags(),
1406exported_preprocessor_flags = get_pt_preprocessor_flags(),
1407visibility = [
1408"//xplat/caffe2/android/...",
1409"//xplat/caffe2/fb/...",
1410"//xplat/caffe2/fb/model_tracer/...",
1411],
1412deps = [
1413":aten_cpu",
1414":backend_interface_lib",
1415":generated-autograd-headers",
1416":torch_headers",
1417":torch_mobile_deserialize",
1418third_party("glog"),
1419third_party("rt"),
1420C10,
1421] + ([] if IS_OSS else [
1422"//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1423"//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1424]),
1425exported_deps = [
1426":torch_common",
1427":torch_mobile_train",
1428],
1429)
1430
1431pt_xplat_cxx_library(
1432name = "torch_train",
1433srcs = [
1434"torch/csrc/api/src/data/samplers/random.cpp",
1435"torch/csrc/api/src/data/samplers/sequential.cpp",
1436"torch/csrc/api/src/optim/optimizer.cpp",
1437"torch/csrc/api/src/optim/serialize.cpp",
1438"torch/csrc/api/src/optim/sgd.cpp",
1439"torch/csrc/api/src/serialize/input-archive.cpp",
1440"torch/csrc/api/src/serialize/output-archive.cpp",
1441"torch/csrc/jit/api/module_save.cpp",
1442],
1443compiler_flags = get_pt_compiler_flags(),
1444exported_preprocessor_flags = get_pt_preprocessor_flags(),
1445visibility = ["PUBLIC"],
1446deps = [
1447":aten_cpu",
1448":torch_headers",
1449":torch",
1450":torch_core",
1451":torch_mobile_deserialize",
1452":torch_mobile_train",
1453":jit_module_saving",
1454C10,
1455],
1456)
1457
1458pt_xplat_cxx_library(
1459name = "torch_mobile_train",
1460srcs = core_trainer_sources + [
1461"torch/csrc/autograd/VariableTypeManual.cpp",
1462"torch/csrc/autograd/FunctionsManual.cpp",
1463"torch/csrc/api/src/data/datasets/mnist.cpp",
1464"torch/csrc/jit/mobile/quantization.cpp",
1465"torch/csrc/jit/mobile/train/export_data.cpp",
1466"torch/csrc/jit/mobile/train/optim/sgd.cpp",
1467"torch/csrc/jit/mobile/train/random.cpp",
1468"torch/csrc/jit/mobile/train/sequential.cpp",
1469":gen_aten_libtorch[autograd/generated/Functions.cpp]",
1470":gen_aten_libtorch[autograd/generated/ViewFuncs.cpp]",
1471],
1472compiler_flags = get_pt_compiler_flags(),
1473exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
1474# torch_mobile_train brings in sources neccessary to read and run a mobile
1475# and save and load mobile params along with autograd
1476# link_whole is enabled so that all symbols linked
1477# operators, registerations and autograd related symbols are need in runtime
1478# @lint-ignore BUCKLINT link_whole
1479link_whole = True,
1480visibility = ["PUBLIC"],
1481deps = [
1482":aten_cpu",
1483":generated-autograd-headers",
1484":torch_headers",
1485":torch_mobile_deserialize",
1486":flatbuffers_serializer_mobile",
1487C10,
1488],
1489)
1490
1491pt_xplat_cxx_library(
1492name = "torch",
1493srcs = [
1494"torch/csrc/jit/runtime/register_c10_ops.cpp",
1495"torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp",
1496],
1497compiler_flags = get_pt_compiler_flags(),
1498exported_preprocessor_flags = get_pt_preprocessor_flags(),
1499# torch brings in all sources neccessary to read and run a mobile module/jit module
1500# link_whole is enabled so that all symbols linked
1501# operators, registerations and other few symbols are need in runtime
1502# @lint-ignore BUCKLINT link_whole
1503link_whole = True,
1504visibility = ["PUBLIC"],
1505deps = [
1506# This is to have autograd profiler available
1507# in xplat/caffe2:torch which some builds are using
1508# notable xplate/facegen:testsAndroid
1509":torch_headers",
1510":torch_kineto_profiling",
1511],
1512exported_deps = [
1513":aten_cpu",
1514":torch_core",
1515C10,
1516],
1517)
1518
1519pt_xplat_cxx_library(
1520name = "torch_mobile_train_import_data",
1521srcs = [
1522"torch/csrc/jit/mobile/import_data.cpp",
1523],
1524compiler_flags = get_pt_compiler_flags(),
1525exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
1526# torch_mobile_train_import_data brings in sources neccessary to read a mobile module
1527# link_whole is enabled so that all symbols linked
1528# operators other few symbols are need in runtime
1529# @lint-ignore BUCKLINT link_whole
1530link_whole = True,
1531visibility = ["PUBLIC"],
1532deps = [
1533":torch_headers",
1534":torch_mobile_observer",
1535":torch_mobile_core",
1536":torch_mobile_train",
1537],
1538)
1539
1540fb_xplat_cxx_library(
1541name = "torch_mobile_compatibility",
1542srcs = [
1543# These .cpp brought in through core_sources_common
1544# "torch/csrc/jit/mobile/compatibility/runtime_compatibility.cpp",
1545# "torch/csrc/jit/serialization/unpickler.cpp",
1546"torch/csrc/jit/mobile/compatibility/model_compatibility.cpp",
1547],
1548header_namespace = "",
1549exported_headers = [
1550"torch/csrc/jit/mobile/compatibility/backport.h",
1551"torch/csrc/jit/mobile/compatibility/backport_manager.h",
1552"torch/csrc/jit/mobile/compatibility/model_compatibility.h",
1553"torch/csrc/jit/mobile/compatibility/runtime_compatibility.h",
1554],
1555compiler_flags = [
1556"-fexceptions",
1557"-frtti",
1558"-Wno-deprecated-declarations",
1559"-Wno-global-constructors",
1560],
1561labels = labels,
1562visibility = ["PUBLIC"],
1563deps = [
1564":torch_mobile_deserialize",
1565],
1566)
1567
1568pt_xplat_cxx_library(
1569name = "jit_module_saving",
1570srcs = [
1571"torch/csrc/jit/api/module_save.cpp",
1572"torch/csrc/jit/serialization/export_bytecode.cpp",
1573"torch/csrc/jit/serialization/export_module.cpp",
1574],
1575compiler_flags = get_pt_compiler_flags(),
1576exported_preprocessor_flags = get_pt_preprocessor_flags() +
1577(["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1578exported_headers = [
1579"torch/csrc/jit/serialization/export.h",
1580],
1581visibility = ["PUBLIC"],
1582deps = [
1583":torch",
1584":torch_mobile_core",
1585":flatbuffers_serializer_mobile",
1586],
1587)
1588
1589pt_xplat_cxx_library(
1590name = "torch_mobile_model_tracer",
1591srcs = [
1592"torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp",
1593"torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp",
1594],
1595headers = [
1596"torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h",
1597"torch/csrc/jit/mobile/model_tracer/TensorUtils.h",
1598],
1599header_namespace = "",
1600exported_headers = [
1601"torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h",
1602],
1603compiler_flags = get_pt_compiler_flags(),
1604exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1605# torch_mobile_model_tracer brings in sources neccessary to read and run a jit module
1606# and trace the ops
1607# link_whole is enabled so that all symbols linked
1608# operators, registerations and other few symbols are need in runtime
1609# @lint-ignore BUCKLINT link_whole
1610link_whole = True,
1611linker_flags = [
1612"-Wl,--no-as-needed",
1613],
1614visibility = ["PUBLIC"],
1615deps = [
1616":caffe2_serialize",
1617":generated-autograd-headers",
1618":torch_mobile_headers",
1619":torch_mobile_observer",
1620":torch_mobile_core",
1621] + ([] if IS_OSS else ["//xplat/folly:molly"]),
1622exported_deps = [
1623":aten_cpu",
1624":torch_common",
1625] + ([] if IS_OSS else [
1626"//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1627"//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1628"//xplat/caffe2/fb/custom_ops/sparsenn:sparsenn-all",
1629]),
1630)
1631
1632#TODO(qihan) delete
1633pt_xplat_cxx_library(
1634name = "torch_mobile_core_flatbuffer",
1635srcs = [],
1636header_namespace = "",
1637exported_headers = [],
1638compiler_flags = get_pt_compiler_flags(),
1639exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1640# @lint-ignore BUCKLINT link_whole
1641link_whole = True,
1642linker_flags = [
1643"-Wl,--no-as-needed",
1644],
1645visibility = ["PUBLIC"],
1646deps = [
1647":generated-autograd-headers",
1648":torch_mobile_headers",
1649":torch_mobile_observer",
1650],
1651exported_deps = [
1652":aten_cpu",
1653":torch_common",
1654],
1655)
1656
1657fb_xplat_cxx_library(
1658name = "backend_interface_lib",
1659srcs = [
1660"torch/csrc/jit/backends/backend_debug_info.cpp",
1661"torch/csrc/jit/backends/backend_interface.cpp",
1662],
1663compiler_flags = get_pt_compiler_flags(),
1664fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
1665# @lint-ignore BUCKLINT link_whole
1666link_whole = True,
1667linker_flags = [
1668"-Wl,--no-as-needed",
1669],
1670visibility = ["PUBLIC"],
1671exported_deps = [
1672":aten_cpu",
1673":torch_common",
1674],
1675)
1676
1677pt_xplat_cxx_library(
1678name = "torch_kineto_profiling",
1679srcs = libtorch_profiler_sources,
1680compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1681exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1682"-DUSE_KINETO",
1683# Need this otherwise USE_KINETO is undefed
1684# for mobile
1685"-DEDGE_PROFILER_USE_KINETO",
1686],
1687# @lint-ignore BUCKLINT link_whole
1688link_whole = True,
1689linker_flags = [
1690"-Wl,--no-as-needed",
1691],
1692visibility = ["PUBLIC"],
1693deps = [
1694third_party("glog"),
1695third_party("kineto"),
1696],
1697exported_deps = [
1698":aten_cpu",
1699":torch_common",
1700],
1701)
1702
1703pt_xplat_cxx_library(
1704name = "torch_edge_profiling",
1705srcs = ["torch/csrc/jit/mobile/profiler_edge.cpp"],
1706compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1707exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1708"-DUSE_KINETO",
1709"-DEDGE_PROFILER_USE_KINETO",
1710],
1711# @lint-ignore BUCKLINT link_whole
1712link_whole = True,
1713linker_flags = [
1714"-Wl,--no-as-needed",
1715],
1716visibility = ["PUBLIC"],
1717exported_deps = [
1718":torch_common",
1719":torch_kineto_profiling",
1720":torch_mobile_core",
1721],
1722)
1723
1724fb_xplat_genrule(
1725name = "mobile_bytecode_header",
1726srcs = [
1727"torch/csrc/jit/serialization/mobile_bytecode.fbs",
1728],
1729outs = {
1730"mobile_bytecode_generated_fbsource.h": ["mobile_bytecode_generated.h"],
1731},
1732cmd = "$(exe {})".format(third_party("flatc")) +
1733" --cpp --gen-mutable --scoped-enums -o ${OUT} ${SRCS}",
1734default_outs = ["."],
1735visibility = [
1736"{}:mobile_bytecode".format(ROOT),
1737],
1738)
1739
1740# Users of this target will need to add third_party("flatbuffers-api") as a
1741# dep.
1742fb_xplat_cxx_library(
1743name = "mobile_bytecode",
1744header_namespace = "",
1745exported_headers = {
1746("torch/csrc/jit/serialization/mobile_bytecode_generated.h" if IS_OSS else "torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h"): ":mobile_bytecode_header[mobile_bytecode_generated_fbsource.h]",
1747},
1748# Avoid leaking implementation details by only exposing this header to
1749# the internals of the loader/serializer layer.
1750visibility = [
1751"{}:flatbuffer_loader".format(ROOT),
1752"{}:flatbuffers_serializer_mobile".format(ROOT),
1753],
1754exported_deps = [
1755third_party("flatbuffers-api"),
1756],
1757)
1758
1759fb_xplat_cxx_library(
1760name = "flatbuffers_serializer_mobile",
1761srcs = ["torch/csrc/jit/serialization/flatbuffer_serializer.cpp"],
1762exported_headers = [
1763"torch/csrc/jit/serialization/flatbuffer_serializer.h",
1764],
1765compiler_flags = [
1766"-g0",
1767"-O3",
1768"-fexceptions",
1769"-frtti",
1770"-Wno-deprecated-declarations",
1771] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1772visibility = ["PUBLIC"],
1773deps = [
1774":mobile_bytecode",
1775":torch_mobile_module",
1776C10,
1777],
1778exported_deps = [
1779":torch_mobile_deserialize",
1780":mobile_bytecode",
1781],
1782)
1783
1784# TODO (qihan) delete
1785pt_xplat_cxx_library(
1786name = "flatbuffer_loader",
1787srcs = [
1788],
1789exported_headers = [
1790"torch/csrc/jit/mobile/flatbuffer_loader.h",
1791],
1792compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1793exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1794"-DUSE_KINETO",
1795# Need this otherwise USE_KINETO is undefed
1796# for mobile
1797"-DEDGE_PROFILER_USE_KINETO",
1798] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1799extra_flags = {
1800"fbandroid_compiler_flags": ["-frtti"],
1801},
1802# torch_mobile_deserialize brings in sources neccessary to read a module
1803# which depends on mobile module definition
1804# link_whole is enable so that all symbols neccessary for mobile module are compiled
1805# instead of only symbols used while loading; this prevents symbol
1806# found definied in runtime
1807# @lint-ignore BUCKLINT link_whole
1808link_whole = True,
1809linker_flags = [
1810"-Wl,--no-as-needed",
1811],
1812visibility = ["PUBLIC"],
1813deps = [
1814":mobile_bytecode",
1815],
1816exported_deps = [
1817C10,
1818],
1819)
1820
1821# TODO(qihan) delete
1822fb_xplat_cxx_library(
1823name = "flatbuffers_serializer_jit",
1824compiler_flags = [
1825"-g0",
1826"-O3",
1827"-fexceptions",
1828"-frtti",
1829"-Wno-deprecated-declarations",
1830],
1831headers = [
1832"torch/csrc/jit/serialization/flatbuffer_serializer_jit.h",
1833],
1834srcs = [
1835"torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp",
1836],
1837linker_flags = [
1838"-Wl,--no-as-needed",
1839],
1840visibility = ["PUBLIC"],
1841deps = [
1842":flatbuffer_loader",
1843":flatbuffers_serializer_mobile",
1844":torch_core",
1845":torch_mobile_module",
1846C10,
1847],
1848)
1849
1850fb_xplat_cxx_library(
1851name = "flatbuffers_jit",
1852visibility = ["PUBLIC"],
1853exported_deps = [
1854":flatbuffer_loader",
1855":flatbuffers_serializer_mobile",
1856":flatbuffers_serializer_jit",
1857],
1858)
1859
1860fb_xplat_cxx_library(
1861name = "flatbuffers_mobile",
1862visibility = ["PUBLIC"],
1863exported_deps = [
1864":flatbuffer_loader",
1865":flatbuffers_serializer_mobile",
1866":torch_mobile_train",
1867],
1868)
1869
1870pt_xplat_cxx_library(
1871name = "torch_supported_mobile_models",
1872srcs = [
1873"fb/supported_mobile_models/SupportedMobileModels.cpp",
1874] if NOT_OSS else [],
1875header_namespace = "",
1876exported_headers = ["fb/supported_mobile_models/SupportedMobileModels.h"] if NOT_OSS else [],
1877compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1878exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1879# @lint-ignore BUCKLINT link_whole
1880link_whole = True,
1881linker_flags = [
1882"-Wl,--no-as-needed",
1883],
1884visibility = ["PUBLIC"],
1885deps = [],
1886exported_deps = [
1887"//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1888"//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1889] if NOT_OSS else [],
1890)
1891
1892fb_xplat_cxx_library(
1893name = "static_runtime",
1894srcs = [
1895"torch/csrc/jit/runtime/static/fusion.cpp",
1896"torch/csrc/jit/runtime/static/generated_ops.cpp",
1897"torch/csrc/jit/runtime/static/impl.cpp",
1898"torch/csrc/jit/runtime/static/memory_planner.cpp",
1899"torch/csrc/jit/runtime/static/native_ops.cpp",
1900"torch/csrc/jit/runtime/static/ops.cpp",
1901"torch/csrc/jit/runtime/static/passes.cpp",
1902"torch/csrc/jit/runtime/static/te_wrapper.cpp",
1903],
1904compiler_flags = ["-fexceptions"],
1905labels = labels,
1906# @lint-ignore BUCKLINT link_whole
1907link_whole = True,
1908visibility = ["PUBLIC"],
1909windows_preferred_linkage = "static" if is_arvr_mode() else None,
1910deps = [
1911":aten_cpu",
1912":caffe2_headers",
1913":torch_core",
1914C10,
1915],
1916)
1917
1918# aten_cpu and aten_native_cpu
1919for name, srcs in [
1920("aten_cpu", jit_core_sources + aten_cpu_source_list + [
1921# Generated
1922":gen_aten[Functions.cpp]",
1923":gen_aten[Operators_0.cpp]",
1924":gen_aten[Operators_1.cpp]",
1925":gen_aten[Operators_2.cpp]",
1926":gen_aten[Operators_3.cpp]",
1927":gen_aten[Operators_4.cpp]",
1928":gen_aten[core/ATenOpList.cpp]",
1929":gen_aten[core/TensorMethods.cpp]",
1930# Needed by ATen/native/EmbeddingBag.cpp
1931"caffe2/perfkernels/embedding_lookup_idx.cc",
1932]),
1933("aten_native_cpu", aten_native_source_list),
1934]:
1935fb_xplat_cxx_library(
1936name = name,
1937srcs = srcs,
1938header_namespace = "",
1939# @lint-ignore BUCKLINT
1940link_whole = True,
1941visibility = ["PUBLIC"],
1942deps = [
1943third_party("omp"),
1944third_party("cpuinfo"),
1945third_party("glog"),
1946third_party("XNNPACK"),
1947third_party("pocketfft"),
1948] + select({
1949"DEFAULT": [],
1950"ovr_config//runtime:fbcode-arm64": [
1951third_party("sleef_arm"),
1952],
1953}),
1954compiler_flags = get_aten_compiler_flags(),
1955exported_preprocessor_flags = get_aten_preprocessor_flags(),
1956exported_deps = [
1957":aten_header",
1958":caffe2_headers",
1959":common_core",
1960":generated_aten_config_header",
1961":generated_aten_headers_cpu",
1962":jit_core_headers",
1963":pthreadpool",
1964third_party("fmt"),
1965third_party("ruy"),
1966C10,
1967ROOT_PATH + "aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack",
1968],
1969labels = labels,
1970**aten_default_args
1971)
1972
1973fb_xplat_cxx_library(
1974name = "lean_runtime_with_flatbuffer",
1975srcs = [
1976"aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp",
1977"torch/csrc/jit/mobile/import.cpp",
1978"torch/csrc/jit/mobile/module.cpp",
1979"torch/csrc/jit/mobile/observer.cpp",
1980"torch/csrc/jit/serialization/import_read.cpp",
1981],
1982header_namespace = "",
1983exported_headers = subdir_glob(
1984[
1985("", "torch/csrc/jit/ir/*.h"),
1986("", "caffe2/serialize/*.h"),
1987("", "caffe2/utils/*.h"),
1988("", "caffe2/core/*.h"),
1989("", "torch/csrc/*.h"),
1990("", "torch/csrc/api/include/torch/*.h"),
1991("", "torch/csrc/autograd/*.h"),
1992("", "torch/csrc/autograd/*/*.h"),
1993("", "torch/csrc/jit/api/*.h"),
1994("", "torch/csrc/jit/backends/*.h"),
1995("", "torch/csrc/jit/mobile/*.h"),
1996("", "torch/csrc/jit/runtime/*.h"),
1997("", "torch/csrc/jit/passes/*.h"),
1998("", "torch/csrc/jit/python/*.h"),
1999("", "torch/csrc/jit/frontend/*.h"),
2000("", "torch/csrc/jit/serialization/*.h"),
2001("", "torch/csrc/profiler/**/*.h"),
2002("", "torch/csrc/utils/*.h"),
2003("", "aten/src/ATen/quantized/*.h"),
2004] + ([
2005("third_party/miniz-2.1.0", "*.h"),
2006] if NOT_OSS else []),
2007exclude = [
2008"torch/csrc/jit/serialization/mobile_bytecode_generated.h",
2009],
2010),
2011compiler_flags = get_pt_compiler_flags() + select({
2012"DEFAULT": [],
2013"ovr_config//os:xtensa-xos": [
2014"-fdata-sections",
2015"-ffunction-sections",
2016],
2017}),
2018exported_preprocessor_flags = get_pt_preprocessor_flags() + [
2019"-DMIN_EDGE_RUNTIME",
2020],
2021linker_flags = [
2022"-Wl,--no-as-needed",
2023] + select({
2024"DEFAULT": [],
2025"ovr_config//os:macos": [
2026"-dead_strip",
2027],
2028"ovr_config//os:xtensa-xos": [
2029"-Wl,--gc-sections",
2030],
2031}),
2032visibility = ["PUBLIC"],
2033exported_deps = [
2034":lean_runtime_with_tensor",
2035],
2036)
2037
2038pt_xplat_cxx_library(
2039name = "lean_runtime_with_tensor",
2040srcs = [
2041"aten/src/ATen/Context.cpp",
2042"aten/src/ATen/EmptyTensor.cpp",
2043"aten/src/ATen/Utils.cpp",
2044"aten/src/ATen/detail/CUDAHooksInterface.cpp",
2045"aten/src/ATen/detail/PrivateUse1HooksInterface.cpp",
2046":gen_aten[Operators_0.cpp]",
2047":gen_aten[Operators_1.cpp]",
2048":gen_aten[Operators_2.cpp]",
2049":gen_aten[Operators_3.cpp]",
2050":gen_aten[Operators_4.cpp]",
2051":gen_aten[core/TensorMethods.cpp]",
2052],
2053header_namespace = "",
2054exported_headers = [
2055"torch/csrc/jit/runtime/custom_operator.h",
2056":gen_aten[core/TensorBody.h]",
2057],
2058compiler_flags = get_pt_compiler_flags() + select({
2059"DEFAULT": [],
2060"ovr_config//os:xtensa-xos": [
2061"-fdata-sections",
2062"-ffunction-sections",
2063],
2064}),
2065exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2066"DEFAULT": [],
2067"ovr_config//os:xtensa-xos": [
2068"-Dthread_local=",
2069],
2070}),
2071# @lint-ignore BUCKLINT link_whole
2072link_whole = True,
2073linker_flags = [
2074"-Wl,--no-as-needed",
2075],
2076visibility = ["PUBLIC"],
2077exported_deps = [
2078":generated_aten_config_header",
2079":lean_runtime_with_op",
2080":aten_header",
2081C10,
2082] + (["//xplat/caffe2/fb/embedded:experimental"] if NOT_OSS else []),
2083)
2084
2085pt_xplat_cxx_library(
2086name = "lean_runtime_with_op",
2087srcs = [
2088"aten/src/ATen/SequenceNumber.cpp",
2089"aten/src/ATen/core/boxing/KernelFunction.cpp",
2090"aten/src/ATen/core/custom_class.cpp",
2091"aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp",
2092"aten/src/ATen/core/dispatch/Dispatcher.cpp",
2093"aten/src/ATen/core/dispatch/ObservedOperators.cpp",
2094"aten/src/ATen/core/dispatch/OperatorEntry.cpp",
2095"aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp",
2096"aten/src/ATen/core/interned_strings.cpp",
2097"aten/src/ATen/core/library.cpp",
2098"aten/src/ATen/core/op_registration/infer_schema.cpp",
2099"aten/src/ATen/core/function_schema.cpp",
2100"aten/src/ATen/core/operator_name.cpp",
2101"aten/src/ATen/core/register_symbols.cpp",
2102"aten/src/ATen/core/tensor_type.cpp",
2103"aten/src/ATen/core/union_type.cpp",
2104"aten/src/ATen/record_function.cpp",
2105"torch/csrc/jit/frontend/edit_distance.cpp",
2106"torch/csrc/jit/frontend/error_report.cpp",
2107"torch/csrc/jit/frontend/function_schema_parser.cpp",
2108"torch/csrc/jit/frontend/lexer.cpp",
2109"torch/csrc/jit/frontend/schema_type_parser.cpp",
2110"torch/csrc/jit/frontend/source_range.cpp",
2111"torch/csrc/jit/frontend/strtod.cpp",
2112"torch/csrc/jit/mobile/parse_operators.cpp",
2113"torch/csrc/jit/mobile/prim_ops_registery.cpp",
2114"torch/csrc/jit/runtime/operator.cpp",
2115"torch/csrc/jit/runtime/slice_indices_adjust.cpp",
2116],
2117header_namespace = "",
2118exported_headers = [
2119"torch/csrc/jit/frontend/edit_distance.h",
2120"torch/csrc/jit/runtime/slice_indices_adjust.h",
2121],
2122compiler_flags = get_pt_compiler_flags() + select({
2123"DEFAULT": [],
2124"ovr_config//os:xtensa-xos": [
2125"-fdata-sections",
2126"-ffunction-sections",
2127],
2128}),
2129exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2130"DEFAULT": [],
2131"ovr_config//os:xtensa-xos": [
2132"-Dthread_local=",
2133],
2134}),
2135# @lint-ignore BUCKLINT link_whole
2136link_whole = True,
2137linker_flags = [
2138"-Wl,--no-as-needed",
2139],
2140visibility = ["PUBLIC"],
2141exported_deps = [
2142":min_runtime_lib",
2143C10,
2144],
2145)
2146
2147pt_xplat_cxx_library(
2148name = "min_runtime_lib",
2149srcs = [
2150"aten/src/ATen/ScalarOps.cpp",
2151"aten/src/ATen/core/Dict.cpp",
2152"aten/src/ATen/core/List.cpp",
2153"aten/src/ATen/core/class_type.cpp",
2154"aten/src/ATen/core/dynamic_type.cpp",
2155"aten/src/ATen/core/ivalue.cpp",
2156"aten/src/ATen/core/type.cpp",
2157"aten/src/ATen/core/type_factory.cpp",
2158"aten/src/ATen/native/prim_native_functions.cpp",
2159"torch/csrc/jit/mobile/function.cpp",
2160"torch/csrc/jit/mobile/interpreter.cpp",
2161"torch/csrc/jit/mobile/parse_bytecode.cpp",
2162"torch/csrc/jit/mobile/promoted_prim_ops.cpp",
2163"torch/csrc/jit/mobile/register_ops_common_utils.cpp",
2164"torch/csrc/jit/mobile/type_parser.cpp",
2165"torch/csrc/jit/runtime/instruction.cpp",
2166"torch/csrc/jit/runtime/jit_exception.cpp",
2167"torch/csrc/jit/runtime/vararg_functions.cpp",
2168],
2169header_namespace = "",
2170exported_headers = [
2171"caffe2/serialize/versions.h",
2172"torch/csrc/jit/backends/backend_exception.h",
2173"torch/csrc/jit/mobile/register_ops_common_utils.h",
2174"torch/csrc/jit/runtime/instruction.h",
2175"torch/csrc/jit/runtime/jit_exception.h",
2176"torch/csrc/jit/runtime/operator.h",
2177"torch/csrc/jit/runtime/operator_options.h",
2178"torch/csrc/jit/runtime/vararg_functions.h",
2179"torch/csrc/jit/serialization/import_export_constants.h",
2180"torch/csrc/jit/serialization/import_export_functions.h",
2181],
2182compiler_flags = get_pt_compiler_flags() + select({
2183"DEFAULT": [],
2184"ovr_config//os:xtensa-xos": [
2185"-fexceptions",
2186"-fdata-sections",
2187"-ffunction-sections",
2188],
2189}),
2190exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2191"DEFAULT": [],
2192"ovr_config//os:xtensa-xos": [
2193"-Dthread_local=",
2194],
2195}),
2196# @lint-ignore BUCKLINT link_whole
2197link_whole = True,
2198linker_flags = [
2199"-Wl,--no-as-needed",
2200],
2201visibility = ["PUBLIC"],
2202exported_deps = [
2203":aten_header",
2204":generated_aten_headers_cpu",
2205":jit_core_headers",
2206":torch_mobile_headers",
2207C10,
2208],
2209)
2210