colossalai
134 строки · 4.6 Кб
1import importlib
2import os
3import time
4from abc import abstractmethod
5from pathlib import Path
6from typing import List
7
8from .base_extension import _Extension
9
10__all__ = ["_CppExtension"]
11
12
13class _CppExtension(_Extension):
14def __init__(self, name: str, priority: int = 1):
15super().__init__(name, support_aot=True, support_jit=True, priority=priority)
16
17# we store the op as an attribute to avoid repeated building and loading
18self.cached_op = None
19
20# build-related variables
21self.prebuilt_module_path = "colossalai._C"
22self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
23self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
24
25def csrc_abs_path(self, path):
26return os.path.join(self.relative_to_abs_path("csrc"), path)
27
28def relative_to_abs_path(self, code_path: str) -> str:
29"""
30This function takes in a path relative to the colossalai root directory and return the absolute path.
31"""
32
33# get the current file path
34# iteratively check the parent directory
35# if the parent directory is "extensions", then the current file path is the root directory
36# otherwise, the current file path is inside the root directory
37current_file_path = Path(__file__)
38while True:
39if current_file_path.name == "extensions":
40break
41else:
42current_file_path = current_file_path.parent
43extension_module_path = current_file_path
44code_abs_path = extension_module_path.joinpath(code_path)
45return str(code_abs_path)
46
47# functions must be overrided over
48def strip_empty_entries(self, args):
49"""
50Drop any empty strings from the list of compile and link flags
51"""
52return [x for x in args if len(x) > 0]
53
54def import_op(self):
55"""
56This function will import the op module by its string name.
57"""
58return importlib.import_module(self.prebuilt_import_path)
59
60def build_aot(self) -> "CppExtension":
61from torch.utils.cpp_extension import CppExtension
62
63return CppExtension(
64name=self.prebuilt_import_path,
65sources=self.strip_empty_entries(self.sources_files()),
66include_dirs=self.strip_empty_entries(self.include_dirs()),
67extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
68)
69
70def build_jit(self) -> None:
71from torch.utils.cpp_extension import load
72
73build_directory = _Extension.get_jit_extension_folder_path()
74build_directory = Path(build_directory)
75build_directory.mkdir(parents=True, exist_ok=True)
76
77# check if the kernel has been built
78compiled_before = False
79kernel_file_path = build_directory.joinpath(f"{self.name}.o")
80if kernel_file_path.exists():
81compiled_before = True
82
83# load the kernel
84if compiled_before:
85print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
86else:
87print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
88
89build_start = time.time()
90op_kernel = load(
91name=self.name,
92sources=self.strip_empty_entries(self.sources_files()),
93extra_include_paths=self.strip_empty_entries(self.include_dirs()),
94extra_cflags=self.cxx_flags(),
95extra_ldflags=[],
96build_directory=str(build_directory),
97)
98build_duration = time.time() - build_start
99
100if compiled_before:
101print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
102else:
103print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
104
105return op_kernel
106
107# functions must be overrided begin
108@abstractmethod
109def sources_files(self) -> List[str]:
110"""
111This function should return a list of source files for extensions.
112"""
113
114@abstractmethod
115def include_dirs(self) -> List[str]:
116"""
117This function should return a list of include files for extensions.
118"""
119
120@abstractmethod
121def cxx_flags(self) -> List[str]:
122"""
123This function should return a list of cxx compilation flags for extensions.
124"""
125
126def load(self):
127try:
128op_kernel = self.import_op()
129except (ImportError, ModuleNotFoundError):
130# if import error occurs, it means that the kernel is not pre-built
131# so we build it jit
132op_kernel = self.build_jit()
133
134return op_kernel
135