pytorch
/
aten.bzl
96 строк · 2.9 Кб
1load("@bazel_skylib//lib:paths.bzl", "paths")
2load("@rules_cc//cc:defs.bzl", "cc_library")
3
4CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"]
5CAPABILITY_COMPILER_FLAGS = {
6"AVX2": ["-mavx2", "-mfma", "-mf16c"],
7"DEFAULT": [],
8}
9
10PREFIX = "aten/src/ATen/native/"
11EXTRA_PREFIX = "aten/src/ATen/"
12
13def intern_build_aten_ops(copts, deps, extra_impls):
14for cpu_capability in CPU_CAPABILITY_NAMES:
15srcs = []
16for impl in native.glob(
17[
18PREFIX + "cpu/*.cpp",
19PREFIX + "quantized/cpu/kernels/*.cpp",
20],
21):
22name = impl.replace(PREFIX, "")
23out = PREFIX + name + "." + cpu_capability + ".cpp"
24native.genrule(
25name = name + "_" + cpu_capability + "_cp",
26srcs = [impl],
27outs = [out],
28cmd = "cp $< $@",
29)
30srcs.append(out)
31
32for impl in extra_impls:
33name = impl.replace(EXTRA_PREFIX, "")
34out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp"
35native.genrule(
36name = name + "_" + cpu_capability + "_cp",
37srcs = [impl],
38outs = [out],
39cmd = "cp $< $@",
40)
41srcs.append(out)
42
43cc_library(
44name = "ATen_CPU_" + cpu_capability,
45srcs = srcs,
46copts = copts + [
47"-DCPU_CAPABILITY=" + cpu_capability,
48"-DCPU_CAPABILITY_" + cpu_capability,
49] + CAPABILITY_COMPILER_FLAGS[cpu_capability],
50deps = deps,
51linkstatic = 1,
52)
53cc_library(
54name = "ATen_CPU",
55deps = [":ATen_CPU_" + cpu_capability for cpu_capability in CPU_CAPABILITY_NAMES],
56linkstatic = 1,
57)
58
59def generate_aten_impl(ctx):
60# Declare the entire ATen/ops/ directory as an output
61ops_dir = ctx.actions.declare_directory("aten/src/ATen/ops")
62outputs = [ops_dir] + ctx.outputs.outs
63
64install_dir = paths.dirname(ops_dir.path)
65tool_inputs, tool_inputs_manifest = ctx.resolve_tools(tools = [ctx.attr.generator])
66ctx.actions.run_shell(
67outputs = outputs,
68inputs = ctx.files.srcs,
69command = ctx.executable.generator.path + " $@",
70arguments = [
71"--source-path",
72"aten/src/ATen",
73"--per-operator-headers",
74"--install_dir",
75install_dir,
76],
77tools = tool_inputs,
78input_manifests = tool_inputs_manifest,
79use_default_shell_env = True,
80mnemonic = "GenerateAten",
81)
82return [DefaultInfo(files = depset(outputs))]
83
84generate_aten = rule(
85implementation = generate_aten_impl,
86attrs = {
87"generator": attr.label(
88executable = True,
89allow_files = True,
90mandatory = True,
91cfg = "exec",
92),
93"outs": attr.output_list(),
94"srcs": attr.label_list(allow_files = True),
95},
96)
97