pytorch

Форк
0
/
BUILD.bazel 
1098 строк · 27.8 Кб
1
load("@bazel_skylib//lib:paths.bzl", "paths")
2
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
3
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
4
load("@rules_python//python:defs.bzl", "py_library", "py_test")
5
load("@pytorch//third_party:substitution.bzl", "header_template_rule", "template_rule")
6
load("@pytorch//:tools/bazel.bzl", "rules")
7
load("@pytorch//tools/rules:cu.bzl", "cu_library")
8
load("@pytorch//tools/config:defs.bzl", "if_cuda")
9
load("@pytorch//:aten.bzl", "generate_aten", "intern_build_aten_ops")
10
load(":build.bzl", "GENERATED_AUTOGRAD_CPP", "GENERATED_AUTOGRAD_PYTHON", "define_targets")
11
load(":build_variables.bzl", "jit_core_sources", "lazy_tensor_ts_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_python_core_sources", "torch_cpp_srcs", "libtorch_python_cuda_sources", "libtorch_python_distributed_sources")
12
load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources")
13
load("//:tools/bazel.bzl", "rules")
14

15
define_targets(rules = rules)
16

17
COMMON_COPTS = [
18
    "-DHAVE_MALLOC_USABLE_SIZE=1",
19
    "-DHAVE_MMAP=1",
20
    "-DHAVE_SHM_OPEN=1",
21
    "-DHAVE_SHM_UNLINK=1",
22
    "-D_FILE_OFFSET_BITS=64",
23
    "-DUSE_FBGEMM",
24
    "-DUSE_DISTRIBUTED",
25
    "-DAT_PER_OPERATOR_HEADERS",
26
    "-DATEN_THREADING=NATIVE",
27
    "-DNO_CUDNN_DESTROY_HANDLE",
28
] + if_cuda([
29
    "-DUSE_CUDA",
30
    "-DUSE_CUDNN",
31
    # TODO: This should be passed only when building for CUDA-11.5 or newer
32
    # use cub in a safe manner, see:
33
    # https://github.com/pytorch/pytorch/pull/55292
34
    "-DCUB_WRAPPED_NAMESPACE=at_cuda_detail",
35
])
36

37
aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/src/ATen/native/tags.yaml"] + glob(["aten/src/ATen/templates/**"])
38

39
generated_cpu_cpp = [
40
    "aten/src/ATen/RegisterBackendSelect.cpp",
41
    "aten/src/ATen/RegisterCPU.cpp",
42
    "aten/src/ATen/RegisterFunctionalization_0.cpp",
43
    "aten/src/ATen/RegisterFunctionalization_1.cpp",
44
    "aten/src/ATen/RegisterFunctionalization_2.cpp",
45
    "aten/src/ATen/RegisterFunctionalization_3.cpp",
46
    # "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
47
    "aten/src/ATen/RegisterMkldnnCPU.cpp",
48
    "aten/src/ATen/RegisterNestedTensorCPU.cpp",
49
    "aten/src/ATen/RegisterQuantizedCPU.cpp",
50
    "aten/src/ATen/RegisterSparseCPU.cpp",
51
    "aten/src/ATen/RegisterSparseCsrCPU.cpp",
52
    "aten/src/ATen/RegisterZeroTensor.cpp",
53
    "aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
54
    "aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
55
    "aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
56
    "aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
57
    "aten/src/ATen/RegisterMeta.cpp",
58
    "aten/src/ATen/RegisterSparseMeta.cpp",
59
    "aten/src/ATen/RegisterQuantizedMeta.cpp",
60
    "aten/src/ATen/RegisterNestedTensorMeta.cpp",
61
    "aten/src/ATen/RegisterSchema.cpp",
62
    "aten/src/ATen/CPUFunctions.h",
63
    "aten/src/ATen/CPUFunctions_inl.h",
64
    "aten/src/ATen/CompositeExplicitAutogradFunctions.h",
65
    "aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h",
66
    "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
67
    "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
68
    "aten/src/ATen/CompositeImplicitAutogradFunctions.h",
69
    "aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
70
    "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h",
71
    "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h",
72
    "aten/src/ATen/CompositeViewCopyKernels.cpp",
73
    "aten/src/ATen/FunctionalInverses.h",
74
    "aten/src/ATen/Functions.h",
75
    "aten/src/ATen/Functions.cpp",
76
    "aten/src/ATen/RedispatchFunctions.h",
77
    "aten/src/ATen/Operators.h",
78
    "aten/src/ATen/Operators_0.cpp",
79
    "aten/src/ATen/Operators_1.cpp",
80
    "aten/src/ATen/Operators_2.cpp",
81
    "aten/src/ATen/Operators_3.cpp",
82
    "aten/src/ATen/Operators_4.cpp",
83
    "aten/src/ATen/NativeFunctions.h",
84
    "aten/src/ATen/MetaFunctions.h",
85
    "aten/src/ATen/MetaFunctions_inl.h",
86
    "aten/src/ATen/MethodOperators.h",
87
    "aten/src/ATen/NativeMetaFunctions.h",
88
    "aten/src/ATen/RegistrationDeclarations.h",
89
    "aten/src/ATen/VmapGeneratedPlumbing.h",
90
    "aten/src/ATen/core/aten_interned_strings.h",
91
    "aten/src/ATen/core/enum_tag.h",
92
    "aten/src/ATen/core/TensorBody.h",
93
    "aten/src/ATen/core/TensorMethods.cpp",
94
    "aten/src/ATen/core/ATenOpList.cpp",
95
]
96

97
generated_cuda_cpp = [
98
    "aten/src/ATen/CUDAFunctions.h",
99
    "aten/src/ATen/CUDAFunctions_inl.h",
100
    "aten/src/ATen/RegisterCUDA.cpp",
101
    "aten/src/ATen/RegisterNestedTensorCUDA.cpp",
102
    "aten/src/ATen/RegisterQuantizedCUDA.cpp",
103
    "aten/src/ATen/RegisterSparseCUDA.cpp",
104
    "aten/src/ATen/RegisterSparseCsrCUDA.cpp",
105
]
106

107
generate_aten(
108
    name = "generated_aten_cpp",
109
    srcs = aten_generation_srcs,
110
    outs = (
111
        generated_cpu_cpp +
112
        generated_cuda_cpp +
113
        aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") +
114
        aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") +
115
        aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + [
116
            "aten/src/ATen/Declarations.yaml",
117
        ]
118
    ),
119
    generator = "//torchgen:gen",
120
)
121

122
filegroup(
123
    name = "cpp_generated_code",
124
    srcs = GENERATED_AUTOGRAD_CPP,
125
    data = [":generate-code"],
126
)
127

128
# ATen
129
filegroup(
130
    name = "aten_base_cpp",
131
    srcs = glob([
132
        "aten/src/ATen/*.cpp",
133
        "aten/src/ATen/functorch/*.cpp",
134
        "aten/src/ATen/detail/*.cpp",
135
        "aten/src/ATen/cpu/*.cpp",
136
    ]),
137
)
138

139
filegroup(
140
    name = "ATen_CORE_SRCS",
141
    srcs = glob(
142
        [
143
            "aten/src/ATen/core/**/*.cpp",
144
        ],
145
        exclude = [
146
            "aten/src/ATen/core/**/*_test.cpp",
147
        ],
148
    ),
149
)
150

151
filegroup(
152
    name = "aten_native_cpp",
153
    srcs = glob(["aten/src/ATen/native/*.cpp"]),
154
)
155

156
filegroup(
157
    name = "aten_native_sparse_cpp",
158
    srcs = glob(["aten/src/ATen/native/sparse/*.cpp"]),
159
)
160

161
filegroup(
162
    name = "aten_native_nested_cpp",
163
    srcs = glob(["aten/src/ATen/native/nested/*.cpp"]),
164
)
165

166
filegroup(
167
    name = "aten_native_quantized_cpp",
168
    srcs = glob(
169
        [
170
            "aten/src/ATen/native/quantized/*.cpp",
171
            "aten/src/ATen/native/quantized/cpu/*.cpp",
172
        ],
173
    ),
174
)
175

176
filegroup(
177
    name = "aten_native_transformers_cpp",
178
    srcs = glob(["aten/src/ATen/native/transformers/*.cpp"]),
179
)
180

181
filegroup(
182
    name = "aten_native_mkl_cpp",
183
    srcs = glob([
184
        "aten/src/ATen/native/mkl/*.cpp",
185
        "aten/src/ATen/mkl/*.cpp",
186
    ]),
187
)
188

189
filegroup(
190
    name = "aten_native_mkldnn_cpp",
191
    srcs = glob(["aten/src/ATen/native/mkldnn/*.cpp"]),
192
)
193

194
filegroup(
195
    name = "aten_native_xnnpack",
196
    srcs = glob(["aten/src/ATen/native/xnnpack/*.cpp"]),
197
)
198

199
filegroup(
200
    name = "aten_base_vulkan",
201
    srcs = glob(["aten/src/ATen/vulkan/*.cpp"]),
202
)
203

204
filegroup(
205
    name = "aten_base_metal",
206
    srcs = glob(["aten/src/ATen/metal/*.cpp"]),
207
)
208

209
filegroup(
210
    name = "ATen_QUANTIZED_SRCS",
211
    srcs = glob(
212
        [
213
            "aten/src/ATen/quantized/**/*.cpp",
214
        ],
215
        exclude = [
216
            "aten/src/ATen/quantized/**/*_test.cpp",
217
        ],
218
    ),
219
)
220

221
filegroup(
222
    name = "aten_cuda_cpp_srcs",
223
    srcs = glob(
224
        [
225
            "aten/src/ATen/cuda/*.cpp",
226
            "aten/src/ATen/cuda/detail/*.cpp",
227
            "aten/src/ATen/cuda/tunable/*.cpp",
228
            "aten/src/ATen/cudnn/*.cpp",
229
            "aten/src/ATen/native/cuda/*.cpp",
230
            "aten/src/ATen/native/cuda/linalg/*.cpp",
231
            "aten/src/ATen/native/cudnn/*.cpp",
232
            "aten/src/ATen/native/miopen/*.cpp",
233
            "aten/src/ATen/native/nested/cuda/*.cpp",
234
            "aten/src/ATen/native/quantized/cuda/*.cpp",
235
            "aten/src/ATen/native/quantized/cudnn/*.cpp",
236
            "aten/src/ATen/native/sparse/cuda/*.cpp",
237
            "aten/src/ATen/native/transformers/cuda/*.cpp",
238
        ],
239
    ),
240
)
241

242
filegroup(
243
    name = "aten_cu_srcs",
244
    srcs = glob([
245
        "aten/src/ATen/cuda/*.cu",
246
        "aten/src/ATen/cuda/detail/*.cu",
247
        "aten/src/ATen/native/cuda/*.cu",
248
        "aten/src/ATen/native/nested/cuda/*.cu",
249
        "aten/src/ATen/native/quantized/cuda/*.cu",
250
        "aten/src/ATen/native/sparse/cuda/*.cu",
251
        "aten/src/ATen/native/transformers/cuda/*.cu",
252
    ]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"),
253
    # It's a bit puzzling to me why it's not necessary to declare the
254
    # target that generates these sources...
255
)
256

257
header_template_rule(
258
    name = "aten_src_ATen_config",
259
    src = "aten/src/ATen/Config.h.in",
260
    out = "aten/src/ATen/Config.h",
261
    include = "aten/src",
262
    substitutions = {
263
        "@AT_MKLDNN_ENABLED@": "1",
264
        "@AT_MKLDNN_ACL_ENABLED@": "0",
265
        "@AT_MKL_ENABLED@": "1",
266
        "@AT_MKL_SEQUENTIAL@": "0",
267
        "@AT_POCKETFFT_ENABLED@": "0",
268
        "@AT_NNPACK_ENABLED@": "0",
269
        "@CAFFE2_STATIC_LINK_CUDA_INT@": "0",
270
        "@AT_BUILD_WITH_BLAS@": "1",
271
        "@AT_BUILD_WITH_LAPACK@": "1",
272
        "@AT_PARALLEL_OPENMP@": "0",
273
        "@AT_PARALLEL_NATIVE@": "1",
274
        "@AT_BLAS_F2C@": "0",
275
        "@AT_BLAS_USE_CBLAS_DOT@": "1",
276
    },
277
)
278

279
header_template_rule(
280
    name = "aten_src_ATen_cuda_config",
281
    src = "aten/src/ATen/cuda/CUDAConfig.h.in",
282
    out = "aten/src/ATen/cuda/CUDAConfig.h",
283
    include = "aten/src",
284
    substitutions = {
285
        "@AT_CUDNN_ENABLED@": "1",
286
        "@AT_CUSPARSELT_ENABLED@": "0",
287
        "@AT_ROCM_ENABLED@": "0",
288
        "@AT_MAGMA_ENABLED@": "0",
289
        "@NVCC_FLAGS_EXTRA@": "",
290
    },
291
)
292

293
cc_library(
294
    name = "aten_headers",
295
    hdrs = [
296
        "torch/csrc/Export.h",
297
        "torch/csrc/jit/frontend/function_schema_parser.h",
298
    ] + glob(
299
        [
300
            "aten/src/**/*.h",
301
            "aten/src/**/*.hpp",
302
            "aten/src/ATen/cuda/**/*.cuh",
303
            "aten/src/ATen/native/**/*.cuh",
304
            "aten/src/THC/*.cuh",
305
        ],
306
    ) + [
307
        ":aten_src_ATen_config",
308
        ":generated_aten_cpp",
309
    ],
310
    includes = [
311
        "aten/src",
312
    ],
313
    deps = [
314
        "//c10",
315
    ],
316
)
317

318
ATEN_COPTS = COMMON_COPTS + [
319
    "-DCAFFE2_BUILD_MAIN_LIBS",
320
    "-DHAVE_AVX_CPU_DEFINITION",
321
    "-DHAVE_AVX2_CPU_DEFINITION",
322
    "-fvisibility-inlines-hidden",
323
    "-fno-math-errno",
324
    "-fno-trapping-math",
325
]
326

327
intern_build_aten_ops(
328
    copts = ATEN_COPTS,
329
    extra_impls = aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}"),
330
    deps = [
331
        ":aten_headers",
332
        "@fbgemm",
333
        "@mkl",
334
        "@sleef",
335
    ],
336
)
337

338
cc_library(
339
    name = "aten",
340
    srcs = [
341
        ":ATen_CORE_SRCS",
342
        ":ATen_QUANTIZED_SRCS",
343
        ":aten_base_cpp",
344
        ":aten_base_metal",
345
        ":aten_base_vulkan",
346
        ":aten_native_cpp",
347
        ":aten_native_mkl_cpp",
348
        ":aten_native_mkldnn_cpp",
349
        ":aten_native_nested_cpp",
350
        ":aten_native_quantized_cpp",
351
        ":aten_native_sparse_cpp",
352
        ":aten_native_transformers_cpp",
353
        ":aten_native_xnnpack",
354
        ":aten_src_ATen_config",
355
    ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
356
    copts = ATEN_COPTS,
357
    linkopts = [
358
      "-ldl",
359
    ],
360
    data = if_cuda(
361
        [":libcaffe2_nvrtc.so"],
362
        [],
363
    ),
364
    visibility = ["//visibility:public"],
365
    deps = [
366
        ":ATen_CPU",
367
        ":aten_headers",
368
        ":caffe2_for_aten_headers",
369
        ":torch_headers",
370
        "@fbgemm",
371
        "@ideep",
372
    ],
373
    alwayslink = True,
374
)
375

376
cc_library(
377
    name = "aten_nvrtc",
378
    srcs = glob([
379
        "aten/src/ATen/cuda/nvrtc_stub/*.cpp",
380
    ]),
381
    copts = ATEN_COPTS,
382
    linkstatic = True,
383
    visibility = ["//visibility:public"],
384
    deps = [
385
        ":aten_headers",
386
        "//c10",
387
        "@cuda",
388
        "@cuda//:cuda_driver",
389
        "@cuda//:nvrtc",
390
    ],
391
    alwayslink = True,
392
)
393

394
cc_binary(
395
    name = "libcaffe2_nvrtc.so",
396
    linkshared = True,
397
    visibility = ["//visibility:public"],
398
    deps = [
399
        ":aten_nvrtc",
400
    ],
401
)
402

403
cc_library(
404
    name = "aten_cuda_cpp",
405
    srcs = [":aten_cuda_cpp_srcs"] + generated_cuda_cpp,
406
    hdrs = [":aten_src_ATen_cuda_config"],
407
    copts = ATEN_COPTS,
408
    visibility = ["//visibility:public"],
409
    deps = [
410
        ":aten",
411
        "@cuda",
412
        "@cuda//:cusolver",
413
        "@cuda//:nvrtc",
414
        "@cudnn",
415
        "@cudnn_frontend",
416
    ],
417
    alwayslink = True,
418
)
419

420
torch_cuda_half_options = [
421
    "-DCUDA_HAS_FP16=1",
422
    "-D__CUDA_NO_HALF_OPERATORS__",
423
    "-D__CUDA_NO_HALF_CONVERSIONS__",
424
    "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
425
    "-D__CUDA_NO_HALF2_OPERATORS__",
426
]
427

428
cu_library(
429
    name = "aten_cuda",
430
    srcs = [":aten_cu_srcs"],
431
    copts = ATEN_COPTS + torch_cuda_half_options,
432
    visibility = ["//visibility:public"],
433
    deps = [
434
        ":aten_cuda_cpp",
435
        "//c10/util:bit_cast",
436
        "@cuda//:cublas",
437
        "@cuda//:cufft",
438
        "@cuda//:cusparse",
439
        "@cutlass",
440
    ],
441
    alwayslink = True,
442
)
443

444
# caffe2
445
CAFFE2_COPTS = COMMON_COPTS + [
446
    "-Dcaffe2_EXPORTS",
447
    "-DCAFFE2_USE_CUDNN",
448
    "-DCAFFE2_BUILD_MAIN_LIB",
449
    "-fvisibility-inlines-hidden",
450
    "-fno-math-errno",
451
    "-fno-trapping-math",
452
]
453

454
filegroup(
455
    name = "caffe2_core_srcs",
456
    srcs = [
457
        "caffe2/core/common.cc",
458
    ],
459
)
460

461
filegroup(
462
    name = "caffe2_perfkernels_srcs",
463
    srcs = [
464
        "caffe2/perfkernels/embedding_lookup_idx.cc",
465
    ],
466
)
467

468

469
filegroup(
470
    name = "caffe2_serialize_srcs",
471
    srcs = [
472
        "caffe2/serialize/file_adapter.cc",
473
        "caffe2/serialize/inline_container.cc",
474
        "caffe2/serialize/istream_adapter.cc",
475
        "caffe2/serialize/read_adapter_interface.cc",
476
    ],
477
)
478

479
filegroup(
480
    name = "caffe2_utils_srcs",
481
    srcs = [
482
        "caffe2/utils/proto_wrap.cc",
483
        "caffe2/utils/string_utils.cc",
484
        "caffe2/utils/threadpool/ThreadPool.cc",
485
        "caffe2/utils/threadpool/pthreadpool.cc",
486
        "caffe2/utils/threadpool/pthreadpool_impl.cc",
487
        "caffe2/utils/threadpool/thread_pool_guard.cpp",
488
    ],
489
)
490

491
# To achieve finer granularity and make debug easier, caffe2 is split into three libraries:
492
# ATen, caffe2 and caffe2_for_aten_headers. ATen lib group up source codes under
493
# aten/ directory and caffe2 contains most files under `caffe2/` directory. Since the
494
# ATen lib and the caffe2 lib would depend on each other, `caffe2_for_aten_headers` is splitted
495
# out from `caffe2` to avoid dependency cycle.
496
cc_library(
497
    name = "caffe2_for_aten_headers",
498
    hdrs = [
499
        "caffe2/core/common.h",
500
        "caffe2/perfkernels/common.h",
501
        "caffe2/perfkernels/embedding_lookup_idx.h",
502
        "caffe2/utils/fixed_divisor.h",
503
    ] + glob([
504
        "caffe2/utils/threadpool/*.h",
505
    ]),
506
    copts = CAFFE2_COPTS,
507
    visibility = ["//visibility:public"],
508
    deps = [
509
        ":caffe2_core_macros",
510
        "//c10",
511
    ],
512
)
513

514
cc_library(
515
    name = "caffe2_headers",
516
    hdrs = glob(
517
        [
518
            "caffe2/perfkernels/*.h",
519
            "caffe2/serialize/*.h",
520
            "caffe2/utils/*.h",
521
            "caffe2/utils/threadpool/*.h",
522
            "modules/**/*.h",
523
        ],
524
        exclude = [
525
            "caffe2/core/macros.h",
526
        ],
527
    ) + if_cuda(glob([
528
        "caffe2/**/*.cuh",
529
    ])),
530
    copts = CAFFE2_COPTS,
531
    visibility = ["//visibility:public"],
532
    deps = [
533
        ":caffe2_core_macros",
534
        ":caffe2_for_aten_headers",
535
    ],
536
)
537

538
cc_library(
539
    name = "caffe2",
540
    srcs = [
541
        ":caffe2_core_srcs",
542
        ":caffe2_perfkernels_srcs",
543
        ":caffe2_serialize_srcs",
544
        ":caffe2_utils_srcs",
545
    ],
546
    copts = CAFFE2_COPTS + ["-mf16c"],
547
    linkstatic = 1,
548
    visibility = ["//visibility:public"],
549
    deps = [
550
        ":caffe2_core_macros",
551
        ":caffe2_headers",
552
        ":caffe2_perfkernels_avx",
553
        ":caffe2_perfkernels_avx2",
554
        "//third_party/miniz-2.1.0:miniz",
555
        "@com_google_protobuf//:protobuf",
556
        "@eigen",
557
        "@fbgemm//:fbgemm_src_headers",
558
        "@fmt",
559
        "@onnx",
560
    ] + if_cuda(
561
        [
562
            ":aten_cuda",
563
            "@tensorpipe//:tensorpipe_cuda",
564
        ],
565
        [
566
            ":aten",
567
            "@tensorpipe//:tensorpipe_cpu",
568
        ],
569
    ),
570
    alwayslink = True,
571
)
572

573
cu_library(
574
    name = "torch_cuda",
575
    srcs = [
576
        "torch/csrc/distributed/c10d/intra_node_comm.cu",
577
        "torch/csrc/distributed/c10d/NanCheck.cu",
578
        "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
579
    ],
580
    copts = torch_cuda_half_options,
581
    visibility = ["//visibility:public"],
582
    deps = [
583
        ":aten",
584
        "@cuda//:cublas",
585
        "@cuda//:curand",
586
        "@cudnn",
587
        "@eigen",
588
        "@tensorpipe//:tensorpipe_cuda",
589
    ],
590
    alwayslink = True,
591
)
592

593
PERF_COPTS = [
594
    "-DHAVE_AVX_CPU_DEFINITION",
595
    "-DHAVE_AVX2_CPU_DEFINITION",
596
    "-DENABLE_ALIAS=1",
597
    "-DHAVE_MALLOC_USABLE_SIZE=1",
598
    "-DHAVE_MMAP=1",
599
    "-DHAVE_SHM_OPEN=1",
600
    "-DHAVE_SHM_UNLINK=1",
601
    "-DSLEEF_STATIC_LIBS=1",
602
    "-DTH_BALS_MKL",
603
    "-D_FILE_OFFSET_BITS=64",
604
    "-DUSE_FBGEMM",
605
    "-fvisibility-inlines-hidden",
606
    "-Wunused-parameter",
607
    "-fno-math-errno",
608
    "-fno-trapping-math",
609
    "-mf16c",
610
]
611

612
PERF_HEADERS = glob([
613
    "caffe2/perfkernels/*.h",
614
    "caffe2/core/*.h",
615
])
616

617
cc_library(
618
    name = "caffe2_perfkernels_avx",
619
    srcs = glob([
620
        "caffe2/perfkernels/*_avx.cc",
621
    ]),
622
    hdrs = PERF_HEADERS,
623
    copts = PERF_COPTS + [
624
        "-mavx",
625
    ],
626
    visibility = ["//visibility:public"],
627
    deps = [
628
        ":caffe2_headers",
629
        "//c10",
630
    ],
631
    alwayslink = True,
632
)
633

634
cc_library(
635
    name = "caffe2_perfkernels_avx2",
636
    srcs = glob([
637
        "caffe2/perfkernels/*_avx2.cc",
638
    ]),
639
    hdrs = PERF_HEADERS,
640
    copts = PERF_COPTS + [
641
        "-mavx2",
642
        "-mfma",
643
        "-mavx",
644
    ],
645
    visibility = ["//visibility:public"],
646
    deps = [
647
        ":caffe2_headers",
648
        "//c10",
649
    ],
650
    alwayslink = True,
651
)
652

653
# torch
654
torch_cuda_headers = glob(["torch/csrc/cuda/*.h"])
655

656
cc_library(
657
    name = "torch_headers",
658
    hdrs = if_cuda(
659
        torch_cuda_headers,
660
    ) + glob(
661
        [
662
            "torch/*.h",
663
            "torch/csrc/**/*.h",
664
            "torch/csrc/distributed/c10d/**/*.hpp",
665
            "torch/lib/libshm/*.h",
666
        ],
667
        exclude = [
668
            "torch/csrc/*/generated/*.h",
669
        ] + torch_cuda_headers,
670
    ) + GENERATED_AUTOGRAD_CPP + [":version_h"],
671
    includes = [
672
        "third_party/kineto/libkineto/include",
673
        "torch/csrc",
674
        "torch/csrc/api/include",
675
        "torch/csrc/distributed",
676
        "torch/lib",
677
        "torch/lib/libshm",
678
    ],
679
    visibility = ["//visibility:public"],
680
    deps = [
681
        ":aten_headers",
682
        ":caffe2_headers",
683
        "//c10",
684
        "@com_github_google_flatbuffers//:flatbuffers",
685
        "@local_config_python//:python_headers",
686
        "@onnx",
687
    ],
688
    alwayslink = True,
689
)
690

691
TORCH_COPTS = COMMON_COPTS + [
692
    "-Dtorch_EXPORTS",
693
    "-DHAVE_AVX_CPU_DEFINITION",
694
    "-DHAVE_AVX2_CPU_DEFINITION",
695
    "-DCAFFE2_USE_GLOO",
696
    "-fvisibility-inlines-hidden",
697
    "-fno-math-errno ",
698
    "-fno-trapping-math",
699
    "-Wno-error=unused-function",
700
]
701

702
torch_sources = {
703
    k: ""
704
    for k in (
705
        libtorch_core_sources +
706
        libtorch_distributed_sources +
707
        torch_cpp_srcs +
708
        libtorch_extra_sources +
709
        jit_core_sources +
710
        lazy_tensor_ts_sources +
711
        GENERATED_AUTOGRAD_CPP
712
    )
713
}.keys()
714

715
cc_library(
716
    name = "torch",
717
    srcs = if_cuda(glob(
718
        libtorch_cuda_sources,
719
        exclude = [
720
            "torch/csrc/cuda/python_nccl.cpp",
721
            "torch/csrc/cuda/nccl.cpp",
722
            "torch/csrc/distributed/c10d/intra_node_comm.cu",
723
            "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
724
            "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
725
            "torch/csrc/distributed/c10d/NanCheck.cu",
726
            "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
727
        ],
728
    )) + torch_sources,
729
    copts = TORCH_COPTS,
730
    linkopts = [
731
      "-lrt",
732
    ],
733
    defines = [
734
        "CAFFE2_NIGHTLY_VERSION=20200115",
735
    ],
736
    visibility = ["//visibility:public"],
737
    deps = [
738
        ":caffe2",
739
        ":torch_headers",
740
        "@kineto",
741
        "@cpp-httplib",
742
        "@nlohmann",
743
    ] + if_cuda([
744
        "@cuda//:nvToolsExt",
745
        "@cutlass",
746
        ":torch_cuda",
747
    ]),
748
    alwayslink = True,
749
)
750

751
cc_library(
752
    name = "shm",
753
    srcs = glob(["torch/lib/libshm/*.cpp"]),
754
    linkopts = [
755
      "-lrt",
756
    ],
757
    deps = [
758
        ":torch",
759
    ],
760
)
761

762
cc_library(
763
    name = "libtorch_headers",
764
    hdrs = glob([
765
        "**/*.h",
766
        "**/*.cuh",
767
    ]) + [
768
        # We need the filegroup here because the raw list causes Bazel
769
        # to see duplicate files. It knows how to deduplicate with the
770
        # filegroup.
771
        ":cpp_generated_code",
772
    ],
773
    includes = [
774
        "torch/csrc/api/include",
775
        "torch/csrc/distributed",
776
        "torch/lib",
777
        "torch/lib/libshm",
778
    ],
779
    visibility = ["//visibility:public"],
780
    deps = [
781
        ":torch_headers",
782
    ],
783
)
784

785
cc_library(
786
    name = "torch_python",
787
    srcs = libtorch_python_core_sources
788
        + if_cuda(libtorch_python_cuda_sources)
789
        + if_cuda(libtorch_python_distributed_sources)
790
        + GENERATED_AUTOGRAD_PYTHON,
791
    hdrs = glob([
792
        "torch/csrc/generic/*.cpp",
793
    ]),
794
    copts = COMMON_COPTS + if_cuda(["-DUSE_CUDA=1"]),
795
    deps = [
796
        ":torch",
797
        ":shm",
798
        "@pybind11",
799
    ],
800
)
801

802
pybind_extension(
803
    name = "torch/_C",
804
    srcs = ["torch/csrc/stub.c"],
805
    deps = [
806
        ":torch_python",
807
        ":aten_nvrtc",
808
    ],
809
)
810

811
cc_library(
812
    name = "functorch",
813
    hdrs = glob([
814
        "functorch/csrc/dim/*.h",
815
    ]),
816
    srcs = glob([
817
        "functorch/csrc/dim/*.cpp",
818
    ]),
819
    deps = [
820
        ":aten_nvrtc",
821
        ":torch_python",
822
        "@pybind11",
823
    ],
824
)
825

826
pybind_extension(
827
    name = "functorch/_C",
828
    copts=[
829
        "-DTORCH_EXTENSION_NAME=_C"
830
    ],
831
    srcs = [
832
        "functorch/csrc/init_dim_only.cpp",
833
    ],
834
    deps = [
835
        ":functorch",
836
        ":torch_python",
837
        ":aten_nvrtc",
838
    ],
839
)
840

841
cc_binary(
842
    name = "torch/bin/torch_shm_manager",
843
    srcs = [
844
        "torch/lib/libshm/manager.cpp",
845
    ],
846
    deps = [
847
        ":shm",
848
    ],
849
    linkstatic = False,
850
)
851

852
template_rule(
853
    name = "gen_version_py",
854
    src = ":torch/version.py.tpl",
855
    out = "torch/version.py",
856
    substitutions = if_cuda({
857
        # Set default to 11.2. Otherwise Torchvision complains about incompatibility.
858
        "{{CUDA_VERSION}}": "11.2",
859
        "{{VERSION}}": "2.0.0",
860
    }, {
861
        "{{CUDA_VERSION}}": "None",
862
        "{{VERSION}}": "2.0.0",
863
    }),
864
)
865

866
py_library(
867
    name = "pytorch_py",
868
    visibility = ["//visibility:public"],
869
    srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
870
    deps = [
871
        rules.requirement("numpy"),
872
        rules.requirement("pyyaml"),
873
        rules.requirement("requests"),
874
        rules.requirement("setuptools"),
875
        rules.requirement("sympy"),
876
        rules.requirement("typing_extensions"),
877
        "//torchgen",
878
    ],
879
    data = [
880
        ":torch/_C.so",
881
        ":functorch/_C.so",
882
        ":torch/bin/torch_shm_manager",
883
    ],
884
)
885

886
# cpp api tests
887
cc_library(
888
    name = "test_support",
889
    testonly = True,
890
    srcs = [
891
        "test/cpp/api/support.cpp",
892
    ],
893
    hdrs = [
894
        "test/cpp/api/init_baseline.h",
895
        "test/cpp/api/optim_baseline.h",
896
        "test/cpp/api/support.h",
897
        "test/cpp/common/support.h",
898
    ],
899
    deps = [
900
        ":torch",
901
        "@com_google_googletest//:gtest_main",
902
    ],
903
)
904

905
# Torch integration tests rely on a labeled data set from the MNIST database.
906
# http://yann.lecun.com/exdb/mnist/
907

908
cpp_api_tests = glob(
909
    ["test/cpp/api/*.cpp"],
910
    exclude = [
911
        "test/cpp/api/imethod.cpp",
912
        "test/cpp/api/integration.cpp",
913
    ],
914
)
915

916
cc_test(
917
    name = "integration_test",
918
    size = "medium",
919
    srcs = ["test/cpp/api/integration.cpp"],
920
    data = [
921
        ":download_mnist",
922
    ],
923
    tags = [
924
        "gpu-required",
925
    ],
926
    deps = [
927
        ":test_support",
928
        "@com_google_googletest//:gtest_main",
929
    ],
930
)
931

932
[
933
    cc_test(
934
        name = paths.split_extension(paths.basename(filename))[0].replace("-", "_") + "_test",
935
        size = "medium",
936
        srcs = [filename],
937
        deps = [
938
            ":test_support",
939
            "@com_google_googletest//:gtest_main",
940
        ],
941
    )
942
    for filename in cpp_api_tests
943
]
944

945
test_suite(
946
    name = "api_tests",
947
    tests = [
948
        "any_test",
949
        "autograd_test",
950
        "dataloader_test",
951
        "enum_test",
952
        "expanding_array_test",
953
        "functional_test",
954
        "init_test",
955
        "integration_test",
956
        "jit_test",
957
        "memory_test",
958
        "misc_test",
959
        "module_test",
960
        "modulelist_test",
961
        "modules_test",
962
        "nn_utils_test",
963
        "optim_test",
964
        "ordered_dict_test",
965
        "rnn_test",
966
        "sequential_test",
967
        "serialize_test",
968
        "static_test",
969
        "tensor_options_test",
970
        "tensor_test",
971
        "torch_include_test",
972
    ],
973
)
974

975
# dist autograd tests
976
cc_test(
977
    name = "torch_dist_autograd_test",
978
    size = "small",
979
    srcs = ["test/cpp/dist_autograd/test_dist_autograd.cpp"],
980
    tags = [
981
        "exclusive",
982
        "gpu-required",
983
    ],
984
    deps = [
985
        ":torch",
986
        "@com_google_googletest//:gtest_main",
987
    ],
988
)
989

990
# jit tests
991
# Because these individual unit tests require custom registering,
992
# it is easier to mimic the cmake build by globing together a single test.
993
cc_test(
994
    name = "jit_tests",
995
    size = "small",
996
    srcs = glob(
997
        [
998
            "test/cpp/jit/*.cpp",
999
            "test/cpp/jit/*.h",
1000
            "test/cpp/tensorexpr/*.cpp",
1001
            "test/cpp/tensorexpr/*.h",
1002
        ],
1003
        exclude = [
1004
            # skip this since <pybind11/embed.h> is not found in OSS build
1005
            "test/cpp/jit/test_exception.cpp",
1006
        ],
1007
    ),
1008
    linkstatic = True,
1009
    tags = [
1010
        "exclusive",
1011
        "gpu-required",
1012
    ],
1013
    deps = [
1014
        ":torch",
1015
        "@com_google_googletest//:gtest_main",
1016
    ],
1017
)
1018

1019
cc_test(
1020
    name = "lazy_tests",
1021
    size = "small",
1022
    srcs = glob(
1023
        [
1024
            "test/cpp/lazy/*.cpp",
1025
            "test/cpp/lazy/*.h",
1026
        ],
1027
        exclude = [
1028
            # skip these since they depend on generated LazyIr.h which isn't available in bazel yet
1029
            "test/cpp/lazy/test_ir.cpp",
1030
            "test/cpp/lazy/test_lazy_ops.cpp",
1031
            "test/cpp/lazy/test_lazy_ops_util.cpp",
1032
        ],
1033
    ),
1034
    linkstatic = True,
1035
    tags = [
1036
        "exclusive",
1037
    ],
1038
    deps = [
1039
        ":torch",
1040
        "@com_google_googletest//:gtest_main",
1041
    ],
1042
)
1043

1044
# python api tests
1045

1046
py_test(
1047
    name = "test_bazel",
1048
    srcs = ["test/_test_bazel.py"],
1049
    main = "test/_test_bazel.py",
1050
    deps = [":pytorch_py"],
1051
)
1052

1053
# all tests
1054
test_suite(
1055
    name = "all_tests",
1056
    tests = [
1057
        "api_tests",
1058
        "jit_tests",
1059
        "torch_dist_autograd_test",
1060
        "//c10/test:tests",
1061
    ],
1062
)
1063

1064
# An internal genrule that we are converging with refers to these file
1065
# as if they are from this package, so we alias them for
1066
# compatibility.
1067

1068
[
1069
    alias(
1070
        name = paths.basename(path),
1071
        actual = path,
1072
    )
1073
    for path in [
1074
        "aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
1075
        "aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
1076
        "aten/src/ATen/templates/LazyIr.h",
1077
        "aten/src/ATen/templates/LazyNonNativeIr.h",
1078
        "aten/src/ATen/templates/RegisterDispatchKey.cpp",
1079
        "aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
1080
        "aten/src/ATen/native/native_functions.yaml",
1081
        "aten/src/ATen/native/tags.yaml",
1082
        "aten/src/ATen/native/ts_native_functions.yaml",
1083
        "torch/csrc/lazy/core/shape_inference.h",
1084
        "torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
1085
    ]
1086
]
1087

1088
genrule(
1089
    name = "download_mnist",
1090
    srcs = ["//:tools/download_mnist.py"],
1091
    outs = [
1092
        "mnist/train-images-idx3-ubyte",
1093
        "mnist/train-labels-idx1-ubyte",
1094
        "mnist/t10k-images-idx3-ubyte",
1095
        "mnist/t10k-labels-idx1-ubyte",
1096
    ],
1097
    cmd = "python3 tools/download_mnist.py -d $(RULEDIR)/mnist",
1098
)
1099

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.