colossalai
82 строки · 2.5 Кб
1import hashlib
2import os
3from abc import ABC, abstractmethod
4from typing import Callable, Union
5
6__all__ = ["_Extension"]
7
8
9class _Extension(ABC):
10def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1):
11self._name = name
12self._support_aot = support_aot
13self._support_jit = support_jit
14self.priority = priority
15
16@property
17def name(self):
18return self._name
19
20@property
21def support_aot(self):
22return self._support_aot
23
24@property
25def support_jit(self):
26return self._support_jit
27
28@staticmethod
29def get_jit_extension_folder_path():
30"""
31Kernels which are compiled during runtime will be stored in the same cache folder for reuse.
32The folder is in the path ~/.cache/colossalai/torch_extensions/<cache-folder>.
33The name of the <cache-folder> follows a common format:
34torch<torch_version_major>.<torch_version_minor>_<device_name><device_version>-<hash>
35
36The <hash> suffix is the hash value of the path of the `colossalai` file.
37"""
38import torch
39
40import colossalai
41from colossalai.accelerator import get_accelerator
42
43# get torch version
44torch_version_major = torch.__version__.split(".")[0]
45torch_version_minor = torch.__version__.split(".")[1]
46
47# get device version
48device_name = get_accelerator().name
49device_version = get_accelerator().get_version()
50
51# use colossalai's file path as hash
52hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest()
53
54# concat
55home_directory = os.path.expanduser("~")
56extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}"
57cache_directory = os.path.join(home_directory, extension_directory)
58return cache_directory
59
60@abstractmethod
61def is_hardware_available(self) -> bool:
62"""
63Check if the hardware required by the kernel is available.
64"""
65
66@abstractmethod
67def assert_hardware_compatible(self) -> None:
68"""
69Check if the hardware required by the kernel is compatible.
70"""
71
72@abstractmethod
73def build_aot(self) -> Union["CppExtension", "CUDAExtension"]:
74pass
75
76@abstractmethod
77def build_jit(self) -> Callable:
78pass
79
80@abstractmethod
81def load(self) -> Callable:
82pass
83