skypilot

Форк
0
/
xformers_patch.py 
133 строки · 5.8 Кб
1
# This code is based on lmsys-org/fastchat. Below is the original copyright:
2
#
3
#    Copyright 2023 FastChat authors
4
#    Licensed under the Apache License, Version 2.0 (the "License");
5
#    you may not use this file except in compliance with the License.
6
#    You may obtain a copy of the License at
7
#
8
#        http://www.apache.org/licenses/LICENSE-2.0
9
#
10
#    Unless required by applicable law or agreed to in writing, software
11
#    distributed under the License is distributed on an "AS IS" BASIS,
12
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
#    See the License for the specific language governing permissions and
14
#    limitations under the License.
15
"""
16
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
17
"""
18

19
import logging
20
import math
21
from typing import Optional, Tuple
22

23
import torch
24
from torch import nn
25
import transformers.models.llama.modeling_llama
26

27
try:
28
    import xformers.ops
29
except ImportError:
30
    logging.error(
31
        "xformers not found! Please install it before trying to use it.")
32

33

34
def replace_llama_attn_with_xformers_attn():
35
    transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
36

37

38
def xformers_forward(
39
    self,
40
    hidden_states: torch.Tensor,
41
    attention_mask: Optional[torch.Tensor] = None,
42
    position_ids: Optional[torch.LongTensor] = None,
43
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
44
    output_attentions: bool = False,
45
    use_cache: bool = False,
46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
47
    # pylint: disable=duplicate-code
48
    bsz, q_len, _ = hidden_states.size()
49

50
    query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
51
                                                    self.head_dim).transpose(
52
                                                        1, 2))
53
    key_states = (self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
54
                                                  self.head_dim).transpose(
55
                                                      1, 2))
56
    value_states = (self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
57
                                                    self.head_dim).transpose(
58
                                                        1, 2))
59

60
    kv_seq_len = key_states.shape[-2]
61
    if past_key_value is not None:
62
        kv_seq_len += past_key_value[0].shape[-2]
63
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
64
    (
65
        query_states,
66
        key_states,
67
    ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
68
        query_states, key_states, cos, sin, position_ids)
69
    # [bsz, nh, t, hd]
70

71
    if past_key_value is not None:
72
        # reuse k, v, self_attention
73
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
74
        value_states = torch.cat([past_key_value[1], value_states], dim=2)
75

76
    past_key_value = (key_states, value_states) if use_cache else None
77

78
    # We only apply xformers optimizations if we don't need to output the whole attention matrix
79
    if not output_attentions:
80
        query_states = query_states.transpose(1, 2)
81
        key_states = key_states.transpose(1, 2)
82
        value_states = value_states.transpose(1, 2)
83

84
        # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
85
        # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
86
        if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
87
            # input and output should be of form (bsz, q_len, num_heads, head_dim)
88
            attn_output = xformers.ops.memory_efficient_attention(
89
                query_states, key_states, value_states, attn_bias=None)
90
        else:
91
            # input and output should be of form (bsz, q_len, num_heads, head_dim)
92
            attn_output = xformers.ops.memory_efficient_attention(
93
                query_states,
94
                key_states,
95
                value_states,
96
                attn_bias=xformers.ops.LowerTriangularMask(),
97
            )
98
        attn_weights = None
99
    else:
100
        attn_weights = torch.matmul(query_states, key_states.transpose(
101
            2, 3)) / math.sqrt(self.head_dim)
102

103
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
104
            raise ValueError(
105
                f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
106
                f" {attn_weights.size()}")
107

108
        if attention_mask is not None:
109
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
110
                raise ValueError(
111
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
112
                )
113
            attn_weights = attn_weights + attention_mask
114
            attn_weights = torch.max(
115
                attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
116

117
        # upcast attention to fp32
118
        attn_weights = nn.functional.softmax(attn_weights,
119
                                             dim=-1,
120
                                             dtype=torch.float32).to(
121
                                                 query_states.dtype)
122
        attn_output = torch.matmul(attn_weights, value_states)
123

124
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
125
            raise ValueError(
126
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
127
                f" {attn_output.size()}")
128

129
        attn_output = attn_output.transpose(1, 2)
130

131
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
132
    attn_output = self.o_proj(attn_output)
133
    return attn_output, attn_weights, past_key_value
134

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.