colossalai

Форк
0
229 строк · 8.1 Кб
1
import os
2
import re
3
import subprocess
4
import warnings
5
from typing import List
6

7

8
def print_rank_0(message: str) -> None:
9
    """
10
    Print on only one process to avoid spamming.
11
    """
12
    try:
13
        import torch.distributed as dist
14

15
        if not dist.is_initialized():
16
            is_main_rank = True
17
        else:
18
            is_main_rank = dist.get_rank() == 0
19
    except ImportError:
20
        is_main_rank = True
21

22
    if is_main_rank:
23
        print(message)
24

25

26
def get_cuda_version_in_pytorch() -> List[int]:
27
    """
28
    This function returns the CUDA version in the PyTorch build.
29

30
    Returns:
31
        The CUDA version required by PyTorch, in the form of tuple (major, minor).
32
    """
33
    import torch
34

35
    try:
36
        torch_cuda_major = torch.version.cuda.split(".")[0]
37
        torch_cuda_minor = torch.version.cuda.split(".")[1]
38
    except:
39
        raise ValueError(
40
            "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda"
41
        )
42
    return torch_cuda_major, torch_cuda_minor
43

44

45
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
46
    """
47
    Get the System CUDA version from nvcc.
48

49
    Args:
50
        cuda_dir (str): the directory for CUDA Toolkit.
51

52
    Returns:
53
        The CUDA version required by PyTorch, in the form of tuple (major, minor).
54
    """
55
    nvcc_path = os.path.join(cuda_dir, "bin/nvcc")
56

57
    if cuda_dir is None:
58
        raise ValueError(
59
            f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
60
        )
61

62
    # check for nvcc path
63
    if not os.path.exists(nvcc_path):
64
        raise FileNotFoundError(
65
            f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
66
        )
67

68
    # parse the nvcc -v output to obtain the system cuda version
69
    try:
70
        raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
71
        output = raw_output.split()
72
        release_idx = output.index("release") + 1
73
        release = output[release_idx].split(".")
74
        bare_metal_major = release[0]
75
        bare_metal_minor = release[1][0]
76
    except:
77
        raise ValueError(
78
            f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
79
        )
80

81
    return bare_metal_major, bare_metal_minor
82

83

84
def check_system_pytorch_cuda_match(cuda_dir):
85
    bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
86
    torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
87

88
    if bare_metal_major != torch_cuda_major:
89
        raise Exception(
90
            f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "
91
            f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."
92
            "Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."
93
        )
94

95
    if bare_metal_minor != torch_cuda_minor:
96
        warnings.warn(
97
            f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
98
            "The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
99
            "If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
100
        )
101
    return True
102

103

104
def get_pytorch_version() -> List[int]:
105
    """
106
    This functions finds the PyTorch version.
107

108
    Returns:
109
        A tuple of integers in the form of (major, minor, patch).
110
    """
111
    import torch
112

113
    torch_version = torch.__version__.split("+")[0]
114
    TORCH_MAJOR = int(torch_version.split(".")[0])
115
    TORCH_MINOR = int(torch_version.split(".")[1])
116
    TORCH_PATCH = int(torch_version.split(".")[2], 16)
117
    return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
118

119

120
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
121
    """
122
    Compare the current PyTorch version with the minium required version.
123

124
    Args:
125
        min_major_version (int): the minimum major version of PyTorch required
126
        min_minor_version (int): the minimum minor version of PyTorch required
127

128
    Returns:
129
        A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
130
    """
131
    # get pytorch version
132
    torch_major, torch_minor, _ = get_pytorch_version()
133

134
    # if the
135
    if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
136
        raise RuntimeError(
137
            f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
138
            "The latest stable release can be obtained from https://pytorch.org/get-started/locally/"
139
        )
140

141

142
def check_cuda_availability():
143
    """
144
    Check if CUDA is available on the system.
145

146
    Returns:
147
        A boolean value. True if CUDA is available and False otherwise.
148
    """
149
    import torch
150

151
    return torch.cuda.is_available()
152

153

154
def set_cuda_arch_list(cuda_dir):
155
    """
156
    This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
157
    Ahead-of-time compilation occurs when BUILD_EXT=1 is set when running 'pip install'.
158
    """
159
    cuda_available = check_cuda_availability()
160

161
    # we only need to set this when CUDA is not available for cross-compilation
162
    if not cuda_available:
163
        warnings.warn(
164
            "\n[extension]  PyTorch did not find available GPUs on this system.\n"
165
            "If your intention is to cross-compile, this is not an error.\n"
166
            "By default, Colossal-AI will cross-compile for \n"
167
            "1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
168
            "2. Volta (compute capability 7.0)\n"
169
            "3. Turing (compute capability 7.5),\n"
170
            "4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"
171
            "\nIf you wish to cross-compile for a single specific architecture,\n"
172
            'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'
173
        )
174

175
        if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
176
            bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
177

178
            arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]
179

180
            if int(bare_metal_major) == 11:
181
                if int(bare_metal_minor) == 0:
182
                    arch_list.append("8.0")
183
                else:
184
                    arch_list.append("8.0")
185
                    arch_list.append("8.6")
186

187
            arch_list_str = ";".join(arch_list)
188
            os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
189
        return False
190
    return True
191

192

193
def get_cuda_cc_flag() -> List[str]:
194
    """
195
    This function produces the cc flags for your GPU arch
196

197
    Returns:
198
        The CUDA cc flags for compilation.
199
    """
200

201
    # only import torch when needed
202
    # this is to avoid importing torch when building on a machine without torch pre-installed
203
    # one case is to build wheel for pypi release
204
    import torch
205

206
    cc_flag = []
207
    max_arch = "".join(str(i) for i in torch.cuda.get_device_capability())
208
    for arch in torch.cuda.get_arch_list():
209
        res = re.search(r"sm_(\d+)", arch)
210
        if res:
211
            arch_cap = res[1]
212
            if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):
213
                cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"])
214
    return cc_flag
215

216

217
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
218
    """
219
    This function appends the threads flag to your nvcc args.
220

221
    Returns:
222
        The nvcc compilation flags including the threads flag.
223
    """
224
    from torch.utils.cpp_extension import CUDA_HOME
225

226
    bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
227
    if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
228
        return nvcc_extra_args + ["--threads", "4"]
229
    return nvcc_extra_args
230

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.