colossalai

Форк
0
/
base_extension.py 
82 строки · 2.5 Кб
1
import hashlib
2
import os
3
from abc import ABC, abstractmethod
4
from typing import Callable, Union
5

6
__all__ = ["_Extension"]
7

8

9
class _Extension(ABC):
10
    def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1):
11
        self._name = name
12
        self._support_aot = support_aot
13
        self._support_jit = support_jit
14
        self.priority = priority
15

16
    @property
17
    def name(self):
18
        return self._name
19

20
    @property
21
    def support_aot(self):
22
        return self._support_aot
23

24
    @property
25
    def support_jit(self):
26
        return self._support_jit
27

28
    @staticmethod
29
    def get_jit_extension_folder_path():
30
        """
31
        Kernels which are compiled during runtime will be stored in the same cache folder for reuse.
32
        The folder is in the path ~/.cache/colossalai/torch_extensions/<cache-folder>.
33
        The name of the <cache-folder> follows a common format:
34
            torch<torch_version_major>.<torch_version_minor>_<device_name><device_version>-<hash>
35

36
        The <hash> suffix is the hash value of the path of the `colossalai` file.
37
        """
38
        import torch
39

40
        import colossalai
41
        from colossalai.accelerator import get_accelerator
42

43
        # get torch version
44
        torch_version_major = torch.__version__.split(".")[0]
45
        torch_version_minor = torch.__version__.split(".")[1]
46

47
        # get device version
48
        device_name = get_accelerator().name
49
        device_version = get_accelerator().get_version()
50

51
        # use colossalai's file path as hash
52
        hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest()
53

54
        # concat
55
        home_directory = os.path.expanduser("~")
56
        extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}"
57
        cache_directory = os.path.join(home_directory, extension_directory)
58
        return cache_directory
59

60
    @abstractmethod
61
    def is_hardware_available(self) -> bool:
62
        """
63
        Check if the hardware required by the kernel is available.
64
        """
65

66
    @abstractmethod
67
    def assert_hardware_compatible(self) -> None:
68
        """
69
        Check if the hardware required by the kernel is compatible.
70
        """
71

72
    @abstractmethod
73
    def build_aot(self) -> Union["CppExtension", "CUDAExtension"]:
74
        pass
75

76
    @abstractmethod
77
    def build_jit(self) -> Callable:
78
        pass
79

80
    @abstractmethod
81
    def load(self) -> Callable:
82
        pass
83

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.