pytorch

Форк
0
/
aten.bzl 
93 строки · 2.8 Кб
1
load("@bazel_skylib//lib:paths.bzl", "paths")
2
load("@rules_cc//cc:defs.bzl", "cc_library")
3

4
CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"]
5
CAPABILITY_COMPILER_FLAGS = {
6
    "AVX2": ["-mavx2", "-mfma", "-mf16c"],
7
    "DEFAULT": [],
8
}
9

10
PREFIX = "aten/src/ATen/native/"
11
EXTRA_PREFIX = "aten/src/ATen/"
12

13
def intern_build_aten_ops(copts, deps, extra_impls):
14
    for cpu_capability in CPU_CAPABILITY_NAMES:
15
        srcs = []
16
        for impl in native.glob(
17
            [
18
                PREFIX + "cpu/*.cpp",
19
                PREFIX + "quantized/cpu/kernels/*.cpp",
20
            ],
21
        ):
22
            name = impl.replace(PREFIX, "")
23
            out = PREFIX + name + "." + cpu_capability + ".cpp"
24
            native.genrule(
25
                name = name + "_" + cpu_capability + "_cp",
26
                srcs = [impl],
27
                outs = [out],
28
                cmd = "cp $< $@",
29
            )
30
            srcs.append(out)
31

32
        for impl in extra_impls:
33
            name = impl.replace(EXTRA_PREFIX, "")
34
            out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp"
35
            native.genrule(
36
                name = name + "_" + cpu_capability + "_cp",
37
                srcs = [impl],
38
                outs = [out],
39
                cmd = "cp $< $@",
40
            )
41
            srcs.append(out)
42

43
        cc_library(
44
            name = "ATen_CPU_" + cpu_capability,
45
            srcs = srcs,
46
            copts = copts + [
47
                "-DCPU_CAPABILITY=" + cpu_capability,
48
                "-DCPU_CAPABILITY_" + cpu_capability,
49
            ] + CAPABILITY_COMPILER_FLAGS[cpu_capability],
50
            deps = deps,
51
            linkstatic = 1,
52
        )
53
    cc_library(
54
        name = "ATen_CPU",
55
        deps = [":ATen_CPU_" + cpu_capability for cpu_capability in CPU_CAPABILITY_NAMES],
56
        linkstatic = 1,
57
    )
58

59
def generate_aten_impl(ctx):
60
    # Declare the entire ATen/ops/ directory as an output
61
    ops_dir = ctx.actions.declare_directory("aten/src/ATen/ops")
62
    outputs = [ops_dir] + ctx.outputs.outs
63

64
    install_dir = paths.dirname(ops_dir.path)
65
    ctx.actions.run(
66
        outputs = outputs,
67
        inputs = ctx.files.srcs,
68
        executable = ctx.executable.generator,
69
        arguments = [
70
            "--source-path",
71
            "aten/src/ATen",
72
            "--per-operator-headers",
73
            "--install_dir",
74
            install_dir,
75
        ],
76
        use_default_shell_env = True,
77
        mnemonic = "GenerateAten",
78
    )
79
    return [DefaultInfo(files = depset(outputs))]
80

81
generate_aten = rule(
82
    implementation = generate_aten_impl,
83
    attrs = {
84
        "generator": attr.label(
85
            executable = True,
86
            allow_files = True,
87
            mandatory = True,
88
            cfg = "exec",
89
        ),
90
        "outs": attr.output_list(),
91
        "srcs": attr.label_list(allow_files = True),
92
    },
93
)
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.