colossalai

Форк
0
90 строк · 4.1 Кб
1
from typing import Optional, Tuple
2

3
import torch
4
import xformers.ops as xops
5
from torch import Tensor
6
from transformers.models.opt.modeling_opt import OPTAttention
7

8

9
# This is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py
10
class XOPTAttention(OPTAttention):
11
    # def _shape(self, tensor: Tensor, seq_len: int, bsz: int):
12
    #     return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
13

14
    def forward(
15
        self,
16
        hidden_states: Tensor,
17
        key_value_states: Optional[Tensor] = None,
18
        past_key_value: Optional[Tensor] = None,
19
        attention_mask: Optional[Tensor] = None,
20
        layer_head_mask: Optional[Tensor] = None,
21
        output_attentions: bool = False,
22
    ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
23
        if not self.training:
24
            return super().forward(
25
                hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
26
            )
27
        """Input shape: Batch x Time x Channel"""
28
        assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
29
        assert not output_attentions, "Xformers attention does not support output_attentions"
30

31
        # if key_value_states are provided this layer is used as a cross-attention layer
32
        # for the decoder
33
        is_cross_attention = key_value_states is not None
34

35
        bsz, tgt_len, _ = hidden_states.size()
36

37
        # get query proj
38
        query_states = self.q_proj(hidden_states)
39
        # get key, value proj
40
        if is_cross_attention and past_key_value is not None:
41
            # reuse k,v, cross_attentions
42
            key_states = past_key_value[0]
43
            value_states = past_key_value[1]
44
        elif is_cross_attention:
45
            # cross_attentions
46
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
47
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
48
        elif past_key_value is not None:
49
            # reuse k, v, self_attention
50
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
51
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
52
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
53
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
54
        else:
55
            # self_attention
56
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
57
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
58

59
        if self.is_decoder:
60
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
61
            # Further calls to cross_attention layer can then reuse all cross-attention
62
            # key/value_states (first "if" case)
63
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
64
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
65
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
66
            # if encoder bi-directional self-attention `past_key_value` is always `None`
67
            past_key_value = (key_states, value_states)
68

69
        query_states = self._shape(query_states, tgt_len, bsz).transpose(1, 2)
70
        key_states = key_states.transpose(1, 2)
71
        value_states = value_states.transpose(1, 2)
72

73
        attn_output = xops.memory_efficient_attention(
74
            query_states,
75
            key_states,
76
            value_states,
77
            attn_bias=xops.LowerTriangularMask(),
78
            p=self.dropout if self.training else 0.0,
79
            scale=self.scaling,
80
        )
81

82
        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
83
        # partitioned across GPUs when using tensor-parallelism.
84
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
85

86
        attn_output = self.out_proj(attn_output)
87

88
        attn_weights_reshaped = None
89

90
        return attn_output, attn_weights_reshaped, past_key_value
91

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.