deepspeed

Форк
0
/
sparse_attn.py 
82 строки · 2.9 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
from .builder import OpBuilder
7

8
try:
9
    from packaging import version as pkg_version
10
except ImportError:
11
    pkg_version = None
12

13

14
class SparseAttnBuilder(OpBuilder):
15
    BUILD_VAR = "DS_BUILD_SPARSE_ATTN"
16
    NAME = "sparse_attn"
17

18
    def __init__(self):
19
        super().__init__(name=self.NAME)
20

21
    def absolute_name(self):
22
        return f'deepspeed.ops.sparse_attention.{self.NAME}_op'
23

24
    def sources(self):
25
        return ['csrc/sparse_attention/utils.cpp']
26

27
    def cxx_args(self):
28
        return ['-O2', '-fopenmp']
29

30
    def is_compatible(self, verbose=True):
31
        # Check to see if llvm and cmake are installed since they are dependencies
32
        #required_commands = ['llvm-config|llvm-config-9', 'cmake']
33
        #command_status = list(map(self.command_exists, required_commands))
34
        #deps_compatible = all(command_status)
35

36
        if self.is_rocm_pytorch():
37
            self.warning(f'{self.NAME} is not compatible with ROCM')
38
            return False
39

40
        try:
41
            import torch
42
        except ImportError:
43
            self.warning(f"unable to import torch, please install it first")
44
            return False
45

46
        # torch-cpu will not have a cuda version
47
        if torch.version.cuda is None:
48
            cuda_compatible = False
49
            self.warning(f"{self.NAME} cuda is not available from torch")
50
        else:
51
            major, minor = torch.version.cuda.split('.')[:2]
52
            cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11)
53
            if not cuda_compatible:
54
                self.warning(f"{self.NAME} requires CUDA version 10.1+")
55

56
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
57
        TORCH_MINOR = int(torch.__version__.split('.')[1])
58
        torch_compatible = (TORCH_MAJOR == 1 and TORCH_MINOR >= 5)
59
        if not torch_compatible:
60
            self.warning(
61
                f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}')
62

63
        try:
64
            import triton
65
        except ImportError:
66
            # auto-install of triton is broken on some systems, reverting to manual install for now
67
            # see this issue: https://github.com/microsoft/DeepSpeed/issues/1710
68
            self.warning(f"please install triton==1.0.0 if you want to use sparse attention")
69
            return False
70

71
        if pkg_version:
72
            installed_triton = pkg_version.parse(triton.__version__)
73
            triton_mismatch = installed_triton != pkg_version.parse("1.0.0")
74
        else:
75
            installed_triton = triton.__version__
76
            triton_mismatch = installed_triton != "1.0.0"
77

78
        if triton_mismatch:
79
            self.warning(f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible")
80
            return False
81

82
        return super().is_compatible(verbose) and torch_compatible and cuda_compatible
83

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.