pytorch

Форк
0
/
_triton.py 
30 строк · 873.0 Байт
1
import functools
2

3
from torch._dynamo.device_interface import get_interface_for_device
4

5

6
@functools.lru_cache(None)
7
def has_triton_package() -> bool:
8
    try:
9
        import triton
10

11
        return triton is not None
12
    except ImportError:
13
        return False
14

15

16
@functools.lru_cache(None)
17
def has_triton() -> bool:
18
    def cuda_extra_check(device_interface):
19
        return device_interface.Worker.get_device_properties().major >= 7
20

21
    triton_supported_devices = {"cuda": cuda_extra_check}
22

23
    def is_device_compatible_with_triton():
24
        for device, extra_check in triton_supported_devices.items():
25
            device_interface = get_interface_for_device(device)
26
            if device_interface.is_available() and extra_check(device_interface):
27
                return True
28
        return False
29

30
    return is_device_compatible_with_triton() and has_triton_package()
31

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.