deepspeed
82 строки · 2.9 Кб
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6from .builder import OpBuilder7
8try:9from packaging import version as pkg_version10except ImportError:11pkg_version = None12
13
14class SparseAttnBuilder(OpBuilder):15BUILD_VAR = "DS_BUILD_SPARSE_ATTN"16NAME = "sparse_attn"17
18def __init__(self):19super().__init__(name=self.NAME)20
21def absolute_name(self):22return f'deepspeed.ops.sparse_attention.{self.NAME}_op'23
24def sources(self):25return ['csrc/sparse_attention/utils.cpp']26
27def cxx_args(self):28return ['-O2', '-fopenmp']29
30def is_compatible(self, verbose=True):31# Check to see if llvm and cmake are installed since they are dependencies32#required_commands = ['llvm-config|llvm-config-9', 'cmake']33#command_status = list(map(self.command_exists, required_commands))34#deps_compatible = all(command_status)35
36if self.is_rocm_pytorch():37self.warning(f'{self.NAME} is not compatible with ROCM')38return False39
40try:41import torch42except ImportError:43self.warning(f"unable to import torch, please install it first")44return False45
46# torch-cpu will not have a cuda version47if torch.version.cuda is None:48cuda_compatible = False49self.warning(f"{self.NAME} cuda is not available from torch")50else:51major, minor = torch.version.cuda.split('.')[:2]52cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11)53if not cuda_compatible:54self.warning(f"{self.NAME} requires CUDA version 10.1+")55
56TORCH_MAJOR = int(torch.__version__.split('.')[0])57TORCH_MINOR = int(torch.__version__.split('.')[1])58torch_compatible = (TORCH_MAJOR == 1 and TORCH_MINOR >= 5)59if not torch_compatible:60self.warning(61f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}')62
63try:64import triton65except ImportError:66# auto-install of triton is broken on some systems, reverting to manual install for now67# see this issue: https://github.com/microsoft/DeepSpeed/issues/171068self.warning(f"please install triton==1.0.0 if you want to use sparse attention")69return False70
71if pkg_version:72installed_triton = pkg_version.parse(triton.__version__)73triton_mismatch = installed_triton != pkg_version.parse("1.0.0")74else:75installed_triton = triton.__version__76triton_mismatch = installed_triton != "1.0.0"77
78if triton_mismatch:79self.warning(f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible")80return False81
82return super().is_compatible(verbose) and torch_compatible and cuda_compatible83