colossalai
109 строк · 3.8 Кб
1import os
2import time
3from abc import abstractmethod
4from pathlib import Path
5from typing import List
6
7from .base_extension import _Extension
8from .cpp_extension import _CppExtension
9from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
10
11__all__ = ["_CudaExtension"]
12
13# Some constants for installation checks
14MIN_PYTORCH_VERSION_MAJOR = 1
15MIN_PYTORCH_VERSION_MINOR = 10
16
17
18class _CudaExtension(_CppExtension):
19@abstractmethod
20def nvcc_flags(self) -> List[str]:
21"""
22This function should return a list of nvcc compilation flags for extensions.
23"""
24
25def is_hardware_available(self) -> bool:
26# cuda extension can only be built if cuda is available
27try:
28import torch
29
30cuda_available = torch.cuda.is_available()
31except:
32cuda_available = False
33return cuda_available
34
35def assert_hardware_compatible(self) -> None:
36from torch.utils.cpp_extension import CUDA_HOME
37
38if not CUDA_HOME:
39raise AssertionError(
40"[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"
41)
42check_system_pytorch_cuda_match(CUDA_HOME)
43check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
44
45def get_cuda_home_include(self):
46"""
47return include path inside the cuda home.
48"""
49from torch.utils.cpp_extension import CUDA_HOME
50
51if CUDA_HOME is None:
52raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
53cuda_include = os.path.join(CUDA_HOME, "include")
54return cuda_include
55
56def build_jit(self) -> None:
57from torch.utils.cpp_extension import CUDA_HOME, load
58
59set_cuda_arch_list(CUDA_HOME)
60
61# get build dir
62build_directory = _Extension.get_jit_extension_folder_path()
63build_directory = Path(build_directory)
64build_directory.mkdir(parents=True, exist_ok=True)
65
66# check if the kernel has been built
67compiled_before = False
68kernel_file_path = build_directory.joinpath(f"{self.name}.o")
69if kernel_file_path.exists():
70compiled_before = True
71
72# load the kernel
73if compiled_before:
74print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
75else:
76print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
77
78build_start = time.time()
79op_kernel = load(
80name=self.name,
81sources=self.strip_empty_entries(self.sources_files()),
82extra_include_paths=self.strip_empty_entries(self.include_dirs()),
83extra_cflags=self.cxx_flags(),
84extra_cuda_cflags=self.nvcc_flags(),
85extra_ldflags=[],
86build_directory=str(build_directory),
87)
88build_duration = time.time() - build_start
89
90if compiled_before:
91print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
92else:
93print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
94
95return op_kernel
96
97def build_aot(self) -> "CUDAExtension":
98from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension
99
100set_cuda_arch_list(CUDA_HOME)
101return CUDAExtension(
102name=self.prebuilt_import_path,
103sources=self.strip_empty_entries(self.sources_files()),
104include_dirs=self.strip_empty_entries(self.include_dirs()),
105extra_compile_args={
106"cxx": self.strip_empty_entries(self.cxx_flags()),
107"nvcc": self.strip_empty_entries(self.nvcc_flags()),
108},
109)
110