colossalai
229 строк · 8.1 Кб
1import os2import re3import subprocess4import warnings5from typing import List6
7
8def print_rank_0(message: str) -> None:9"""10Print on only one process to avoid spamming.
11"""
12try:13import torch.distributed as dist14
15if not dist.is_initialized():16is_main_rank = True17else:18is_main_rank = dist.get_rank() == 019except ImportError:20is_main_rank = True21
22if is_main_rank:23print(message)24
25
26def get_cuda_version_in_pytorch() -> List[int]:27"""28This function returns the CUDA version in the PyTorch build.
29
30Returns:
31The CUDA version required by PyTorch, in the form of tuple (major, minor).
32"""
33import torch34
35try:36torch_cuda_major = torch.version.cuda.split(".")[0]37torch_cuda_minor = torch.version.cuda.split(".")[1]38except:39raise ValueError(40"[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda"41)42return torch_cuda_major, torch_cuda_minor43
44
45def get_cuda_bare_metal_version(cuda_dir) -> List[int]:46"""47Get the System CUDA version from nvcc.
48
49Args:
50cuda_dir (str): the directory for CUDA Toolkit.
51
52Returns:
53The CUDA version required by PyTorch, in the form of tuple (major, minor).
54"""
55nvcc_path = os.path.join(cuda_dir, "bin/nvcc")56
57if cuda_dir is None:58raise ValueError(59f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."60)61
62# check for nvcc path63if not os.path.exists(nvcc_path):64raise FileNotFoundError(65f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."66)67
68# parse the nvcc -v output to obtain the system cuda version69try:70raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)71output = raw_output.split()72release_idx = output.index("release") + 173release = output[release_idx].split(".")74bare_metal_major = release[0]75bare_metal_minor = release[1][0]76except:77raise ValueError(78f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"79)80
81return bare_metal_major, bare_metal_minor82
83
84def check_system_pytorch_cuda_match(cuda_dir):85bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)86torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()87
88if bare_metal_major != torch_cuda_major:89raise Exception(90f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "91f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."92"Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."93)94
95if bare_metal_minor != torch_cuda_minor:96warnings.warn(97f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "98"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "99"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"100)101return True102
103
104def get_pytorch_version() -> List[int]:105"""106This functions finds the PyTorch version.
107
108Returns:
109A tuple of integers in the form of (major, minor, patch).
110"""
111import torch112
113torch_version = torch.__version__.split("+")[0]114TORCH_MAJOR = int(torch_version.split(".")[0])115TORCH_MINOR = int(torch_version.split(".")[1])116TORCH_PATCH = int(torch_version.split(".")[2], 16)117return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH118
119
120def check_pytorch_version(min_major_version, min_minor_version) -> bool:121"""122Compare the current PyTorch version with the minium required version.
123
124Args:
125min_major_version (int): the minimum major version of PyTorch required
126min_minor_version (int): the minimum minor version of PyTorch required
127
128Returns:
129A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
130"""
131# get pytorch version132torch_major, torch_minor, _ = get_pytorch_version()133
134# if the135if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):136raise RuntimeError(137f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"138"The latest stable release can be obtained from https://pytorch.org/get-started/locally/"139)140
141
142def check_cuda_availability():143"""144Check if CUDA is available on the system.
145
146Returns:
147A boolean value. True if CUDA is available and False otherwise.
148"""
149import torch150
151return torch.cuda.is_available()152
153
154def set_cuda_arch_list(cuda_dir):155"""156This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
157Ahead-of-time compilation occurs when BUILD_EXT=1 is set when running 'pip install'.
158"""
159cuda_available = check_cuda_availability()160
161# we only need to set this when CUDA is not available for cross-compilation162if not cuda_available:163warnings.warn(164"\n[extension] PyTorch did not find available GPUs on this system.\n"165"If your intention is to cross-compile, this is not an error.\n"166"By default, Colossal-AI will cross-compile for \n"167"1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n"168"2. Volta (compute capability 7.0)\n"169"3. Turing (compute capability 7.5),\n"170"4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"171"\nIf you wish to cross-compile for a single specific architecture,\n"172'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'173)174
175if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:176bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)177
178arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]179
180if int(bare_metal_major) == 11:181if int(bare_metal_minor) == 0:182arch_list.append("8.0")183else:184arch_list.append("8.0")185arch_list.append("8.6")186
187arch_list_str = ";".join(arch_list)188os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str189return False190return True191
192
193def get_cuda_cc_flag() -> List[str]:194"""195This function produces the cc flags for your GPU arch
196
197Returns:
198The CUDA cc flags for compilation.
199"""
200
201# only import torch when needed202# this is to avoid importing torch when building on a machine without torch pre-installed203# one case is to build wheel for pypi release204import torch205
206cc_flag = []207max_arch = "".join(str(i) for i in torch.cuda.get_device_capability())208for arch in torch.cuda.get_arch_list():209res = re.search(r"sm_(\d+)", arch)210if res:211arch_cap = res[1]212if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):213cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"])214return cc_flag215
216
217def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:218"""219This function appends the threads flag to your nvcc args.
220
221Returns:
222The nvcc compilation flags including the threads flag.
223"""
224from torch.utils.cpp_extension import CUDA_HOME225
226bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)227if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:228return nvcc_extra_args + ["--threads", "4"]229return nvcc_extra_args230