intel-extension-for-pytorch
218 строк · 8.1 Кб
1import unittest
2import torch
3from itertools import product
4from common_utils import TestCase
5import intel_extension_for_pytorch as ipex
6
7
8class FusedROPETester(TestCase):
9def setUp(self):
10self.batch = 1
11self.seq_len = 32
12self.max_seq_len = 384
13self.head_size = 256
14self.num_heads = 16
15self.hidden_size = self.head_size * self.num_heads
16return super().setUp()
17
18def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
19inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
20sinusoid_inp = torch.einsum(
21"i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq
22).float()
23return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
24
25def test_rope(self):
26def _get_embed_positions(embed_positions, position_ids):
27if embed_positions.device != position_ids.device:
28embed_positions = embed_positions.to(position_ids.device)
29self.embed_positions = embed_positions
30return embed_positions.repeat(position_ids.shape[0], 1, 1)
31
32def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
33x1 = x[:, :, :, ::2]
34x2 = x[:, :, :, 1::2]
35x = torch.stack((-x2, x1), dim=-1)
36return x.flatten(-2)
37
38def rotate_half(x):
39"""Rotates half the hidden dims of the input."""
40x1 = x[..., : x.shape[-1] // 2]
41x2 = x[..., x.shape[-1] // 2 :]
42return torch.cat((-x2, x1), dim=-1)
43
44def apply_rotary_pos_emb(
45tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, offset: int = 1
46) -> torch.Tensor:
47if offset == 1:
48sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
49cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
50return (tensor * cos) + (rotate_every_two(tensor) * sin)
51else:
52sin = sin[:, :, None, :].repeat(1, 1, 1, 2)
53cos = cos[:, :, None, :].repeat(1, 1, 1, 2)
54return (tensor * cos) + (rotate_half(tensor) * sin)
55
56def func(
57input,
58embed_positions,
59position_ids,
60num_heads,
61head_size,
62offset,
63rotary_dim,
64):
65return torch.ops.torch_ipex.rotary_position_embedding(
66input,
67embed_positions,
68position_ids,
69num_heads,
70head_size,
71offset,
72rotary_dim,
73)
74
75def hf_forward(
76query, key, position_ids, embed_positions, offset=None, rotary_dim=None
77):
78embed_positions = _get_embed_positions(embed_positions, position_ids)
79sincos = embed_positions.squeeze()[position_ids]
80sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
81
82if rotary_dim < self.head_size:
83k_rot = key[:, :, :, :rotary_dim]
84k_pass = key[:, :, :, rotary_dim:]
85
86q_rot = query[:, :, :, :rotary_dim]
87q_pass = query[:, :, :, rotary_dim:]
88
89k_rot = apply_rotary_pos_emb(k_rot, sin, cos, offset)
90q_rot = apply_rotary_pos_emb(q_rot, sin, cos, offset)
91
92key = torch.cat([k_rot, k_pass], dim=-1)
93query = torch.cat([q_rot, q_pass], dim=-1)
94else:
95key = apply_rotary_pos_emb(key, sin, cos, offset)
96query = apply_rotary_pos_emb(query, sin, cos, offset)
97return query, key
98
99kv_heads = [self.num_heads, self.num_heads // 2]
100dtypes = [torch.float32, torch.bfloat16, torch.float16]
101position_ids_t = torch.arange(self.seq_len).unsqueeze(0)
102position_ids_s = torch.Tensor([0]).to(torch.int64)
103model2rope_config = {
104"gptj": (64, 1, position_ids_t),
105"falcon": (self.head_size, 1, position_ids_s),
106"llama": (self.head_size, self.head_size // 2, position_ids_t),
107"gpt-neox": (24, 12, position_ids_t),
108"chatglm": (64, 1, position_ids_s),
109"codegen": (self.head_size, self.head_size // 2, position_ids_t),
110}
111for rope_config, kv_head, dtype in product(
112model2rope_config.values(), kv_heads, dtypes
113):
114rotary_dim, offset, position_ids = rope_config
115# concat linear output
116linear_outs = torch.rand(
117self.batch,
118self.seq_len,
119self.hidden_size + kv_head * 2 * self.head_size,
120).to(dtype)
121
122query = (
123linear_outs[:, :, : self.hidden_size]
124.contiguous()
125.view(self.batch, self.seq_len, self.num_heads, self.head_size)
126)
127key = (
128linear_outs[
129:, :, self.hidden_size : self.hidden_size + kv_head * self.head_size
130]
131.contiguous()
132.view(self.batch, self.seq_len, kv_head, self.head_size)
133)
134embed_positions = self.create_sinusoidal_positions(2048, rotary_dim)
135query_hf, key_hf = hf_forward(
136query, key, position_ids_t, embed_positions, offset, rotary_dim
137)
138# no concat q/k/v
139query_ipex_no_concat, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
140query,
141embed_positions,
142position_ids,
143self.num_heads,
144self.head_size,
145offset,
146rotary_dim,
147)
148key_ipex_no_concat, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
149key,
150embed_positions,
151position_ids,
152kv_head,
153self.head_size,
154offset,
155rotary_dim,
156)
157# concat q/k/v qkv_cocat -> ROPE -> (q, k, v)
158(
159query_ipex,
160key_ipex,
161value_ipex,
162) = torch.ops.torch_ipex.rotary_position_embedding(
163linear_outs,
164embed_positions,
165position_ids,
166self.num_heads,
167self.head_size,
168offset,
169rotary_dim,
170)
171
172# torch compile with IPEX backend.
173torch._dynamo.reset()
174ipex._set_compiler_backend("inductor")
175func_compile = torch.compile(func, backend="ipex")
176
177query_compile_no_concat, _, _ = func_compile(
178query,
179embed_positions,
180position_ids,
181self.num_heads,
182self.head_size,
183offset,
184rotary_dim,
185)
186query_compile, key_compile, value_compile = func_compile(
187linear_outs,
188embed_positions,
189position_ids,
190self.num_heads,
191self.head_size,
192offset,
193rotary_dim,
194)
195prec = 1e-5 if dtype == torch.float32 else 5e-3
196self.assertEqual(query_compile_no_concat, query_hf, prec=prec)
197self.assertEqual(query_compile, query_hf, prec=prec)
198self.assertEqual(key_compile, key_hf, prec=prec)
199self.assertEqual(query_hf, query_ipex_no_concat, prec=prec)
200self.assertEqual(key_hf, key_ipex_no_concat, prec=prec)
201self.assertEqual(query_hf, query_ipex, prec=prec)
202self.assertEqual(key_hf, key_ipex, prec=prec)
203self.assertEqual(
204value_ipex,
205linear_outs[:, :, self.hidden_size + kv_head * self.head_size :].view(
206self.batch, self.seq_len, kv_head, self.head_size
207),
208)
209self.assertEqual(
210value_compile,
211linear_outs[:, :, self.hidden_size + kv_head * self.head_size :].view(
212self.batch, self.seq_len, kv_head, self.head_size
213),
214)
215
216
217if __name__ == "__main__":
218test = unittest.main()
219