colossalai

Форк
0
/
cpp_extension.py 
134 строки · 4.6 Кб
1
import importlib
2
import os
3
import time
4
from abc import abstractmethod
5
from pathlib import Path
6
from typing import List
7

8
from .base_extension import _Extension
9

10
__all__ = ["_CppExtension"]
11

12

13
class _CppExtension(_Extension):
14
    def __init__(self, name: str, priority: int = 1):
15
        super().__init__(name, support_aot=True, support_jit=True, priority=priority)
16

17
        # we store the op as an attribute to avoid repeated building and loading
18
        self.cached_op = None
19

20
        # build-related variables
21
        self.prebuilt_module_path = "colossalai._C"
22
        self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
23
        self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
24

25
    def csrc_abs_path(self, path):
26
        return os.path.join(self.relative_to_abs_path("csrc"), path)
27

28
    def relative_to_abs_path(self, code_path: str) -> str:
29
        """
30
        This function takes in a path relative to the colossalai root directory and return the absolute path.
31
        """
32

33
        # get the current file path
34
        # iteratively check the parent directory
35
        # if the parent directory is "extensions", then the current file path is the root directory
36
        # otherwise, the current file path is inside the root directory
37
        current_file_path = Path(__file__)
38
        while True:
39
            if current_file_path.name == "extensions":
40
                break
41
            else:
42
                current_file_path = current_file_path.parent
43
        extension_module_path = current_file_path
44
        code_abs_path = extension_module_path.joinpath(code_path)
45
        return str(code_abs_path)
46

47
    # functions must be overrided over
48
    def strip_empty_entries(self, args):
49
        """
50
        Drop any empty strings from the list of compile and link flags
51
        """
52
        return [x for x in args if len(x) > 0]
53

54
    def import_op(self):
55
        """
56
        This function will import the op module by its string name.
57
        """
58
        return importlib.import_module(self.prebuilt_import_path)
59

60
    def build_aot(self) -> "CppExtension":
61
        from torch.utils.cpp_extension import CppExtension
62

63
        return CppExtension(
64
            name=self.prebuilt_import_path,
65
            sources=self.strip_empty_entries(self.sources_files()),
66
            include_dirs=self.strip_empty_entries(self.include_dirs()),
67
            extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
68
        )
69

70
    def build_jit(self) -> None:
71
        from torch.utils.cpp_extension import load
72

73
        build_directory = _Extension.get_jit_extension_folder_path()
74
        build_directory = Path(build_directory)
75
        build_directory.mkdir(parents=True, exist_ok=True)
76

77
        # check if the kernel has been built
78
        compiled_before = False
79
        kernel_file_path = build_directory.joinpath(f"{self.name}.o")
80
        if kernel_file_path.exists():
81
            compiled_before = True
82

83
        # load the kernel
84
        if compiled_before:
85
            print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
86
        else:
87
            print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
88

89
        build_start = time.time()
90
        op_kernel = load(
91
            name=self.name,
92
            sources=self.strip_empty_entries(self.sources_files()),
93
            extra_include_paths=self.strip_empty_entries(self.include_dirs()),
94
            extra_cflags=self.cxx_flags(),
95
            extra_ldflags=[],
96
            build_directory=str(build_directory),
97
        )
98
        build_duration = time.time() - build_start
99

100
        if compiled_before:
101
            print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
102
        else:
103
            print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
104

105
        return op_kernel
106

107
    # functions must be overrided begin
108
    @abstractmethod
109
    def sources_files(self) -> List[str]:
110
        """
111
        This function should return a list of source files for extensions.
112
        """
113

114
    @abstractmethod
115
    def include_dirs(self) -> List[str]:
116
        """
117
        This function should return a list of include files for extensions.
118
        """
119

120
    @abstractmethod
121
    def cxx_flags(self) -> List[str]:
122
        """
123
        This function should return a list of cxx compilation flags for extensions.
124
        """
125

126
    def load(self):
127
        try:
128
            op_kernel = self.import_op()
129
        except (ImportError, ModuleNotFoundError):
130
            # if import error occurs, it means that the kernel is not pre-built
131
            # so we build it jit
132
            op_kernel = self.build_jit()
133

134
        return op_kernel
135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.