colossalai

Форк
0
/
flash_attention_xformers_cuda.py 
94 строки · 3.5 Кб
1
from ..base_extension import _Extension
2

3

4
class FlashAttentionXformersCudaExtension(_Extension):
5
    def __init__(self):
6
        super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
7

8
    def is_hardware_available(self) -> bool:
9
        # cuda extension can only be built if cuda is available
10
        try:
11
            import torch
12

13
            cuda_available = torch.cuda.is_available()
14
        except:
15
            cuda_available = False
16
        return cuda_available
17

18
    def assert_hardware_compatible(self) -> bool:
19
        pass
20

21
    def build_aot(self) -> None:
22
        raise NotImplementedError(
23
            "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
24
        )
25

26
    def build_jit(self) -> None:
27
        raise NotImplementedError(
28
            "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
29
        )
30

31
    def load(self):
32
        try:
33
            from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
34
            from xformers.ops.fmha.attn_bias import (
35
                BlockDiagonalCausalMask,
36
                BlockDiagonalMask,
37
                LowerTriangularMask,
38
                LowerTriangularMaskWithTensorBias,
39
            )
40
        except ImportError:
41
            raise ModuleNotFoundError(
42
                (
43
                    "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
44
                )
45
            )
46
        from typing import Optional
47

48
        import torch
49

50
        allow_alibi = True
51
        for op in MemoryEfficientAttentionCutlassOp:
52
            allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
53

54
        def mem_eff_attention(
55
            q: torch.Tensor,
56
            k: torch.Tensor,
57
            v: torch.Tensor,
58
            seq_len_info_q: "SeqLenInfo",
59
            seq_len_info_kv: "SeqLenInfo",
60
            origin_attn_mask: Optional[torch.Tensor] = None,
61
            bias: Optional[torch.Tensor] = None,
62
            dropout_p: float = 0.0,
63
            scale: float = None,
64
            causal: bool = False,
65
            padded: bool = False,
66
        ):
67
            attn_bias = None
68
            if padded:  # bert style
69
                if not causal:
70
                    attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
71
                else:
72
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
73
            elif causal:  # gpt style
74
                attn_bias = LowerTriangularMask()
75

76
            if bias is not None:  # alibi / relative position embedding
77
                assert allow_alibi, "flash attention with bias is not supported in this system."
78
                assert causal, "attention with bias is only supported for causal attention so far."
79
                attn_bias = attn_bias.add_bias(bias)
80

81
            if padded:
82
                q = q.unsqueeze(0)
83
                k = k.unsqueeze(0)
84
                v = v.unsqueeze(0)
85

86
            out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
87

88
            # shape: (b*s, n, d)
89
            if padded:
90
                out = out.squeeze(0)
91

92
            return out
93

94
        return mem_eff_attention
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.