pytorch
/
ufunc_defs.bzl
25 строк · 800.0 Байт
1load("@bazel_skylib//lib:paths.bzl", "paths")
2load(":build_variables.bzl", "aten_ufunc_headers")
3
4aten_ufunc_names = [
5paths.split_extension(paths.basename(h))[0]
6for h in aten_ufunc_headers
7]
8
9def aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"):
10return [gencode_pattern.format(name) for name in [
11"UfuncCPU_{}.cpp".format(n)
12for n in aten_ufunc_names
13]]
14
15def aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"):
16return [gencode_pattern.format(name) for name in [
17"UfuncCPUKernel_{}.cpp".format(n)
18for n in aten_ufunc_names
19]]
20
21def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"):
22return [gencode_pattern.format(name) for name in [
23"UfuncCUDA_{}.cu".format(n)
24for n in aten_ufunc_names
25]]
26