longlm

Форк
0
/
phi_self_extend_patch.py 
211 строк · 10.4 Кб
1
# transfromers version 4.36.2
2
# Should work for 'susnato/phi-2', a hf version of microsfot/phi-2, check the detail in Huggingface Hub. 
3
# Haven't tested it! ! !
4
import math
5
from typing import Optional, Tuple
6
from transformers.cache_utils import Cache
7
import torch
8
import torch.utils.checkpoint
9
from torch import nn
10

11

12

13

14
# Copied from transformers.models.llama.modeling_llama.rotate_half
15
def rotate_half(x):
16
    """Rotates half the hidden dims of the input."""
17
    x1 = x[..., : x.shape[-1] // 2]
18
    x2 = x[..., x.shape[-1] // 2 :]
19
    return torch.cat((-x2, x1), dim=-1)
20

21

22
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
23
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
24
    """Applies Rotary Position Embedding to the query and key tensors.
25

26
    Args:
27
        q (`torch.Tensor`): The query tensor.
28
        k (`torch.Tensor`): The key tensor.
29
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
30
        sin (`torch.Tensor`): The sine part of the rotary embedding.
31
        position_ids (`torch.Tensor`):
32
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
33
            used to pass offsetted position ids when working with a KV-cache.
34
        unsqueeze_dim (`int`, *optional*, defaults to 1):
35
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
36
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
37
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
38
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
39
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
40
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
41
    Returns:
42
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
43
    """
44
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
45
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
46
    q_embed = (q * cos) + (rotate_half(q) * sin) if q is not None else None
47
    k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None
48
    return q_embed, k_embed
49

50
def apply_group_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1, group_size_1=2, group_size_2=512):
51
    """Applies Rotary Position Embedding to the query and key tensors.
52

53
    Args:
54
        q (`torch.Tensor`): The query tensor.
55
        k (`torch.Tensor`): The key tensor.
56
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
57
        sin (`torch.Tensor`): The sine part of the rotary embedding.
58
        position_ids (`torch.Tensor`):
59
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
60
            used to pass offsetted position ids when working with a KV-cache.
61
        unsqueeze_dim (`int`, *optional*, defaults to 1):
62
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
63
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
64
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
65
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
66
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
67
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
68
    Returns:
69
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
70
    """
71
    q_pos = position_ids//group_size_1 + group_size_2 - group_size_2//group_size_1
72
    k_pos = position_ids//group_size_1 
73

74
    q_cos = cos[q_pos].unsqueeze(unsqueeze_dim)
75
    q_sin = sin[q_pos].unsqueeze(unsqueeze_dim)
76
    k_cos = cos[k_pos].unsqueeze(unsqueeze_dim)
77
    k_sin = sin[k_pos].unsqueeze(unsqueeze_dim)
78
    q_embed = (q * q_cos) + (rotate_half(q) * q_sin) if q is not None else None
79
    k_embed = (k * k_cos) + (rotate_half(k) * k_sin) if k is not None else None
80
    return q_embed, k_embed
81

82

83
def self_extend_forward(
84
    self,
85
    hidden_states: torch.Tensor,
86
    attention_mask: Optional[torch.Tensor] = None,
87
    position_ids: Optional[torch.LongTensor] = None,
88
    past_key_value: Optional[Cache] = None,
89
    output_attentions: bool = False,
90
    use_cache: bool = False,
91
    group_size_1: int = 2,
92
    group_size_2: int = 512,
93
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
94
    bsz, q_len, _ = hidden_states.size()
95

96
    # [batch_size, seq_length, 3 x hidden_size]
97
    fused_qkv = self.query_key_value(hidden_states)
98

99
    # 3 x [batch_size, seq_length, num_heads, head_dim]
100
    (query_states, key_states, value_states) = self._split_heads(fused_qkv)
101

102
    if self.qk_layernorm:
103
        query_states = self.q_layernorm(query_states)
104
        key_states = self.k_layernorm(key_states)
105

106
    # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
107
    query_states = query_states.transpose(1, 2)
108
    value_states = value_states.transpose(1, 2)
109
    key_states = key_states.transpose(1, 2)
110

111
    kv_seq_len = key_states.shape[-2]
112
    if past_key_value is not None:
113
        if self.layer_idx is None:
114
            raise ValueError(
115
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
116
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
117
                "with a layer index."
118
            )
119
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
120
        #print(kv_seq_len)
121
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
122

123

124
    if past_key_value is not None:
125
        # Specific to RoPE models with partial rotation
126
        cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
127
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
128
    # The new Cache class does not support add one more key_states, have to do RoPE computation later.
129

130
    # Partial rotary embedding
131
    query_rot, query_pass = (
132
        query_states[..., : self.rotary_emb.dim],
133
        query_states[..., self.rotary_emb.dim :],
134
    )
135
    key_rot, key_pass = (
136
        key_states[..., : self.rotary_emb.dim],
137
        key_states[..., self.rotary_emb.dim :],
138
    )
139
    # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
140
    k_pos = torch.arange(kv_seq_len, device=position_ids.device).view(bsz, kv_seq_len)
141
 # need to recompute 
142
    q_pos = position_ids
143

144
    neighbor_query_rot, _ = apply_rotary_pos_emb(query_rot, None, cos, sin, q_pos)
145
    _, neighbor_key_rot = apply_rotary_pos_emb(None, key_rot, cos, sin, k_pos)
146

147
    _re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
148
    group_query_rot, _ = apply_group_rotary_pos_emb(query_rot, None, cos, sin, q_pos, group_size_1=group_size_1, group_size_2=_re_group_size_2)
149
    _, group_key_rot = apply_group_rotary_pos_emb(None, key_rot, cos, sin, k_pos, group_size_1=group_size_1, group_size_2=_re_group_size_2)
150

151
    # [batch_size, seq_length, num_heads, head_dim]
152
    neighbor_query_states = torch.cat((neighbor_query_rot, query_pass), dim=-1)
153
    neighbor_key_states = torch.cat((neighbor_key_rot, key_pass), dim=-1)
154

155
    group_query_states = torch.cat((group_query_rot, query_pass), dim=-1)
156
    group_key_states = torch.cat((group_key_rot, key_pass), dim=-1)
157

158

159
    neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
160
    group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
161

162

163
    if neighbor_attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
164
        raise ValueError(
165
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
166
            f" {neighbor_attn_weights.size()}"
167
        )
168

169
    if attention_mask is not None:
170
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
171
            raise ValueError(
172
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
173
            )
174
        neighbor_attn_weights = neighbor_attn_weights + attention_mask
175
        group_attn_weights = group_attn_weights + attention_mask
176
     
177
    if q_len == 1:
178
        neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
179
        neighbor_attention_mask[:, -group_size_2:] = 1
180
    elif q_len == kv_seq_len:
181
        neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
182
        neighbor_attention_mask = torch.tril(neighbor_attention_mask)
183
        if q_len > group_size_2:
184
            # seq length is larger than group_size_2, should do replacement. 
185
            group_attention_mask =  torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
186
            neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
187
    else:
188
        raise ValueError("q_len should be 1 or seq_len.")
189
    
190
    merged_attn_weights = torch.where(neighbor_attention_mask.bool(), neighbor_attn_weights, group_attn_weights) # replace the group attention with neighbor attention within the neighbor window. 
191
    # upcast attention to fp32
192
    attn_weights = nn.functional.softmax(merged_attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
193
    attn_weights = self.attention_dropout(attn_weights)
194

195
    attn_output = torch.matmul(attn_weights, value_states)
196

197
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
198
        raise ValueError(
199
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
200
            f" {attn_output.size()}"
201
        )
202

203
    attn_output = attn_output.transpose(1, 2).contiguous()
204
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
205

206
    attn_output = self.dense(attn_output)
207

208
    if not output_attentions:
209
        attn_weights = None
210

211
    return attn_output, attn_weights, past_key_value
212

213

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.