longlm
/
llama_self_extend_patch_4_36.py
188 строк · 9.3 Кб
1# transfromers version 4.36.2
2# Haven't done comprehensive test, but it should work.
3import torch4from transformers.models.llama.modeling_llama import *5import torch.nn as nn6import math7from typing import Optional, Tuple8import torch.nn.functional as F9from transformers.cache_utils import Cache10
11def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:12"""13This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
14num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
15"""
16batch, num_key_value_heads, slen, head_dim = hidden_states.shape17if n_rep == 1:18return hidden_states19hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)20return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)21
22
23def rotate_half(x):24"""Rotates half the hidden dims of the input."""25x1 = x[..., : x.shape[-1] // 2]26x2 = x[..., x.shape[-1] // 2 :]27return torch.cat((-x2, x1), dim=-1)28
29
30def apply_rotary_pos_emb(q, k, cos, sin, position_ids):31# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.32cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]33sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]34cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]35sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]36q_embed = (q * cos[:,:, -q.shape[2]:]) + (rotate_half(q) * sin[:,:, -q.shape[2]:]) if q is not None else None37k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None38return q_embed, k_embed39
40def apply_grouped_rotary_pos_emb(q, k, cos, sin, position_ids, g_size_1=8, g_size_2=1024):41# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.42position_ids_q = position_ids//g_size_1 + g_size_2 - g_size_2//g_size_143position_ids_k = position_ids//g_size_144
45cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]46sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]47cos_q = cos[position_ids_q].unsqueeze(1) # [bs, 1, seq_len, dim]48sin_q = sin[position_ids_q].unsqueeze(1) # [bs, 1, seq_len, dim]49cos_k = cos[position_ids_k].unsqueeze(1) # [bs, 1, seq_len, dim]50sin_k = sin[position_ids_k].unsqueeze(1) # [bs, 1, seq_len, dim]51q_embed = (q * cos_q) + (rotate_half(q) * sin_q) if q is not None else None52k_embed = (k * cos_k) + (rotate_half(k) * sin_k) if k is not None else None53return q_embed, k_embed54
55
56
57def self_extend_forward(58self,59hidden_states: torch.Tensor,60attention_mask: Optional[torch.Tensor] = None,61position_ids: Optional[torch.LongTensor] = None,62past_key_value: Optional[Cache] = None,63output_attentions: bool = False,64use_cache: bool = False,65group_size_1: Optional[float] = 8,66group_size_2: Optional[float] = 1024,67**kwargs,68) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:69if "padding_mask" in kwargs:70warnings.warn(71"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"72)73
74bsz, q_len, _ = hidden_states.size()75
76if self.config.pretraining_tp > 1:77key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp78query_slices = self.q_proj.weight.split(79(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=080)81key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)82value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)83
84query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]85query_states = torch.cat(query_states, dim=-1)86
87key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]88key_states = torch.cat(key_states, dim=-1)89
90value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]91value_states = torch.cat(value_states, dim=-1)92
93else:94query_states = self.q_proj(hidden_states)95key_states = self.k_proj(hidden_states)96value_states = self.v_proj(hidden_states)97
98query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)99key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)100value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)101
102kv_seq_len = key_states.shape[-2]103if past_key_value is not None:104if self.layer_idx is None:105raise ValueError(106f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "107"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "108"with a layer index."109)110kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)111cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)112
113if past_key_value is not None:114cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models115key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)116
117query_position_ids = position_ids118key_position_ids = torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position_ids.device).view(bsz, kv_seq_len)119
120
121neighbor_query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin, query_position_ids)122_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, cos, sin, key_position_ids)123_re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position124group_query_states, _ = apply_grouped_rotary_pos_emb(query_states, None, cos, sin, query_position_ids, g_size_1=group_size_1, g_size_2=_re_group_size_2)125_, group_key_states = apply_grouped_rotary_pos_emb(None, key_states, cos, sin, key_position_ids, g_size_1=group_size_1, g_size_2=_re_group_size_2)126
127
128group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)129neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)130value_states = repeat_kv(value_states, self.num_key_value_groups)131
132neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)133group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)134
135
136if group_attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):137raise ValueError(138f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"139f" {group_attn_weights.size()}"140)141
142if attention_mask is not None:143if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):144raise ValueError(145f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"146)147group_attn_weights = group_attn_weights + attention_mask148neighbor_attn_weights = neighbor_attn_weights + attention_mask149
150if q_len == 1:151neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)152neighbor_attention_mask[:, -group_size_2:] = 1153elif q_len == kv_seq_len:154neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)155neighbor_attention_mask = torch.tril(neighbor_attention_mask)156if q_len-group_size_2 > 0:157group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))158neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask159else:160raise ValueError("q_len should be 1 or seq_len.")161
162neighbor_attention_mask = neighbor_attention_mask.bool()163attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)164attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)165attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)166attn_output = torch.matmul(attn_weights, value_states)167
168if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):169raise ValueError(170f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"171f" {attn_output.size()}"172)173
174attn_output = attn_output.transpose(1, 2).contiguous()175
176attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)177
178if self.config.pretraining_tp > 1:179attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)180o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)181attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])182else:183attn_output = self.o_proj(attn_output)184
185if not output_attentions:186attn_weights = None187
188return attn_output, attn_weights, past_key_value189
190