intel-extension-for-pytorch
936 строк · 36.2 Кб
1import unittest
2
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6import intel_extension_for_pytorch as ipex
7import math
8import copy
9from common_utils import TestCase
10
11
12# (from Diffusers 0.12.1)
13class SD_MHA_Model_v1(nn.Module):
14def __init__(self, scale, num_heads, weightsize, hiddensize):
15super(SD_MHA_Model_v1, self).__init__()
16self.scale = scale
17self.heads = num_heads
18self.weightsize = weightsize
19self.hiddensize = hiddensize
20self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
21self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
22self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
23
24def batch_to_head_dim(self, tensor):
25head_size = self.heads
26batch_size, seq_len, dim = tensor.shape
27tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
28tensor = tensor.permute(0, 2, 1, 3).reshape(
29batch_size // head_size, seq_len, dim * head_size
30)
31return tensor
32
33def head_to_batch_dim(self, tensor):
34head_size = self.heads
35batch_size, seq_len, dim = tensor.shape
36tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
37tensor = tensor.permute(0, 2, 1, 3).reshape(
38batch_size * head_size, seq_len, dim // head_size
39)
40return tensor
41
42def get_attention_scores(self, query, key):
43dtype = query.dtype
44attention_scores = torch.baddbmm(
45torch.empty(
46query.shape[0],
47query.shape[1],
48key.shape[1],
49dtype=query.dtype,
50device=query.device,
51),
52query,
53key.transpose(-1, -2),
54beta=0,
55alpha=self.scale,
56)
57attention_probs = attention_scores.softmax(dim=-1)
58attention_probs = attention_probs.to(dtype)
59return attention_probs
60
61def forward(self, x):
62query = self.query(x)
63query = self.head_to_batch_dim(query)
64key = self.key(x)
65key = self.head_to_batch_dim(key)
66value = self.value(x)
67value = self.head_to_batch_dim(value)
68attention_probs = self.get_attention_scores(query, key)
69hidden_states = torch.bmm(attention_probs, value)
70output = self.batch_to_head_dim(hidden_states)
71return output
72
73
74# (from Diffusers 0.12.1)
75class SD_MHA_Model_v2(nn.Module):
76def __init__(self, scale, num_heads, weightsize, hiddensize):
77super(SD_MHA_Model_v2, self).__init__()
78self.scale = scale
79self.heads = num_heads
80self.weightsize = weightsize
81self.hiddensize = hiddensize
82self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
83self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
84self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
85
86def batch_to_head_dim(self, tensor):
87head_size = self.heads
88batch_size, seq_len, dim = tensor.shape
89tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
90tensor = tensor.permute(0, 2, 1, 3).reshape(
91batch_size // head_size, seq_len, dim * head_size
92)
93return tensor
94
95def head_to_batch_dim(self, tensor):
96head_size = self.heads
97batch_size, seq_len, dim = tensor.shape
98tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
99tensor = tensor.permute(0, 2, 1, 3).reshape(
100batch_size * head_size, seq_len, dim // head_size
101)
102return tensor
103
104def get_attention_scores(self, query, key):
105dtype = query.dtype
106attention_scores = torch.baddbmm(
107torch.empty(
108query.shape[0],
109query.shape[1],
110key.shape[1],
111dtype=query.dtype,
112device=query.device,
113),
114query,
115key.transpose(-1, -2),
116beta=0,
117alpha=self.scale,
118)
119attention_probs = attention_scores.softmax(dim=-1)
120attention_probs = attention_probs.to(dtype)
121return attention_probs
122
123def forward(self, x, y):
124query = self.query(x)
125query = self.head_to_batch_dim(query)
126key = self.key(y)
127key = self.head_to_batch_dim(key)
128value = self.value(y)
129value = self.head_to_batch_dim(value)
130attention_probs = self.get_attention_scores(query, key)
131hidden_states = torch.bmm(attention_probs, value)
132output = self.batch_to_head_dim(hidden_states)
133return output
134
135
136# (from Diffusers 0.13)
137class SD_MHA_Model_v3(nn.Module):
138def __init__(self, num_heads, weightsize, hiddensize):
139super(SD_MHA_Model_v3, self).__init__()
140self.heads = num_heads
141self.weightsize = weightsize
142self.hiddensize = hiddensize
143self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
144self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
145self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
146
147def forward(self, x):
148query = self.query(x)
149key = self.key(x)
150value = self.value(x)
151batch_size, sequence_length, inner_dim = x.shape
152head_dim = inner_dim // self.heads
153query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
154key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
155value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
156hidden_states = F.scaled_dot_product_attention(
157query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
158)
159hidden_states = hidden_states.transpose(1, 2).reshape(
160batch_size, -1, self.heads * head_dim
161)
162output = hidden_states.to(query.dtype)
163return output
164
165
166# (from Diffusers 0.13)
167class SD_MHA_Model_scale_v3(nn.Module):
168def __init__(self, num_heads, weightsize, hiddensize, scale):
169super(SD_MHA_Model_scale_v3, self).__init__()
170self.heads = num_heads
171self.weightsize = weightsize
172self.hiddensize = hiddensize
173self.scale = scale
174self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
175self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
176self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
177
178def forward(self, x):
179query = self.query(x)
180key = self.key(x)
181value = self.value(x)
182batch_size, sequence_length, inner_dim = x.shape
183head_dim = inner_dim // self.heads
184query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
185key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
186value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
187hidden_states = F.scaled_dot_product_attention(
188query,
189key,
190value,
191attn_mask=None,
192dropout_p=0.0,
193is_causal=False,
194scale=self.scale,
195)
196hidden_states = hidden_states.transpose(1, 2).reshape(
197batch_size, -1, self.heads * head_dim
198)
199output = hidden_states.to(query.dtype)
200return output
201
202
203# (from Diffusers 0.13)
204class SD_MHA_Model_v4(nn.Module):
205def __init__(self, num_heads, weightsize, hiddensize):
206super(SD_MHA_Model_v4, self).__init__()
207self.heads = num_heads
208self.weightsize = weightsize
209self.hiddensize = hiddensize
210self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
211self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
212self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
213
214def forward(self, x, y):
215query = self.query(x)
216key = self.key(y)
217value = self.value(y)
218batch_size, sequence_length, inner_dim = x.shape
219head_dim = inner_dim // self.heads
220query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
221key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
222value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
223hidden_states = F.scaled_dot_product_attention(
224query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
225)
226hidden_states = hidden_states.transpose(1, 2).reshape(
227batch_size, -1, self.heads * head_dim
228)
229output = hidden_states.to(query.dtype)
230return output
231
232
233# (from Diffusers 0.13)
234class SD_MHA_Model_scale_v4(nn.Module):
235def __init__(self, num_heads, weightsize, hiddensize, scale):
236super(SD_MHA_Model_scale_v4, self).__init__()
237self.heads = num_heads
238self.weightsize = weightsize
239self.hiddensize = hiddensize
240self.scale = scale
241self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
242self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
243self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
244
245def forward(self, x, y):
246query = self.query(x)
247key = self.key(y)
248value = self.value(y)
249batch_size, sequence_length, inner_dim = x.shape
250head_dim = inner_dim // self.heads
251query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
252key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
253value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
254hidden_states = F.scaled_dot_product_attention(
255query,
256key,
257value,
258attn_mask=None,
259dropout_p=0.0,
260is_causal=False,
261scale=self.scale,
262)
263hidden_states = hidden_states.transpose(1, 2).reshape(
264batch_size, -1, self.heads * head_dim
265)
266output = hidden_states.to(query.dtype)
267return output
268
269
270# (Fake Diffusers Model - Fall back to ipex::mha_scores_calc)
271class Fake_SD_MHA_Model(nn.Module):
272def __init__(self, dim_per_head, softmax_dim=-1):
273super(Fake_SD_MHA_Model, self).__init__()
274self.softmax = nn.Softmax(dim=softmax_dim)
275self.dim_per_head = dim_per_head
276
277def forward(self, mat1, mat2, mat3, bias):
278mat1 = mat1 / math.sqrt(self.dim_per_head)
279qk = torch.matmul(mat1, mat2.transpose(2, 3))
280scores = self.softmax(qk + bias)
281output = torch.matmul(scores, mat3)
282return output
283
284
285class MHA_Model_BERT(nn.Module):
286def __init__(self, scale, num_heads, head_dims, permute_idx, trans_a, trans_b):
287super(MHA_Model_BERT, self).__init__()
288self.scale = scale
289self.num_heads = num_heads
290self.head_dims = head_dims
291self.embed_dims = self.num_heads * self.head_dims
292self.query = nn.Linear(self.embed_dims, self.embed_dims, bias=True)
293self.key = nn.Linear(self.embed_dims, self.embed_dims, bias=True)
294self.value = nn.Linear(self.embed_dims, self.embed_dims, bias=True)
295self.permute_idx = permute_idx
296self.trans_a = trans_a
297self.trans_b = trans_b
298
299def transpose_for_scores(self, x):
300new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dims)
301x = x.view(new_x_shape)
302return x.permute(self.permute_idx)
303
304def forward(self, x, mask):
305query_layer = self.transpose_for_scores(self.query(x))
306key_layer = self.transpose_for_scores(self.key(x)).transpose(
307self.trans_a, self.trans_b
308)
309value_layer = self.transpose_for_scores(self.value(x))
310attention_scores = torch.matmul(query_layer, key_layer) / self.scale + mask
311attention_probs = nn.functional.softmax(attention_scores, dim=-1)
312context_layer = torch.matmul(attention_probs, value_layer)
313context_layer = context_layer.permute(self.permute_idx).contiguous()
314new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dims,)
315context_layer = context_layer.view(new_context_layer_shape)
316
317return context_layer
318
319
320class MHA_Model_Distil(nn.Module):
321def __init__(
322self,
323scale,
324num_heads,
325head_dims,
326trans_a,
327trans_b,
328trans_c,
329fill_value=-float("inf"),
330):
331super(MHA_Model_Distil, self).__init__()
332self.scale = scale
333self.n_head = num_heads
334self.head_dims = head_dims
335self.dim = self.n_head * self.head_dims
336self.q_lin = nn.Linear(self.dim, self.dim, bias=True)
337self.k_lin = nn.Linear(self.dim, self.dim, bias=True)
338self.v_lin = nn.Linear(self.dim, self.dim, bias=True)
339self.trans_a = trans_a
340self.trans_b = trans_b
341self.trans_c = trans_c
342self.fill_value = fill_value
343
344def forward(self, x, mask):
345bs, q_length, dim = x.size()
346k_length = x.size(1)
347
348def shape(x: torch.Tensor) -> torch.Tensor:
349"""separate heads"""
350return x.view(bs, -1, self.n_head, self.head_dims).transpose(
351self.trans_a, self.trans_b
352)
353
354def unshape(x: torch.Tensor) -> torch.Tensor:
355"""group heads"""
356return (
357x.transpose(self.trans_a, self.trans_b)
358.contiguous()
359.view(bs, -1, self.n_head * self.head_dims)
360)
361
362q = shape(self.q_lin(x))
363k = shape(self.k_lin(x))
364v = shape(self.v_lin(x))
365mask_reshp = (bs, 1, 1, k_length)
366q = q / self.scale
367scores = torch.matmul(q, k.transpose(self.trans_b, self.trans_c))
368mask = (mask == 0).view(mask_reshp).expand_as(scores)
369scores = scores.masked_fill(mask, self.fill_value)
370weights = nn.functional.softmax(scores, dim=-1)
371context = torch.matmul(weights, v)
372context_layer = unshape(context)
373
374return context_layer
375
376
377class MHA_Model_ViT(nn.Module):
378def __init__(
379self,
380scale,
381num_heads,
382head_dims,
383permute_idx,
384trans_a,
385trans_b,
386select_a,
387select_b,
388):
389super(MHA_Model_ViT, self).__init__()
390self.scale = 1.0 / scale
391self.num_heads = num_heads
392self.head_dims = head_dims
393self.embed_dims = self.num_heads * self.head_dims
394self.qkv = nn.Linear(self.embed_dims, self.embed_dims * 3, bias=True)
395self.permute_idx = permute_idx
396self.trans_a = trans_a
397self.trans_b = trans_b
398self.select_a = select_a
399self.select_b = select_b
400
401def forward(self, x):
402B, N, _ = x.shape
403qkv = (
404self.qkv(x)
405.reshape(B, N, 3, self.num_heads, self.head_dims)
406.permute(self.permute_idx)
407)
408q, k, v = qkv[0], qkv[self.select_a], qkv[self.select_b]
409attn = (q @ k.transpose(self.trans_a, self.trans_b)) * self.scale
410attn = attn.softmax(dim=-1)
411context_layer = (
412(attn @ v)
413.transpose(self.select_a, self.select_b)
414.reshape(B, N, self.embed_dims)
415)
416
417return context_layer
418
419
420bs = [5, 3, 11]
421seq = [128, 384, 31]
422scales = [8, 13, 21]
423num_heads = [12, 16, 29]
424head_dims = [64, 96, 17]
425
426
427# In this UT case, "+15" is desgined to trigger the overflow of SoftMax when using pos_FLT_MIN.
428# Since the input values are very large for the BMM and SoftMax, the resulting accumulations of MHA
429# result will also be large, thus the tolerance value should be set to 1.5e-0 for such case.
430class TransFreeMHATester(TestCase):
431def sd_mha_bf16_common(self, model, mat1, mat2=None):
432for neg_FLT_MIN in [True, False]:
433sd_mha_model = copy.deepcopy(model)
434if mat2 is not None:
435inputs = (
436(mat1.to(torch.bfloat16), mat2.to(torch.bfloat16))
437if not neg_FLT_MIN
438else (
439(mat1 + 15).to(torch.bfloat16),
440(mat2 + 15).to(torch.bfloat16),
441)
442)
443else:
444inputs = (
445(mat1.to(torch.bfloat16),)
446if not neg_FLT_MIN
447else ((mat1 + 15).to(torch.bfloat16),)
448)
449mha_ipex = ipex.optimize(sd_mha_model, dtype=torch.bfloat16, level="O1")
450with torch.cpu.amp.autocast(), torch.no_grad():
451mha_ipex = torch.jit.trace(mha_ipex, inputs)
452mha_ipex = torch.jit.freeze(mha_ipex)
453
454for _ in range(2):
455mha_jit = mha_ipex(*inputs)
456mha_ref = sd_mha_model(*inputs)
457self.assertEqual(mha_ref, mha_jit, prec=1.5e-0 if neg_FLT_MIN else 1e-2)
458
459mha_graph = mha_ipex.graph_for(*inputs)
460self.assertTrue(
461any(n.kind() == "ipex::sd_flash_mha" for n in mha_graph.nodes())
462)
463
464def test_sd_mha_bf16_v1(self):
465mat = torch.randn(2, 4096, 320)
466sd_mha_model = SD_MHA_Model_v1(0.3, 8, 320, 320).eval()
467self.sd_mha_bf16_common(sd_mha_model, mat)
468
469def test_sd_mha_bf16_v2(self):
470mat1 = torch.randn(2, 4096, 320)
471mat2 = torch.randn(2, 77, 320)
472sd_mha_model = SD_MHA_Model_v2(0.3, 8, 320, 320).eval()
473self.sd_mha_bf16_common(sd_mha_model, mat1, mat2)
474
475# def test_sd_mha_bf16_v3(self):
476# mat = torch.randn(2, 4096, 320)
477# sd_mha_model = SD_MHA_Model_v3(8, 320, 320).eval()
478# self.sd_mha_bf16_common(sd_mha_model, mat)
479
480# def test_sd_mha_bf16_scale_v3(self):
481# mat = torch.randn(2, 4096, 320)
482# sd_mha_model = SD_MHA_Model_scale_v3(8, 320, 320, 0.3).eval()
483# self.sd_mha_bf16_common(sd_mha_model, mat)
484
485# def test_sd_mha_bf16_v4(self):
486# mat1 = torch.randn(2, 4096, 320)
487# mat2 = torch.randn(2, 77, 320)
488# sd_mha_model = SD_MHA_Model_v4(8, 320, 320).eval()
489# self.sd_mha_bf16_common(sd_mha_model, mat1, mat2)
490
491# def test_sd_mha_bf16_scale_v4(self):
492# mat1 = torch.randn(2, 4096, 320)
493# mat2 = torch.randn(2, 77, 320)
494# sd_mha_model = SD_MHA_Model_scale_v4(8, 320, 320, 0.11).eval()
495# self.sd_mha_bf16_common(sd_mha_model, mat1, mat2)
496
497def test_fake_sd_mha_bf16(self):
498mat1 = (torch.randn(1, 2, 64, 64) + 20).to(torch.bfloat16)
499mat2 = (torch.randn(1, 2, 64, 64) - 20).to(torch.bfloat16)
500mat3 = torch.randn(1, 2, 64, 64).to(torch.bfloat16)
501mask = (torch.ones(1, 1, 1, 64)).to(torch.bfloat16)
502fake_sd_mha_model = Fake_SD_MHA_Model(64, -1).eval()
503fake_mha_ipex = ipex.optimize(
504fake_sd_mha_model, dtype=torch.bfloat16, level="O1"
505)
506
507with torch.cpu.amp.autocast(), torch.no_grad():
508fake_mha_ipex = torch.jit.trace(
509fake_mha_ipex,
510(
511mat1,
512mat2,
513mat3,
514mask,
515),
516)
517fake_mha_ipex = torch.jit.freeze(fake_mha_ipex)
518
519for _ in range(2):
520fake_mha_jit = fake_mha_ipex(mat1, mat2, mat3, mask)
521fake_mha_ref = fake_sd_mha_model(mat1, mat2, mat3, mask)
522self.assertEqual(fake_mha_ref, fake_mha_jit, prec=1e-1)
523
524fake_mha_graph = fake_mha_ipex.graph_for(mat1, mat2, mat3, mask)
525self.assertTrue(
526any(n.kind() == "ipex::mha_scores_calc" for n in fake_mha_graph.nodes())
527)
528
529def test_transfree_mha_bf16(self):
530for i in range(len(bs)):
531mat = torch.randn(bs[i], seq[i], num_heads[i] * head_dims[i]).to(
532torch.bfloat16
533)
534mask_base = torch.randn(bs[i], 1, 1, seq[i]).to(torch.bfloat16)
535mask_distil = torch.randn(bs[i], seq[i]).to(torch.bfloat16)
536
537mha_model = MHA_Model_BERT(
538scales[i], num_heads[i], head_dims[i], [0, 2, 1, 3], -1, -2
539).eval()
540mha_ipex = ipex.optimize(mha_model, dtype=torch.bfloat16, level="O1")
541
542vit_mha_model = MHA_Model_ViT(
543scales[i], num_heads[i], head_dims[i], [2, 0, 3, 1, 4], -2, -1, 1, 2
544).eval()
545vit_mha_ipex = ipex.optimize(
546vit_mha_model, dtype=torch.bfloat16, level="O1"
547)
548
549with torch.cpu.amp.autocast(), torch.no_grad():
550mha_ipex = torch.jit.trace(
551mha_ipex,
552(
553mat,
554mask_base,
555),
556)
557mha_ipex = torch.jit.freeze(mha_ipex)
558
559vit_mha_ipex = torch.jit.trace(vit_mha_ipex, (mat,))
560vit_mha_ipex = torch.jit.freeze(vit_mha_ipex)
561
562for _ in range(2):
563mha_jit = mha_ipex(mat, mask_base)
564vit_mha_jit = vit_mha_ipex(mat)
565
566mha_ref = mha_model(mat, mask_base)
567vit_mha_ref = vit_mha_model(mat)
568
569self.assertEqual(mha_ref, mha_jit, prec=1e-2)
570self.assertEqual(vit_mha_ref, vit_mha_jit, prec=1e-2)
571
572mha_graph = mha_ipex.graph_for(mat, mask_base)
573vit_mha_graph = vit_mha_ipex.graph_for(mat)
574
575self.assertTrue(
576any(n.kind() == "ipex::bert_flash_mha" for n in mha_graph.nodes())
577)
578self.assertTrue(
579any(
580n.kind() == "ipex::transfree_vit_mha"
581for n in vit_mha_graph.nodes()
582)
583)
584
585for fill_value in [-float("inf"), torch.tensor(torch.finfo(float).min)]:
586distil_mha_model = MHA_Model_Distil(
587scales[i], num_heads[i], head_dims[i], 1, 2, 3, fill_value
588).eval()
589distil_mha_ipex = ipex.optimize(
590distil_mha_model, dtype=torch.bfloat16, level="O1"
591)
592
593with torch.cpu.amp.autocast(), torch.no_grad():
594distil_mha_ipex = torch.jit.trace(
595distil_mha_ipex,
596(
597mat,
598mask_distil,
599),
600)
601distil_mha_ipex = torch.jit.freeze(distil_mha_ipex)
602
603for _ in range(2):
604distil_mha_jit = distil_mha_ipex(mat, mask_distil)
605distil_mha_ref = distil_mha_model(mat, mask_distil)
606self.assertEqual(distil_mha_ref, distil_mha_jit, prec=1e-2)
607distil_mha_graph = distil_mha_ipex.graph_for(mat, mask_distil)
608self.assertTrue(
609any(
610n.kind() == "ipex::distil_mha_scores_calc"
611for n in distil_mha_graph.nodes()
612)
613)
614
615def test_fake_mha_bf16(self):
616mat = torch.randn(16, 16, 256).to(torch.bfloat16)
617mask_base = torch.randn(16, 1, 1, 16).to(torch.bfloat16)
618mask_distil = torch.randn(16, 16).to(torch.bfloat16)
619
620fake_mha_model = []
621fake_mha_ipex = []
622
623fake_mha_model.append(MHA_Model_BERT(16, 16, 16, [0, 2, 3, 1], -1, -2).eval())
624fake_mha_model.append(MHA_Model_BERT(16, 16, 16, [0, 2, 1, 3], -2, -3).eval())
625fake_mha_ipex.append(
626ipex.optimize(fake_mha_model[0], dtype=torch.bfloat16, level="O1")
627)
628fake_mha_ipex.append(
629ipex.optimize(fake_mha_model[1], dtype=torch.bfloat16, level="O1")
630)
631
632fake_mha_model.append(MHA_Model_Distil(16, 16, 16, 1, 2, 1).eval())
633fake_mha_model.append(MHA_Model_Distil(16, 16, 16, 2, 1, 3).eval())
634fake_mha_ipex.append(
635ipex.optimize(fake_mha_model[2], dtype=torch.bfloat16, level="O1")
636)
637fake_mha_ipex.append(
638ipex.optimize(fake_mha_model[3], dtype=torch.bfloat16, level="O1")
639)
640
641fake_mha_model.append(
642MHA_Model_ViT(16, 16, 16, [2, 0, 1, 3, 4], -2, -1, 1, 2).eval()
643)
644fake_mha_model.append(
645MHA_Model_ViT(16, 16, 16, [2, 0, 3, 1, 4], -2, -3, 1, 2).eval()
646)
647fake_mha_model.append(
648MHA_Model_ViT(16, 16, 16, [2, 0, 3, 1, 4], -2, -1, 0, 2).eval()
649)
650fake_mha_ipex.append(
651ipex.optimize(fake_mha_model[4], dtype=torch.bfloat16, level="O1")
652)
653fake_mha_ipex.append(
654ipex.optimize(fake_mha_model[5], dtype=torch.bfloat16, level="O1")
655)
656fake_mha_ipex.append(
657ipex.optimize(fake_mha_model[6], dtype=torch.bfloat16, level="O1")
658)
659
660with torch.cpu.amp.autocast(), torch.no_grad():
661fake_mha_jit = []
662fake_mha_ref = []
663
664for i in range(0, 2):
665fake_mha_ipex[i] = torch.jit.trace(
666fake_mha_ipex[i],
667(
668mat,
669mask_base,
670),
671)
672fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
673for _ in range(2):
674fake_mha_ipex[i](mat, mask_base)
675fake_mha_jit.append(fake_mha_ipex[i](mat, mask_base))
676fake_mha_ref.append(fake_mha_model[i](mat, mask_base))
677fake_mha_graph = fake_mha_ipex[i].graph_for(mat, mask_base)
678self.assertTrue(
679any(
680n.kind() == "ipex::mha_scores_calc"
681for n in fake_mha_graph.nodes()
682)
683)
684
685for i in range(2, 4):
686fake_mha_ipex[i] = torch.jit.trace(
687fake_mha_ipex[i],
688(
689mat,
690mask_distil,
691),
692)
693fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
694for _ in range(2):
695fake_mha_ipex[i](mat, mask_distil)
696fake_mha_jit.append(fake_mha_ipex[i](mat, mask_distil))
697fake_mha_ref.append(fake_mha_model[i](mat, mask_distil))
698fake_mha_graph = fake_mha_ipex[i].graph_for(mat, mask_distil)
699self.assertTrue(
700any(
701n.kind() == "ipex::distil_mha_scores_calc"
702for n in fake_mha_graph.nodes()
703)
704)
705
706for i in range(4, 7):
707fake_mha_ipex[i] = torch.jit.trace(fake_mha_ipex[i], mat)
708fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
709for _ in range(2):
710fake_mha_ipex[i](mat)
711fake_mha_jit.append(fake_mha_ipex[i](mat))
712fake_mha_ref.append(fake_mha_model[i](mat))
713fake_mha_graph = fake_mha_ipex[i].graph_for(mat)
714self.assertFalse(
715any(
716n.kind() == "ipex::transfree_vit_mha"
717for n in fake_mha_graph.nodes()
718)
719)
720
721for i in range(7):
722self.assertEqual(fake_mha_ref[i], fake_mha_jit[i], prec=1e-2)
723
724def test_transfree_mha_fp32(self):
725for i in range(len(bs)):
726mat = torch.randn(bs[i], seq[i], num_heads[i] * head_dims[i]).to(
727torch.float
728)
729mask_base = torch.randn(bs[i], 1, 1, seq[i]).to(torch.float)
730mask_distil = torch.randn(bs[i], seq[i]).to(torch.float)
731
732mha_model = MHA_Model_BERT(
733scales[i], num_heads[i], head_dims[i], [0, 2, 1, 3], -1, -2
734).eval()
735mha_ipex = ipex.optimize(mha_model, dtype=torch.float, level="O1")
736
737distil_mha_model = MHA_Model_Distil(
738scales[i], num_heads[i], head_dims[i], 1, 2, 3
739).eval()
740distil_mha_ipex = ipex.optimize(
741distil_mha_model, dtype=torch.float, level="O1"
742)
743
744vit_mha_model = MHA_Model_ViT(
745scales[i], num_heads[i], head_dims[i], [2, 0, 3, 1, 4], -2, -1, 1, 2
746).eval()
747vit_mha_ipex = ipex.optimize(vit_mha_model, dtype=torch.float, level="O1")
748
749with torch.no_grad():
750mha_ipex = torch.jit.trace(
751mha_ipex,
752(
753mat,
754mask_base,
755),
756)
757mha_ipex = torch.jit.freeze(mha_ipex)
758
759distil_mha_ipex = torch.jit.trace(
760distil_mha_ipex,
761(
762mat,
763mask_distil,
764),
765)
766distil_mha_ipex = torch.jit.freeze(distil_mha_ipex)
767
768vit_mha_ipex = torch.jit.trace(vit_mha_ipex, (mat,))
769vit_mha_ipex = torch.jit.freeze(vit_mha_ipex)
770
771for _ in range(2):
772mha_jit = mha_ipex(mat, mask_base)
773distil_mha_jit = distil_mha_ipex(mat, mask_distil)
774vit_mha_jit = vit_mha_ipex(mat)
775
776mha_ref = mha_model(mat, mask_base)
777distil_mha_ref = distil_mha_model(mat, mask_distil)
778vit_mha_ref = vit_mha_model(mat)
779
780self.assertEqual(mha_ref, mha_jit, prec=1e-5)
781self.assertEqual(distil_mha_ref, distil_mha_jit, prec=1e-5)
782self.assertEqual(vit_mha_ref, vit_mha_jit, prec=1e-5)
783
784mha_graph = mha_ipex.graph_for(mat, mask_base)
785distil_mha_graph = distil_mha_ipex.graph_for(mat, mask_distil)
786vit_mha_graph = vit_mha_ipex.graph_for(mat)
787
788self.assertTrue(
789any(n.kind() == "ipex::matmul_outtrans" for n in mha_graph.nodes())
790)
791self.assertTrue(
792any(
793n.kind() == "ipex::matmul_outtrans"
794for n in distil_mha_graph.nodes()
795)
796)
797self.assertTrue(
798any(
799n.kind() == "ipex::matmul_outtrans"
800for n in vit_mha_graph.nodes()
801)
802)
803
804def test_fake_mha_fp32(self):
805mat = torch.randn(16, 16, 256)
806mask_base = torch.randn(16, 1, 1, 16)
807mask_distil = torch.randn(16, 16)
808
809fake_mha_model = []
810fake_mha_ipex = []
811
812fake_mha_model.append(MHA_Model_BERT(16, 16, 16, [0, 2, 3, 1], -1, -2).eval())
813fake_mha_model.append(MHA_Model_BERT(16, 16, 16, [0, 2, 1, 3], -2, -3).eval())
814fake_mha_ipex.append(
815ipex.optimize(fake_mha_model[0], dtype=torch.float, level="O1")
816)
817fake_mha_ipex.append(
818ipex.optimize(fake_mha_model[1], dtype=torch.float, level="O1")
819)
820
821fake_mha_model.append(MHA_Model_Distil(16, 16, 16, 1, 2, 1).eval())
822fake_mha_model.append(MHA_Model_Distil(16, 16, 16, 2, 1, 3).eval())
823fake_mha_ipex.append(
824ipex.optimize(fake_mha_model[2], dtype=torch.float, level="O1")
825)
826fake_mha_ipex.append(
827ipex.optimize(fake_mha_model[3], dtype=torch.float, level="O1")
828)
829
830fake_mha_model.append(
831MHA_Model_ViT(16, 16, 16, [2, 0, 1, 3, 4], -2, -1, 1, 2).eval()
832)
833fake_mha_model.append(
834MHA_Model_ViT(16, 16, 16, [2, 0, 3, 1, 4], -2, -3, 1, 2).eval()
835)
836fake_mha_model.append(
837MHA_Model_ViT(16, 16, 16, [2, 0, 3, 1, 4], -2, -1, 0, 2).eval()
838)
839fake_mha_ipex.append(
840ipex.optimize(fake_mha_model[4], dtype=torch.float, level="O1")
841)
842fake_mha_ipex.append(
843ipex.optimize(fake_mha_model[5], dtype=torch.float, level="O1")
844)
845fake_mha_ipex.append(
846ipex.optimize(fake_mha_model[6], dtype=torch.float, level="O1")
847)
848
849with torch.no_grad():
850fake_mha_jit = []
851fake_mha_ref = []
852
853for i in range(0, 2):
854fake_mha_ipex[i] = torch.jit.trace(
855fake_mha_ipex[i],
856(
857mat,
858mask_base,
859),
860)
861fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
862for _ in range(2):
863fake_mha_ipex[i](mat, mask_base)
864fake_mha_jit.append(fake_mha_ipex[i](mat, mask_base))
865fake_mha_ref.append(fake_mha_model[i](mat, mask_base))
866fake_mha_graph = fake_mha_ipex[i].graph_for(mat, mask_base)
867self.assertTrue(
868any(
869n.kind() == "ipex::mha_scores_calc"
870for n in fake_mha_graph.nodes()
871)
872)
873with torch.profiler.profile(
874activities=[torch.profiler.ProfilerActivity.CPU]
875) as p:
876fake_mha_ipex[i](mat, mask_base)
877if i == 0:
878self.assertTrue("dil_matmul" in str(p.key_averages()))
879else:
880self.assertTrue("dil_mha_bmm" in str(p.key_averages()))
881
882for i in range(2, 4):
883fake_mha_ipex[i] = torch.jit.trace(
884fake_mha_ipex[i],
885(
886mat,
887mask_distil,
888),
889)
890fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
891for _ in range(2):
892fake_mha_ipex[i](mat, mask_distil)
893fake_mha_jit.append(fake_mha_ipex[i](mat, mask_distil))
894fake_mha_ref.append(fake_mha_model[i](mat, mask_distil))
895fake_mha_graph = fake_mha_ipex[i].graph_for(mat, mask_distil)
896self.assertTrue(
897any(
898n.kind() == "ipex::distil_mha_scores_calc"
899for n in fake_mha_graph.nodes()
900)
901)
902with torch.profiler.profile(
903activities=[torch.profiler.ProfilerActivity.CPU]
904) as p:
905fake_mha_ipex[i](mat, mask_distil)
906if i == 2:
907self.assertTrue("dil_mha_bmm" in str(p.key_averages()))
908else:
909self.assertTrue("dil_matmul" in str(p.key_averages()))
910
911for i in range(4, 7):
912fake_mha_ipex[i] = torch.jit.trace(fake_mha_ipex[i], mat)
913fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
914for _ in range(2):
915fake_mha_ipex[i](mat)
916fake_mha_jit.append(fake_mha_ipex[i](mat))
917fake_mha_ref.append(fake_mha_model[i](mat))
918fake_mha_graph = fake_mha_ipex[i].graph_for(mat)
919self.assertTrue(
920any(n.kind() == "ipex::matmul_mul" for n in fake_mha_graph.nodes())
921)
922with torch.profiler.profile(
923activities=[torch.profiler.ProfilerActivity.CPU]
924) as p:
925fake_mha_ipex[i](mat)
926if i == 6:
927self.assertTrue("dil_matmul" in str(p.key_averages()))
928else:
929self.assertTrue("dil_mha_bmm" in str(p.key_averages()))
930
931for i in range(7):
932self.assertEqual(fake_mha_ref[i], fake_mha_jit[i], prec=1e-5)
933
934
935if __name__ == "__main__":
936test = unittest.main()
937