colossalai
94 строки · 3.5 Кб
1from ..base_extension import _Extension
2
3
4class FlashAttentionXformersCudaExtension(_Extension):
5def __init__(self):
6super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
7
8def is_hardware_available(self) -> bool:
9# cuda extension can only be built if cuda is available
10try:
11import torch
12
13cuda_available = torch.cuda.is_available()
14except:
15cuda_available = False
16return cuda_available
17
18def assert_hardware_compatible(self) -> bool:
19pass
20
21def build_aot(self) -> None:
22raise NotImplementedError(
23"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
24)
25
26def build_jit(self) -> None:
27raise NotImplementedError(
28"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
29)
30
31def load(self):
32try:
33from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
34from xformers.ops.fmha.attn_bias import (
35BlockDiagonalCausalMask,
36BlockDiagonalMask,
37LowerTriangularMask,
38LowerTriangularMaskWithTensorBias,
39)
40except ImportError:
41raise ModuleNotFoundError(
42(
43"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
44)
45)
46from typing import Optional
47
48import torch
49
50allow_alibi = True
51for op in MemoryEfficientAttentionCutlassOp:
52allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
53
54def mem_eff_attention(
55q: torch.Tensor,
56k: torch.Tensor,
57v: torch.Tensor,
58seq_len_info_q: "SeqLenInfo",
59seq_len_info_kv: "SeqLenInfo",
60origin_attn_mask: Optional[torch.Tensor] = None,
61bias: Optional[torch.Tensor] = None,
62dropout_p: float = 0.0,
63scale: float = None,
64causal: bool = False,
65padded: bool = False,
66):
67attn_bias = None
68if padded: # bert style
69if not causal:
70attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
71else:
72attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
73elif causal: # gpt style
74attn_bias = LowerTriangularMask()
75
76if bias is not None: # alibi / relative position embedding
77assert allow_alibi, "flash attention with bias is not supported in this system."
78assert causal, "attention with bias is only supported for causal attention so far."
79attn_bias = attn_bias.add_bias(bias)
80
81if padded:
82q = q.unsqueeze(0)
83k = k.unsqueeze(0)
84v = v.unsqueeze(0)
85
86out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
87
88# shape: (b*s, n, d)
89if padded:
90out = out.squeeze(0)
91
92return out
93
94return mem_eff_attention
95