colossalai
90 строк · 4.1 Кб
1from typing import Optional, Tuple
2
3import torch
4import xformers.ops as xops
5from torch import Tensor
6from transformers.models.opt.modeling_opt import OPTAttention
7
8
9# This is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py
10class XOPTAttention(OPTAttention):
11# def _shape(self, tensor: Tensor, seq_len: int, bsz: int):
12# return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
13
14def forward(
15self,
16hidden_states: Tensor,
17key_value_states: Optional[Tensor] = None,
18past_key_value: Optional[Tensor] = None,
19attention_mask: Optional[Tensor] = None,
20layer_head_mask: Optional[Tensor] = None,
21output_attentions: bool = False,
22) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
23if not self.training:
24return super().forward(
25hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
26)
27"""Input shape: Batch x Time x Channel"""
28assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
29assert not output_attentions, "Xformers attention does not support output_attentions"
30
31# if key_value_states are provided this layer is used as a cross-attention layer
32# for the decoder
33is_cross_attention = key_value_states is not None
34
35bsz, tgt_len, _ = hidden_states.size()
36
37# get query proj
38query_states = self.q_proj(hidden_states)
39# get key, value proj
40if is_cross_attention and past_key_value is not None:
41# reuse k,v, cross_attentions
42key_states = past_key_value[0]
43value_states = past_key_value[1]
44elif is_cross_attention:
45# cross_attentions
46key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
47value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
48elif past_key_value is not None:
49# reuse k, v, self_attention
50key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
51value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
52key_states = torch.cat([past_key_value[0], key_states], dim=2)
53value_states = torch.cat([past_key_value[1], value_states], dim=2)
54else:
55# self_attention
56key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
57value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
58
59if self.is_decoder:
60# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
61# Further calls to cross_attention layer can then reuse all cross-attention
62# key/value_states (first "if" case)
63# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
64# all previous decoder key/value_states. Further calls to uni-directional self-attention
65# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
66# if encoder bi-directional self-attention `past_key_value` is always `None`
67past_key_value = (key_states, value_states)
68
69query_states = self._shape(query_states, tgt_len, bsz).transpose(1, 2)
70key_states = key_states.transpose(1, 2)
71value_states = value_states.transpose(1, 2)
72
73attn_output = xops.memory_efficient_attention(
74query_states,
75key_states,
76value_states,
77attn_bias=xops.LowerTriangularMask(),
78p=self.dropout if self.training else 0.0,
79scale=self.scaling,
80)
81
82# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
83# partitioned across GPUs when using tensor-parallelism.
84attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
85
86attn_output = self.out_proj(attn_output)
87
88attn_weights_reshaped = None
89
90return attn_output, attn_weights_reshaped, past_key_value
91