intel-extension-for-pytorch

Форк
0
936 строк · 36.2 Кб
1
import unittest
2

3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
import intel_extension_for_pytorch as ipex
7
import math
8
import copy
9
from common_utils import TestCase
10

11

12
# (from Diffusers 0.12.1)
13
class SD_MHA_Model_v1(nn.Module):
14
    def __init__(self, scale, num_heads, weightsize, hiddensize):
15
        super(SD_MHA_Model_v1, self).__init__()
16
        self.scale = scale
17
        self.heads = num_heads
18
        self.weightsize = weightsize
19
        self.hiddensize = hiddensize
20
        self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
21
        self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
22
        self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
23

24
    def batch_to_head_dim(self, tensor):
25
        head_size = self.heads
26
        batch_size, seq_len, dim = tensor.shape
27
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
28
        tensor = tensor.permute(0, 2, 1, 3).reshape(
29
            batch_size // head_size, seq_len, dim * head_size
30
        )
31
        return tensor
32

33
    def head_to_batch_dim(self, tensor):
34
        head_size = self.heads
35
        batch_size, seq_len, dim = tensor.shape
36
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
37
        tensor = tensor.permute(0, 2, 1, 3).reshape(
38
            batch_size * head_size, seq_len, dim // head_size
39
        )
40
        return tensor
41

42
    def get_attention_scores(self, query, key):
43
        dtype = query.dtype
44
        attention_scores = torch.baddbmm(
45
            torch.empty(
46
                query.shape[0],
47
                query.shape[1],
48
                key.shape[1],
49
                dtype=query.dtype,
50
                device=query.device,
51
            ),
52
            query,
53
            key.transpose(-1, -2),
54
            beta=0,
55
            alpha=self.scale,
56
        )
57
        attention_probs = attention_scores.softmax(dim=-1)
58
        attention_probs = attention_probs.to(dtype)
59
        return attention_probs
60

61
    def forward(self, x):
62
        query = self.query(x)
63
        query = self.head_to_batch_dim(query)
64
        key = self.key(x)
65
        key = self.head_to_batch_dim(key)
66
        value = self.value(x)
67
        value = self.head_to_batch_dim(value)
68
        attention_probs = self.get_attention_scores(query, key)
69
        hidden_states = torch.bmm(attention_probs, value)
70
        output = self.batch_to_head_dim(hidden_states)
71
        return output
72

73

74
# (from Diffusers 0.12.1)
75
class SD_MHA_Model_v2(nn.Module):
76
    def __init__(self, scale, num_heads, weightsize, hiddensize):
77
        super(SD_MHA_Model_v2, self).__init__()
78
        self.scale = scale
79
        self.heads = num_heads
80
        self.weightsize = weightsize
81
        self.hiddensize = hiddensize
82
        self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
83
        self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
84
        self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
85

86
    def batch_to_head_dim(self, tensor):
87
        head_size = self.heads
88
        batch_size, seq_len, dim = tensor.shape
89
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
90
        tensor = tensor.permute(0, 2, 1, 3).reshape(
91
            batch_size // head_size, seq_len, dim * head_size
92
        )
93
        return tensor
94

95
    def head_to_batch_dim(self, tensor):
96
        head_size = self.heads
97
        batch_size, seq_len, dim = tensor.shape
98
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
99
        tensor = tensor.permute(0, 2, 1, 3).reshape(
100
            batch_size * head_size, seq_len, dim // head_size
101
        )
102
        return tensor
103

104
    def get_attention_scores(self, query, key):
105
        dtype = query.dtype
106
        attention_scores = torch.baddbmm(
107
            torch.empty(
108
                query.shape[0],
109
                query.shape[1],
110
                key.shape[1],
111
                dtype=query.dtype,
112
                device=query.device,
113
            ),
114
            query,
115
            key.transpose(-1, -2),
116
            beta=0,
117
            alpha=self.scale,
118
        )
119
        attention_probs = attention_scores.softmax(dim=-1)
120
        attention_probs = attention_probs.to(dtype)
121
        return attention_probs
122

123
    def forward(self, x, y):
124
        query = self.query(x)
125
        query = self.head_to_batch_dim(query)
126
        key = self.key(y)
127
        key = self.head_to_batch_dim(key)
128
        value = self.value(y)
129
        value = self.head_to_batch_dim(value)
130
        attention_probs = self.get_attention_scores(query, key)
131
        hidden_states = torch.bmm(attention_probs, value)
132
        output = self.batch_to_head_dim(hidden_states)
133
        return output
134

135

136
# (from Diffusers 0.13)
137
class SD_MHA_Model_v3(nn.Module):
138
    def __init__(self, num_heads, weightsize, hiddensize):
139
        super(SD_MHA_Model_v3, self).__init__()
140
        self.heads = num_heads
141
        self.weightsize = weightsize
142
        self.hiddensize = hiddensize
143
        self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
144
        self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
145
        self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
146

147
    def forward(self, x):
148
        query = self.query(x)
149
        key = self.key(x)
150
        value = self.value(x)
151
        batch_size, sequence_length, inner_dim = x.shape
152
        head_dim = inner_dim // self.heads
153
        query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
154
        key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
155
        value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
156
        hidden_states = F.scaled_dot_product_attention(
157
            query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
158
        )
159
        hidden_states = hidden_states.transpose(1, 2).reshape(
160
            batch_size, -1, self.heads * head_dim
161
        )
162
        output = hidden_states.to(query.dtype)
163
        return output
164

165

166
# (from Diffusers 0.13)
167
class SD_MHA_Model_scale_v3(nn.Module):
168
    def __init__(self, num_heads, weightsize, hiddensize, scale):
169
        super(SD_MHA_Model_scale_v3, self).__init__()
170
        self.heads = num_heads
171
        self.weightsize = weightsize
172
        self.hiddensize = hiddensize
173
        self.scale = scale
174
        self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
175
        self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
176
        self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
177

178
    def forward(self, x):
179
        query = self.query(x)
180
        key = self.key(x)
181
        value = self.value(x)
182
        batch_size, sequence_length, inner_dim = x.shape
183
        head_dim = inner_dim // self.heads
184
        query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
185
        key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
186
        value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
187
        hidden_states = F.scaled_dot_product_attention(
188
            query,
189
            key,
190
            value,
191
            attn_mask=None,
192
            dropout_p=0.0,
193
            is_causal=False,
194
            scale=self.scale,
195
        )
196
        hidden_states = hidden_states.transpose(1, 2).reshape(
197
            batch_size, -1, self.heads * head_dim
198
        )
199
        output = hidden_states.to(query.dtype)
200
        return output
201

202

203
# (from Diffusers 0.13)
204
class SD_MHA_Model_v4(nn.Module):
205
    def __init__(self, num_heads, weightsize, hiddensize):
206
        super(SD_MHA_Model_v4, self).__init__()
207
        self.heads = num_heads
208
        self.weightsize = weightsize
209
        self.hiddensize = hiddensize
210
        self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
211
        self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
212
        self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
213

214
    def forward(self, x, y):
215
        query = self.query(x)
216
        key = self.key(y)
217
        value = self.value(y)
218
        batch_size, sequence_length, inner_dim = x.shape
219
        head_dim = inner_dim // self.heads
220
        query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
221
        key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
222
        value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
223
        hidden_states = F.scaled_dot_product_attention(
224
            query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
225
        )
226
        hidden_states = hidden_states.transpose(1, 2).reshape(
227
            batch_size, -1, self.heads * head_dim
228
        )
229
        output = hidden_states.to(query.dtype)
230
        return output
231

232

233
# (from Diffusers 0.13)
234
class SD_MHA_Model_scale_v4(nn.Module):
235
    def __init__(self, num_heads, weightsize, hiddensize, scale):
236
        super(SD_MHA_Model_scale_v4, self).__init__()
237
        self.heads = num_heads
238
        self.weightsize = weightsize
239
        self.hiddensize = hiddensize
240
        self.scale = scale
241
        self.query = nn.Linear(self.weightsize, self.hiddensize, bias=True)
242
        self.key = nn.Linear(self.weightsize, self.hiddensize, bias=True)
243
        self.value = nn.Linear(self.weightsize, self.hiddensize, bias=True)
244

245
    def forward(self, x, y):
246
        query = self.query(x)
247
        key = self.key(y)
248
        value = self.value(y)
249
        batch_size, sequence_length, inner_dim = x.shape
250
        head_dim = inner_dim // self.heads
251
        query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
252
        key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
253
        value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
254
        hidden_states = F.scaled_dot_product_attention(
255
            query,
256
            key,
257
            value,
258
            attn_mask=None,
259
            dropout_p=0.0,
260
            is_causal=False,
261
            scale=self.scale,
262
        )
263
        hidden_states = hidden_states.transpose(1, 2).reshape(
264
            batch_size, -1, self.heads * head_dim
265
        )
266
        output = hidden_states.to(query.dtype)
267
        return output
268

269

270
# (Fake Diffusers Model - Fall back to ipex::mha_scores_calc)
271
class Fake_SD_MHA_Model(nn.Module):
272
    def __init__(self, dim_per_head, softmax_dim=-1):
273
        super(Fake_SD_MHA_Model, self).__init__()
274
        self.softmax = nn.Softmax(dim=softmax_dim)
275
        self.dim_per_head = dim_per_head
276

277
    def forward(self, mat1, mat2, mat3, bias):
278
        mat1 = mat1 / math.sqrt(self.dim_per_head)
279
        qk = torch.matmul(mat1, mat2.transpose(2, 3))
280
        scores = self.softmax(qk + bias)
281
        output = torch.matmul(scores, mat3)
282
        return output
283

284

285
class MHA_Model_BERT(nn.Module):
286
    def __init__(self, scale, num_heads, head_dims, permute_idx, trans_a, trans_b):
287
        super(MHA_Model_BERT, self).__init__()
288
        self.scale = scale
289
        self.num_heads = num_heads
290
        self.head_dims = head_dims
291
        self.embed_dims = self.num_heads * self.head_dims
292
        self.query = nn.Linear(self.embed_dims, self.embed_dims, bias=True)
293
        self.key = nn.Linear(self.embed_dims, self.embed_dims, bias=True)
294
        self.value = nn.Linear(self.embed_dims, self.embed_dims, bias=True)
295
        self.permute_idx = permute_idx
296
        self.trans_a = trans_a
297
        self.trans_b = trans_b
298

299
    def transpose_for_scores(self, x):
300
        new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dims)
301
        x = x.view(new_x_shape)
302
        return x.permute(self.permute_idx)
303

304
    def forward(self, x, mask):
305
        query_layer = self.transpose_for_scores(self.query(x))
306
        key_layer = self.transpose_for_scores(self.key(x)).transpose(
307
            self.trans_a, self.trans_b
308
        )
309
        value_layer = self.transpose_for_scores(self.value(x))
310
        attention_scores = torch.matmul(query_layer, key_layer) / self.scale + mask
311
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
312
        context_layer = torch.matmul(attention_probs, value_layer)
313
        context_layer = context_layer.permute(self.permute_idx).contiguous()
314
        new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dims,)
315
        context_layer = context_layer.view(new_context_layer_shape)
316

317
        return context_layer
318

319

320
class MHA_Model_Distil(nn.Module):
321
    def __init__(
322
        self,
323
        scale,
324
        num_heads,
325
        head_dims,
326
        trans_a,
327
        trans_b,
328
        trans_c,
329
        fill_value=-float("inf"),
330
    ):
331
        super(MHA_Model_Distil, self).__init__()
332
        self.scale = scale
333
        self.n_head = num_heads
334
        self.head_dims = head_dims
335
        self.dim = self.n_head * self.head_dims
336
        self.q_lin = nn.Linear(self.dim, self.dim, bias=True)
337
        self.k_lin = nn.Linear(self.dim, self.dim, bias=True)
338
        self.v_lin = nn.Linear(self.dim, self.dim, bias=True)
339
        self.trans_a = trans_a
340
        self.trans_b = trans_b
341
        self.trans_c = trans_c
342
        self.fill_value = fill_value
343

344
    def forward(self, x, mask):
345
        bs, q_length, dim = x.size()
346
        k_length = x.size(1)
347

348
        def shape(x: torch.Tensor) -> torch.Tensor:
349
            """separate heads"""
350
            return x.view(bs, -1, self.n_head, self.head_dims).transpose(
351
                self.trans_a, self.trans_b
352
            )
353

354
        def unshape(x: torch.Tensor) -> torch.Tensor:
355
            """group heads"""
356
            return (
357
                x.transpose(self.trans_a, self.trans_b)
358
                .contiguous()
359
                .view(bs, -1, self.n_head * self.head_dims)
360
            )
361

362
        q = shape(self.q_lin(x))
363
        k = shape(self.k_lin(x))
364
        v = shape(self.v_lin(x))
365
        mask_reshp = (bs, 1, 1, k_length)
366
        q = q / self.scale
367
        scores = torch.matmul(q, k.transpose(self.trans_b, self.trans_c))
368
        mask = (mask == 0).view(mask_reshp).expand_as(scores)
369
        scores = scores.masked_fill(mask, self.fill_value)
370
        weights = nn.functional.softmax(scores, dim=-1)
371
        context = torch.matmul(weights, v)
372
        context_layer = unshape(context)
373

374
        return context_layer
375

376

377
class MHA_Model_ViT(nn.Module):
378
    def __init__(
379
        self,
380
        scale,
381
        num_heads,
382
        head_dims,
383
        permute_idx,
384
        trans_a,
385
        trans_b,
386
        select_a,
387
        select_b,
388
    ):
389
        super(MHA_Model_ViT, self).__init__()
390
        self.scale = 1.0 / scale
391
        self.num_heads = num_heads
392
        self.head_dims = head_dims
393
        self.embed_dims = self.num_heads * self.head_dims
394
        self.qkv = nn.Linear(self.embed_dims, self.embed_dims * 3, bias=True)
395
        self.permute_idx = permute_idx
396
        self.trans_a = trans_a
397
        self.trans_b = trans_b
398
        self.select_a = select_a
399
        self.select_b = select_b
400

401
    def forward(self, x):
402
        B, N, _ = x.shape
403
        qkv = (
404
            self.qkv(x)
405
            .reshape(B, N, 3, self.num_heads, self.head_dims)
406
            .permute(self.permute_idx)
407
        )
408
        q, k, v = qkv[0], qkv[self.select_a], qkv[self.select_b]
409
        attn = (q @ k.transpose(self.trans_a, self.trans_b)) * self.scale
410
        attn = attn.softmax(dim=-1)
411
        context_layer = (
412
            (attn @ v)
413
            .transpose(self.select_a, self.select_b)
414
            .reshape(B, N, self.embed_dims)
415
        )
416

417
        return context_layer
418

419

420
bs = [5, 3, 11]
421
seq = [128, 384, 31]
422
scales = [8, 13, 21]
423
num_heads = [12, 16, 29]
424
head_dims = [64, 96, 17]
425

426

427
# In this UT case, "+15" is desgined to trigger the overflow of SoftMax when using pos_FLT_MIN.
428
# Since the input values are very large for the BMM and SoftMax, the resulting accumulations of MHA
429
# result will also be large, thus the tolerance value should be set to 1.5e-0 for such case.
430
class TransFreeMHATester(TestCase):
431
    def sd_mha_bf16_common(self, model, mat1, mat2=None):
432
        for neg_FLT_MIN in [True, False]:
433
            sd_mha_model = copy.deepcopy(model)
434
            if mat2 is not None:
435
                inputs = (
436
                    (mat1.to(torch.bfloat16), mat2.to(torch.bfloat16))
437
                    if not neg_FLT_MIN
438
                    else (
439
                        (mat1 + 15).to(torch.bfloat16),
440
                        (mat2 + 15).to(torch.bfloat16),
441
                    )
442
                )
443
            else:
444
                inputs = (
445
                    (mat1.to(torch.bfloat16),)
446
                    if not neg_FLT_MIN
447
                    else ((mat1 + 15).to(torch.bfloat16),)
448
                )
449
            mha_ipex = ipex.optimize(sd_mha_model, dtype=torch.bfloat16, level="O1")
450
            with torch.cpu.amp.autocast(), torch.no_grad():
451
                mha_ipex = torch.jit.trace(mha_ipex, inputs)
452
                mha_ipex = torch.jit.freeze(mha_ipex)
453

454
                for _ in range(2):
455
                    mha_jit = mha_ipex(*inputs)
456
                mha_ref = sd_mha_model(*inputs)
457
                self.assertEqual(mha_ref, mha_jit, prec=1.5e-0 if neg_FLT_MIN else 1e-2)
458

459
                mha_graph = mha_ipex.graph_for(*inputs)
460
                self.assertTrue(
461
                    any(n.kind() == "ipex::sd_flash_mha" for n in mha_graph.nodes())
462
                )
463

464
    def test_sd_mha_bf16_v1(self):
465
        mat = torch.randn(2, 4096, 320)
466
        sd_mha_model = SD_MHA_Model_v1(0.3, 8, 320, 320).eval()
467
        self.sd_mha_bf16_common(sd_mha_model, mat)
468

469
    def test_sd_mha_bf16_v2(self):
470
        mat1 = torch.randn(2, 4096, 320)
471
        mat2 = torch.randn(2, 77, 320)
472
        sd_mha_model = SD_MHA_Model_v2(0.3, 8, 320, 320).eval()
473
        self.sd_mha_bf16_common(sd_mha_model, mat1, mat2)
474

475
    # def test_sd_mha_bf16_v3(self):
476
    #     mat = torch.randn(2, 4096, 320)
477
    #     sd_mha_model = SD_MHA_Model_v3(8, 320, 320).eval()
478
    #     self.sd_mha_bf16_common(sd_mha_model, mat)
479

480
    # def test_sd_mha_bf16_scale_v3(self):
481
    #     mat = torch.randn(2, 4096, 320)
482
    #     sd_mha_model = SD_MHA_Model_scale_v3(8, 320, 320, 0.3).eval()
483
    #     self.sd_mha_bf16_common(sd_mha_model, mat)
484

485
    # def test_sd_mha_bf16_v4(self):
486
    #     mat1 = torch.randn(2, 4096, 320)
487
    #     mat2 = torch.randn(2, 77, 320)
488
    #     sd_mha_model = SD_MHA_Model_v4(8, 320, 320).eval()
489
    #     self.sd_mha_bf16_common(sd_mha_model, mat1, mat2)
490

491
    # def test_sd_mha_bf16_scale_v4(self):
492
    #     mat1 = torch.randn(2, 4096, 320)
493
    #     mat2 = torch.randn(2, 77, 320)
494
    #     sd_mha_model = SD_MHA_Model_scale_v4(8, 320, 320, 0.11).eval()
495
    #     self.sd_mha_bf16_common(sd_mha_model, mat1, mat2)
496

497
    def test_fake_sd_mha_bf16(self):
498
        mat1 = (torch.randn(1, 2, 64, 64) + 20).to(torch.bfloat16)
499
        mat2 = (torch.randn(1, 2, 64, 64) - 20).to(torch.bfloat16)
500
        mat3 = torch.randn(1, 2, 64, 64).to(torch.bfloat16)
501
        mask = (torch.ones(1, 1, 1, 64)).to(torch.bfloat16)
502
        fake_sd_mha_model = Fake_SD_MHA_Model(64, -1).eval()
503
        fake_mha_ipex = ipex.optimize(
504
            fake_sd_mha_model, dtype=torch.bfloat16, level="O1"
505
        )
506

507
        with torch.cpu.amp.autocast(), torch.no_grad():
508
            fake_mha_ipex = torch.jit.trace(
509
                fake_mha_ipex,
510
                (
511
                    mat1,
512
                    mat2,
513
                    mat3,
514
                    mask,
515
                ),
516
            )
517
            fake_mha_ipex = torch.jit.freeze(fake_mha_ipex)
518

519
            for _ in range(2):
520
                fake_mha_jit = fake_mha_ipex(mat1, mat2, mat3, mask)
521
            fake_mha_ref = fake_sd_mha_model(mat1, mat2, mat3, mask)
522
            self.assertEqual(fake_mha_ref, fake_mha_jit, prec=1e-1)
523

524
            fake_mha_graph = fake_mha_ipex.graph_for(mat1, mat2, mat3, mask)
525
            self.assertTrue(
526
                any(n.kind() == "ipex::mha_scores_calc" for n in fake_mha_graph.nodes())
527
            )
528

529
    def test_transfree_mha_bf16(self):
530
        for i in range(len(bs)):
531
            mat = torch.randn(bs[i], seq[i], num_heads[i] * head_dims[i]).to(
532
                torch.bfloat16
533
            )
534
            mask_base = torch.randn(bs[i], 1, 1, seq[i]).to(torch.bfloat16)
535
            mask_distil = torch.randn(bs[i], seq[i]).to(torch.bfloat16)
536

537
            mha_model = MHA_Model_BERT(
538
                scales[i], num_heads[i], head_dims[i], [0, 2, 1, 3], -1, -2
539
            ).eval()
540
            mha_ipex = ipex.optimize(mha_model, dtype=torch.bfloat16, level="O1")
541

542
            vit_mha_model = MHA_Model_ViT(
543
                scales[i], num_heads[i], head_dims[i], [2, 0, 3, 1, 4], -2, -1, 1, 2
544
            ).eval()
545
            vit_mha_ipex = ipex.optimize(
546
                vit_mha_model, dtype=torch.bfloat16, level="O1"
547
            )
548

549
            with torch.cpu.amp.autocast(), torch.no_grad():
550
                mha_ipex = torch.jit.trace(
551
                    mha_ipex,
552
                    (
553
                        mat,
554
                        mask_base,
555
                    ),
556
                )
557
                mha_ipex = torch.jit.freeze(mha_ipex)
558

559
                vit_mha_ipex = torch.jit.trace(vit_mha_ipex, (mat,))
560
                vit_mha_ipex = torch.jit.freeze(vit_mha_ipex)
561

562
                for _ in range(2):
563
                    mha_jit = mha_ipex(mat, mask_base)
564
                    vit_mha_jit = vit_mha_ipex(mat)
565

566
                mha_ref = mha_model(mat, mask_base)
567
                vit_mha_ref = vit_mha_model(mat)
568

569
                self.assertEqual(mha_ref, mha_jit, prec=1e-2)
570
                self.assertEqual(vit_mha_ref, vit_mha_jit, prec=1e-2)
571

572
                mha_graph = mha_ipex.graph_for(mat, mask_base)
573
                vit_mha_graph = vit_mha_ipex.graph_for(mat)
574

575
                self.assertTrue(
576
                    any(n.kind() == "ipex::bert_flash_mha" for n in mha_graph.nodes())
577
                )
578
                self.assertTrue(
579
                    any(
580
                        n.kind() == "ipex::transfree_vit_mha"
581
                        for n in vit_mha_graph.nodes()
582
                    )
583
                )
584

585
            for fill_value in [-float("inf"), torch.tensor(torch.finfo(float).min)]:
586
                distil_mha_model = MHA_Model_Distil(
587
                    scales[i], num_heads[i], head_dims[i], 1, 2, 3, fill_value
588
                ).eval()
589
                distil_mha_ipex = ipex.optimize(
590
                    distil_mha_model, dtype=torch.bfloat16, level="O1"
591
                )
592

593
                with torch.cpu.amp.autocast(), torch.no_grad():
594
                    distil_mha_ipex = torch.jit.trace(
595
                        distil_mha_ipex,
596
                        (
597
                            mat,
598
                            mask_distil,
599
                        ),
600
                    )
601
                    distil_mha_ipex = torch.jit.freeze(distil_mha_ipex)
602

603
                    for _ in range(2):
604
                        distil_mha_jit = distil_mha_ipex(mat, mask_distil)
605
                    distil_mha_ref = distil_mha_model(mat, mask_distil)
606
                    self.assertEqual(distil_mha_ref, distil_mha_jit, prec=1e-2)
607
                    distil_mha_graph = distil_mha_ipex.graph_for(mat, mask_distil)
608
                    self.assertTrue(
609
                        any(
610
                            n.kind() == "ipex::distil_mha_scores_calc"
611
                            for n in distil_mha_graph.nodes()
612
                        )
613
                    )
614

615
    def test_fake_mha_bf16(self):
616
        mat = torch.randn(16, 16, 256).to(torch.bfloat16)
617
        mask_base = torch.randn(16, 1, 1, 16).to(torch.bfloat16)
618
        mask_distil = torch.randn(16, 16).to(torch.bfloat16)
619

620
        fake_mha_model = []
621
        fake_mha_ipex = []
622

623
        fake_mha_model.append(MHA_Model_BERT(16, 16, 16, [0, 2, 3, 1], -1, -2).eval())
624
        fake_mha_model.append(MHA_Model_BERT(16, 16, 16, [0, 2, 1, 3], -2, -3).eval())
625
        fake_mha_ipex.append(
626
            ipex.optimize(fake_mha_model[0], dtype=torch.bfloat16, level="O1")
627
        )
628
        fake_mha_ipex.append(
629
            ipex.optimize(fake_mha_model[1], dtype=torch.bfloat16, level="O1")
630
        )
631

632
        fake_mha_model.append(MHA_Model_Distil(16, 16, 16, 1, 2, 1).eval())
633
        fake_mha_model.append(MHA_Model_Distil(16, 16, 16, 2, 1, 3).eval())
634
        fake_mha_ipex.append(
635
            ipex.optimize(fake_mha_model[2], dtype=torch.bfloat16, level="O1")
636
        )
637
        fake_mha_ipex.append(
638
            ipex.optimize(fake_mha_model[3], dtype=torch.bfloat16, level="O1")
639
        )
640

641
        fake_mha_model.append(
642
            MHA_Model_ViT(16, 16, 16, [2, 0, 1, 3, 4], -2, -1, 1, 2).eval()
643
        )
644
        fake_mha_model.append(
645
            MHA_Model_ViT(16, 16, 16, [2, 0, 3, 1, 4], -2, -3, 1, 2).eval()
646
        )
647
        fake_mha_model.append(
648
            MHA_Model_ViT(16, 16, 16, [2, 0, 3, 1, 4], -2, -1, 0, 2).eval()
649
        )
650
        fake_mha_ipex.append(
651
            ipex.optimize(fake_mha_model[4], dtype=torch.bfloat16, level="O1")
652
        )
653
        fake_mha_ipex.append(
654
            ipex.optimize(fake_mha_model[5], dtype=torch.bfloat16, level="O1")
655
        )
656
        fake_mha_ipex.append(
657
            ipex.optimize(fake_mha_model[6], dtype=torch.bfloat16, level="O1")
658
        )
659

660
        with torch.cpu.amp.autocast(), torch.no_grad():
661
            fake_mha_jit = []
662
            fake_mha_ref = []
663

664
            for i in range(0, 2):
665
                fake_mha_ipex[i] = torch.jit.trace(
666
                    fake_mha_ipex[i],
667
                    (
668
                        mat,
669
                        mask_base,
670
                    ),
671
                )
672
                fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
673
                for _ in range(2):
674
                    fake_mha_ipex[i](mat, mask_base)
675
                fake_mha_jit.append(fake_mha_ipex[i](mat, mask_base))
676
                fake_mha_ref.append(fake_mha_model[i](mat, mask_base))
677
                fake_mha_graph = fake_mha_ipex[i].graph_for(mat, mask_base)
678
                self.assertTrue(
679
                    any(
680
                        n.kind() == "ipex::mha_scores_calc"
681
                        for n in fake_mha_graph.nodes()
682
                    )
683
                )
684

685
            for i in range(2, 4):
686
                fake_mha_ipex[i] = torch.jit.trace(
687
                    fake_mha_ipex[i],
688
                    (
689
                        mat,
690
                        mask_distil,
691
                    ),
692
                )
693
                fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
694
                for _ in range(2):
695
                    fake_mha_ipex[i](mat, mask_distil)
696
                fake_mha_jit.append(fake_mha_ipex[i](mat, mask_distil))
697
                fake_mha_ref.append(fake_mha_model[i](mat, mask_distil))
698
                fake_mha_graph = fake_mha_ipex[i].graph_for(mat, mask_distil)
699
                self.assertTrue(
700
                    any(
701
                        n.kind() == "ipex::distil_mha_scores_calc"
702
                        for n in fake_mha_graph.nodes()
703
                    )
704
                )
705

706
            for i in range(4, 7):
707
                fake_mha_ipex[i] = torch.jit.trace(fake_mha_ipex[i], mat)
708
                fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
709
                for _ in range(2):
710
                    fake_mha_ipex[i](mat)
711
                fake_mha_jit.append(fake_mha_ipex[i](mat))
712
                fake_mha_ref.append(fake_mha_model[i](mat))
713
                fake_mha_graph = fake_mha_ipex[i].graph_for(mat)
714
                self.assertFalse(
715
                    any(
716
                        n.kind() == "ipex::transfree_vit_mha"
717
                        for n in fake_mha_graph.nodes()
718
                    )
719
                )
720

721
            for i in range(7):
722
                self.assertEqual(fake_mha_ref[i], fake_mha_jit[i], prec=1e-2)
723

724
    def test_transfree_mha_fp32(self):
725
        for i in range(len(bs)):
726
            mat = torch.randn(bs[i], seq[i], num_heads[i] * head_dims[i]).to(
727
                torch.float
728
            )
729
            mask_base = torch.randn(bs[i], 1, 1, seq[i]).to(torch.float)
730
            mask_distil = torch.randn(bs[i], seq[i]).to(torch.float)
731

732
            mha_model = MHA_Model_BERT(
733
                scales[i], num_heads[i], head_dims[i], [0, 2, 1, 3], -1, -2
734
            ).eval()
735
            mha_ipex = ipex.optimize(mha_model, dtype=torch.float, level="O1")
736

737
            distil_mha_model = MHA_Model_Distil(
738
                scales[i], num_heads[i], head_dims[i], 1, 2, 3
739
            ).eval()
740
            distil_mha_ipex = ipex.optimize(
741
                distil_mha_model, dtype=torch.float, level="O1"
742
            )
743

744
            vit_mha_model = MHA_Model_ViT(
745
                scales[i], num_heads[i], head_dims[i], [2, 0, 3, 1, 4], -2, -1, 1, 2
746
            ).eval()
747
            vit_mha_ipex = ipex.optimize(vit_mha_model, dtype=torch.float, level="O1")
748

749
            with torch.no_grad():
750
                mha_ipex = torch.jit.trace(
751
                    mha_ipex,
752
                    (
753
                        mat,
754
                        mask_base,
755
                    ),
756
                )
757
                mha_ipex = torch.jit.freeze(mha_ipex)
758

759
                distil_mha_ipex = torch.jit.trace(
760
                    distil_mha_ipex,
761
                    (
762
                        mat,
763
                        mask_distil,
764
                    ),
765
                )
766
                distil_mha_ipex = torch.jit.freeze(distil_mha_ipex)
767

768
                vit_mha_ipex = torch.jit.trace(vit_mha_ipex, (mat,))
769
                vit_mha_ipex = torch.jit.freeze(vit_mha_ipex)
770

771
                for _ in range(2):
772
                    mha_jit = mha_ipex(mat, mask_base)
773
                    distil_mha_jit = distil_mha_ipex(mat, mask_distil)
774
                    vit_mha_jit = vit_mha_ipex(mat)
775

776
                mha_ref = mha_model(mat, mask_base)
777
                distil_mha_ref = distil_mha_model(mat, mask_distil)
778
                vit_mha_ref = vit_mha_model(mat)
779

780
                self.assertEqual(mha_ref, mha_jit, prec=1e-5)
781
                self.assertEqual(distil_mha_ref, distil_mha_jit, prec=1e-5)
782
                self.assertEqual(vit_mha_ref, vit_mha_jit, prec=1e-5)
783

784
                mha_graph = mha_ipex.graph_for(mat, mask_base)
785
                distil_mha_graph = distil_mha_ipex.graph_for(mat, mask_distil)
786
                vit_mha_graph = vit_mha_ipex.graph_for(mat)
787

788
                self.assertTrue(
789
                    any(n.kind() == "ipex::matmul_outtrans" for n in mha_graph.nodes())
790
                )
791
                self.assertTrue(
792
                    any(
793
                        n.kind() == "ipex::matmul_outtrans"
794
                        for n in distil_mha_graph.nodes()
795
                    )
796
                )
797
                self.assertTrue(
798
                    any(
799
                        n.kind() == "ipex::matmul_outtrans"
800
                        for n in vit_mha_graph.nodes()
801
                    )
802
                )
803

804
    def test_fake_mha_fp32(self):
805
        mat = torch.randn(16, 16, 256)
806
        mask_base = torch.randn(16, 1, 1, 16)
807
        mask_distil = torch.randn(16, 16)
808

809
        fake_mha_model = []
810
        fake_mha_ipex = []
811

812
        fake_mha_model.append(MHA_Model_BERT(16, 16, 16, [0, 2, 3, 1], -1, -2).eval())
813
        fake_mha_model.append(MHA_Model_BERT(16, 16, 16, [0, 2, 1, 3], -2, -3).eval())
814
        fake_mha_ipex.append(
815
            ipex.optimize(fake_mha_model[0], dtype=torch.float, level="O1")
816
        )
817
        fake_mha_ipex.append(
818
            ipex.optimize(fake_mha_model[1], dtype=torch.float, level="O1")
819
        )
820

821
        fake_mha_model.append(MHA_Model_Distil(16, 16, 16, 1, 2, 1).eval())
822
        fake_mha_model.append(MHA_Model_Distil(16, 16, 16, 2, 1, 3).eval())
823
        fake_mha_ipex.append(
824
            ipex.optimize(fake_mha_model[2], dtype=torch.float, level="O1")
825
        )
826
        fake_mha_ipex.append(
827
            ipex.optimize(fake_mha_model[3], dtype=torch.float, level="O1")
828
        )
829

830
        fake_mha_model.append(
831
            MHA_Model_ViT(16, 16, 16, [2, 0, 1, 3, 4], -2, -1, 1, 2).eval()
832
        )
833
        fake_mha_model.append(
834
            MHA_Model_ViT(16, 16, 16, [2, 0, 3, 1, 4], -2, -3, 1, 2).eval()
835
        )
836
        fake_mha_model.append(
837
            MHA_Model_ViT(16, 16, 16, [2, 0, 3, 1, 4], -2, -1, 0, 2).eval()
838
        )
839
        fake_mha_ipex.append(
840
            ipex.optimize(fake_mha_model[4], dtype=torch.float, level="O1")
841
        )
842
        fake_mha_ipex.append(
843
            ipex.optimize(fake_mha_model[5], dtype=torch.float, level="O1")
844
        )
845
        fake_mha_ipex.append(
846
            ipex.optimize(fake_mha_model[6], dtype=torch.float, level="O1")
847
        )
848

849
        with torch.no_grad():
850
            fake_mha_jit = []
851
            fake_mha_ref = []
852

853
            for i in range(0, 2):
854
                fake_mha_ipex[i] = torch.jit.trace(
855
                    fake_mha_ipex[i],
856
                    (
857
                        mat,
858
                        mask_base,
859
                    ),
860
                )
861
                fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
862
                for _ in range(2):
863
                    fake_mha_ipex[i](mat, mask_base)
864
                fake_mha_jit.append(fake_mha_ipex[i](mat, mask_base))
865
                fake_mha_ref.append(fake_mha_model[i](mat, mask_base))
866
                fake_mha_graph = fake_mha_ipex[i].graph_for(mat, mask_base)
867
                self.assertTrue(
868
                    any(
869
                        n.kind() == "ipex::mha_scores_calc"
870
                        for n in fake_mha_graph.nodes()
871
                    )
872
                )
873
                with torch.profiler.profile(
874
                    activities=[torch.profiler.ProfilerActivity.CPU]
875
                ) as p:
876
                    fake_mha_ipex[i](mat, mask_base)
877
                if i == 0:
878
                    self.assertTrue("dil_matmul" in str(p.key_averages()))
879
                else:
880
                    self.assertTrue("dil_mha_bmm" in str(p.key_averages()))
881

882
            for i in range(2, 4):
883
                fake_mha_ipex[i] = torch.jit.trace(
884
                    fake_mha_ipex[i],
885
                    (
886
                        mat,
887
                        mask_distil,
888
                    ),
889
                )
890
                fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
891
                for _ in range(2):
892
                    fake_mha_ipex[i](mat, mask_distil)
893
                fake_mha_jit.append(fake_mha_ipex[i](mat, mask_distil))
894
                fake_mha_ref.append(fake_mha_model[i](mat, mask_distil))
895
                fake_mha_graph = fake_mha_ipex[i].graph_for(mat, mask_distil)
896
                self.assertTrue(
897
                    any(
898
                        n.kind() == "ipex::distil_mha_scores_calc"
899
                        for n in fake_mha_graph.nodes()
900
                    )
901
                )
902
                with torch.profiler.profile(
903
                    activities=[torch.profiler.ProfilerActivity.CPU]
904
                ) as p:
905
                    fake_mha_ipex[i](mat, mask_distil)
906
                if i == 2:
907
                    self.assertTrue("dil_mha_bmm" in str(p.key_averages()))
908
                else:
909
                    self.assertTrue("dil_matmul" in str(p.key_averages()))
910

911
            for i in range(4, 7):
912
                fake_mha_ipex[i] = torch.jit.trace(fake_mha_ipex[i], mat)
913
                fake_mha_ipex[i] = torch.jit.freeze(fake_mha_ipex[i])
914
                for _ in range(2):
915
                    fake_mha_ipex[i](mat)
916
                fake_mha_jit.append(fake_mha_ipex[i](mat))
917
                fake_mha_ref.append(fake_mha_model[i](mat))
918
                fake_mha_graph = fake_mha_ipex[i].graph_for(mat)
919
                self.assertTrue(
920
                    any(n.kind() == "ipex::matmul_mul" for n in fake_mha_graph.nodes())
921
                )
922
                with torch.profiler.profile(
923
                    activities=[torch.profiler.ProfilerActivity.CPU]
924
                ) as p:
925
                    fake_mha_ipex[i](mat)
926
                if i == 6:
927
                    self.assertTrue("dil_matmul" in str(p.key_averages()))
928
                else:
929
                    self.assertTrue("dil_mha_bmm" in str(p.key_averages()))
930

931
            for i in range(7):
932
                self.assertEqual(fake_mha_ref[i], fake_mha_jit[i], prec=1e-5)
933

934

935
if __name__ == "__main__":
936
    test = unittest.main()
937

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.