colossalai

Форк
0
/
cuda_extension.py 
109 строк · 3.8 Кб
1
import os
2
import time
3
from abc import abstractmethod
4
from pathlib import Path
5
from typing import List
6

7
from .base_extension import _Extension
8
from .cpp_extension import _CppExtension
9
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
10

11
__all__ = ["_CudaExtension"]
12

13
# Some constants for installation checks
14
MIN_PYTORCH_VERSION_MAJOR = 1
15
MIN_PYTORCH_VERSION_MINOR = 10
16

17

18
class _CudaExtension(_CppExtension):
19
    @abstractmethod
20
    def nvcc_flags(self) -> List[str]:
21
        """
22
        This function should return a list of nvcc compilation flags for extensions.
23
        """
24

25
    def is_hardware_available(self) -> bool:
26
        # cuda extension can only be built if cuda is available
27
        try:
28
            import torch
29

30
            cuda_available = torch.cuda.is_available()
31
        except:
32
            cuda_available = False
33
        return cuda_available
34

35
    def assert_hardware_compatible(self) -> None:
36
        from torch.utils.cpp_extension import CUDA_HOME
37

38
        if not CUDA_HOME:
39
            raise AssertionError(
40
                "[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"
41
            )
42
        check_system_pytorch_cuda_match(CUDA_HOME)
43
        check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
44

45
    def get_cuda_home_include(self):
46
        """
47
        return include path inside the cuda home.
48
        """
49
        from torch.utils.cpp_extension import CUDA_HOME
50

51
        if CUDA_HOME is None:
52
            raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
53
        cuda_include = os.path.join(CUDA_HOME, "include")
54
        return cuda_include
55

56
    def build_jit(self) -> None:
57
        from torch.utils.cpp_extension import CUDA_HOME, load
58

59
        set_cuda_arch_list(CUDA_HOME)
60

61
        # get build dir
62
        build_directory = _Extension.get_jit_extension_folder_path()
63
        build_directory = Path(build_directory)
64
        build_directory.mkdir(parents=True, exist_ok=True)
65

66
        # check if the kernel has been built
67
        compiled_before = False
68
        kernel_file_path = build_directory.joinpath(f"{self.name}.o")
69
        if kernel_file_path.exists():
70
            compiled_before = True
71

72
        # load the kernel
73
        if compiled_before:
74
            print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
75
        else:
76
            print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
77

78
        build_start = time.time()
79
        op_kernel = load(
80
            name=self.name,
81
            sources=self.strip_empty_entries(self.sources_files()),
82
            extra_include_paths=self.strip_empty_entries(self.include_dirs()),
83
            extra_cflags=self.cxx_flags(),
84
            extra_cuda_cflags=self.nvcc_flags(),
85
            extra_ldflags=[],
86
            build_directory=str(build_directory),
87
        )
88
        build_duration = time.time() - build_start
89

90
        if compiled_before:
91
            print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
92
        else:
93
            print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
94

95
        return op_kernel
96

97
    def build_aot(self) -> "CUDAExtension":
98
        from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension
99

100
        set_cuda_arch_list(CUDA_HOME)
101
        return CUDAExtension(
102
            name=self.prebuilt_import_path,
103
            sources=self.strip_empty_entries(self.sources_files()),
104
            include_dirs=self.strip_empty_entries(self.include_dirs()),
105
            extra_compile_args={
106
                "cxx": self.strip_empty_entries(self.cxx_flags()),
107
                "nvcc": self.strip_empty_entries(self.nvcc_flags()),
108
            },
109
        )
110

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.