skypilot

Форк
0
/
flash_attn_patch.py 
117 строк · 5.3 Кб
1
import logging
2
from typing import List, Optional, Tuple
3

4
from einops import rearrange
5
from flash_attn.bert_padding import pad_input
6
from flash_attn.bert_padding import unpad_input
7
# pip3 install "flash-attn>=2.0"
8
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
9
import torch
10
from torch import nn
11
import transformers
12
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
13

14

15
def forward(
16
    self,
17
    hidden_states: torch.Tensor,
18
    attention_mask: Optional[torch.Tensor] = None,
19
    position_ids: Optional[torch.Tensor] = None,
20
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
21
    output_attentions: bool = False,
22
    use_cache: bool = False,
23
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
24
    """Input shape: Batch x Time x Channel
25

26
    attention_mask: [bsz, q_len]
27
    """
28
    bsz, q_len, _ = hidden_states.size()
29

30
    query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
31
                                                    self.head_dim).transpose(
32
                                                        1, 2))
33
    key_states = (self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
34
                                                  self.head_dim).transpose(
35
                                                      1, 2))
36
    value_states = (self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
37
                                                    self.head_dim).transpose(
38
                                                        1, 2))
39
    # [bsz, q_len, nh, hd]
40
    # [bsz, nh, q_len, hd]
41

42
    kv_seq_len = key_states.shape[-2]
43
    assert past_key_value is None, "past_key_value is not supported"
44

45
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
46
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
47
                                                    cos, sin, position_ids)
48
    # [bsz, nh, t, hd]
49
    assert not output_attentions, "output_attentions is not supported"
50
    assert not use_cache, "use_cache is not supported"
51

52
    # Flash attention codes from
53
    # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
54

55
    # transform the data into the format required by flash attention
56
    qkv = torch.stack([query_states, key_states, value_states],
57
                      dim=2)  # [bsz, nh, 3, q_len, hd]
58
    qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]
59
    # We have disabled _prepare_decoder_attention_mask in LlamaModel
60
    # the attention_mask should be the same as the key_padding_mask
61
    key_padding_mask = attention_mask
62

63
    if key_padding_mask is None:
64
        qkv = rearrange(qkv, "b s ... -> (b s) ...")
65
        max_s = q_len
66
        cu_q_lens = torch.arange(0, (bsz + 1) * q_len,
67
                                 step=q_len,
68
                                 dtype=torch.int32,
69
                                 device=qkv.device)
70
        output = flash_attn_varlen_qkvpacked_func(qkv,
71
                                                  cu_q_lens,
72
                                                  max_s,
73
                                                  0.0,
74
                                                  softmax_scale=None,
75
                                                  causal=True)
76
        output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
77
    else:
78
        nheads = qkv.shape[-2]
79
        x = rearrange(qkv, "b s three h d -> b s (three h d)")
80
        x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
81
        x_unpad = rearrange(x_unpad,
82
                            "nnz (three h d) -> nnz three h d",
83
                            three=3,
84
                            h=nheads)
85
        output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad,
86
                                                        cu_q_lens,
87
                                                        max_s,
88
                                                        0.0,
89
                                                        softmax_scale=None,
90
                                                        causal=True)
91
        output = rearrange(
92
            pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices,
93
                      bsz, q_len),
94
            "b s (h d) -> b s h d",
95
            h=nheads,
96
        )
97
    return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
98

99

100
# Disable the transformation of the attention mask in LlamaModel as the flash attention
101
# requires the attention mask to be the same as the key_padding_mask
102
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
103
                                    inputs_embeds, past_key_values_length):
104
    # [bsz, seq_len]
105
    return attention_mask
106

107

108
def replace_llama_attn_with_flash_attn():
109
    cuda_major, cuda_minor = torch.cuda.get_device_capability()
110
    if cuda_major < 8:
111
        logging.warning(
112
            "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
113
            "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
114
        )
115
    transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
116
        _prepare_decoder_attention_mask)
117
    transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.