intel-extension-for-pytorch

Форк
0
218 строк · 8.1 Кб
1
import unittest
2
import torch
3
from itertools import product
4
from common_utils import TestCase
5
import intel_extension_for_pytorch as ipex
6

7

8
class FusedROPETester(TestCase):
9
    def setUp(self):
10
        self.batch = 1
11
        self.seq_len = 32
12
        self.max_seq_len = 384
13
        self.head_size = 256
14
        self.num_heads = 16
15
        self.hidden_size = self.head_size * self.num_heads
16
        return super().setUp()
17

18
    def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
19
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
20
        sinusoid_inp = torch.einsum(
21
            "i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq
22
        ).float()
23
        return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
24

25
    def test_rope(self):
26
        def _get_embed_positions(embed_positions, position_ids):
27
            if embed_positions.device != position_ids.device:
28
                embed_positions = embed_positions.to(position_ids.device)
29
                self.embed_positions = embed_positions
30
            return embed_positions.repeat(position_ids.shape[0], 1, 1)
31

32
        def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
33
            x1 = x[:, :, :, ::2]
34
            x2 = x[:, :, :, 1::2]
35
            x = torch.stack((-x2, x1), dim=-1)
36
            return x.flatten(-2)
37

38
        def rotate_half(x):
39
            """Rotates half the hidden dims of the input."""
40
            x1 = x[..., : x.shape[-1] // 2]
41
            x2 = x[..., x.shape[-1] // 2 :]
42
            return torch.cat((-x2, x1), dim=-1)
43

44
        def apply_rotary_pos_emb(
45
            tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, offset: int = 1
46
        ) -> torch.Tensor:
47
            if offset == 1:
48
                sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
49
                cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
50
                return (tensor * cos) + (rotate_every_two(tensor) * sin)
51
            else:
52
                sin = sin[:, :, None, :].repeat(1, 1, 1, 2)
53
                cos = cos[:, :, None, :].repeat(1, 1, 1, 2)
54
                return (tensor * cos) + (rotate_half(tensor) * sin)
55

56
        def func(
57
            input,
58
            embed_positions,
59
            position_ids,
60
            num_heads,
61
            head_size,
62
            offset,
63
            rotary_dim,
64
        ):
65
            return torch.ops.torch_ipex.rotary_position_embedding(
66
                input,
67
                embed_positions,
68
                position_ids,
69
                num_heads,
70
                head_size,
71
                offset,
72
                rotary_dim,
73
            )
74

75
        def hf_forward(
76
            query, key, position_ids, embed_positions, offset=None, rotary_dim=None
77
        ):
78
            embed_positions = _get_embed_positions(embed_positions, position_ids)
79
            sincos = embed_positions.squeeze()[position_ids]
80
            sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
81

82
            if rotary_dim < self.head_size:
83
                k_rot = key[:, :, :, :rotary_dim]
84
                k_pass = key[:, :, :, rotary_dim:]
85

86
                q_rot = query[:, :, :, :rotary_dim]
87
                q_pass = query[:, :, :, rotary_dim:]
88

89
                k_rot = apply_rotary_pos_emb(k_rot, sin, cos, offset)
90
                q_rot = apply_rotary_pos_emb(q_rot, sin, cos, offset)
91

92
                key = torch.cat([k_rot, k_pass], dim=-1)
93
                query = torch.cat([q_rot, q_pass], dim=-1)
94
            else:
95
                key = apply_rotary_pos_emb(key, sin, cos, offset)
96
                query = apply_rotary_pos_emb(query, sin, cos, offset)
97
            return query, key
98

99
        kv_heads = [self.num_heads, self.num_heads // 2]
100
        dtypes = [torch.float32, torch.bfloat16, torch.float16]
101
        position_ids_t = torch.arange(self.seq_len).unsqueeze(0)
102
        position_ids_s = torch.Tensor([0]).to(torch.int64)
103
        model2rope_config = {
104
            "gptj": (64, 1, position_ids_t),
105
            "falcon": (self.head_size, 1, position_ids_s),
106
            "llama": (self.head_size, self.head_size // 2, position_ids_t),
107
            "gpt-neox": (24, 12, position_ids_t),
108
            "chatglm": (64, 1, position_ids_s),
109
            "codegen": (self.head_size, self.head_size // 2, position_ids_t),
110
        }
111
        for rope_config, kv_head, dtype in product(
112
            model2rope_config.values(), kv_heads, dtypes
113
        ):
114
            rotary_dim, offset, position_ids = rope_config
115
            # concat linear output
116
            linear_outs = torch.rand(
117
                self.batch,
118
                self.seq_len,
119
                self.hidden_size + kv_head * 2 * self.head_size,
120
            ).to(dtype)
121

122
            query = (
123
                linear_outs[:, :, : self.hidden_size]
124
                .contiguous()
125
                .view(self.batch, self.seq_len, self.num_heads, self.head_size)
126
            )
127
            key = (
128
                linear_outs[
129
                    :, :, self.hidden_size : self.hidden_size + kv_head * self.head_size
130
                ]
131
                .contiguous()
132
                .view(self.batch, self.seq_len, kv_head, self.head_size)
133
            )
134
            embed_positions = self.create_sinusoidal_positions(2048, rotary_dim)
135
            query_hf, key_hf = hf_forward(
136
                query, key, position_ids_t, embed_positions, offset, rotary_dim
137
            )
138
            # no concat q/k/v
139
            query_ipex_no_concat, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
140
                query,
141
                embed_positions,
142
                position_ids,
143
                self.num_heads,
144
                self.head_size,
145
                offset,
146
                rotary_dim,
147
            )
148
            key_ipex_no_concat, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
149
                key,
150
                embed_positions,
151
                position_ids,
152
                kv_head,
153
                self.head_size,
154
                offset,
155
                rotary_dim,
156
            )
157
            # concat q/k/v qkv_cocat -> ROPE -> (q, k, v)
158
            (
159
                query_ipex,
160
                key_ipex,
161
                value_ipex,
162
            ) = torch.ops.torch_ipex.rotary_position_embedding(
163
                linear_outs,
164
                embed_positions,
165
                position_ids,
166
                self.num_heads,
167
                self.head_size,
168
                offset,
169
                rotary_dim,
170
            )
171

172
            # torch compile with IPEX backend.
173
            torch._dynamo.reset()
174
            ipex._set_compiler_backend("inductor")
175
            func_compile = torch.compile(func, backend="ipex")
176

177
            query_compile_no_concat, _, _ = func_compile(
178
                query,
179
                embed_positions,
180
                position_ids,
181
                self.num_heads,
182
                self.head_size,
183
                offset,
184
                rotary_dim,
185
            )
186
            query_compile, key_compile, value_compile = func_compile(
187
                linear_outs,
188
                embed_positions,
189
                position_ids,
190
                self.num_heads,
191
                self.head_size,
192
                offset,
193
                rotary_dim,
194
            )
195
            prec = 1e-5 if dtype == torch.float32 else 5e-3
196
            self.assertEqual(query_compile_no_concat, query_hf, prec=prec)
197
            self.assertEqual(query_compile, query_hf, prec=prec)
198
            self.assertEqual(key_compile, key_hf, prec=prec)
199
            self.assertEqual(query_hf, query_ipex_no_concat, prec=prec)
200
            self.assertEqual(key_hf, key_ipex_no_concat, prec=prec)
201
            self.assertEqual(query_hf, query_ipex, prec=prec)
202
            self.assertEqual(key_hf, key_ipex, prec=prec)
203
            self.assertEqual(
204
                value_ipex,
205
                linear_outs[:, :, self.hidden_size + kv_head * self.head_size :].view(
206
                    self.batch, self.seq_len, kv_head, self.head_size
207
                ),
208
            )
209
            self.assertEqual(
210
                value_compile,
211
                linear_outs[:, :, self.hidden_size + kv_head * self.head_size :].view(
212
                    self.batch, self.seq_len, kv_head, self.head_size
213
                ),
214
            )
215

216

217
if __name__ == "__main__":
218
    test = unittest.main()
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.