colossalai
21 строка · 589.0 Байт
1from .base_extension import _Extension
2
3__all__ = ["_TritonExtension"]
4
5
6class _TritonExtension(_Extension):
7def __init__(self, name: str, priority: int = 1):
8super().__init__(name, support_aot=False, support_jit=True, priority=priority)
9
10def is_hardware_compatible(self) -> bool:
11# cuda extension can only be built if cuda is available
12try:
13import torch
14
15cuda_available = torch.cuda.is_available()
16except:
17cuda_available = False
18return cuda_available
19
20def load(self):
21return self.build_jit()
22