pytorch
1import functools2
3from torch._dynamo.device_interface import get_interface_for_device4
5
6@functools.lru_cache(None)7def has_triton_package() -> bool:8try:9import triton10
11return triton is not None12except ImportError:13return False14
15
16@functools.lru_cache(None)17def has_triton() -> bool:18def cuda_extra_check(device_interface):19return device_interface.Worker.get_device_properties().major >= 720
21triton_supported_devices = {"cuda": cuda_extra_check}22
23def is_device_compatible_with_triton():24for device, extra_check in triton_supported_devices.items():25device_interface = get_interface_for_device(device)26if device_interface.is_available() and extra_check(device_interface):27return True28return False29
30return is_device_compatible_with_triton() and has_triton_package()31