colossalai

Форк
0
/
triton_extension.py 
21 строка · 589.0 Байт
1
from .base_extension import _Extension
2

3
__all__ = ["_TritonExtension"]
4

5

6
class _TritonExtension(_Extension):
7
    def __init__(self, name: str, priority: int = 1):
8
        super().__init__(name, support_aot=False, support_jit=True, priority=priority)
9

10
    def is_hardware_compatible(self) -> bool:
11
        # cuda extension can only be built if cuda is available
12
        try:
13
            import torch
14

15
            cuda_available = torch.cuda.is_available()
16
        except:
17
            cuda_available = False
18
        return cuda_available
19

20
    def load(self):
21
        return self.build_jit()
22

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.