gpt-neox

Форк
0
90 строк · 2.9 Кб
1
# Copyright (c) 2024, EleutherAI
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from setuptools import setup, find_packages
16
from torch.utils import cpp_extension
17
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
18
from pathlib import Path
19
import subprocess
20

21

22
def _get_cuda_bare_metal_version(cuda_dir):
23
    raw_output = subprocess.check_output(
24
        [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
25
    )
26
    output = raw_output.split()
27
    release_idx = output.index("release") + 1
28
    release = output[release_idx].split(".")
29
    bare_metal_major = release[0]
30
    bare_metal_minor = release[1][0]
31

32
    return raw_output, bare_metal_major, bare_metal_minor
33

34

35
srcpath = Path(__file__).parent.absolute()
36
cc_flag = []
37
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
38
if int(bare_metal_major) >= 11:
39
    cc_flag.append("-gencode")
40
    cc_flag.append("arch=compute_80,code=sm_80")
41

42
nvcc_flags = [
43
    "-O3",
44
    "-gencode",
45
    "arch=compute_70,code=sm_70",
46
    "--use_fast_math",
47
    "-U__CUDA_NO_HALF_OPERATORS__",
48
    "-U__CUDA_NO_HALF_CONVERSIONS__",
49
    "--expt-relaxed-constexpr",
50
    "--expt-extended-lambda",
51
]
52
cuda_ext_args = {"cxx": ["-O3"], "nvcc": nvcc_flags + cc_flag}
53
layernorm_cuda_args = {
54
    "cxx": ["-O3"],
55
    "nvcc": nvcc_flags + cc_flag + ["-maxrregcount=50"],
56
}
57
setup(
58
    name="fused_kernels",
59
    version="0.0.2",
60
    author="EleutherAI",
61
    author_email="contact@eleuther.ai",
62
    include_package_data=False,
63
    ext_modules=[
64
        CUDAExtension(
65
            name="scaled_upper_triang_masked_softmax_cuda",
66
            sources=[
67
                str(srcpath / "scaled_upper_triang_masked_softmax.cpp"),
68
                str(srcpath / "scaled_upper_triang_masked_softmax_cuda.cu"),
69
            ],
70
            extra_compile_args=cuda_ext_args,
71
        ),
72
        CUDAExtension(
73
            name="scaled_masked_softmax_cuda",
74
            sources=[
75
                str(srcpath / "scaled_masked_softmax.cpp"),
76
                str(srcpath / "scaled_masked_softmax_cuda.cu"),
77
            ],
78
            extra_compile_args=cuda_ext_args,
79
        ),
80
        CUDAExtension(
81
            name="fused_rotary_positional_embedding",
82
            sources=[
83
                str(srcpath / "fused_rotary_positional_embedding.cpp"),
84
                str(srcpath / "fused_rotary_positional_embedding_cuda.cu"),
85
            ],
86
            extra_compile_args=cuda_ext_args,
87
        ),
88
    ],
89
    cmdclass={"build_ext": BuildExtension},
90
)
91

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.