intel-extension-for-pytorch

Форк
0
/
test_ao_jit_llga_quantization_fuser.py 
2384 строки · 85.2 Кб
1
# This Python file uses the following encoding: utf-8
2
# !/usr/bin/env python
3

4
import unittest
5
import itertools
6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
from test_ao_jit_llga_utils import (
10
    JitLlgaTestCase,
11
    LLGA_FUSION_GROUP,
12
    get_eltwise_fn,
13
)
14
from torch.quantization.quantize_fx import prepare_fx, convert_fx
15
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_qat_fx
16
from torch.testing._internal.common_utils import run_tests
17
from torch.ao.quantization import (
18
    MinMaxObserver,
19
    PerChannelMinMaxObserver,
20
    HistogramObserver,
21
    QConfig,
22
)
23

24
default_weight_observer = PerChannelMinMaxObserver.with_args(
25
    dtype=torch.qint8, qscheme=torch.per_channel_symmetric
26
)
27

28
static_qconfig = [
29
    QConfig(
30
        activation=MinMaxObserver.with_args(
31
            qscheme=torch.per_tensor_affine, dtype=torch.quint8
32
        ),
33
        weight=default_weight_observer,
34
    ),
35
    QConfig(
36
        activation=MinMaxObserver.with_args(
37
            qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
38
        ),
39
        weight=default_weight_observer,
40
    ),
41
    QConfig(
42
        activation=HistogramObserver.with_args(
43
            qscheme=torch.per_tensor_affine, dtype=torch.quint8, reduce_range=True
44
        ),
45
        weight=default_weight_observer,
46
    ),
47
    QConfig(
48
        activation=HistogramObserver.with_args(
49
            qscheme=torch.per_tensor_symmetric, dtype=torch.qint8, reduce_range=True
50
        ),
51
        weight=default_weight_observer,
52
    ),
53
]
54

55
try:
56
    import torchvision
57

58
    HAS_TORCHVISION = True
59
except ImportError:
60
    HAS_TORCHVISION = False
61
except RuntimeError:
62
    HAS_TORCHVISION = False
63
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
64

65

66
class TestOp(JitLlgaTestCase):
67
    def test_conv_int8_in_f32_out(self):
68
        for [
69
            spatial,
70
            in_channels,
71
            out_channels,
72
            kernel,
73
            padding,
74
            stride,
75
            dilation,
76
            g,
77
            bias,
78
            memory_format,
79
            module,
80
        ] in itertools.product(
81
            [7],
82
            [2],
83
            [3],
84
            [3],
85
            [0, 2],
86
            [1, 2],
87
            [1, 2],
88
            [1, 2],
89
            [True, False],
90
            [torch.contiguous_format, torch.channels_last],
91
            [torch.nn.Conv2d, torch.nn.Conv3d],
92
        ):
93
            m = module(
94
                in_channels=in_channels * g,
95
                out_channels=out_channels * g,
96
                kernel_size=kernel,
97
                padding=padding,
98
                stride=stride,
99
                dilation=dilation,
100
                groups=g,
101
                bias=bias,
102
            )
103
            input_shape = [1, in_channels * g, spatial, spatial]
104
            if isinstance(m, torch.nn.Conv3d):
105
                input_shape.append(spatial)
106
                if memory_format == torch.channels_last:
107
                    memory_format = torch.channels_last_3d
108
            x = torch.rand(input_shape).to(memory_format=memory_format)
109
            patterns = [["aten::dequantize", "aten::_convolution"]]
110
            # TODO: enable more config case.
111
            for qconfig in static_qconfig:
112
                input_shape[0] = 5
113
                x_var = [torch.rand(input_shape, requires_grad=False)]
114
                graph = self.checkQuantizeTrace(
115
                    m, [x], x_var=x_var, atol=2e-1, qconfig=qconfig
116
                )
117
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
118
                self.assertFused(graph, ["aten::_convolution", "aten::dequantize"])
119
                self.checkPatterns(graph, patterns)
120

121
    def test_deconv_int8_in_f32_out(self):
122
        class M(nn.Module):
123
            def __init__(
124
                self,
125
                in_channels,
126
                out_channels,
127
                kernel_size,
128
                padding,
129
                stride,
130
                dilation,
131
                groups,
132
                bias,
133
                module,
134
            ):
135
                super(M, self).__init__()
136
                self.conv = module(
137
                    in_channels=in_channels * groups,
138
                    out_channels=out_channels * groups,
139
                    kernel_size=kernel_size,
140
                    padding=padding,
141
                    stride=stride,
142
                    dilation=dilation,
143
                    groups=groups,
144
                    bias=bias,
145
                )
146
                inverse_module = (
147
                    torch.nn.ConvTranspose2d
148
                    if (module == torch.nn.Conv2d)
149
                    else torch.nn.ConvTranspose3d
150
                )
151
                self.deconv = inverse_module(
152
                    in_channels=out_channels * groups,
153
                    out_channels=in_channels * groups,
154
                    kernel_size=kernel_size,
155
                    padding=padding,
156
                    stride=stride,
157
                    dilation=dilation,
158
                    groups=groups,
159
                    bias=bias,
160
                )
161

162
            def forward(self, x):
163
                y = self.conv(x)
164
                return self.deconv(y)
165

166
        for [
167
            spatial,
168
            in_channels,
169
            out_channels,
170
            kernel,
171
            padding,
172
            stride,
173
            dilation,
174
            g,
175
            bias,
176
            memory_format,
177
            module,
178
        ] in itertools.product(
179
            [7],
180
            [3],
181
            [3],
182
            [3],
183
            [0, 2],
184
            [1, 2],
185
            [1, 2],
186
            [1, 2],
187
            [True, False],
188
            [torch.contiguous_format, torch.channels_last],
189
            [torch.nn.Conv2d, torch.nn.Conv3d],
190
        ):
191
            m = M(
192
                in_channels=in_channels,
193
                out_channels=out_channels,
194
                kernel_size=kernel,
195
                padding=padding,
196
                stride=stride,
197
                dilation=dilation,
198
                groups=g,
199
                bias=bias,
200
                module=module,
201
            )
202

203
            input_shape = [1, in_channels * g, spatial, spatial]
204
            if module == torch.nn.Conv3d:
205
                input_shape.append(spatial)
206
                if memory_format == torch.channels_last:
207
                    memory_format = torch.channels_last_3d
208
            x = torch.rand(input_shape).to(memory_format=memory_format)
209

210
            patterns = [
211
                ["aten::dequantize", "aten::_convolution"],
212
                ["aten::dequantize", "aten::_convolution"],
213
            ]
214

215
            # TODO: enable more config case.
216
            for qconfig in static_qconfig:
217
                input_shape[0] = 5
218
                graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
219
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
220
                self.assertFused(graph, ["aten::_convolution", "aten::dequantize"])
221
                self.checkPatterns(graph, patterns)
222

223
    def test_conv_no_freeze(self):
224
        m = nn.Conv2d(
225
            in_channels=3,
226
            out_channels=3,
227
            kernel_size=3,
228
            padding=1,
229
            stride=1,
230
            dilation=1,
231
            groups=1,
232
            bias=True,
233
        )
234
        x = torch.rand(1, 3, 5, 5)
235
        graph = self.checkQuantizeTrace(
236
            m, [x], atol=2e-1, qconfig=static_qconfig[0], freeze=False
237
        )
238
        patterns = [
239
            ["aten::dequantize", "aten::quantize_per_channel", "aten::_convolution"]
240
        ]
241
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
242
        self.assertFused(
243
            graph,
244
            ["aten::_convolution", "aten::quantize_per_channel", "aten::dequantize"],
245
        )
246
        self.checkPatterns(graph, patterns)
247

248
    def test_conv_share_dequant_weight(self):
249
        class M(nn.Module):
250
            def __init__(self):
251
                super(M, self).__init__()
252
                self.conv = nn.Conv2d(32, 32, 3, padding=1, bias=True)
253

254
            def forward(self, x):
255
                # type: (List[Tensor]) -> Tensor
256
                all_logits = []
257
                for feature in x:
258
                    logits = self.conv(feature)
259
                    all_logits.append(logits)
260
                return torch.cat(all_logits, dim=1)
261

262
        for memory_format in [torch.contiguous_format, torch.channels_last]:
263
            patterns = [
264
                ["aten::dequantize", "aten::_convolution"],
265
                ["aten::dequantize", "aten::_convolution"],
266
                ["aten::dequantize", "aten::_convolution"],
267
            ]
268
            a = torch.randn(1, 32, 28, 28).to(memory_format=memory_format)
269
            b = torch.randn(1, 32, 28, 28).to(memory_format=memory_format)
270
            c = torch.randn(1, 32, 28, 28).to(memory_format=memory_format)
271
            x = [a, b, c]
272
            for qconfig in static_qconfig:
273
                m = M()
274
                graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
275
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
276
                self.assertFused(graph, ["aten::_convolution", "aten::dequantize"])
277
                self.checkPatterns(graph, patterns)
278

279
    def test_linear_int8_in_f32_out(self):
280
        for bias in [True, False]:
281
            x = torch.rand(32, 28)
282
            m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
283

284
            patterns = [
285
                ["aten::dequantize", "aten::linear"],
286
            ]
287
            for qconfig in static_qconfig:
288
                graph = self.checkQuantizeTrace(m, [x], atol=1e-1, qconfig=qconfig)
289
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
290
                self.assertFused(graph, ["aten::linear", "aten::dequantize"])
291
                self.checkPatterns(graph, patterns)
292

293
    def test_linear_int8_in_int8_out(self):
294
        class M(nn.Module):
295
            def __init__(self, bias):
296
                super(M, self).__init__()
297
                self.linear1 = nn.Linear(15, 20, bias=bias)
298
                self.linear2 = nn.Linear(20, 3, bias=bias)
299

300
            def forward(self, x, y):
301
                x = self.linear1(x)
302
                x = self.linear2(x)
303
                return x
304

305
        for bias in [True, False]:
306
            x = torch.randn(2, 15)
307
            y = torch.randn(2, 20)
308
            m = M(bias)
309

310
            patterns = [
311
                ["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
312
                ["aten::dequantize", "aten::linear"],
313
            ]
314

315
            for qconfig in static_qconfig:
316
                graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
317
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
318
                self.assertFused(
319
                    graph,
320
                    ["aten::linear", "aten::quantize_per_channel", "aten::dequantize"],
321
                )
322
                self.checkPatterns(graph, patterns)
323

324
    def test_linear_int8_in_bf16_out(self):
325
        class M(nn.Module):
326
            def __init__(self, bias):
327
                super(M, self).__init__()
328
                self.linear1 = nn.Linear(15, 20, bias=bias)
329

330
            def forward(self, x):
331
                x = self.linear1(x)
332
                return x
333

334
        for bias in [True]:  # TODO:[True, False] when supported in backend
335
            x = torch.randn(2, 15)
336

337
            patterns = [
338
                ["aten::dequantize", "aten::to", "aten::linear"],
339
            ]
340

341
            for qconfig in static_qconfig:
342
                m = M(bias)
343
                graph = self.checkQuantizeTrace(
344
                    m, [x], atol=2e-1, qconfig=qconfig, int8_bf16=True
345
                )
346
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
347
                # single aten::to won't be rewritten by llga backend
348
                self.assertFused(graph, ["aten::dequantize", "aten::linear"])
349
                self.checkPatterns(graph, patterns)
350

351
    def test_max_pool2d(self):
352
        class M(nn.Module):
353
            def __init__(self, **kargs):
354
                super(M, self).__init__()
355
                self.conv = nn.Conv2d(3, 3, 1, 1)
356
                self.max_pool = nn.MaxPool2d(**kargs)
357

358
            def forward(self, x):
359
                x = self.conv(x)
360
                x = self.max_pool(x)
361
                return x
362

363
        for [
364
            spatial,
365
            kernel,
366
            padding,
367
            stride,
368
            dilation,
369
            ceil_mode,
370
            memory_format,
371
        ] in itertools.product(
372
            [15],  # [15, 16], TODO: check backend
373
            [3, 5],  # [3, 4, 5], TODO: check backend
374
            [0, 1],
375
            [1, 2],  # [1, 2, 4], TODO: fix issue in pad calculation
376
            [1, 2],
377
            [True, False],
378
            [torch.contiguous_format, torch.channels_last],
379
        ):
380
            m = M(
381
                kernel_size=kernel,
382
                stride=stride,
383
                padding=padding,
384
                dilation=dilation,
385
                ceil_mode=ceil_mode,
386
            )
387
            x = torch.rand(1, 3, spatial, spatial).to(memory_format=memory_format)
388

389
            patterns = [
390
                [
391
                    "aten::dequantize",
392
                    "aten::dequantize",
393
                    "aten::_convolution",
394
                    "aten::quantize_per_tensor",
395
                ],
396
                ["aten::dequantize", "aten::max_pool2d", "aten::quantize_per_tensor"],
397
            ]
398
            for qconfig in static_qconfig:
399
                graph = self.checkQuantizeTrace(m, [x], atol=1e-1, qconfig=qconfig)
400
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
401
                self.assertFused(graph, ["aten::max_pool2d"])
402
                self.checkPatterns(graph, patterns)
403

404
    def test_add_scalar_input(self):
405
        class M(torch.nn.Module):
406
            def __init__(self):
407
                super(M, self).__init__()
408

409
            def forward(self, x):
410
                x_shape1 = x.size()[0]
411
                x_shape2 = x.size()[1]
412
                y1 = x_shape1 + 2
413
                y2 = x_shape2 + 3
414
                return y1 + y2
415

416
        # input[0] to add being scalar is unsupported
417
        x = torch.randn(3, 3)
418
        m = M()
419
        graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
420
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
421
        self.assertGraphContainsExactly(graph, "aten::add", 3)
422

423
    def test_reshape_6D_linear(self):
424
        class M(nn.Module):
425
            def __init__(self):
426
                super(M, self).__init__()
427
                self.linear = torch.nn.Linear(
428
                    in_features=64, out_features=192, bias=True
429
                )
430

431
            def forward(self, x):
432
                x = x.reshape(4, 8, 7, 8, 8, 64).transpose(2, 3)
433
                x = self.linear(x)
434
                return x
435

436
        for bias in [True, False]:
437
            x = torch.randn(4, 56, 64, 64)
438
            m = M()
439

440
            patterns = [["aten::dequantize", "aten::linear"]]
441

442
            for qconfig in static_qconfig:
443
                graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
444
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
445
                self.assertFused(graph, ["aten::linear", "aten::dequantize"])
446
                self.checkPatterns(graph, patterns)
447

448
    def test_3d_bmm_int8_in_f32_out(self):
449
        class M(nn.Module):
450
            def __init__(self):
451
                super(M, self).__init__()
452

453
            def forward(self, x, y):
454
                return torch.bmm(x, y)
455

456
        x = torch.randn(128, 3, 4) * 0.1
457
        y = torch.randn(128, 4, 5) * 0.1
458
        patterns = [
459
            ["aten::dequantize", "aten::bmm"],
460
        ]
461
        m = M()
462
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
463
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
464
        self.assertFused(graph, ["aten::dequantize", "aten::bmm"])
465
        self.checkPatterns(graph, patterns)
466

467
    def test_bmm_int8_in_f32_out(self):
468
        class M(nn.Module):
469
            def __init__(self):
470
                super(M, self).__init__()
471

472
            def forward(self, x, y):
473
                mm_res = torch.matmul(x, y)
474
                return mm_res
475

476
        x = torch.randn(128, 16, 384, 64) * 0.1
477
        y = torch.randn(128, 1, 64, 384) * 0.1
478
        patterns = [
479
            ["aten::dequantize", "aten::matmul"],
480
        ]
481
        m = M()
482
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
483
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
484
        self.assertFused(graph, ["aten::matmul"])
485
        self.checkPatterns(graph, patterns)
486

487
    def test_strided_bmm_int8_in_bf16_out(self):
488
        class M(nn.Module):
489
            def __init__(self):
490
                super(M, self).__init__()
491
                self.num_attention_heads = 16
492
                self.attention_head_size = 4
493

494
            def forward(self, x, y):
495
                new_x_shape = x.size()[:-1] + (
496
                    self.num_attention_heads,
497
                    self.attention_head_size,
498
                )
499
                x = x.view(*new_x_shape)
500
                z1 = x.permute(0, 2, 1, 3)
501

502
                new_y_shape2 = y.size()[:-1] + (
503
                    self.num_attention_heads,
504
                    self.attention_head_size,
505
                )
506
                y = y.view(*new_y_shape2)
507
                z2 = y.permute(0, 2, 1, 3)
508

509
                # inputs to matmul has been permuted or transposed, thus are strided tensor
510
                return torch.matmul(z1, z2.transpose(-1, -2))
511

512
        m = M()
513
        x = torch.randn(2, 3, 64)
514
        y = torch.randn(2, 3, 64)
515

516
        patterns = [
517
            ["aten::dequantize", "aten::to", "aten::matmul"],
518
        ]
519

520
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
521
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
522
        self.assertFused(graph, ["aten::matmul", "aten::dequantize"])
523
        self.checkPatterns(graph, patterns)
524

525
    def test_mixed_precision_softmax(self):
526
        class M(torch.nn.Module):
527
            def __init__(self):
528
                super(M, self).__init__()
529

530
            def forward(self, x, y, z, a):
531
                o = torch.matmul(x, y) / 8.0
532
                o = o + a.to(o.dtype)
533
                o = torch.softmax(o, -1)
534
                o = o.matmul(z)
535
                return o
536

537
        x = torch.randn(1, 16, 16, 64)
538
        y = torch.randn(1, 16, 64, 16)
539
        z = torch.randn(1, 16, 16, 64)
540
        a = torch.randn(1, 1, 1, 16)
541
        m = M()
542

543
        # fp32 in int8 out softmax
544
        graph = self.checkQuantizeTrace(m, [x, y, z, a], atol=2e-1, int8_bf16=False)
545
        self.assertFused(
546
            graph, ["aten::matmul", "aten::div", "aten::add", "aten::softmax"]
547
        )
548

549
        # bf16 in int8 out softmax
550
        graph = self.checkQuantizeTrace(m, [x, y, z, a], atol=2e-1, int8_bf16=True)
551
        self.assertFused(
552
            graph, ["aten::matmul", "aten::div", "aten::add", "aten::softmax"]
553
        )
554

555

556
class TestFusionPattern(JitLlgaTestCase):
557
    def test_conv2d_eltwise(self):
558
        class M(nn.Module):
559
            def __init__(self, eltwise_fn):
560
                super(M, self).__init__()
561
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
562
                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
563
                self.eltwise = eltwise_fn
564

565
            def forward(self, x):
566
                x = self.conv1(x)
567
                x = self.eltwise(x)
568
                x = self.conv2(x)
569
                return x
570

571
        for eltwise in [
572
            "relu",
573
            "leaky_relu",
574
            "sigmoid",
575
            "round",
576
            "abs",
577
            "square",
578
            "abs",
579
            "round",
580
            "exp",
581
            "hardswish",
582
            "tanh",
583
            "hardtanh",
584
            "mish",
585
        ]:
586
            for inplace in [False, True]:
587
                for memory_format in [torch.contiguous_format, torch.channels_last]:
588
                    eltwise_fn_name = eltwise + "_" if inplace else eltwise
589
                    eltwise_fn = get_eltwise_fn(eltwise_fn_name)
590

591
                    m = M(eltwise_fn)
592
                    x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
593

594
                    patterns = [
595
                        [
596
                            "aten::dequantize",
597
                            "aten::_convolution",
598
                            "aten::" + eltwise,
599
                            "aten::quantize_per_tensor",
600
                        ],  # inplace op will become outplace op on the JIT graph
601
                        ["aten::dequantize", "aten::_convolution"],
602
                    ]
603
                    for qconfig in static_qconfig:
604
                        graph = self.checkQuantizeTrace(
605
                            m, [x], atol=2e-1, qconfig=qconfig
606
                        )
607
                        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
608
                        self.assertFused(
609
                            graph,
610
                            [
611
                                "aten::_convolution",
612
                                "aten::" + eltwise,
613
                                "aten::quantize_per_channel",
614
                                "aten::dequantize",
615
                            ],
616
                        )
617
                        self.checkPatterns(graph, patterns)
618

619
    def test_conv2d_clamp(self):
620
        class M(nn.Module):
621
            def __init__(self):
622
                super(M, self).__init__()
623
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
624
                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
625
                self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
626
                self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
627
                self.conv5 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
628

629
            def forward(self, x):
630
                x = self.conv1(x)
631
                x = torch.clamp(x, min=float("-inf"))
632
                x = self.conv2(x)
633
                x = torch.clamp(x, min=-5)
634
                x = self.conv3(x)
635
                x = torch.clamp(x, min=0, max=float("inf"))
636
                x = self.conv4(x)
637
                x = torch.clamp(x, min=1, max=5)
638
                x = self.conv5(x)
639
                x = torch.clamp(x, max=2)
640
                return x
641

642
        for inplace in [False, True]:
643
            for memory_format in [torch.contiguous_format, torch.channels_last]:
644
                x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
645
                m = M()
646
                for qconfig in static_qconfig:
647
                    graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
648
                    self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
649
                    self.assertFused(
650
                        graph,
651
                        [
652
                            "aten::_convolution",
653
                            "aten::" + "clamp",
654
                            "aten::quantize_per_channel",
655
                            "aten::dequantize",
656
                        ],
657
                    )
658

659
    def test_conv2d_silu(self):
660
        class M(nn.Module):
661
            def __init__(self, inplace):
662
                super(M, self).__init__()
663
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
664
                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
665
                self.eltwise = nn.SiLU(inplace=inplace)
666

667
            def forward(self, x):
668
                x = self.conv1(x)
669
                x = self.eltwise(x)
670
                x = self.conv2(x)
671
                return x
672

673
        for inplace in [False, True]:
674
            for memory_format in [torch.contiguous_format, torch.channels_last]:
675
                m = M(inplace)
676
                x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
677

678
                graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
679
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
680

681
                silu_op = "aten::silu_" if inplace else "aten::silu"
682

683
                # oneDNN graph does not have silu OP. The bridge will convert silu to sigmoid - mul
684
                patterns = [
685
                    [
686
                        "aten::dequantize",
687
                        "aten::_convolution",
688
                        "aten::sigmoid",
689
                        "aten::mul",
690
                        "aten::quantize_per_tensor",
691
                    ],  # inplace op will become outplace op on the JIT graph
692
                    ["aten::dequantize", "aten::_convolution"],
693
                ]
694

695
                self.assertFused(
696
                    graph, ["aten::_convolution", silu_op, "aten::dequantize"]
697
                )
698
                self.checkPatterns(graph, patterns)
699

700
    def test_deconv_silu(self):
701
        class M(nn.Module):
702
            def __init__(self, inplace):
703
                super(M, self).__init__()
704
                self.deconv = nn.ConvTranspose2d(3, 2, 3, stride=2)
705
                self.eltwise = nn.SiLU(inplace=inplace)
706

707
            def forward(self, x):
708
                x = self.deconv(x)
709
                x = self.eltwise(x)
710
                return x
711

712
        for inplace in [False, True]:
713
            m = M(inplace)
714
            x = torch.rand(1, 3, 28, 28)
715
            graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
716
            patterns = [
717
                ["aten::dequantize", "aten::_convolution", "aten::sigmoid", "aten::mul"]
718
            ]
719
            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
720
            self.checkPatterns(graph, patterns)
721

722
    def test_ensure_tensor_is_rewrapped(self):
723
        class M(nn.Module):
724
            def __init__(self, eltwise_fn):
725
                super(M, self).__init__()
726
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
727
                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
728
                self.eltwise = eltwise_fn
729
                self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
730

731
            def forward(self, x, y):
732
                x = self.conv1(x)
733
                y = self.conv2(y)
734
                y = self.eltwise(y)
735
                x = torch.add(x, y)
736
                x = self.adaptive_avg_pool_2d(x)
737
                return x
738

739
        eltwise_fn_name = "relu"
740
        eltwise_fn = get_eltwise_fn(eltwise_fn_name)
741

742
        m = M(eltwise_fn)
743
        x = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
744
        y = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
745
        for qconfig in static_qconfig:
746
            # The output of the fourth partition is input to adaptive_avg_pool2d, which is
747
            # unsupported by LLGA. In resnext101 32x16d, we had encountered an accuracy issue.
748
            # The UT checks that the input to adaptive_avg_pool_2d has not been wrapped by
749
            # LlgaTensorImpl (assertEqual would fail in that case).
750
            graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
751
            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
752

753
    def test_conv2d_bn(self):
754
        class M(nn.Module):
755
            def __init__(self, bias):
756
                super(M, self).__init__()
757
                self.conv1 = nn.Conv2d(32, 5, 3, padding=1, bias=False)
758
                self.bn1 = nn.BatchNorm2d(5)
759

760
            def forward(self, x):
761
                x = self.conv1(x)
762
                x = self.bn1(x)
763
                return x
764

765
        for bias in [False, True]:
766
            for memory_format in [torch.contiguous_format, torch.channels_last]:
767
                m = M(bias).eval()
768
                x = torch.rand(1, 32, 16, 16).to(memory_format=memory_format)
769
                # TODO: This shape will fail
770
                # x = torch.rand(1, 32, 28, 28)
771

772
                patterns = [["aten::dequantize", "aten::_convolution"]]
773
                # TODO: add torch.per_tensor_symmetric case.
774
                for qconfig in static_qconfig:
775
                    graph = self.checkQuantizeTrace(m, [x], atol=1e-1, qconfig=qconfig)
776
                    self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
777
                    self.assertFused(
778
                        graph,
779
                        [
780
                            "aten::_convolution",
781
                            "aten::quantize_per_channel",
782
                            "aten::dequantize",
783
                        ],
784
                    )
785
                    self.checkPatterns(graph, patterns)
786

787
    def test_conv2d_bn_relu(self):
788
        class M(nn.Module):
789
            def __init__(self):
790
                super(M, self).__init__()
791
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
792
                self.bn1 = nn.BatchNorm2d(32)
793

794
            def forward(self, x):
795
                x = self.conv1(x)
796
                x = self.bn1(x)
797
                x = F.relu(x)
798
                return x
799

800
        for memory_format in [torch.contiguous_format, torch.channels_last]:
801
            m = M().eval()
802
            x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
803
            patterns = [
804
                ["aten::dequantize", "aten::_convolution", "aten::relu"],
805
            ]
806
            for qconfig in static_qconfig:
807
                graph = self.checkQuantizeTrace(m, [x], atol=1e-1, qconfig=qconfig)
808
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
809
                self.assertFused(
810
                    graph,
811
                    ["aten::_convolution", "aten::relu", "aten::quantize_per_channel"],
812
                )
813
                self.checkPatterns(graph, patterns)
814

815
    def test_linear_bn(self):
816
        class M(nn.Module):
817
            def __init__(self, dim):
818
                super(M, self).__init__()
819
                self.linear = nn.Linear(32, 32)
820
                if dim == 1:
821
                    self.input1 = torch.randn(1, 32)
822
                    self.bn = nn.BatchNorm1d(32)
823
                elif dim == 2:
824
                    self.input1 = torch.randn(1, 32, 32, 32)
825
                    self.bn = nn.BatchNorm2d(32)
826
                elif dim == 3:
827
                    self.input1 = torch.randn(1, 32, 32, 32, 32)
828
                    self.bn = nn.BatchNorm3d(32)
829

830
            def forward(self, x):
831
                x = self.linear(x)
832
                x = self.bn(x)
833
                return x
834

835
        for dim in [1, 2, 3]:
836
            m = M(dim=dim)
837
            x = m.input1
838
            patterns = [["aten::dequantize", "aten::linear"]]
839
            for qconfig in static_qconfig:
840
                graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
841
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
842
                self.assertFused(graph, ["ipex::batch_norm"])
843
                self.checkPatterns(graph, patterns)
844

845
    def test_conv_bn_linear_bn(self):
846
        class M(nn.Module):
847
            def __init__(
848
                self,
849
            ):
850
                super(M, self).__init__()
851
                self.input1 = torch.randn(1, 32, 32, 32)
852
                self.conv = nn.Conv2d(32, 32, 1)
853
                self.bn1 = nn.BatchNorm2d(32)
854
                self.linear = nn.Linear(32, 32)
855
                self.bn2 = nn.BatchNorm2d(32)
856

857
            def forward(self, x):
858
                x = self.conv(x)
859
                x = self.bn1(x)
860
                x = self.linear(x)
861
                x = self.bn2(x)
862
                return x
863

864
        m = M()
865
        x = m.input1
866
        patterns = [
867
            ["aten::dequantize", "aten::_convolution"],
868
            ["aten::dequantize", "aten::linear"],
869
        ]
870
        for qconfig in static_qconfig:
871
            graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
872
            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
873
            self.assertFused(graph, ["ipex::batch_norm"])
874
            self.checkPatterns(graph, patterns)
875

876
    def test_linear_eltwise(self):
877
        class M(nn.Module):
878
            def __init__(self, eltwise_fn, bias):
879
                super(M, self).__init__()
880
                self.linear = nn.Linear(28, 64, bias)
881
                self.eltwise = eltwise_fn
882

883
            def forward(self, x):
884
                x = self.linear(x)
885
                x = self.eltwise(x)
886
                return x
887

888
        # TODO: use itertools.product once all combinations is supported
889
        for [has_bias, eltwise] in [
890
            [True, "relu"],
891
            [False, "relu"],
892
            # [True, 'gelu'], # TODO: enable it once linear_gelu default recipe is fixed
893
            # [False, 'gelu'], # TODO: enable it once linear_gelu default recipe is fixed
894
            [True, "sigmoid"],
895
            [False, "sigmoid"],
896
        ]:
897
            eltwise_fn = get_eltwise_fn(eltwise)
898
            m = M(eltwise_fn, has_bias)
899
            x = torch.rand(32, 28, requires_grad=False)
900
            patterns = [
901
                ["aten::dequantize", "aten::linear", "aten::" + eltwise],
902
            ]
903
            for qconfig in static_qconfig:
904
                graph = self.checkQuantizeTrace(
905
                    m,
906
                    [x],
907
                    x_var=[torch.rand(2, 28, requires_grad=False)],
908
                    atol=1e-1,
909
                    qconfig=qconfig,
910
                )
911
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
912
                self.assertFused(graph, ["aten::" + eltwise])
913
                self.checkPatterns(graph, patterns)
914

915
    def test_linear_silu(self):
916
        class M(nn.Module):
917
            def __init__(self, inplace):
918
                super(M, self).__init__()
919
                self.linear = nn.Linear(28, 64)
920
                self.eltwise = nn.SiLU(inplace=inplace)
921

922
            def forward(self, x):
923
                x = self.linear(x)
924
                x = self.eltwise(x)
925
                return x
926

927
        for inplace in [False, True]:
928
            m = M(inplace)
929
            x = torch.rand(1, 28, requires_grad=False)
930

931
            silu_op = "aten::silu_" if inplace else "aten::silu"
932

933
            patterns = [
934
                ["aten::dequantize", "aten::linear", "aten::sigmoid", "aten::mul"],
935
            ]
936
            graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
937
            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
938
            self.assertFused(graph, ["aten::linear", silu_op, "aten::dequantize"])
939
            self.checkPatterns(graph, patterns)
940

941
    def test_conv_relu_sigmoid_mul(self):
942
        #        dequant
943
        #           |
944
        #         conv
945
        #           |
946
        #         relu
947
        #          /  |
948
        #       quant |
949
        #        /    |
950
        #     dequant |
951
        #       |     |
952
        #     conv    |
953
        #       |     |
954
        #     relu    |
955
        #       |     |
956
        #     quant   |
957
        #       |     |
958
        #    dequant  |
959
        #       |     |
960
        #     conv    |
961
        #       |     |
962
        #    sigmoid  |
963
        #         \   /
964
        #          mul
965

966
        class M(nn.Module):
967
            def __init__(self):
968
                super(M, self).__init__()
969
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1)
970
                self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
971
                self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
972

973
            def forward(self, x):
974
                x = self.conv1(x)
975

976
                # The output y of relu is used by mul
977
                y = x.relu()
978

979
                z = self.conv2(y)
980
                z = z.relu()
981
                z = self.conv3(z)
982
                z = z.sigmoid()
983
                z = z.mul(y)
984
                return z
985

986
        x = torch.rand(1, 32, 16, 16, requires_grad=False)
987
        m = M()
988
        graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
989
        patterns = [
990
            ["aten::dequantize", "aten::_convolution", "aten::relu"],
991
            [
992
                "aten::dequantize",
993
                "aten::_convolution",
994
                "aten::relu",
995
                "aten::quantize_per_tensor",
996
            ],
997
            ["aten::dequantize", "aten::_convolution", "aten::sigmoid", "aten::mul"],
998
        ]
999
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1000
        self.assertFused(
1001
            graph, ["aten::_convolution", "aten::relu", "aten::sigmoid", "aten::mul"]
1002
        )
1003
        self.checkPatterns(graph, patterns)
1004

1005
    def test_conv_eltwise_tensor_method(self):
1006
        class ConvSigmoid(nn.Module):
1007
            def __init__(self):
1008
                super(ConvSigmoid, self).__init__()
1009
                self.conv = nn.Conv2d(32, 32, 3, padding=1)
1010

1011
            def forward(self, x):
1012
                x = self.conv(x)
1013
                x = x.sigmoid()
1014
                return x
1015

1016
        class ConvReLU(nn.Module):
1017
            def __init__(self):
1018
                super(ConvReLU, self).__init__()
1019
                self.conv = nn.Conv2d(32, 32, 3, padding=1)
1020

1021
            def forward(self, x):
1022
                x = self.conv(x)
1023
                x = x.relu()
1024
                return x
1025

1026
        m = ConvSigmoid().eval()
1027
        x = torch.rand(1, 32, 16, 16, requires_grad=False)
1028
        patterns = [["aten::dequantize", "aten::_convolution", "aten::sigmoid"]]
1029
        graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
1030
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1031
        self.assertFused(graph, ["aten::_convolution", "aten::sigmoid"])
1032
        self.checkPatterns(graph, patterns)
1033

1034
        m = ConvReLU().eval()
1035
        x = torch.rand(1, 32, 16, 16, requires_grad=False)
1036
        patterns = [["aten::dequantize", "aten::_convolution", "aten::relu"]]
1037
        graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
1038
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1039
        self.assertFused(graph, ["aten::_convolution", "aten::relu"])
1040
        self.checkPatterns(graph, patterns)
1041

1042
    def test_conv2d_sum(self):
1043
        class M(nn.Module):
1044
            def __init__(self, bias=False):
1045
                super(M, self).__init__()
1046
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
1047
                self.bn1 = nn.BatchNorm2d(32)
1048
                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
1049
                self.bn2 = nn.BatchNorm2d(32)
1050
                self.relu = nn.ReLU()
1051
                self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
1052
                self.bn3 = nn.BatchNorm2d(32)
1053

1054
            def forward(self, x, y):
1055
                x = self.conv1(x)
1056
                x = self.bn1(x)
1057
                y = self.conv2(y)
1058
                y = self.bn2(y)
1059
                z = self.relu(x + y)
1060
                z = self.conv3(z)
1061
                z = self.bn3(z)
1062
                return z
1063

1064
        for bias in [True, False]:
1065
            for memory_format in [torch.contiguous_format, torch.channels_last]:
1066
                m = M(bias).eval()
1067
                x = torch.rand(1, 32, 16, 16, requires_grad=False).to(
1068
                    memory_format=memory_format
1069
                )
1070
                y = torch.rand(1, 32, 16, 16, requires_grad=False).to(
1071
                    memory_format=memory_format
1072
                )
1073
                patterns = [
1074
                    [
1075
                        "aten::dequantize",
1076
                        "aten::_convolution",
1077
                        "aten::quantize_per_tensor",
1078
                    ],
1079
                    [
1080
                        "aten::dequantize",
1081
                        "aten::_convolution",
1082
                        "aten::relu",
1083
                        "aten::add",
1084
                        "aten::quantize_per_tensor",
1085
                    ],
1086
                    ["aten::dequantize", "aten::_convolution"],
1087
                ]
1088
                for qconfig in static_qconfig:
1089
                    graph = self.checkQuantizeTrace(
1090
                        m, [x, y], atol=1e-1, qconfig=qconfig
1091
                    )
1092
                    self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1093
                    self.assertFused(
1094
                        graph,
1095
                        [
1096
                            "aten::_convolution",
1097
                            "aten::relu",
1098
                            "aten::add",
1099
                            "aten::quantize_per_channel",
1100
                            "aten::dequantize",
1101
                        ],
1102
                    )
1103
                    self.checkPatterns(graph, patterns)
1104

1105
    def test_add_quantization(self):
1106
        class M(nn.Module):
1107
            def __init__(self, bias=False):
1108
                super(M, self).__init__()
1109
                self.conv1 = nn.Conv2d(16, 16, 1)
1110
                self.conv2 = nn.Conv2d(16, 16, 1)
1111

1112
            def forward(self, x):
1113
                x = self.conv1(x)
1114
                y = self.conv2(x)
1115
                y = y.mul(10)
1116
                z = torch.add(x, y)
1117
                return z
1118

1119
        m = M().eval()
1120
        x = torch.rand(1, 16, 16, 16, requires_grad=False)
1121
        x2 = torch.rand(1, 16, 16, 16, requires_grad=False)
1122

1123
        patterns = [
1124
            ["aten::dequantize", "aten::_convolution"],
1125
            ["aten::dequantize", "aten::_convolution"],
1126
        ]
1127
        graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
1128
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1129
        self.assertFused(graph, ["aten::_convolution", "aten::quantize_per_channel"])
1130
        self.checkPatterns(graph, patterns)
1131

1132
    def test_conv2d_sigmoid_mul_(self):
1133
        class M(nn.Module):
1134
            def __init__(self, in_channels, out_channels, kernel_size, image_size):
1135
                super(M, self).__init__()
1136
                self.conv = torch.nn.Conv2d(
1137
                    in_channels, out_channels, kernel_size, image_size
1138
                )
1139

1140
            def forward(self, x):
1141
                a = self.conv(x)
1142
                b = torch.sigmoid(a)
1143
                res = a.mul_(b)
1144
                return res
1145

1146
        for memory_format in [torch.contiguous_format, torch.channels_last]:
1147
            m = M(3, 16, 3, 224).eval()
1148
            x = torch.rand(1, 3, 224, 224, requires_grad=False).to(
1149
                memory_format=memory_format
1150
            )
1151
            patterns = [
1152
                [
1153
                    "aten::dequantize",
1154
                    "aten::_convolution",
1155
                    "aten::sigmoid",
1156
                    "aten::mul",
1157
                ],
1158
            ]
1159
            for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
1160
                graph = self.checkQuantizeTrace(m, [x])
1161
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1162
                self.assertFused(
1163
                    graph,
1164
                    [
1165
                        "aten::_convolution",
1166
                        "aten::sigmoid",
1167
                        "aten::mul",
1168
                        "aten::quantize_per_channel",
1169
                        "aten::dequantize",
1170
                    ],
1171
                )
1172
                self.checkPatterns(graph, patterns)
1173

1174
        # inplace mul_ cannot be replaced with mul
1175
        class M2(nn.Module):
1176
            def __init__(self, in_channels, out_channels, kernel_size, image_size):
1177
                super(M2, self).__init__()
1178
                self.conv = torch.nn.Conv2d(
1179
                    in_channels, out_channels, kernel_size, image_size
1180
                )
1181

1182
            def forward(self, x):
1183
                a = self.conv(x)
1184
                b = torch.sigmoid(a)
1185
                c = a[0]
1186
                res = a.mul_(b)
1187
                c += 2
1188
                return c
1189

1190
        for memory_format in [torch.contiguous_format, torch.channels_last]:
1191
            m = M2(3, 16, 3, 224).eval()
1192
            x = torch.rand(1, 3, 224, 224, requires_grad=False).to(
1193
                memory_format=memory_format
1194
            )
1195
            patterns = [
1196
                ["aten::dequantize", "aten::_convolution"],
1197
            ]
1198
            for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
1199
                graph = self.checkQuantizeTrace(m, [x])
1200
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1201
                self.assertFused(
1202
                    graph,
1203
                    [
1204
                        "aten::_convolution",
1205
                        "aten::quantize_per_channel",
1206
                        "aten::dequantize",
1207
                    ],
1208
                )
1209
                self.checkPatterns(graph, patterns)
1210

1211
    def test_conv2d_hardsigmoid_mul_(self):
1212
        class M(nn.Module):
1213
            def __init__(self, in_channels, out_channels, kernel_size, image_size):
1214
                super(M, self).__init__()
1215
                self.conv = torch.nn.Conv2d(
1216
                    in_channels, out_channels, kernel_size, image_size
1217
                )
1218
                self.activation = torch.nn.Hardsigmoid()
1219

1220
            def forward(self, x):
1221
                a = self.conv(x)
1222
                b = self.activation(a)
1223
                res = a.mul_(b)
1224
                return res
1225

1226
        for memory_format in [torch.contiguous_format, torch.channels_last]:
1227
            m = M(3, 16, 3, 224).eval()
1228
            x = torch.rand(1, 3, 224, 224, requires_grad=False).to(
1229
                memory_format=memory_format
1230
            )
1231
            patterns = [
1232
                [
1233
                    "aten::dequantize",
1234
                    "aten::_convolution",
1235
                    "aten::hardsigmoid",
1236
                    "aten::mul",
1237
                ],
1238
            ]
1239
            for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
1240
                graph = self.checkQuantizeTrace(m, [x])
1241
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1242
                self.assertFused(
1243
                    graph,
1244
                    [
1245
                        "aten::_convolution",
1246
                        "aten::hardsigmoid",
1247
                        "aten::mul",
1248
                        "aten::quantize_per_channel",
1249
                        "aten::dequantize",
1250
                    ],
1251
                )
1252
                self.checkPatterns(graph, patterns)
1253

1254
    def test_linear_dropout_sum(self):
1255
        class M(nn.Module):
1256
            def __init__(self):
1257
                super(M, self).__init__()
1258
                self.linear1 = nn.Linear(15, 20)
1259
                self.dropout = nn.Dropout()
1260
                self.linear2 = nn.Linear(20, 3)
1261

1262
            def forward(self, x, y):
1263
                x = self.linear1(x)
1264
                x = self.dropout(x)
1265
                z = self.linear2(x + y)
1266
                return z
1267

1268
        x = torch.randn(2, 15)
1269
        y = torch.randn(2, 20)
1270
        m = M()
1271
        patterns = [
1272
            [
1273
                "aten::dequantize",
1274
                "aten::linear",
1275
                "aten::add",
1276
                "aten::quantize_per_tensor",
1277
            ],
1278
            ["aten::dequantize", "aten::linear"],
1279
        ]
1280
        for qconfig in static_qconfig:
1281
            graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
1282
            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1283
            self.assertFused(
1284
                graph,
1285
                [
1286
                    "aten::linear",
1287
                    "aten::add",
1288
                    "aten::quantize_per_channel",
1289
                    "aten::dequantize",
1290
                ],
1291
            )
1292
        self.checkPatterns(graph, patterns)
1293

1294
    def test_linear_sum_inplace(self):
1295
        class M(nn.Module):
1296
            def __init__(self):
1297
                super(M, self).__init__()
1298
                self.linear1 = nn.Linear(15, 20)
1299

1300
            def forward(self, x, y):
1301
                x = self.linear1(x)
1302
                x += y.clone()
1303
                return x
1304

1305
        x = torch.randn(2, 15)
1306
        y = torch.randn(2, 20)
1307
        m = M()
1308
        patterns = [
1309
            ["aten::dequantize", "aten::linear", "aten::dequantize"],
1310
        ]
1311
        # HistogramObserver failed, need to do some checks?
1312
        for qconfig in static_qconfig[:2]:
1313
            graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
1314
            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1315
            self.assertFused(
1316
                graph,
1317
                ["aten::linear", "aten::quantize_per_channel", "aten::dequantize"],
1318
            )
1319
            self.checkPatterns(graph, patterns)
1320

1321
    def test_linear_with_multiple_add(self):
1322
        class M(nn.Module):
1323
            def __init__(self):
1324
                super(M, self).__init__()
1325
                self.linear1 = nn.Linear(15, 20)
1326
                self.linear2 = nn.Linear(15, 20)
1327

1328
            def forward(self, x1, y1, x2, y2):
1329
                x1 = self.linear1(x1)
1330
                x1 += y1.clone()
1331
                x2 = self.linear2(x2)
1332
                x2 += y2.clone()
1333
                return x1 + x2
1334

1335
        x1 = torch.randn(2, 15)
1336
        y1 = torch.randn(2, 20)
1337
        x2 = torch.randn(2, 15)
1338
        y2 = torch.randn(2, 20)
1339

1340
        m = M()
1341
        patterns = [
1342
            ["aten::dequantize", "aten::linear", "aten::add"],
1343
            ["aten::dequantize", "aten::linear", "aten::add", "aten::add"],
1344
        ]
1345
        for qconfig in static_qconfig[:2]:
1346
            graph = self.checkQuantizeTrace(
1347
                m, [x1, y1, x2, y2], atol=2e-1, qconfig=qconfig
1348
            )
1349
            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1350
            # There shouldn't have single add node which doesn't fused into subgraph.
1351
            self.assertFused(
1352
                graph,
1353
                ["aten::linear", "aten::add"],
1354
            )
1355
            self.checkPatterns(graph, patterns)
1356

1357
    def test_linear_dropout_sum_bf16(self):
1358
        class M(nn.Module):
1359
            def __init__(self):
1360
                super(M, self).__init__()
1361
                self.linear1 = nn.Linear(15, 20, bias=True)
1362
                self.dropout = nn.Dropout()
1363
                self.linear2 = nn.Linear(15, 20, bias=True)
1364

1365
            def forward(self, x, y):
1366
                x = self.linear1(x)
1367
                x = self.dropout(x)
1368
                z = self.linear2(y) + x
1369
                return z
1370

1371
        x = torch.randn(2, 15)
1372
        y = torch.randn(2, 15)
1373
        m = M()
1374
        patterns = [
1375
            [
1376
                "aten::dequantize",
1377
                "aten::to",
1378
                "aten::linear",
1379
            ],
1380
            ["aten::dequantize", "aten::to", "aten::linear", "aten::add"],
1381
        ]
1382
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1383
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1384
        # TODO: oneDNN primitive raised more limitations to sum post-ops, it forced fusion changes on oneDNN graph side.
1385
        # The dequant node connected to aten::add can't be fused into the INT8 linear-add partition any more.
1386
        # oneDNN graph expects no end to end model performance impact.
1387
        # Revisit this change if validation has found model level regression.
1388
        self.assertFused(graph, ["aten::linear", "aten::add"])
1389
        self.checkPatterns(graph, patterns)
1390

1391
    def test_linear_gelu_bf16(self):
1392
        class M(nn.Module):
1393
            def __init__(self):
1394
                super(M, self).__init__()
1395
                self.linear = nn.Linear(28, 64, bias=True)
1396
                self.eltwise = nn.GELU()
1397
                self.linear2 = nn.Linear(64, 1, bias=True)
1398

1399
            def forward(self, x):
1400
                x = self.linear(x)
1401
                x = self.eltwise(x)
1402
                x = self.linear2(x)
1403
                return x
1404

1405
        patterns = [
1406
            [
1407
                "aten::dequantize",
1408
                "aten::to",
1409
                "aten::linear",
1410
                "aten::gelu",
1411
                "aten::to",
1412
                "aten::quantize_per_tensor",
1413
            ],
1414
            ["aten::dequantize", "aten::to", "aten::linear"],
1415
        ]
1416
        m = M()
1417
        x = torch.rand(32, 28, requires_grad=False)
1418
        for qscheme in [torch.per_tensor_affine]:
1419
            graph = self.checkQuantizeTrace(
1420
                m,
1421
                [x],
1422
                x_var=[torch.rand(2, 28, requires_grad=False)],
1423
                atol=1e-1,
1424
                int8_bf16=True,
1425
            )
1426
            self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1427
            self.assertFused(graph, ["aten::dequantize", "aten::linear", "aten::gelu"])
1428
            self.checkPatterns(graph, patterns)
1429

1430
    def test_defer_size(self):
1431
        class M(nn.Module):
1432
            def __init__(self):
1433
                super(M, self).__init__()
1434
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
1435
                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
1436
                self.eltwise = nn.ReLU()
1437

1438
            def forward(self, x):
1439
                x = self.conv1(x)
1440
                x = self.eltwise(x)
1441
                y = self.conv2(x)
1442
                y = y.reshape(x.size(0), -1)
1443
                return y
1444

1445
        for memory_format in [torch.contiguous_format, torch.channels_last]:
1446
            m = M()
1447
            x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
1448
            patterns = [
1449
                [
1450
                    "aten::dequantize",
1451
                    "aten::_convolution",
1452
                    "aten::relu",
1453
                    "aten::quantize_per_tensor",
1454
                ],
1455
                ["aten::dequantize", "aten::_convolution"],
1456
            ]
1457
            for qconfig in static_qconfig:
1458
                graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
1459
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1460
                self.assertFused(
1461
                    graph,
1462
                    [
1463
                        "aten::_convolution",
1464
                        "aten::relu",
1465
                        "aten::quantize_per_channel",
1466
                        "aten::dequantize",
1467
                    ],
1468
                )
1469
                self.checkPatterns(graph, patterns)
1470

1471
    def test_lift_up_quant(self):
1472
        class M(nn.Module):
1473
            def __init__(self, bias):
1474
                super(M, self).__init__()
1475
                self.linear = nn.Linear(28, 64, bias)
1476
                self.linear2 = nn.Linear(28, 64, bias=True)
1477
                self.num_attention_heads = 16
1478
                self.attention_head_size = 4
1479

1480
            def forward(self, x, y):
1481
                x = self.linear(x)
1482
                new_x_shape = x.size()[:-1] + (
1483
                    self.num_attention_heads,
1484
                    self.attention_head_size,
1485
                )
1486
                x = x.view(*new_x_shape)
1487
                z1 = x.permute(0, 2, 1, 3)
1488

1489
                y = self.linear2(y)
1490
                new_y_shape2 = y.size()[:-1] + (
1491
                    self.num_attention_heads,
1492
                    self.attention_head_size,
1493
                )
1494
                y = y.view(*new_y_shape2)
1495
                z2 = y.permute(0, 2, 1, 3)
1496

1497
                return torch.matmul(z1, z2.transpose(-1, -2))
1498

1499
        m = M(bias=True)
1500
        x = torch.randn(2, 3, 28)
1501
        y = torch.randn(2, 3, 28)
1502

1503
        patterns = [
1504
            ["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
1505
            ["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
1506
            ["aten::dequantize", "aten::matmul"],
1507
        ]
1508

1509
        # TODO: test shape fallback
1510
        graph = self.checkQuantizeTrace(m, [x, y], atol=1e-1)
1511
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1512
        self.assertFused(graph, ["aten::dequantize", "aten::linear", "aten::matmul"])
1513
        self.checkPatterns(graph, patterns)
1514

1515
    def test_lift_up_to_quant_bf16(self):
1516
        class M(nn.Module):
1517
            def __init__(self, bias):
1518
                super(M, self).__init__()
1519
                self.linear = nn.Linear(28, 64, bias)
1520
                self.linear2 = nn.Linear(28, 64, bias=True)
1521
                self.num_attention_heads = 16
1522
                self.attention_head_size = 4
1523

1524
            def forward(self, x, y):
1525
                x = self.linear(x)
1526
                new_x_shape = x.size()[:-1] + (
1527
                    self.num_attention_heads,
1528
                    self.attention_head_size,
1529
                )
1530
                x = x.view(*new_x_shape)
1531
                z1 = x.permute(0, 2, 1, 3)
1532

1533
                y = self.linear2(y)
1534
                new_y_shape2 = y.size()[:-1] + (
1535
                    self.num_attention_heads,
1536
                    self.attention_head_size,
1537
                )
1538
                y = y.view(*new_y_shape2)
1539
                z2 = y.permute(0, 2, 1, 3)
1540

1541
                return torch.matmul(z1, z2.transpose(-1, -2))
1542

1543
        m = M(bias=True)
1544
        x = torch.randn(2, 3, 28)
1545
        y = torch.randn(2, 3, 28)
1546

1547
        patterns = [
1548
            [
1549
                "aten::dequantize",
1550
                "aten::to",
1551
                "aten::linear",
1552
                "aten::to",
1553
                "aten::quantize_per_tensor",
1554
            ],
1555
            [
1556
                "aten::dequantize",
1557
                "aten::to",
1558
                "aten::linear",
1559
                "aten::to",
1560
                "aten::quantize_per_tensor",
1561
            ],
1562
            ["aten::dequantize", "aten::to", "aten::matmul"],
1563
        ]
1564

1565
        # TODO: test shape fallback
1566
        graph = self.checkQuantizeTrace(m, [x, y], atol=1e-1, int8_bf16=True)
1567
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1568
        self.assertFused(graph, ["aten::dequantize", "aten::linear", "aten::matmul"])
1569
        self.checkPatterns(graph, patterns)
1570

1571
    def test_lift_up_quant_unsupported(self):
1572
        # Original graph:
1573
        #          |
1574
        #        view
1575
        #      /  (f32)\   /(f32)
1576
        #   quant       add
1577
        #     |
1578

1579
        # Lifting up in this case will raise:
1580
        # promoteTypes with quantized numbers is not handled in aten::add;
1581
        #          |
1582
        #        quant
1583
        #          |
1584
        #         view
1585
        #         (int8)\  /(f32)
1586
        #                add
1587
        class M(nn.Module):
1588
            def __init__(self):
1589
                super(M, self).__init__()
1590
                self.conv1 = nn.Conv2d(3, 8, 1)
1591
                self.conv2 = nn.Conv2d(8, 8, 1)
1592

1593
            def forward(self, x, y):
1594
                x = self.conv1(x)
1595
                z1 = x.permute(0, 3, 1, 2)
1596
                z2 = self.conv2(z1)
1597
                z = z1 + y
1598
                output = z2 + z
1599
                return output
1600

1601
        x = torch.randn(1, 3, 8, 8)
1602
        y = torch.randn(1, 8, 8, 8)
1603
        m = M()
1604

1605
        patterns = [
1606
            ["aten::dequantize", "aten::_convolution"],
1607
            ["aten::dequantize", "aten::_convolution", "aten::add"],
1608
        ]
1609

1610
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1611
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1612
        # TODO: oneDNN primitive raised more limitations to sum post-ops, it forced fusion changes on oneDNN graph side.
1613
        # The dequant node connected to aten::add can't be fused into the INT8 conv-add partition any more.
1614
        # oneDNN graph expects no end to end model performance impact.
1615
        # Revisit this change if validation has found model level regression.
1616
        self.assertFused(graph, ["aten::_convolution"])
1617
        self.checkPatterns(graph, patterns)
1618

1619
    def test_wildcard(self):
1620
        class M(nn.Module):
1621
            def __init__(self):
1622
                super(M, self).__init__()
1623
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
1624
                self.eltwise = nn.ReLU()
1625

1626
            def forward(self, x):
1627
                x = self.conv1(x)
1628
                y = self.eltwise(x)
1629
                return [x, y]
1630

1631
        # The pattern is as the following:
1632
        #      conv
1633
        #     |    \
1634
        # eltwise   \
1635
        #    |       \
1636
        #  ListConstruct
1637
        #
1638
        # The output of conv is used by a wildcard op: ListConstruct.
1639
        # Thus conv-eltwise cannot be selected into the same Partition.
1640
        m = M()
1641
        x = torch.rand(1, 32, 28, 28)
1642
        patterns = [
1643
            ["aten::dequantize", "aten::_convolution"],
1644
        ]
1645
        graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
1646
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1647
        self.assertGraphContainsExactly(graph, "aten::relu", 1)
1648
        self.assertFused(graph, ["aten::_convolution", "aten::quantize_per_channel"])
1649
        self.checkPatterns(graph, patterns)
1650

1651
    def test_bmm_div_scalar(self):
1652
        class M(nn.Module):
1653
            def __init__(self, div_value):
1654
                super(M, self).__init__()
1655
                self.div_value = div_value
1656

1657
            def forward(self, x, y):
1658
                mm_res = torch.matmul(x, y)
1659
                return mm_res.div(self.div_value)
1660

1661
        x = torch.randn(1, 16, 384, 64)
1662
        y = torch.randn(1, 1, 64, 384)
1663
        patterns = [
1664
            ["aten::dequantize", "aten::matmul", "aten::div"],
1665
        ]
1666
        m = M(8.0)
1667
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1668
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1669
        self.assertFused(graph, ["aten::matmul", "aten::div"])
1670
        self.checkPatterns(graph, patterns)
1671

1672
    def test_bmm_div_identity(self):
1673
        class M(nn.Module):
1674
            def __init__(self, div_value):
1675
                super(M, self).__init__()
1676
                self.div_value = div_value
1677

1678
            def forward(self, x, y):
1679
                mm_res = torch.matmul(x, y)
1680
                return mm_res.div(self.div_value)
1681

1682
        x = torch.randn(1, 16, 384, 64) * 0.1
1683
        y = torch.randn(1, 1, 64, 384) * 0.1
1684
        patterns = [
1685
            ["aten::dequantize", "aten::matmul"],
1686
        ]
1687
        m = M(1.0)
1688
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1689
        # divide by 1 should be removed by Constant Propagation
1690
        self.assertGraphContainsExactly(graph, "aten::div", 0, consider_subgraphs=True)
1691
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1692
        self.assertFused(graph, ["aten::matmul"])
1693
        self.checkPatterns(graph, patterns)
1694

1695
    def test_bmm_div_tensor(self):
1696
        class M(nn.Module):
1697
            def __init__(self):
1698
                super(M, self).__init__()
1699

1700
            def forward(self, x, y, z):
1701
                mm_res = torch.matmul(x, y)
1702
                return mm_res.div(z)
1703

1704
        x = torch.randn(1, 16, 384, 64) * 0.1
1705
        y = torch.randn(1, 1, 64, 384) * 0.1
1706
        z = torch.randn(
1707
            1
1708
        )  # TODO: enable torch.randn(20) and torch.randn(1, 1, 20, 20) once backend supported them
1709
        patterns = [
1710
            ["aten::dequantize", "aten::matmul", "aten::div"],
1711
        ]
1712
        m = M()
1713
        graph = self.checkQuantizeTrace(m, [x, y, z], atol=2e-1)
1714
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1715
        self.assertFused(graph, ["aten::matmul", "aten::div"])
1716
        self.checkPatterns(graph, patterns)
1717

1718
    def test_bmm_div_int8_in_bf16_out(self):
1719
        class M(nn.Module):
1720
            def __init__(self):
1721
                super(M, self).__init__()
1722

1723
            def forward(self, x, y):
1724
                mm_res = torch.matmul(x, y) / 2
1725
                return mm_res
1726

1727
        x = torch.randn(1, 16, 384, 64) * 0.1
1728
        y = torch.randn(1, 1, 64, 384) * 0.1
1729
        patterns = [
1730
            ["aten::dequantize", "aten::to", "aten::matmul", "aten::div"],
1731
        ]
1732
        m = M()
1733
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1734
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1735
        # single aten::to won't be rewritten by llga backend
1736
        self.assertFused(graph, ["aten::dequantize", "aten::matmul", "aten::div"])
1737
        self.checkPatterns(graph, patterns)
1738

1739
    def test_bmm_method_bf16(self):
1740
        class M(nn.Module):
1741
            def __init__(self):
1742
                super(M, self).__init__()
1743

1744
            def forward(self, x, y):
1745
                mm_res = x.matmul(y)
1746
                return mm_res
1747

1748
        x = torch.randn(1, 16, 384, 64) * 0.1
1749
        y = torch.randn(1, 1, 64, 384) * 0.1
1750
        patterns = [
1751
            ["aten::dequantize", "aten::to", "aten::matmul"],
1752
        ]
1753
        m = M()
1754
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1755
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1756
        # single aten::to won't be rewritten by llga backend
1757
        self.assertFused(graph, ["aten::dequantize", "aten::matmul"])
1758
        self.checkPatterns(graph, patterns)
1759

1760
    def test_bmm_method_fp32(self):
1761
        class M(nn.Module):
1762
            def __init__(self):
1763
                super(M, self).__init__()
1764

1765
            def forward(self, x, y):
1766
                mm_res = x.matmul(y)
1767
                return mm_res
1768

1769
        x = torch.randn(1, 16, 384, 64) * 0.1
1770
        y = torch.randn(1, 1, 64, 384) * 0.1
1771
        patterns = [
1772
            ["aten::dequantize", "aten::matmul"],
1773
        ]
1774
        m = M()
1775
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1776
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1777
        self.assertFused(graph, ["aten::dequantize", "aten::matmul"])
1778
        self.checkPatterns(graph, patterns)
1779

1780
    def test_strided_bmm_div_int8_in_bf16_out(self):
1781
        class M(nn.Module):
1782
            def __init__(self):
1783
                super(M, self).__init__()
1784
                self.num_attention_heads = 16
1785
                self.attention_head_size = 4
1786

1787
            def forward(self, x, y):
1788
                new_x_shape = x.size()[:-1] + (
1789
                    self.num_attention_heads,
1790
                    self.attention_head_size,
1791
                )
1792
                x = x.view(*new_x_shape)
1793
                z1 = x.permute(0, 2, 1, 3)
1794

1795
                new_y_shape2 = y.size()[:-1] + (
1796
                    self.num_attention_heads,
1797
                    self.attention_head_size,
1798
                )
1799
                y = y.view(*new_y_shape2)
1800
                z2 = y.permute(0, 2, 1, 3)
1801

1802
                # inputs to matmul has been permuted or transposed, thus are strided tensor
1803
                return torch.matmul(z1, z2.transpose(-1, -2)) / 0.4
1804

1805
        m = M()
1806
        x = torch.randn(2, 3, 64)
1807
        y = torch.randn(2, 3, 64)
1808

1809
        patterns = [
1810
            ["aten::dequantize", "aten::to", "aten::matmul", "aten::div"],
1811
        ]
1812

1813
        graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1814
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1815
        self.assertFused(graph, ["aten::matmul", "aten::dequantize"])
1816
        self.checkPatterns(graph, patterns)
1817

1818
    def test_bmm_div_add_int8_fp32(self):
1819
        class M(nn.Module):
1820
            def __init__(self):
1821
                super(M, self).__init__()
1822
                self.num_attention_heads = 16
1823
                self.attention_head_size = 4
1824

1825
            def forward(self, x, y, z):
1826
                new_x_shape = x.size()[:-1] + (
1827
                    self.num_attention_heads,
1828
                    self.attention_head_size,
1829
                )
1830
                x = x.view(*new_x_shape)
1831
                z1 = x.permute(0, 2, 1, 3)
1832

1833
                new_y_shape2 = y.size()[:-1] + (
1834
                    self.num_attention_heads,
1835
                    self.attention_head_size,
1836
                )
1837
                y = y.view(*new_y_shape2)
1838
                z2 = y.permute(0, 2, 1, 3)
1839

1840
                # inputs to matmul has been permuted or transposed, thus are strided tensor
1841
                s = torch.matmul(z1, z2.transpose(-1, -2)) / 0.4
1842
                s = s + z
1843
                return s
1844

1845
        m = M()
1846
        x = torch.randn(2, 3, 64)
1847
        y = torch.randn(2, 3, 64)
1848
        z = torch.randn(2, 1, 1, 3)
1849

1850
        patterns = [
1851
            ["aten::dequantize", "aten::matmul", "aten::div", "aten::add"],
1852
        ]
1853

1854
        graph = self.checkQuantizeTrace(m, [x, y, z], atol=2e-1)
1855
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1856
        self.assertFused(
1857
            graph, ["aten::matmul", "aten::dequantize", "aten::div", "aten::add"]
1858
        )
1859
        self.checkPatterns(graph, patterns)
1860

1861
    @unittest.skip("Graph Compiler unit-test")
1862
    def test_mha_pattern_int8_fp32(self):
1863
        class M(torch.nn.Module):
1864
            def __init__(self):
1865
                super(M, self).__init__()
1866
                self.linear = nn.Linear(1024, 1024, False)
1867

1868
            def forward(self, x, y, z, a):
1869
                x = x.permute(0, 2, 1, 3)
1870

1871
                y = y.permute(0, 2, 1, 3)
1872
                y = y.transpose(-1, -2)
1873

1874
                z = z.permute(0, 2, 1, 3)
1875
                tmp = torch.matmul(x, y) / 8.0 + a
1876
                tmp = torch.softmax(tmp, -1)
1877
                tmp = tmp.matmul(z)
1878
                tmp = tmp.permute(0, 2, 1, 3)
1879
                tmp = tmp.contiguous()
1880
                tmp = tmp.view(1, 16, 1024)
1881
                tmp = self.linear(tmp)
1882
                return tmp
1883

1884
        x = torch.randn(1, 16, 16, 64)
1885
        y = torch.randn(1, 16, 16, 64)
1886
        z = torch.randn(1, 16, 16, 64)
1887
        m = M()
1888
        a = torch.randn(1, 1, 1, 16)
1889
        graph = self.checkQuantizeTrace(m, [x, y, z, a], atol=2e-1)
1890
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1891
        self.assertFused(
1892
            graph,
1893
            [
1894
                "aten::matmul",
1895
                "aten::div",
1896
                "aten:add",
1897
                "aten:softmax",
1898
                "aten::contiguous",
1899
                "aten::dequantize",
1900
            ],
1901
        )
1902

1903
    @unittest.skip("Graph Compiler unit-test")
1904
    def test_mha_pattern_int8_bf16(self):
1905
        class M(torch.nn.Module):
1906
            def __init__(self):
1907
                super(M, self).__init__()
1908
                self.linear = nn.Linear(1024, 1024, False)
1909

1910
            def forward(self, x, y, z, a):
1911
                x = x.permute(0, 2, 1, 3)
1912

1913
                y = y.permute(0, 2, 1, 3)
1914
                y = y.transpose(-1, -2)
1915

1916
                z = z.permute(0, 2, 1, 3)
1917
                tmp = torch.matmul(x, y) / 8.0 + a
1918
                tmp = torch.softmax(tmp, -1)
1919
                tmp = tmp.matmul(z)
1920
                tmp = tmp.permute(0, 2, 1, 3)
1921
                tmp = tmp.contiguous()
1922
                tmp = tmp.view(1, 16, 1024)
1923
                tmp = self.linear(tmp)
1924
                return tmp
1925

1926
        x = torch.randn(1, 16, 16, 64)
1927
        y = torch.randn(1, 16, 16, 64)
1928
        z = torch.randn(1, 16, 16, 64)
1929
        m = M()
1930
        a = torch.randn(1, 1, 1, 16, dtype=torch.bfloat16)
1931
        graph = self.checkQuantizeTrace(
1932
            m,
1933
            [x, y, z, a],
1934
            atol=2e-1,
1935
            config_name="mha_pattern",
1936
            qscheme=torch.per_tensor_affine,
1937
            int8_bf16=True,
1938
        )
1939
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
1940
        self.assertFused(
1941
            graph,
1942
            [
1943
                "aten::matmul",
1944
                "aten::div",
1945
                "aten:add",
1946
                "aten:softmax",
1947
                "aten::contiguous",
1948
                "aten::dequantize",
1949
                "aten::quantize_per_tensor",
1950
            ],
1951
        )
1952

1953
    def test_bmm_div_add_int8_bf16(self):
1954
        class M(nn.Module):
1955
            def __init__(self):
1956
                super(M, self).__init__()
1957
                self.num_attention_heads = 16
1958
                self.attention_head_size = 4
1959

1960
            def forward(self, x, y, z):
1961
                new_x_shape = x.size()[:-1] + (
1962
                    self.num_attention_heads,
1963
                    self.attention_head_size,
1964
                )
1965
                x = x.view(*new_x_shape)
1966
                z1 = x.permute(0, 2, 1, 3)
1967

1968
                new_y_shape2 = y.size()[:-1] + (
1969
                    self.num_attention_heads,
1970
                    self.attention_head_size,
1971
                )
1972
                y = y.view(*new_y_shape2)
1973
                z2 = y.permute(0, 2, 1, 3)
1974

1975
                # inputs to matmul has been permuted or transposed, thus are strided tensor
1976
                s = torch.matmul(z1, z2.transpose(-1, -2)) / 0.4
1977
                s = s + z.to(s.dtype)
1978
                return s
1979

1980
        m = M()
1981
        x = torch.randn(2, 3, 64)
1982
        y = torch.randn(2, 3, 64)
1983
        z = torch.randn(2, 1, 1, 3)
1984

1985
        patterns = [
1986
            ["aten::dequantize", "aten::to", "aten::matmul", "aten::div", "aten::add"],
1987
        ]
1988

1989
        graph = self.checkQuantizeTrace(m, [x, y, z], atol=2e-1, int8_bf16=True)
1990
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1991
        self.assertFused(
1992
            graph, ["aten::matmul", "aten::dequantize", "aten::div", "aten::add"]
1993
        )
1994
        self.checkPatterns(graph, patterns)
1995

1996
    def test_split_dequant_to(self):
1997
        class M(nn.Module):
1998
            def __init__(self):
1999
                super(M, self).__init__()
2000
                self.linear1 = nn.Linear(2, 1, bias=True)
2001
                self.linear2 = nn.Linear(2, 1, bias=True)
2002
                self.linear3 = nn.Linear(2, 1, bias=True)
2003

2004
            def forward(self, x):
2005
                a = self.linear1(x)
2006
                b = self.linear2(x)
2007
                c = self.linear3(x)
2008
                return torch.cat([a, b, c])
2009

2010
        # The below pattern:
2011
        #         quant
2012
        #           |
2013
        #        dequant
2014
        #           |
2015
        #          to
2016
        #     /    |    \
2017
        # linear linear linear
2018
        #    |     |      |
2019
        #
2020
        # should be transformed to:
2021
        #               to
2022
        #               |
2023
        #             quant
2024
        #        /      |     \
2025
        #   dequant dequant  dequant
2026
        #      |       |       |
2027
        #     to       to     to
2028
        #      |       |       |
2029
        #  linear   linear  linear
2030
        #      |       |       |
2031

2032
        patterns = [
2033
            ["aten::dequantize", "aten::to", "aten::linear"],
2034
            ["aten::dequantize", "aten::to", "aten::linear"],
2035
            ["aten::dequantize", "aten::to", "aten::linear"],
2036
        ]
2037
        m = M()
2038
        x = torch.randn(2, 2)
2039
        graph = self.checkQuantizeTrace(m, [x], atol=2e-1, int8_bf16=True)
2040
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
2041
        # single aten::to won't be rewritten by llga backend
2042
        self.assertFused(graph, ["aten::dequantize", "aten::linear"])
2043
        self.checkPatterns(graph, patterns)
2044

2045
    def test_dequant_remove_attr(self):
2046
        class M(nn.Module):
2047
            def __init__(self):
2048
                super(M, self).__init__()
2049

2050
            def forward(self, x):
2051
                x = torch.quantize_per_channel(
2052
                    x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8
2053
                )
2054
                x = torch.dequantize(x)
2055
                return x
2056

2057
        x = x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
2058
        m = M()
2059
        traced = torch.jit.trace(m, x)
2060
        traced(x)
2061
        graph = traced.graph_for(x)
2062
        self.checkAttr(graph, "aten::dequantize", "qtype")
2063

2064
    def test_fx_converted_model(self):
2065
        class M(nn.Module):
2066
            def __init__(self):
2067
                super(M, self).__init__()
2068
                self.linear = nn.Linear(15, 20)
2069

2070
            def forward(self, x):
2071
                x = self.linear(x)
2072
                return x
2073

2074
        x = x = torch.randn(2, 15)
2075
        m = M()
2076
        m.eval()
2077

2078
        qconfig_dict = {"": static_qconfig[0]}
2079

2080
        m = prepare_fx(m, qconfig_dict, x)
2081
        m = convert_fx(m)
2082
        graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
2083
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
2084

2085
    def test_fx_ao_qat_converted_model(self):
2086
        class M(nn.Module):
2087
            def __init__(self):
2088
                super(M, self).__init__()
2089
                self.linear = nn.Linear(15, 20)
2090

2091
            def forward(self, x):
2092
                x = self.linear(x)
2093
                return x
2094

2095
        x = x = torch.randn(2, 15)
2096
        m = M()
2097
        m.eval()
2098

2099
        qconfig_dict = {"": static_qconfig[0]}
2100

2101
        m = prepare_qat_fx(m, qconfig_dict, x)
2102
        m = convert_to_reference_fx(m)
2103
        graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
2104
        # dequant -> linear should be mapped to LLGA
2105
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
2106

2107
    @unittest.skipIf(True, "Poor accuracy")
2108
    @skipIfNoTorchVision
2109
    def test_fx_ao_qat_model(self):
2110
        class M(nn.Module):
2111
            def __init__(self):
2112
                super(M, self).__init__()
2113
                self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
2114
                self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
2115
                self.eltwise = torch.nn.ReLU()
2116

2117
            def forward(self, x):
2118
                x = self.conv1(x)
2119
                x = self.eltwise(x)
2120
                x = self.conv2(x)
2121
                return x
2122

2123
        data = torch.randn(1, 32, 224, 224).to(memory_format=torch.channels_last)
2124
        m = M()
2125
        m.eval()
2126
        #
2127
        # quantization aware training for static quantization
2128
        #
2129
        qconfig_dict = {"": torch.quantization.get_default_qat_qconfig("fbgemm")}
2130
        m.train()
2131
        model_prepared = prepare_qat_fx(m, qconfig_dict, example_inputs=data)
2132
        model_quantized = convert_to_reference_fx(model_prepared)
2133
        model_quantized = model_quantized.eval()
2134
        model = model_quantized.to(memory_format=torch.channels_last)
2135
        graph = self.checkQuantizeTrace(model, [data], atol=2e-1)
2136
        self.checkPatterns(
2137
            graph,
2138
            [
2139
                [
2140
                    "aten::dequantize",
2141
                    "aten::quantize_per_channel",
2142
                    "aten::_convolution",
2143
                    "aten::relu",
2144
                    "aten::quantize_per_tensor",
2145
                ],
2146
                [
2147
                    "aten::dequantize",
2148
                    "aten::quantize_per_channel",
2149
                    "aten::_convolution",
2150
                    "aten::quantize_per_tensor",
2151
                ],
2152
            ],
2153
        )
2154

2155
    def test_ffn_residual(self):
2156
        class FFN_Residual(nn.Module):
2157
            def __init__(self, hidden_size, intermediate_size):
2158
                super(FFN_Residual, self).__init__()
2159
                self.linear1 = nn.Linear(hidden_size, intermediate_size)
2160
                self.linear2 = nn.Linear(intermediate_size, hidden_size)
2161
                self.LayerNorm1 = nn.LayerNorm(hidden_size)
2162
                self.LayerNorm2 = nn.LayerNorm(hidden_size)
2163
                self.intermediate_act_fn = nn.functional.gelu
2164

2165
            def forward(self, x):
2166
                x1 = self.LayerNorm1(x)
2167
                x2 = self.linear1(x1)
2168
                x3 = self.intermediate_act_fn(x2)
2169
                x4 = self.linear2(x3)
2170
                x5 = self.LayerNorm2(x4 + x)
2171
                return x5
2172

2173
        patterns = [
2174
            [
2175
                "aten::dequantize",
2176
                "aten::linear",
2177
                "aten::gelu",
2178
                "aten::quantize_per_tensor",
2179
            ],
2180
            ["aten::dequantize", "aten::linear", "aten::add"],
2181
        ]
2182
        m = FFN_Residual(1024, 4096).eval()
2183
        x = torch.rand(128, 1024)
2184
        graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
2185
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
2186
        self.assertFused(graph, ["aten::linear", "aten::gelu"])
2187
        self.assertFused(graph, ["aten::linear", "aten::add"])
2188
        self.checkPatterns(graph, patterns)
2189

2190
    def test_inplace_computation_accuracy(self):
2191
        class LowRankCrossNet(nn.Module):
2192
            def __init__(
2193
                self, in_features: int, num_layers: int, low_rank: int
2194
            ) -> None:
2195
                super().__init__()
2196
                assert low_rank >= 1, "Low rank must be larger or equal to 1"
2197
                self._num_layers = num_layers
2198
                self._low_rank = low_rank
2199
                W_kernels: nn.ParameterList = nn.ParameterList()
2200
                for i in range(self._num_layers):
2201
                    Wp = nn.Parameter(torch.randn(in_features, self._low_rank))
2202
                    W_kernels.append(Wp)
2203
                V_kernels: nn.ParameterList = nn.ParameterList()
2204
                for i in range(self._num_layers):
2205
                    V_kernels.append(
2206
                        nn.Parameter(torch.randn(self._low_rank, in_features))
2207
                    )
2208
                bias: nn.ParameterList = nn.ParameterList(
2209
                    [
2210
                        nn.Parameter(nn.init.zeros_(torch.empty(in_features)))
2211
                        for i in range(self._num_layers)
2212
                    ]
2213
                )
2214
                self.MLPs = nn.ModuleDict()
2215
                for i in range(num_layers):
2216
                    self.MLPs[f"V{i}"] = nn.Linear(in_features, low_rank, bias=False)
2217
                    self.MLPs[f"W{i}"] = nn.Linear(low_rank, in_features, bias=True)
2218
                    self.MLPs[f"V{i}"].weight = V_kernels[i]
2219
                    self.MLPs[f"W{i}"].weight = W_kernels[i]
2220
                    self.MLPs[f"W{i}"].bias = bias[i]
2221

2222
            def forward(self, input: torch.Tensor) -> torch.Tensor:
2223
                x_0 = input
2224
                x_l = x_0  # .clone()
2225
                for layer in range(self._num_layers):
2226
                    x_l_v = self.MLPs[f"V{layer}"](x_l)
2227
                    x_l_w = self.MLPs[f"W{layer}"](x_l_v)
2228
                    x_l = x_0 * x_l_w + x_l  # (B, N)
2229
                return x_l, x_0
2230

2231
        class FakeQuant(nn.Module):
2232
            def __init__(self):
2233
                super().__init__()
2234

2235
            def forward(self, x):
2236
                x = torch.quantize_per_tensor(x, 0.1, 0, torch.qint8)
2237
                return x.dequantize()
2238

2239
        class TinyDLRM(nn.Module):
2240
            def __init__(self):
2241
                super().__init__()
2242
                self.pre_model = FakeQuant()
2243
                self.cross_net = LowRankCrossNet(2, 2, 5)
2244

2245
            def forward(self, x):
2246
                out = self.pre_model(x)
2247
                out = self.cross_net(out)
2248
                return out
2249

2250
        m = TinyDLRM().eval()
2251
        x = torch.rand(2048, 2)
2252
        graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
2253
        print(graph)
2254
        self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
2255
        self.assertFused(graph, ["aten::linear", "aten::mul", "aten::add"])
2256

2257

2258
class TestShapeFallback(JitLlgaTestCase):
2259
    @unittest.skipIf(True, "Size peephole optimization not enabled yet")
2260
    def test_view_permute(self):
2261
        class M(nn.Module):
2262
            def __init__(self):
2263
                super(M, self).__init__()
2264

2265
            def forward(self, x):
2266
                new_x_shape = x.size()[:-1] + (3, 5)
2267
                x = x.view(*new_x_shape)
2268
                return x.permute(0, 2, 1, 3)
2269

2270
        x = torch.randn(5, 10, 15)
2271
        m = M()
2272

2273
        for qconfig in static_qconfig:
2274
            graph = self.checkQuantizeTrace(m, [x], qconfig=qconfig)
2275
            self.assertGraphContainsExactly(graph, "aten::size", 0)
2276
            self.assertGraphContainsExactly(graph, "prim::ListConstruct", 0)
2277

2278
            # change the size of the input
2279
            x2 = torch.randn(6, 4, 15)
2280
            # Bailout get triggered here
2281
            y2 = m(x2)
2282

2283
    def test_conv_reshape(self):
2284
        class M(nn.Module):
2285
            def __init__(self):
2286
                super(M, self).__init__()
2287
                self.conv1 = nn.Conv2d(4, 4, 3, padding=1, bias=True)
2288
                self.conv2 = nn.Conv2d(4, 32, 3, padding=1, bias=True)
2289

2290
            def forward(self, x):
2291
                x = self.conv1(x)
2292
                x = self.conv2(x).reshape(x.size(0), 4, -1)
2293
                return x
2294

2295
        for memory_format in [torch.contiguous_format, torch.channels_last]:
2296
            x = torch.randn(15, 4, 28, 28).to(memory_format=memory_format)
2297
            # change the size of the input, check the fallback
2298
            x_var = torch.randn(7, 4, 16, 16).to(memory_format=memory_format)
2299
            m = M()
2300
            for qconfig in static_qconfig:
2301
                graph = self.checkQuantizeTrace(
2302
                    m, [x], x_var=[x_var], atol=2e-1, qconfig=qconfig
2303
                )
2304

2305
                # TODO: enable this check when size peephole optimization is enabled
2306
                # self.assertGraphContainsExactly(graph, "aten::size", 0)
2307

2308
    def test_add_recipe(self):
2309
        class ConvAddRelu(nn.Module):
2310
            def __init__(self, in_channels, out_channels, kernel_size, image_size):
2311
                super(ConvAddRelu, self).__init__()
2312
                self.conv = torch.nn.Conv2d(
2313
                    in_channels, out_channels, kernel_size, image_size
2314
                )
2315

2316
            def forward(self, x1, x2):
2317
                return torch.relu(torch.add(self.conv(x1), x2))
2318

2319
        class ConvAdd(nn.Module):
2320
            def __init__(self, in_channels, out_channels, kernel_size, image_size):
2321
                super(ConvAdd, self).__init__()
2322
                self.conv = torch.nn.Conv2d(
2323
                    in_channels, out_channels, kernel_size, image_size
2324
                )
2325

2326
            def forward(self, x1, x2):
2327
                return torch.add(self.conv(x1), x2)
2328

2329
        for memory_format in [torch.contiguous_format, torch.channels_last]:
2330
            conv_add_relu = ConvAddRelu(3, 16, 3, 2)
2331
            conv_add = ConvAdd(3, 16, 3, 2)
2332
            x1 = torch.rand(1, 3, 224, 224, requires_grad=False).to(
2333
                memory_format=memory_format
2334
            )
2335
            x2 = torch.rand(1, 16, 111, 111, requires_grad=False).to(
2336
                memory_format=memory_format
2337
            )
2338
            input = [x1, x2]
2339
            graph1 = self.checkQuantizeTrace(conv_add_relu, input, atol=1e-2)
2340
            self.assertGraphContainsExactly(graph1, "aten::quantize_per_tensor", 2)
2341
            graph2 = self.checkQuantizeTrace(conv_add, input, atol=1e-2)
2342
            self.assertGraphContainsExactly(graph2, "aten::quantize_per_tensor", 1)
2343

2344

2345
class TestModel(JitLlgaTestCase):
2346
    @skipIfNoTorchVision
2347
    def _test_vision(self, model_name):
2348
        for memory_format in [torch.contiguous_format, torch.channels_last]:
2349
            m = getattr(torchvision.models, model_name)().eval()
2350
            x = (torch.rand(1, 3, 224, 224) / 10).to(memory_format=memory_format)
2351

2352
            for qconfig in static_qconfig:
2353
                graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
2354

2355
                # TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
2356
                self.assertFused(
2357
                    graph,
2358
                    [
2359
                        "aten::_convolution",
2360
                        "aten::relu",
2361
                        "aten::max_pool2d",
2362
                        "aten::linear",
2363
                        "aten::quantize_per_channel",
2364
                    ],
2365
                )
2366
                # large partition: 7 fusion group in total
2367
                self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 7)
2368

2369

2370
for model_name, enabled in [
2371
    ["resnet50", True],
2372
]:
2373

2374
    def wrapper(mname):
2375
        @unittest.skipIf(not enabled, "Disabled")
2376
        def test(self):
2377
            return self._test_vision(mname)
2378

2379
        return test
2380

2381
    setattr(TestModel, "test_vision_%s" % model_name, wrapper(model_name))
2382

2383
if __name__ == "__main__":
2384
    run_tests()
2385

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.