longlm

Форк
0
/
llama_self_extend_patch_4_36.py 
188 строк · 9.3 Кб
1
# transfromers version 4.36.2
2
# Haven't done comprehensive test, but it should work.
3
import torch
4
from transformers.models.llama.modeling_llama import *
5
import torch.nn as nn
6
import math
7
from typing import Optional, Tuple
8
import torch.nn.functional as F
9
from transformers.cache_utils import Cache
10

11
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
12
    """
13
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
14
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
15
    """
16
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
17
    if n_rep == 1:
18
        return hidden_states
19
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
20
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
21

22

23
def rotate_half(x):
24
    """Rotates half the hidden dims of the input."""
25
    x1 = x[..., : x.shape[-1] // 2]
26
    x2 = x[..., x.shape[-1] // 2 :]
27
    return torch.cat((-x2, x1), dim=-1)
28

29

30
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
31
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
32
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
33
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
34
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
35
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
36
    q_embed = (q * cos[:,:, -q.shape[2]:]) + (rotate_half(q) * sin[:,:, -q.shape[2]:]) if q is not None else None
37
    k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None    
38
    return q_embed, k_embed
39

40
def apply_grouped_rotary_pos_emb(q, k, cos, sin, position_ids, g_size_1=8, g_size_2=1024):
41
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
42
    position_ids_q = position_ids//g_size_1 + g_size_2 - g_size_2//g_size_1
43
    position_ids_k = position_ids//g_size_1
44

45
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
46
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
47
    cos_q = cos[position_ids_q].unsqueeze(1)  # [bs, 1, seq_len, dim]
48
    sin_q = sin[position_ids_q].unsqueeze(1)  # [bs, 1, seq_len, dim]
49
    cos_k = cos[position_ids_k].unsqueeze(1)  # [bs, 1, seq_len, dim]
50
    sin_k = sin[position_ids_k].unsqueeze(1)  # [bs, 1, seq_len, dim]
51
    q_embed = (q * cos_q) + (rotate_half(q) * sin_q) if q is not None else None 
52
    k_embed = (k * cos_k) + (rotate_half(k) * sin_k) if k is not None else None
53
    return q_embed, k_embed
54

55

56

57
def self_extend_forward(
58
    self,
59
    hidden_states: torch.Tensor,
60
    attention_mask: Optional[torch.Tensor] = None,
61
    position_ids: Optional[torch.LongTensor] = None,
62
    past_key_value: Optional[Cache] = None,
63
    output_attentions: bool = False,
64
    use_cache: bool = False,
65
    group_size_1: Optional[float] = 8,
66
    group_size_2: Optional[float] = 1024,
67
    **kwargs,
68
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
69
    if "padding_mask" in kwargs:
70
        warnings.warn(
71
            "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
72
        )
73

74
    bsz, q_len, _ = hidden_states.size()
75

76
    if self.config.pretraining_tp > 1:
77
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
78
        query_slices = self.q_proj.weight.split(
79
            (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
80
        )
81
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
82
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
83

84
        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
85
        query_states = torch.cat(query_states, dim=-1)
86

87
        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
88
        key_states = torch.cat(key_states, dim=-1)
89

90
        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
91
        value_states = torch.cat(value_states, dim=-1)
92

93
    else:
94
        query_states = self.q_proj(hidden_states)
95
        key_states = self.k_proj(hidden_states)
96
        value_states = self.v_proj(hidden_states)
97

98
    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
99
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
100
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
101

102
    kv_seq_len = key_states.shape[-2]
103
    if past_key_value is not None:
104
        if self.layer_idx is None:
105
            raise ValueError(
106
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
107
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
108
                "with a layer index."
109
            )
110
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
111
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
112

113
    if past_key_value is not None:
114
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
115
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
116

117
    query_position_ids = position_ids
118
    key_position_ids = torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position_ids.device).view(bsz, kv_seq_len)
119

120
        
121
    neighbor_query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin, query_position_ids) 
122
    _, neighbor_key_states = apply_rotary_pos_emb(None, key_states, cos, sin, key_position_ids) 
123
    _re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
124
    group_query_states, _ = apply_grouped_rotary_pos_emb(query_states, None, cos, sin, query_position_ids, g_size_1=group_size_1, g_size_2=_re_group_size_2) 
125
    _, group_key_states = apply_grouped_rotary_pos_emb(None, key_states, cos, sin, key_position_ids, g_size_1=group_size_1, g_size_2=_re_group_size_2) 
126

127

128
    group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
129
    neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
130
    value_states = repeat_kv(value_states, self.num_key_value_groups)
131

132
    neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
133
    group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 
134

135

136
    if group_attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
137
        raise ValueError(
138
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
139
            f" {group_attn_weights.size()}"
140
        )
141
    
142
    if attention_mask is not None:
143
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
144
            raise ValueError(
145
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
146
            )
147
        group_attn_weights = group_attn_weights + attention_mask
148
        neighbor_attn_weights = neighbor_attn_weights + attention_mask
149

150
    if q_len == 1:
151
        neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
152
        neighbor_attention_mask[:, -group_size_2:] = 1
153
    elif q_len == kv_seq_len:
154
        neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
155
        neighbor_attention_mask = torch.tril(neighbor_attention_mask)
156
        if q_len-group_size_2 > 0:
157
            group_attention_mask =  torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
158
            neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
159
    else:
160
        raise ValueError("q_len should be 1 or seq_len.")
161

162
    neighbor_attention_mask = neighbor_attention_mask.bool()
163
    attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
164
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
165
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
166
    attn_output = torch.matmul(attn_weights, value_states)
167

168
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
169
        raise ValueError(
170
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
171
            f" {attn_output.size()}"
172
        )
173

174
    attn_output = attn_output.transpose(1, 2).contiguous()
175

176
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
177

178
    if self.config.pretraining_tp > 1:
179
        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
180
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
181
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
182
    else:
183
        attn_output = self.o_proj(attn_output)
184

185
    if not output_attentions:
186
        attn_weights = None
187

188
    return attn_output, attn_weights, past_key_value
189
        
190

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.