longlm
/
phi_self_extend_patch.py
211 строк · 10.4 Кб
1# transfromers version 4.36.2
2# Should work for 'susnato/phi-2', a hf version of microsfot/phi-2, check the detail in Huggingface Hub.
3# Haven't tested it! ! !
4import math5from typing import Optional, Tuple6from transformers.cache_utils import Cache7import torch8import torch.utils.checkpoint9from torch import nn10
11
12
13
14# Copied from transformers.models.llama.modeling_llama.rotate_half
15def rotate_half(x):16"""Rotates half the hidden dims of the input."""17x1 = x[..., : x.shape[-1] // 2]18x2 = x[..., x.shape[-1] // 2 :]19return torch.cat((-x2, x1), dim=-1)20
21
22# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
23def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):24"""Applies Rotary Position Embedding to the query and key tensors.25
26Args:
27q (`torch.Tensor`): The query tensor.
28k (`torch.Tensor`): The key tensor.
29cos (`torch.Tensor`): The cosine part of the rotary embedding.
30sin (`torch.Tensor`): The sine part of the rotary embedding.
31position_ids (`torch.Tensor`):
32The position indices of the tokens corresponding to the query and key tensors. For example, this can be
33used to pass offsetted position ids when working with a KV-cache.
34unsqueeze_dim (`int`, *optional*, defaults to 1):
35The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
36sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
37that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
38k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
39cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
40the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
41Returns:
42`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
43"""
44cos = cos[position_ids].unsqueeze(unsqueeze_dim)45sin = sin[position_ids].unsqueeze(unsqueeze_dim)46q_embed = (q * cos) + (rotate_half(q) * sin) if q is not None else None47k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None48return q_embed, k_embed49
50def apply_group_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1, group_size_1=2, group_size_2=512):51"""Applies Rotary Position Embedding to the query and key tensors.52
53Args:
54q (`torch.Tensor`): The query tensor.
55k (`torch.Tensor`): The key tensor.
56cos (`torch.Tensor`): The cosine part of the rotary embedding.
57sin (`torch.Tensor`): The sine part of the rotary embedding.
58position_ids (`torch.Tensor`):
59The position indices of the tokens corresponding to the query and key tensors. For example, this can be
60used to pass offsetted position ids when working with a KV-cache.
61unsqueeze_dim (`int`, *optional*, defaults to 1):
62The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
63sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
64that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
65k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
66cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
67the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
68Returns:
69`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
70"""
71q_pos = position_ids//group_size_1 + group_size_2 - group_size_2//group_size_172k_pos = position_ids//group_size_173
74q_cos = cos[q_pos].unsqueeze(unsqueeze_dim)75q_sin = sin[q_pos].unsqueeze(unsqueeze_dim)76k_cos = cos[k_pos].unsqueeze(unsqueeze_dim)77k_sin = sin[k_pos].unsqueeze(unsqueeze_dim)78q_embed = (q * q_cos) + (rotate_half(q) * q_sin) if q is not None else None79k_embed = (k * k_cos) + (rotate_half(k) * k_sin) if k is not None else None80return q_embed, k_embed81
82
83def self_extend_forward(84self,85hidden_states: torch.Tensor,86attention_mask: Optional[torch.Tensor] = None,87position_ids: Optional[torch.LongTensor] = None,88past_key_value: Optional[Cache] = None,89output_attentions: bool = False,90use_cache: bool = False,91group_size_1: int = 2,92group_size_2: int = 512,93) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:94bsz, q_len, _ = hidden_states.size()95
96# [batch_size, seq_length, 3 x hidden_size]97fused_qkv = self.query_key_value(hidden_states)98
99# 3 x [batch_size, seq_length, num_heads, head_dim]100(query_states, key_states, value_states) = self._split_heads(fused_qkv)101
102if self.qk_layernorm:103query_states = self.q_layernorm(query_states)104key_states = self.k_layernorm(key_states)105
106# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]107query_states = query_states.transpose(1, 2)108value_states = value_states.transpose(1, 2)109key_states = key_states.transpose(1, 2)110
111kv_seq_len = key_states.shape[-2]112if past_key_value is not None:113if self.layer_idx is None:114raise ValueError(115f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "116"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "117"with a layer index."118)119kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)120#print(kv_seq_len)121cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)122
123
124if past_key_value is not None:125# Specific to RoPE models with partial rotation126cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}127key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)128# The new Cache class does not support add one more key_states, have to do RoPE computation later.129
130# Partial rotary embedding131query_rot, query_pass = (132query_states[..., : self.rotary_emb.dim],133query_states[..., self.rotary_emb.dim :],134)135key_rot, key_pass = (136key_states[..., : self.rotary_emb.dim],137key_states[..., self.rotary_emb.dim :],138)139# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]140k_pos = torch.arange(kv_seq_len, device=position_ids.device).view(bsz, kv_seq_len)141# need to recompute142q_pos = position_ids143
144neighbor_query_rot, _ = apply_rotary_pos_emb(query_rot, None, cos, sin, q_pos)145_, neighbor_key_rot = apply_rotary_pos_emb(None, key_rot, cos, sin, k_pos)146
147_re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position148group_query_rot, _ = apply_group_rotary_pos_emb(query_rot, None, cos, sin, q_pos, group_size_1=group_size_1, group_size_2=_re_group_size_2)149_, group_key_rot = apply_group_rotary_pos_emb(None, key_rot, cos, sin, k_pos, group_size_1=group_size_1, group_size_2=_re_group_size_2)150
151# [batch_size, seq_length, num_heads, head_dim]152neighbor_query_states = torch.cat((neighbor_query_rot, query_pass), dim=-1)153neighbor_key_states = torch.cat((neighbor_key_rot, key_pass), dim=-1)154
155group_query_states = torch.cat((group_query_rot, query_pass), dim=-1)156group_key_states = torch.cat((group_key_rot, key_pass), dim=-1)157
158
159neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)160group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)161
162
163if neighbor_attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):164raise ValueError(165f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"166f" {neighbor_attn_weights.size()}"167)168
169if attention_mask is not None:170if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):171raise ValueError(172f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"173)174neighbor_attn_weights = neighbor_attn_weights + attention_mask175group_attn_weights = group_attn_weights + attention_mask176
177if q_len == 1:178neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)179neighbor_attention_mask[:, -group_size_2:] = 1180elif q_len == kv_seq_len:181neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)182neighbor_attention_mask = torch.tril(neighbor_attention_mask)183if q_len > group_size_2:184# seq length is larger than group_size_2, should do replacement.185group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))186neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask187else:188raise ValueError("q_len should be 1 or seq_len.")189
190merged_attn_weights = torch.where(neighbor_attention_mask.bool(), neighbor_attn_weights, group_attn_weights) # replace the group attention with neighbor attention within the neighbor window.191# upcast attention to fp32192attn_weights = nn.functional.softmax(merged_attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)193attn_weights = self.attention_dropout(attn_weights)194
195attn_output = torch.matmul(attn_weights, value_states)196
197if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):198raise ValueError(199f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"200f" {attn_output.size()}"201)202
203attn_output = attn_output.transpose(1, 2).contiguous()204attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)205
206attn_output = self.dense(attn_output)207
208if not output_attentions:209attn_weights = None210
211return attn_output, attn_weights, past_key_value212
213