gpt-neox
90 строк · 2.9 Кб
1# Copyright (c) 2024, EleutherAI
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from setuptools import setup, find_packages
16from torch.utils import cpp_extension
17from torch.utils.cpp_extension import BuildExtension, CUDAExtension
18from pathlib import Path
19import subprocess
20
21
22def _get_cuda_bare_metal_version(cuda_dir):
23raw_output = subprocess.check_output(
24[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
25)
26output = raw_output.split()
27release_idx = output.index("release") + 1
28release = output[release_idx].split(".")
29bare_metal_major = release[0]
30bare_metal_minor = release[1][0]
31
32return raw_output, bare_metal_major, bare_metal_minor
33
34
35srcpath = Path(__file__).parent.absolute()
36cc_flag = []
37_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
38if int(bare_metal_major) >= 11:
39cc_flag.append("-gencode")
40cc_flag.append("arch=compute_80,code=sm_80")
41
42nvcc_flags = [
43"-O3",
44"-gencode",
45"arch=compute_70,code=sm_70",
46"--use_fast_math",
47"-U__CUDA_NO_HALF_OPERATORS__",
48"-U__CUDA_NO_HALF_CONVERSIONS__",
49"--expt-relaxed-constexpr",
50"--expt-extended-lambda",
51]
52cuda_ext_args = {"cxx": ["-O3"], "nvcc": nvcc_flags + cc_flag}
53layernorm_cuda_args = {
54"cxx": ["-O3"],
55"nvcc": nvcc_flags + cc_flag + ["-maxrregcount=50"],
56}
57setup(
58name="fused_kernels",
59version="0.0.2",
60author="EleutherAI",
61author_email="contact@eleuther.ai",
62include_package_data=False,
63ext_modules=[
64CUDAExtension(
65name="scaled_upper_triang_masked_softmax_cuda",
66sources=[
67str(srcpath / "scaled_upper_triang_masked_softmax.cpp"),
68str(srcpath / "scaled_upper_triang_masked_softmax_cuda.cu"),
69],
70extra_compile_args=cuda_ext_args,
71),
72CUDAExtension(
73name="scaled_masked_softmax_cuda",
74sources=[
75str(srcpath / "scaled_masked_softmax.cpp"),
76str(srcpath / "scaled_masked_softmax_cuda.cu"),
77],
78extra_compile_args=cuda_ext_args,
79),
80CUDAExtension(
81name="fused_rotary_positional_embedding",
82sources=[
83str(srcpath / "fused_rotary_positional_embedding.cpp"),
84str(srcpath / "fused_rotary_positional_embedding_cuda.cu"),
85],
86extra_compile_args=cuda_ext_args,
87),
88],
89cmdclass={"build_ext": BuildExtension},
90)
91