intel-extension-for-pytorch
2384 строки · 85.2 Кб
1# This Python file uses the following encoding: utf-8
2# !/usr/bin/env python
3
4import unittest
5import itertools
6import torch
7import torch.nn as nn
8import torch.nn.functional as F
9from test_ao_jit_llga_utils import (
10JitLlgaTestCase,
11LLGA_FUSION_GROUP,
12get_eltwise_fn,
13)
14from torch.quantization.quantize_fx import prepare_fx, convert_fx
15from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_qat_fx
16from torch.testing._internal.common_utils import run_tests
17from torch.ao.quantization import (
18MinMaxObserver,
19PerChannelMinMaxObserver,
20HistogramObserver,
21QConfig,
22)
23
24default_weight_observer = PerChannelMinMaxObserver.with_args(
25dtype=torch.qint8, qscheme=torch.per_channel_symmetric
26)
27
28static_qconfig = [
29QConfig(
30activation=MinMaxObserver.with_args(
31qscheme=torch.per_tensor_affine, dtype=torch.quint8
32),
33weight=default_weight_observer,
34),
35QConfig(
36activation=MinMaxObserver.with_args(
37qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
38),
39weight=default_weight_observer,
40),
41QConfig(
42activation=HistogramObserver.with_args(
43qscheme=torch.per_tensor_affine, dtype=torch.quint8, reduce_range=True
44),
45weight=default_weight_observer,
46),
47QConfig(
48activation=HistogramObserver.with_args(
49qscheme=torch.per_tensor_symmetric, dtype=torch.qint8, reduce_range=True
50),
51weight=default_weight_observer,
52),
53]
54
55try:
56import torchvision
57
58HAS_TORCHVISION = True
59except ImportError:
60HAS_TORCHVISION = False
61except RuntimeError:
62HAS_TORCHVISION = False
63skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
64
65
66class TestOp(JitLlgaTestCase):
67def test_conv_int8_in_f32_out(self):
68for [
69spatial,
70in_channels,
71out_channels,
72kernel,
73padding,
74stride,
75dilation,
76g,
77bias,
78memory_format,
79module,
80] in itertools.product(
81[7],
82[2],
83[3],
84[3],
85[0, 2],
86[1, 2],
87[1, 2],
88[1, 2],
89[True, False],
90[torch.contiguous_format, torch.channels_last],
91[torch.nn.Conv2d, torch.nn.Conv3d],
92):
93m = module(
94in_channels=in_channels * g,
95out_channels=out_channels * g,
96kernel_size=kernel,
97padding=padding,
98stride=stride,
99dilation=dilation,
100groups=g,
101bias=bias,
102)
103input_shape = [1, in_channels * g, spatial, spatial]
104if isinstance(m, torch.nn.Conv3d):
105input_shape.append(spatial)
106if memory_format == torch.channels_last:
107memory_format = torch.channels_last_3d
108x = torch.rand(input_shape).to(memory_format=memory_format)
109patterns = [["aten::dequantize", "aten::_convolution"]]
110# TODO: enable more config case.
111for qconfig in static_qconfig:
112input_shape[0] = 5
113x_var = [torch.rand(input_shape, requires_grad=False)]
114graph = self.checkQuantizeTrace(
115m, [x], x_var=x_var, atol=2e-1, qconfig=qconfig
116)
117self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
118self.assertFused(graph, ["aten::_convolution", "aten::dequantize"])
119self.checkPatterns(graph, patterns)
120
121def test_deconv_int8_in_f32_out(self):
122class M(nn.Module):
123def __init__(
124self,
125in_channels,
126out_channels,
127kernel_size,
128padding,
129stride,
130dilation,
131groups,
132bias,
133module,
134):
135super(M, self).__init__()
136self.conv = module(
137in_channels=in_channels * groups,
138out_channels=out_channels * groups,
139kernel_size=kernel_size,
140padding=padding,
141stride=stride,
142dilation=dilation,
143groups=groups,
144bias=bias,
145)
146inverse_module = (
147torch.nn.ConvTranspose2d
148if (module == torch.nn.Conv2d)
149else torch.nn.ConvTranspose3d
150)
151self.deconv = inverse_module(
152in_channels=out_channels * groups,
153out_channels=in_channels * groups,
154kernel_size=kernel_size,
155padding=padding,
156stride=stride,
157dilation=dilation,
158groups=groups,
159bias=bias,
160)
161
162def forward(self, x):
163y = self.conv(x)
164return self.deconv(y)
165
166for [
167spatial,
168in_channels,
169out_channels,
170kernel,
171padding,
172stride,
173dilation,
174g,
175bias,
176memory_format,
177module,
178] in itertools.product(
179[7],
180[3],
181[3],
182[3],
183[0, 2],
184[1, 2],
185[1, 2],
186[1, 2],
187[True, False],
188[torch.contiguous_format, torch.channels_last],
189[torch.nn.Conv2d, torch.nn.Conv3d],
190):
191m = M(
192in_channels=in_channels,
193out_channels=out_channels,
194kernel_size=kernel,
195padding=padding,
196stride=stride,
197dilation=dilation,
198groups=g,
199bias=bias,
200module=module,
201)
202
203input_shape = [1, in_channels * g, spatial, spatial]
204if module == torch.nn.Conv3d:
205input_shape.append(spatial)
206if memory_format == torch.channels_last:
207memory_format = torch.channels_last_3d
208x = torch.rand(input_shape).to(memory_format=memory_format)
209
210patterns = [
211["aten::dequantize", "aten::_convolution"],
212["aten::dequantize", "aten::_convolution"],
213]
214
215# TODO: enable more config case.
216for qconfig in static_qconfig:
217input_shape[0] = 5
218graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
219self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
220self.assertFused(graph, ["aten::_convolution", "aten::dequantize"])
221self.checkPatterns(graph, patterns)
222
223def test_conv_no_freeze(self):
224m = nn.Conv2d(
225in_channels=3,
226out_channels=3,
227kernel_size=3,
228padding=1,
229stride=1,
230dilation=1,
231groups=1,
232bias=True,
233)
234x = torch.rand(1, 3, 5, 5)
235graph = self.checkQuantizeTrace(
236m, [x], atol=2e-1, qconfig=static_qconfig[0], freeze=False
237)
238patterns = [
239["aten::dequantize", "aten::quantize_per_channel", "aten::_convolution"]
240]
241self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
242self.assertFused(
243graph,
244["aten::_convolution", "aten::quantize_per_channel", "aten::dequantize"],
245)
246self.checkPatterns(graph, patterns)
247
248def test_conv_share_dequant_weight(self):
249class M(nn.Module):
250def __init__(self):
251super(M, self).__init__()
252self.conv = nn.Conv2d(32, 32, 3, padding=1, bias=True)
253
254def forward(self, x):
255# type: (List[Tensor]) -> Tensor
256all_logits = []
257for feature in x:
258logits = self.conv(feature)
259all_logits.append(logits)
260return torch.cat(all_logits, dim=1)
261
262for memory_format in [torch.contiguous_format, torch.channels_last]:
263patterns = [
264["aten::dequantize", "aten::_convolution"],
265["aten::dequantize", "aten::_convolution"],
266["aten::dequantize", "aten::_convolution"],
267]
268a = torch.randn(1, 32, 28, 28).to(memory_format=memory_format)
269b = torch.randn(1, 32, 28, 28).to(memory_format=memory_format)
270c = torch.randn(1, 32, 28, 28).to(memory_format=memory_format)
271x = [a, b, c]
272for qconfig in static_qconfig:
273m = M()
274graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
275self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
276self.assertFused(graph, ["aten::_convolution", "aten::dequantize"])
277self.checkPatterns(graph, patterns)
278
279def test_linear_int8_in_f32_out(self):
280for bias in [True, False]:
281x = torch.rand(32, 28)
282m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
283
284patterns = [
285["aten::dequantize", "aten::linear"],
286]
287for qconfig in static_qconfig:
288graph = self.checkQuantizeTrace(m, [x], atol=1e-1, qconfig=qconfig)
289self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
290self.assertFused(graph, ["aten::linear", "aten::dequantize"])
291self.checkPatterns(graph, patterns)
292
293def test_linear_int8_in_int8_out(self):
294class M(nn.Module):
295def __init__(self, bias):
296super(M, self).__init__()
297self.linear1 = nn.Linear(15, 20, bias=bias)
298self.linear2 = nn.Linear(20, 3, bias=bias)
299
300def forward(self, x, y):
301x = self.linear1(x)
302x = self.linear2(x)
303return x
304
305for bias in [True, False]:
306x = torch.randn(2, 15)
307y = torch.randn(2, 20)
308m = M(bias)
309
310patterns = [
311["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
312["aten::dequantize", "aten::linear"],
313]
314
315for qconfig in static_qconfig:
316graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
317self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
318self.assertFused(
319graph,
320["aten::linear", "aten::quantize_per_channel", "aten::dequantize"],
321)
322self.checkPatterns(graph, patterns)
323
324def test_linear_int8_in_bf16_out(self):
325class M(nn.Module):
326def __init__(self, bias):
327super(M, self).__init__()
328self.linear1 = nn.Linear(15, 20, bias=bias)
329
330def forward(self, x):
331x = self.linear1(x)
332return x
333
334for bias in [True]: # TODO:[True, False] when supported in backend
335x = torch.randn(2, 15)
336
337patterns = [
338["aten::dequantize", "aten::to", "aten::linear"],
339]
340
341for qconfig in static_qconfig:
342m = M(bias)
343graph = self.checkQuantizeTrace(
344m, [x], atol=2e-1, qconfig=qconfig, int8_bf16=True
345)
346self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
347# single aten::to won't be rewritten by llga backend
348self.assertFused(graph, ["aten::dequantize", "aten::linear"])
349self.checkPatterns(graph, patterns)
350
351def test_max_pool2d(self):
352class M(nn.Module):
353def __init__(self, **kargs):
354super(M, self).__init__()
355self.conv = nn.Conv2d(3, 3, 1, 1)
356self.max_pool = nn.MaxPool2d(**kargs)
357
358def forward(self, x):
359x = self.conv(x)
360x = self.max_pool(x)
361return x
362
363for [
364spatial,
365kernel,
366padding,
367stride,
368dilation,
369ceil_mode,
370memory_format,
371] in itertools.product(
372[15], # [15, 16], TODO: check backend
373[3, 5], # [3, 4, 5], TODO: check backend
374[0, 1],
375[1, 2], # [1, 2, 4], TODO: fix issue in pad calculation
376[1, 2],
377[True, False],
378[torch.contiguous_format, torch.channels_last],
379):
380m = M(
381kernel_size=kernel,
382stride=stride,
383padding=padding,
384dilation=dilation,
385ceil_mode=ceil_mode,
386)
387x = torch.rand(1, 3, spatial, spatial).to(memory_format=memory_format)
388
389patterns = [
390[
391"aten::dequantize",
392"aten::dequantize",
393"aten::_convolution",
394"aten::quantize_per_tensor",
395],
396["aten::dequantize", "aten::max_pool2d", "aten::quantize_per_tensor"],
397]
398for qconfig in static_qconfig:
399graph = self.checkQuantizeTrace(m, [x], atol=1e-1, qconfig=qconfig)
400self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
401self.assertFused(graph, ["aten::max_pool2d"])
402self.checkPatterns(graph, patterns)
403
404def test_add_scalar_input(self):
405class M(torch.nn.Module):
406def __init__(self):
407super(M, self).__init__()
408
409def forward(self, x):
410x_shape1 = x.size()[0]
411x_shape2 = x.size()[1]
412y1 = x_shape1 + 2
413y2 = x_shape2 + 3
414return y1 + y2
415
416# input[0] to add being scalar is unsupported
417x = torch.randn(3, 3)
418m = M()
419graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
420self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
421self.assertGraphContainsExactly(graph, "aten::add", 3)
422
423def test_reshape_6D_linear(self):
424class M(nn.Module):
425def __init__(self):
426super(M, self).__init__()
427self.linear = torch.nn.Linear(
428in_features=64, out_features=192, bias=True
429)
430
431def forward(self, x):
432x = x.reshape(4, 8, 7, 8, 8, 64).transpose(2, 3)
433x = self.linear(x)
434return x
435
436for bias in [True, False]:
437x = torch.randn(4, 56, 64, 64)
438m = M()
439
440patterns = [["aten::dequantize", "aten::linear"]]
441
442for qconfig in static_qconfig:
443graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
444self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
445self.assertFused(graph, ["aten::linear", "aten::dequantize"])
446self.checkPatterns(graph, patterns)
447
448def test_3d_bmm_int8_in_f32_out(self):
449class M(nn.Module):
450def __init__(self):
451super(M, self).__init__()
452
453def forward(self, x, y):
454return torch.bmm(x, y)
455
456x = torch.randn(128, 3, 4) * 0.1
457y = torch.randn(128, 4, 5) * 0.1
458patterns = [
459["aten::dequantize", "aten::bmm"],
460]
461m = M()
462graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
463self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
464self.assertFused(graph, ["aten::dequantize", "aten::bmm"])
465self.checkPatterns(graph, patterns)
466
467def test_bmm_int8_in_f32_out(self):
468class M(nn.Module):
469def __init__(self):
470super(M, self).__init__()
471
472def forward(self, x, y):
473mm_res = torch.matmul(x, y)
474return mm_res
475
476x = torch.randn(128, 16, 384, 64) * 0.1
477y = torch.randn(128, 1, 64, 384) * 0.1
478patterns = [
479["aten::dequantize", "aten::matmul"],
480]
481m = M()
482graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
483self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
484self.assertFused(graph, ["aten::matmul"])
485self.checkPatterns(graph, patterns)
486
487def test_strided_bmm_int8_in_bf16_out(self):
488class M(nn.Module):
489def __init__(self):
490super(M, self).__init__()
491self.num_attention_heads = 16
492self.attention_head_size = 4
493
494def forward(self, x, y):
495new_x_shape = x.size()[:-1] + (
496self.num_attention_heads,
497self.attention_head_size,
498)
499x = x.view(*new_x_shape)
500z1 = x.permute(0, 2, 1, 3)
501
502new_y_shape2 = y.size()[:-1] + (
503self.num_attention_heads,
504self.attention_head_size,
505)
506y = y.view(*new_y_shape2)
507z2 = y.permute(0, 2, 1, 3)
508
509# inputs to matmul has been permuted or transposed, thus are strided tensor
510return torch.matmul(z1, z2.transpose(-1, -2))
511
512m = M()
513x = torch.randn(2, 3, 64)
514y = torch.randn(2, 3, 64)
515
516patterns = [
517["aten::dequantize", "aten::to", "aten::matmul"],
518]
519
520graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
521self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
522self.assertFused(graph, ["aten::matmul", "aten::dequantize"])
523self.checkPatterns(graph, patterns)
524
525def test_mixed_precision_softmax(self):
526class M(torch.nn.Module):
527def __init__(self):
528super(M, self).__init__()
529
530def forward(self, x, y, z, a):
531o = torch.matmul(x, y) / 8.0
532o = o + a.to(o.dtype)
533o = torch.softmax(o, -1)
534o = o.matmul(z)
535return o
536
537x = torch.randn(1, 16, 16, 64)
538y = torch.randn(1, 16, 64, 16)
539z = torch.randn(1, 16, 16, 64)
540a = torch.randn(1, 1, 1, 16)
541m = M()
542
543# fp32 in int8 out softmax
544graph = self.checkQuantizeTrace(m, [x, y, z, a], atol=2e-1, int8_bf16=False)
545self.assertFused(
546graph, ["aten::matmul", "aten::div", "aten::add", "aten::softmax"]
547)
548
549# bf16 in int8 out softmax
550graph = self.checkQuantizeTrace(m, [x, y, z, a], atol=2e-1, int8_bf16=True)
551self.assertFused(
552graph, ["aten::matmul", "aten::div", "aten::add", "aten::softmax"]
553)
554
555
556class TestFusionPattern(JitLlgaTestCase):
557def test_conv2d_eltwise(self):
558class M(nn.Module):
559def __init__(self, eltwise_fn):
560super(M, self).__init__()
561self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
562self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
563self.eltwise = eltwise_fn
564
565def forward(self, x):
566x = self.conv1(x)
567x = self.eltwise(x)
568x = self.conv2(x)
569return x
570
571for eltwise in [
572"relu",
573"leaky_relu",
574"sigmoid",
575"round",
576"abs",
577"square",
578"abs",
579"round",
580"exp",
581"hardswish",
582"tanh",
583"hardtanh",
584"mish",
585]:
586for inplace in [False, True]:
587for memory_format in [torch.contiguous_format, torch.channels_last]:
588eltwise_fn_name = eltwise + "_" if inplace else eltwise
589eltwise_fn = get_eltwise_fn(eltwise_fn_name)
590
591m = M(eltwise_fn)
592x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
593
594patterns = [
595[
596"aten::dequantize",
597"aten::_convolution",
598"aten::" + eltwise,
599"aten::quantize_per_tensor",
600], # inplace op will become outplace op on the JIT graph
601["aten::dequantize", "aten::_convolution"],
602]
603for qconfig in static_qconfig:
604graph = self.checkQuantizeTrace(
605m, [x], atol=2e-1, qconfig=qconfig
606)
607self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
608self.assertFused(
609graph,
610[
611"aten::_convolution",
612"aten::" + eltwise,
613"aten::quantize_per_channel",
614"aten::dequantize",
615],
616)
617self.checkPatterns(graph, patterns)
618
619def test_conv2d_clamp(self):
620class M(nn.Module):
621def __init__(self):
622super(M, self).__init__()
623self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
624self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
625self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
626self.conv4 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
627self.conv5 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
628
629def forward(self, x):
630x = self.conv1(x)
631x = torch.clamp(x, min=float("-inf"))
632x = self.conv2(x)
633x = torch.clamp(x, min=-5)
634x = self.conv3(x)
635x = torch.clamp(x, min=0, max=float("inf"))
636x = self.conv4(x)
637x = torch.clamp(x, min=1, max=5)
638x = self.conv5(x)
639x = torch.clamp(x, max=2)
640return x
641
642for inplace in [False, True]:
643for memory_format in [torch.contiguous_format, torch.channels_last]:
644x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
645m = M()
646for qconfig in static_qconfig:
647graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
648self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
649self.assertFused(
650graph,
651[
652"aten::_convolution",
653"aten::" + "clamp",
654"aten::quantize_per_channel",
655"aten::dequantize",
656],
657)
658
659def test_conv2d_silu(self):
660class M(nn.Module):
661def __init__(self, inplace):
662super(M, self).__init__()
663self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
664self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
665self.eltwise = nn.SiLU(inplace=inplace)
666
667def forward(self, x):
668x = self.conv1(x)
669x = self.eltwise(x)
670x = self.conv2(x)
671return x
672
673for inplace in [False, True]:
674for memory_format in [torch.contiguous_format, torch.channels_last]:
675m = M(inplace)
676x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
677
678graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
679self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
680
681silu_op = "aten::silu_" if inplace else "aten::silu"
682
683# oneDNN graph does not have silu OP. The bridge will convert silu to sigmoid - mul
684patterns = [
685[
686"aten::dequantize",
687"aten::_convolution",
688"aten::sigmoid",
689"aten::mul",
690"aten::quantize_per_tensor",
691], # inplace op will become outplace op on the JIT graph
692["aten::dequantize", "aten::_convolution"],
693]
694
695self.assertFused(
696graph, ["aten::_convolution", silu_op, "aten::dequantize"]
697)
698self.checkPatterns(graph, patterns)
699
700def test_deconv_silu(self):
701class M(nn.Module):
702def __init__(self, inplace):
703super(M, self).__init__()
704self.deconv = nn.ConvTranspose2d(3, 2, 3, stride=2)
705self.eltwise = nn.SiLU(inplace=inplace)
706
707def forward(self, x):
708x = self.deconv(x)
709x = self.eltwise(x)
710return x
711
712for inplace in [False, True]:
713m = M(inplace)
714x = torch.rand(1, 3, 28, 28)
715graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
716patterns = [
717["aten::dequantize", "aten::_convolution", "aten::sigmoid", "aten::mul"]
718]
719self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
720self.checkPatterns(graph, patterns)
721
722def test_ensure_tensor_is_rewrapped(self):
723class M(nn.Module):
724def __init__(self, eltwise_fn):
725super(M, self).__init__()
726self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
727self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
728self.eltwise = eltwise_fn
729self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7))
730
731def forward(self, x, y):
732x = self.conv1(x)
733y = self.conv2(y)
734y = self.eltwise(y)
735x = torch.add(x, y)
736x = self.adaptive_avg_pool_2d(x)
737return x
738
739eltwise_fn_name = "relu"
740eltwise_fn = get_eltwise_fn(eltwise_fn_name)
741
742m = M(eltwise_fn)
743x = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
744y = torch.rand(1, 32, 28, 28).to(memory_format=torch.channels_last)
745for qconfig in static_qconfig:
746# The output of the fourth partition is input to adaptive_avg_pool2d, which is
747# unsupported by LLGA. In resnext101 32x16d, we had encountered an accuracy issue.
748# The UT checks that the input to adaptive_avg_pool_2d has not been wrapped by
749# LlgaTensorImpl (assertEqual would fail in that case).
750graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
751self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
752
753def test_conv2d_bn(self):
754class M(nn.Module):
755def __init__(self, bias):
756super(M, self).__init__()
757self.conv1 = nn.Conv2d(32, 5, 3, padding=1, bias=False)
758self.bn1 = nn.BatchNorm2d(5)
759
760def forward(self, x):
761x = self.conv1(x)
762x = self.bn1(x)
763return x
764
765for bias in [False, True]:
766for memory_format in [torch.contiguous_format, torch.channels_last]:
767m = M(bias).eval()
768x = torch.rand(1, 32, 16, 16).to(memory_format=memory_format)
769# TODO: This shape will fail
770# x = torch.rand(1, 32, 28, 28)
771
772patterns = [["aten::dequantize", "aten::_convolution"]]
773# TODO: add torch.per_tensor_symmetric case.
774for qconfig in static_qconfig:
775graph = self.checkQuantizeTrace(m, [x], atol=1e-1, qconfig=qconfig)
776self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
777self.assertFused(
778graph,
779[
780"aten::_convolution",
781"aten::quantize_per_channel",
782"aten::dequantize",
783],
784)
785self.checkPatterns(graph, patterns)
786
787def test_conv2d_bn_relu(self):
788class M(nn.Module):
789def __init__(self):
790super(M, self).__init__()
791self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
792self.bn1 = nn.BatchNorm2d(32)
793
794def forward(self, x):
795x = self.conv1(x)
796x = self.bn1(x)
797x = F.relu(x)
798return x
799
800for memory_format in [torch.contiguous_format, torch.channels_last]:
801m = M().eval()
802x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
803patterns = [
804["aten::dequantize", "aten::_convolution", "aten::relu"],
805]
806for qconfig in static_qconfig:
807graph = self.checkQuantizeTrace(m, [x], atol=1e-1, qconfig=qconfig)
808self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
809self.assertFused(
810graph,
811["aten::_convolution", "aten::relu", "aten::quantize_per_channel"],
812)
813self.checkPatterns(graph, patterns)
814
815def test_linear_bn(self):
816class M(nn.Module):
817def __init__(self, dim):
818super(M, self).__init__()
819self.linear = nn.Linear(32, 32)
820if dim == 1:
821self.input1 = torch.randn(1, 32)
822self.bn = nn.BatchNorm1d(32)
823elif dim == 2:
824self.input1 = torch.randn(1, 32, 32, 32)
825self.bn = nn.BatchNorm2d(32)
826elif dim == 3:
827self.input1 = torch.randn(1, 32, 32, 32, 32)
828self.bn = nn.BatchNorm3d(32)
829
830def forward(self, x):
831x = self.linear(x)
832x = self.bn(x)
833return x
834
835for dim in [1, 2, 3]:
836m = M(dim=dim)
837x = m.input1
838patterns = [["aten::dequantize", "aten::linear"]]
839for qconfig in static_qconfig:
840graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
841self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
842self.assertFused(graph, ["ipex::batch_norm"])
843self.checkPatterns(graph, patterns)
844
845def test_conv_bn_linear_bn(self):
846class M(nn.Module):
847def __init__(
848self,
849):
850super(M, self).__init__()
851self.input1 = torch.randn(1, 32, 32, 32)
852self.conv = nn.Conv2d(32, 32, 1)
853self.bn1 = nn.BatchNorm2d(32)
854self.linear = nn.Linear(32, 32)
855self.bn2 = nn.BatchNorm2d(32)
856
857def forward(self, x):
858x = self.conv(x)
859x = self.bn1(x)
860x = self.linear(x)
861x = self.bn2(x)
862return x
863
864m = M()
865x = m.input1
866patterns = [
867["aten::dequantize", "aten::_convolution"],
868["aten::dequantize", "aten::linear"],
869]
870for qconfig in static_qconfig:
871graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
872self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
873self.assertFused(graph, ["ipex::batch_norm"])
874self.checkPatterns(graph, patterns)
875
876def test_linear_eltwise(self):
877class M(nn.Module):
878def __init__(self, eltwise_fn, bias):
879super(M, self).__init__()
880self.linear = nn.Linear(28, 64, bias)
881self.eltwise = eltwise_fn
882
883def forward(self, x):
884x = self.linear(x)
885x = self.eltwise(x)
886return x
887
888# TODO: use itertools.product once all combinations is supported
889for [has_bias, eltwise] in [
890[True, "relu"],
891[False, "relu"],
892# [True, 'gelu'], # TODO: enable it once linear_gelu default recipe is fixed
893# [False, 'gelu'], # TODO: enable it once linear_gelu default recipe is fixed
894[True, "sigmoid"],
895[False, "sigmoid"],
896]:
897eltwise_fn = get_eltwise_fn(eltwise)
898m = M(eltwise_fn, has_bias)
899x = torch.rand(32, 28, requires_grad=False)
900patterns = [
901["aten::dequantize", "aten::linear", "aten::" + eltwise],
902]
903for qconfig in static_qconfig:
904graph = self.checkQuantizeTrace(
905m,
906[x],
907x_var=[torch.rand(2, 28, requires_grad=False)],
908atol=1e-1,
909qconfig=qconfig,
910)
911self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
912self.assertFused(graph, ["aten::" + eltwise])
913self.checkPatterns(graph, patterns)
914
915def test_linear_silu(self):
916class M(nn.Module):
917def __init__(self, inplace):
918super(M, self).__init__()
919self.linear = nn.Linear(28, 64)
920self.eltwise = nn.SiLU(inplace=inplace)
921
922def forward(self, x):
923x = self.linear(x)
924x = self.eltwise(x)
925return x
926
927for inplace in [False, True]:
928m = M(inplace)
929x = torch.rand(1, 28, requires_grad=False)
930
931silu_op = "aten::silu_" if inplace else "aten::silu"
932
933patterns = [
934["aten::dequantize", "aten::linear", "aten::sigmoid", "aten::mul"],
935]
936graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
937self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
938self.assertFused(graph, ["aten::linear", silu_op, "aten::dequantize"])
939self.checkPatterns(graph, patterns)
940
941def test_conv_relu_sigmoid_mul(self):
942# dequant
943# |
944# conv
945# |
946# relu
947# / |
948# quant |
949# / |
950# dequant |
951# | |
952# conv |
953# | |
954# relu |
955# | |
956# quant |
957# | |
958# dequant |
959# | |
960# conv |
961# | |
962# sigmoid |
963# \ /
964# mul
965
966class M(nn.Module):
967def __init__(self):
968super(M, self).__init__()
969self.conv1 = nn.Conv2d(32, 32, 3, padding=1)
970self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
971self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
972
973def forward(self, x):
974x = self.conv1(x)
975
976# The output y of relu is used by mul
977y = x.relu()
978
979z = self.conv2(y)
980z = z.relu()
981z = self.conv3(z)
982z = z.sigmoid()
983z = z.mul(y)
984return z
985
986x = torch.rand(1, 32, 16, 16, requires_grad=False)
987m = M()
988graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
989patterns = [
990["aten::dequantize", "aten::_convolution", "aten::relu"],
991[
992"aten::dequantize",
993"aten::_convolution",
994"aten::relu",
995"aten::quantize_per_tensor",
996],
997["aten::dequantize", "aten::_convolution", "aten::sigmoid", "aten::mul"],
998]
999self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1000self.assertFused(
1001graph, ["aten::_convolution", "aten::relu", "aten::sigmoid", "aten::mul"]
1002)
1003self.checkPatterns(graph, patterns)
1004
1005def test_conv_eltwise_tensor_method(self):
1006class ConvSigmoid(nn.Module):
1007def __init__(self):
1008super(ConvSigmoid, self).__init__()
1009self.conv = nn.Conv2d(32, 32, 3, padding=1)
1010
1011def forward(self, x):
1012x = self.conv(x)
1013x = x.sigmoid()
1014return x
1015
1016class ConvReLU(nn.Module):
1017def __init__(self):
1018super(ConvReLU, self).__init__()
1019self.conv = nn.Conv2d(32, 32, 3, padding=1)
1020
1021def forward(self, x):
1022x = self.conv(x)
1023x = x.relu()
1024return x
1025
1026m = ConvSigmoid().eval()
1027x = torch.rand(1, 32, 16, 16, requires_grad=False)
1028patterns = [["aten::dequantize", "aten::_convolution", "aten::sigmoid"]]
1029graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
1030self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1031self.assertFused(graph, ["aten::_convolution", "aten::sigmoid"])
1032self.checkPatterns(graph, patterns)
1033
1034m = ConvReLU().eval()
1035x = torch.rand(1, 32, 16, 16, requires_grad=False)
1036patterns = [["aten::dequantize", "aten::_convolution", "aten::relu"]]
1037graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
1038self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1039self.assertFused(graph, ["aten::_convolution", "aten::relu"])
1040self.checkPatterns(graph, patterns)
1041
1042def test_conv2d_sum(self):
1043class M(nn.Module):
1044def __init__(self, bias=False):
1045super(M, self).__init__()
1046self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
1047self.bn1 = nn.BatchNorm2d(32)
1048self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
1049self.bn2 = nn.BatchNorm2d(32)
1050self.relu = nn.ReLU()
1051self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
1052self.bn3 = nn.BatchNorm2d(32)
1053
1054def forward(self, x, y):
1055x = self.conv1(x)
1056x = self.bn1(x)
1057y = self.conv2(y)
1058y = self.bn2(y)
1059z = self.relu(x + y)
1060z = self.conv3(z)
1061z = self.bn3(z)
1062return z
1063
1064for bias in [True, False]:
1065for memory_format in [torch.contiguous_format, torch.channels_last]:
1066m = M(bias).eval()
1067x = torch.rand(1, 32, 16, 16, requires_grad=False).to(
1068memory_format=memory_format
1069)
1070y = torch.rand(1, 32, 16, 16, requires_grad=False).to(
1071memory_format=memory_format
1072)
1073patterns = [
1074[
1075"aten::dequantize",
1076"aten::_convolution",
1077"aten::quantize_per_tensor",
1078],
1079[
1080"aten::dequantize",
1081"aten::_convolution",
1082"aten::relu",
1083"aten::add",
1084"aten::quantize_per_tensor",
1085],
1086["aten::dequantize", "aten::_convolution"],
1087]
1088for qconfig in static_qconfig:
1089graph = self.checkQuantizeTrace(
1090m, [x, y], atol=1e-1, qconfig=qconfig
1091)
1092self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1093self.assertFused(
1094graph,
1095[
1096"aten::_convolution",
1097"aten::relu",
1098"aten::add",
1099"aten::quantize_per_channel",
1100"aten::dequantize",
1101],
1102)
1103self.checkPatterns(graph, patterns)
1104
1105def test_add_quantization(self):
1106class M(nn.Module):
1107def __init__(self, bias=False):
1108super(M, self).__init__()
1109self.conv1 = nn.Conv2d(16, 16, 1)
1110self.conv2 = nn.Conv2d(16, 16, 1)
1111
1112def forward(self, x):
1113x = self.conv1(x)
1114y = self.conv2(x)
1115y = y.mul(10)
1116z = torch.add(x, y)
1117return z
1118
1119m = M().eval()
1120x = torch.rand(1, 16, 16, 16, requires_grad=False)
1121x2 = torch.rand(1, 16, 16, 16, requires_grad=False)
1122
1123patterns = [
1124["aten::dequantize", "aten::_convolution"],
1125["aten::dequantize", "aten::_convolution"],
1126]
1127graph = self.checkQuantizeTrace(m, [x], atol=1e-1)
1128self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1129self.assertFused(graph, ["aten::_convolution", "aten::quantize_per_channel"])
1130self.checkPatterns(graph, patterns)
1131
1132def test_conv2d_sigmoid_mul_(self):
1133class M(nn.Module):
1134def __init__(self, in_channels, out_channels, kernel_size, image_size):
1135super(M, self).__init__()
1136self.conv = torch.nn.Conv2d(
1137in_channels, out_channels, kernel_size, image_size
1138)
1139
1140def forward(self, x):
1141a = self.conv(x)
1142b = torch.sigmoid(a)
1143res = a.mul_(b)
1144return res
1145
1146for memory_format in [torch.contiguous_format, torch.channels_last]:
1147m = M(3, 16, 3, 224).eval()
1148x = torch.rand(1, 3, 224, 224, requires_grad=False).to(
1149memory_format=memory_format
1150)
1151patterns = [
1152[
1153"aten::dequantize",
1154"aten::_convolution",
1155"aten::sigmoid",
1156"aten::mul",
1157],
1158]
1159for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
1160graph = self.checkQuantizeTrace(m, [x])
1161self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1162self.assertFused(
1163graph,
1164[
1165"aten::_convolution",
1166"aten::sigmoid",
1167"aten::mul",
1168"aten::quantize_per_channel",
1169"aten::dequantize",
1170],
1171)
1172self.checkPatterns(graph, patterns)
1173
1174# inplace mul_ cannot be replaced with mul
1175class M2(nn.Module):
1176def __init__(self, in_channels, out_channels, kernel_size, image_size):
1177super(M2, self).__init__()
1178self.conv = torch.nn.Conv2d(
1179in_channels, out_channels, kernel_size, image_size
1180)
1181
1182def forward(self, x):
1183a = self.conv(x)
1184b = torch.sigmoid(a)
1185c = a[0]
1186res = a.mul_(b)
1187c += 2
1188return c
1189
1190for memory_format in [torch.contiguous_format, torch.channels_last]:
1191m = M2(3, 16, 3, 224).eval()
1192x = torch.rand(1, 3, 224, 224, requires_grad=False).to(
1193memory_format=memory_format
1194)
1195patterns = [
1196["aten::dequantize", "aten::_convolution"],
1197]
1198for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
1199graph = self.checkQuantizeTrace(m, [x])
1200self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1201self.assertFused(
1202graph,
1203[
1204"aten::_convolution",
1205"aten::quantize_per_channel",
1206"aten::dequantize",
1207],
1208)
1209self.checkPatterns(graph, patterns)
1210
1211def test_conv2d_hardsigmoid_mul_(self):
1212class M(nn.Module):
1213def __init__(self, in_channels, out_channels, kernel_size, image_size):
1214super(M, self).__init__()
1215self.conv = torch.nn.Conv2d(
1216in_channels, out_channels, kernel_size, image_size
1217)
1218self.activation = torch.nn.Hardsigmoid()
1219
1220def forward(self, x):
1221a = self.conv(x)
1222b = self.activation(a)
1223res = a.mul_(b)
1224return res
1225
1226for memory_format in [torch.contiguous_format, torch.channels_last]:
1227m = M(3, 16, 3, 224).eval()
1228x = torch.rand(1, 3, 224, 224, requires_grad=False).to(
1229memory_format=memory_format
1230)
1231patterns = [
1232[
1233"aten::dequantize",
1234"aten::_convolution",
1235"aten::hardsigmoid",
1236"aten::mul",
1237],
1238]
1239for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
1240graph = self.checkQuantizeTrace(m, [x])
1241self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1242self.assertFused(
1243graph,
1244[
1245"aten::_convolution",
1246"aten::hardsigmoid",
1247"aten::mul",
1248"aten::quantize_per_channel",
1249"aten::dequantize",
1250],
1251)
1252self.checkPatterns(graph, patterns)
1253
1254def test_linear_dropout_sum(self):
1255class M(nn.Module):
1256def __init__(self):
1257super(M, self).__init__()
1258self.linear1 = nn.Linear(15, 20)
1259self.dropout = nn.Dropout()
1260self.linear2 = nn.Linear(20, 3)
1261
1262def forward(self, x, y):
1263x = self.linear1(x)
1264x = self.dropout(x)
1265z = self.linear2(x + y)
1266return z
1267
1268x = torch.randn(2, 15)
1269y = torch.randn(2, 20)
1270m = M()
1271patterns = [
1272[
1273"aten::dequantize",
1274"aten::linear",
1275"aten::add",
1276"aten::quantize_per_tensor",
1277],
1278["aten::dequantize", "aten::linear"],
1279]
1280for qconfig in static_qconfig:
1281graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
1282self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1283self.assertFused(
1284graph,
1285[
1286"aten::linear",
1287"aten::add",
1288"aten::quantize_per_channel",
1289"aten::dequantize",
1290],
1291)
1292self.checkPatterns(graph, patterns)
1293
1294def test_linear_sum_inplace(self):
1295class M(nn.Module):
1296def __init__(self):
1297super(M, self).__init__()
1298self.linear1 = nn.Linear(15, 20)
1299
1300def forward(self, x, y):
1301x = self.linear1(x)
1302x += y.clone()
1303return x
1304
1305x = torch.randn(2, 15)
1306y = torch.randn(2, 20)
1307m = M()
1308patterns = [
1309["aten::dequantize", "aten::linear", "aten::dequantize"],
1310]
1311# HistogramObserver failed, need to do some checks?
1312for qconfig in static_qconfig[:2]:
1313graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, qconfig=qconfig)
1314self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1315self.assertFused(
1316graph,
1317["aten::linear", "aten::quantize_per_channel", "aten::dequantize"],
1318)
1319self.checkPatterns(graph, patterns)
1320
1321def test_linear_with_multiple_add(self):
1322class M(nn.Module):
1323def __init__(self):
1324super(M, self).__init__()
1325self.linear1 = nn.Linear(15, 20)
1326self.linear2 = nn.Linear(15, 20)
1327
1328def forward(self, x1, y1, x2, y2):
1329x1 = self.linear1(x1)
1330x1 += y1.clone()
1331x2 = self.linear2(x2)
1332x2 += y2.clone()
1333return x1 + x2
1334
1335x1 = torch.randn(2, 15)
1336y1 = torch.randn(2, 20)
1337x2 = torch.randn(2, 15)
1338y2 = torch.randn(2, 20)
1339
1340m = M()
1341patterns = [
1342["aten::dequantize", "aten::linear", "aten::add"],
1343["aten::dequantize", "aten::linear", "aten::add", "aten::add"],
1344]
1345for qconfig in static_qconfig[:2]:
1346graph = self.checkQuantizeTrace(
1347m, [x1, y1, x2, y2], atol=2e-1, qconfig=qconfig
1348)
1349self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1350# There shouldn't have single add node which doesn't fused into subgraph.
1351self.assertFused(
1352graph,
1353["aten::linear", "aten::add"],
1354)
1355self.checkPatterns(graph, patterns)
1356
1357def test_linear_dropout_sum_bf16(self):
1358class M(nn.Module):
1359def __init__(self):
1360super(M, self).__init__()
1361self.linear1 = nn.Linear(15, 20, bias=True)
1362self.dropout = nn.Dropout()
1363self.linear2 = nn.Linear(15, 20, bias=True)
1364
1365def forward(self, x, y):
1366x = self.linear1(x)
1367x = self.dropout(x)
1368z = self.linear2(y) + x
1369return z
1370
1371x = torch.randn(2, 15)
1372y = torch.randn(2, 15)
1373m = M()
1374patterns = [
1375[
1376"aten::dequantize",
1377"aten::to",
1378"aten::linear",
1379],
1380["aten::dequantize", "aten::to", "aten::linear", "aten::add"],
1381]
1382graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1383self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1384# TODO: oneDNN primitive raised more limitations to sum post-ops, it forced fusion changes on oneDNN graph side.
1385# The dequant node connected to aten::add can't be fused into the INT8 linear-add partition any more.
1386# oneDNN graph expects no end to end model performance impact.
1387# Revisit this change if validation has found model level regression.
1388self.assertFused(graph, ["aten::linear", "aten::add"])
1389self.checkPatterns(graph, patterns)
1390
1391def test_linear_gelu_bf16(self):
1392class M(nn.Module):
1393def __init__(self):
1394super(M, self).__init__()
1395self.linear = nn.Linear(28, 64, bias=True)
1396self.eltwise = nn.GELU()
1397self.linear2 = nn.Linear(64, 1, bias=True)
1398
1399def forward(self, x):
1400x = self.linear(x)
1401x = self.eltwise(x)
1402x = self.linear2(x)
1403return x
1404
1405patterns = [
1406[
1407"aten::dequantize",
1408"aten::to",
1409"aten::linear",
1410"aten::gelu",
1411"aten::to",
1412"aten::quantize_per_tensor",
1413],
1414["aten::dequantize", "aten::to", "aten::linear"],
1415]
1416m = M()
1417x = torch.rand(32, 28, requires_grad=False)
1418for qscheme in [torch.per_tensor_affine]:
1419graph = self.checkQuantizeTrace(
1420m,
1421[x],
1422x_var=[torch.rand(2, 28, requires_grad=False)],
1423atol=1e-1,
1424int8_bf16=True,
1425)
1426self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1427self.assertFused(graph, ["aten::dequantize", "aten::linear", "aten::gelu"])
1428self.checkPatterns(graph, patterns)
1429
1430def test_defer_size(self):
1431class M(nn.Module):
1432def __init__(self):
1433super(M, self).__init__()
1434self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
1435self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
1436self.eltwise = nn.ReLU()
1437
1438def forward(self, x):
1439x = self.conv1(x)
1440x = self.eltwise(x)
1441y = self.conv2(x)
1442y = y.reshape(x.size(0), -1)
1443return y
1444
1445for memory_format in [torch.contiguous_format, torch.channels_last]:
1446m = M()
1447x = torch.rand(1, 32, 28, 28).to(memory_format=memory_format)
1448patterns = [
1449[
1450"aten::dequantize",
1451"aten::_convolution",
1452"aten::relu",
1453"aten::quantize_per_tensor",
1454],
1455["aten::dequantize", "aten::_convolution"],
1456]
1457for qconfig in static_qconfig:
1458graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
1459self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1460self.assertFused(
1461graph,
1462[
1463"aten::_convolution",
1464"aten::relu",
1465"aten::quantize_per_channel",
1466"aten::dequantize",
1467],
1468)
1469self.checkPatterns(graph, patterns)
1470
1471def test_lift_up_quant(self):
1472class M(nn.Module):
1473def __init__(self, bias):
1474super(M, self).__init__()
1475self.linear = nn.Linear(28, 64, bias)
1476self.linear2 = nn.Linear(28, 64, bias=True)
1477self.num_attention_heads = 16
1478self.attention_head_size = 4
1479
1480def forward(self, x, y):
1481x = self.linear(x)
1482new_x_shape = x.size()[:-1] + (
1483self.num_attention_heads,
1484self.attention_head_size,
1485)
1486x = x.view(*new_x_shape)
1487z1 = x.permute(0, 2, 1, 3)
1488
1489y = self.linear2(y)
1490new_y_shape2 = y.size()[:-1] + (
1491self.num_attention_heads,
1492self.attention_head_size,
1493)
1494y = y.view(*new_y_shape2)
1495z2 = y.permute(0, 2, 1, 3)
1496
1497return torch.matmul(z1, z2.transpose(-1, -2))
1498
1499m = M(bias=True)
1500x = torch.randn(2, 3, 28)
1501y = torch.randn(2, 3, 28)
1502
1503patterns = [
1504["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
1505["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
1506["aten::dequantize", "aten::matmul"],
1507]
1508
1509# TODO: test shape fallback
1510graph = self.checkQuantizeTrace(m, [x, y], atol=1e-1)
1511self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1512self.assertFused(graph, ["aten::dequantize", "aten::linear", "aten::matmul"])
1513self.checkPatterns(graph, patterns)
1514
1515def test_lift_up_to_quant_bf16(self):
1516class M(nn.Module):
1517def __init__(self, bias):
1518super(M, self).__init__()
1519self.linear = nn.Linear(28, 64, bias)
1520self.linear2 = nn.Linear(28, 64, bias=True)
1521self.num_attention_heads = 16
1522self.attention_head_size = 4
1523
1524def forward(self, x, y):
1525x = self.linear(x)
1526new_x_shape = x.size()[:-1] + (
1527self.num_attention_heads,
1528self.attention_head_size,
1529)
1530x = x.view(*new_x_shape)
1531z1 = x.permute(0, 2, 1, 3)
1532
1533y = self.linear2(y)
1534new_y_shape2 = y.size()[:-1] + (
1535self.num_attention_heads,
1536self.attention_head_size,
1537)
1538y = y.view(*new_y_shape2)
1539z2 = y.permute(0, 2, 1, 3)
1540
1541return torch.matmul(z1, z2.transpose(-1, -2))
1542
1543m = M(bias=True)
1544x = torch.randn(2, 3, 28)
1545y = torch.randn(2, 3, 28)
1546
1547patterns = [
1548[
1549"aten::dequantize",
1550"aten::to",
1551"aten::linear",
1552"aten::to",
1553"aten::quantize_per_tensor",
1554],
1555[
1556"aten::dequantize",
1557"aten::to",
1558"aten::linear",
1559"aten::to",
1560"aten::quantize_per_tensor",
1561],
1562["aten::dequantize", "aten::to", "aten::matmul"],
1563]
1564
1565# TODO: test shape fallback
1566graph = self.checkQuantizeTrace(m, [x, y], atol=1e-1, int8_bf16=True)
1567self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
1568self.assertFused(graph, ["aten::dequantize", "aten::linear", "aten::matmul"])
1569self.checkPatterns(graph, patterns)
1570
1571def test_lift_up_quant_unsupported(self):
1572# Original graph:
1573# |
1574# view
1575# / (f32)\ /(f32)
1576# quant add
1577# |
1578
1579# Lifting up in this case will raise:
1580# promoteTypes with quantized numbers is not handled in aten::add;
1581# |
1582# quant
1583# |
1584# view
1585# (int8)\ /(f32)
1586# add
1587class M(nn.Module):
1588def __init__(self):
1589super(M, self).__init__()
1590self.conv1 = nn.Conv2d(3, 8, 1)
1591self.conv2 = nn.Conv2d(8, 8, 1)
1592
1593def forward(self, x, y):
1594x = self.conv1(x)
1595z1 = x.permute(0, 3, 1, 2)
1596z2 = self.conv2(z1)
1597z = z1 + y
1598output = z2 + z
1599return output
1600
1601x = torch.randn(1, 3, 8, 8)
1602y = torch.randn(1, 8, 8, 8)
1603m = M()
1604
1605patterns = [
1606["aten::dequantize", "aten::_convolution"],
1607["aten::dequantize", "aten::_convolution", "aten::add"],
1608]
1609
1610graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1611self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1612# TODO: oneDNN primitive raised more limitations to sum post-ops, it forced fusion changes on oneDNN graph side.
1613# The dequant node connected to aten::add can't be fused into the INT8 conv-add partition any more.
1614# oneDNN graph expects no end to end model performance impact.
1615# Revisit this change if validation has found model level regression.
1616self.assertFused(graph, ["aten::_convolution"])
1617self.checkPatterns(graph, patterns)
1618
1619def test_wildcard(self):
1620class M(nn.Module):
1621def __init__(self):
1622super(M, self).__init__()
1623self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
1624self.eltwise = nn.ReLU()
1625
1626def forward(self, x):
1627x = self.conv1(x)
1628y = self.eltwise(x)
1629return [x, y]
1630
1631# The pattern is as the following:
1632# conv
1633# | \
1634# eltwise \
1635# | \
1636# ListConstruct
1637#
1638# The output of conv is used by a wildcard op: ListConstruct.
1639# Thus conv-eltwise cannot be selected into the same Partition.
1640m = M()
1641x = torch.rand(1, 32, 28, 28)
1642patterns = [
1643["aten::dequantize", "aten::_convolution"],
1644]
1645graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
1646self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1647self.assertGraphContainsExactly(graph, "aten::relu", 1)
1648self.assertFused(graph, ["aten::_convolution", "aten::quantize_per_channel"])
1649self.checkPatterns(graph, patterns)
1650
1651def test_bmm_div_scalar(self):
1652class M(nn.Module):
1653def __init__(self, div_value):
1654super(M, self).__init__()
1655self.div_value = div_value
1656
1657def forward(self, x, y):
1658mm_res = torch.matmul(x, y)
1659return mm_res.div(self.div_value)
1660
1661x = torch.randn(1, 16, 384, 64)
1662y = torch.randn(1, 1, 64, 384)
1663patterns = [
1664["aten::dequantize", "aten::matmul", "aten::div"],
1665]
1666m = M(8.0)
1667graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1668self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1669self.assertFused(graph, ["aten::matmul", "aten::div"])
1670self.checkPatterns(graph, patterns)
1671
1672def test_bmm_div_identity(self):
1673class M(nn.Module):
1674def __init__(self, div_value):
1675super(M, self).__init__()
1676self.div_value = div_value
1677
1678def forward(self, x, y):
1679mm_res = torch.matmul(x, y)
1680return mm_res.div(self.div_value)
1681
1682x = torch.randn(1, 16, 384, 64) * 0.1
1683y = torch.randn(1, 1, 64, 384) * 0.1
1684patterns = [
1685["aten::dequantize", "aten::matmul"],
1686]
1687m = M(1.0)
1688graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1689# divide by 1 should be removed by Constant Propagation
1690self.assertGraphContainsExactly(graph, "aten::div", 0, consider_subgraphs=True)
1691self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1692self.assertFused(graph, ["aten::matmul"])
1693self.checkPatterns(graph, patterns)
1694
1695def test_bmm_div_tensor(self):
1696class M(nn.Module):
1697def __init__(self):
1698super(M, self).__init__()
1699
1700def forward(self, x, y, z):
1701mm_res = torch.matmul(x, y)
1702return mm_res.div(z)
1703
1704x = torch.randn(1, 16, 384, 64) * 0.1
1705y = torch.randn(1, 1, 64, 384) * 0.1
1706z = torch.randn(
17071
1708) # TODO: enable torch.randn(20) and torch.randn(1, 1, 20, 20) once backend supported them
1709patterns = [
1710["aten::dequantize", "aten::matmul", "aten::div"],
1711]
1712m = M()
1713graph = self.checkQuantizeTrace(m, [x, y, z], atol=2e-1)
1714self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1715self.assertFused(graph, ["aten::matmul", "aten::div"])
1716self.checkPatterns(graph, patterns)
1717
1718def test_bmm_div_int8_in_bf16_out(self):
1719class M(nn.Module):
1720def __init__(self):
1721super(M, self).__init__()
1722
1723def forward(self, x, y):
1724mm_res = torch.matmul(x, y) / 2
1725return mm_res
1726
1727x = torch.randn(1, 16, 384, 64) * 0.1
1728y = torch.randn(1, 1, 64, 384) * 0.1
1729patterns = [
1730["aten::dequantize", "aten::to", "aten::matmul", "aten::div"],
1731]
1732m = M()
1733graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1734self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1735# single aten::to won't be rewritten by llga backend
1736self.assertFused(graph, ["aten::dequantize", "aten::matmul", "aten::div"])
1737self.checkPatterns(graph, patterns)
1738
1739def test_bmm_method_bf16(self):
1740class M(nn.Module):
1741def __init__(self):
1742super(M, self).__init__()
1743
1744def forward(self, x, y):
1745mm_res = x.matmul(y)
1746return mm_res
1747
1748x = torch.randn(1, 16, 384, 64) * 0.1
1749y = torch.randn(1, 1, 64, 384) * 0.1
1750patterns = [
1751["aten::dequantize", "aten::to", "aten::matmul"],
1752]
1753m = M()
1754graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1755self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1756# single aten::to won't be rewritten by llga backend
1757self.assertFused(graph, ["aten::dequantize", "aten::matmul"])
1758self.checkPatterns(graph, patterns)
1759
1760def test_bmm_method_fp32(self):
1761class M(nn.Module):
1762def __init__(self):
1763super(M, self).__init__()
1764
1765def forward(self, x, y):
1766mm_res = x.matmul(y)
1767return mm_res
1768
1769x = torch.randn(1, 16, 384, 64) * 0.1
1770y = torch.randn(1, 1, 64, 384) * 0.1
1771patterns = [
1772["aten::dequantize", "aten::matmul"],
1773]
1774m = M()
1775graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1)
1776self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1777self.assertFused(graph, ["aten::dequantize", "aten::matmul"])
1778self.checkPatterns(graph, patterns)
1779
1780def test_strided_bmm_div_int8_in_bf16_out(self):
1781class M(nn.Module):
1782def __init__(self):
1783super(M, self).__init__()
1784self.num_attention_heads = 16
1785self.attention_head_size = 4
1786
1787def forward(self, x, y):
1788new_x_shape = x.size()[:-1] + (
1789self.num_attention_heads,
1790self.attention_head_size,
1791)
1792x = x.view(*new_x_shape)
1793z1 = x.permute(0, 2, 1, 3)
1794
1795new_y_shape2 = y.size()[:-1] + (
1796self.num_attention_heads,
1797self.attention_head_size,
1798)
1799y = y.view(*new_y_shape2)
1800z2 = y.permute(0, 2, 1, 3)
1801
1802# inputs to matmul has been permuted or transposed, thus are strided tensor
1803return torch.matmul(z1, z2.transpose(-1, -2)) / 0.4
1804
1805m = M()
1806x = torch.randn(2, 3, 64)
1807y = torch.randn(2, 3, 64)
1808
1809patterns = [
1810["aten::dequantize", "aten::to", "aten::matmul", "aten::div"],
1811]
1812
1813graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, int8_bf16=True)
1814self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1815self.assertFused(graph, ["aten::matmul", "aten::dequantize"])
1816self.checkPatterns(graph, patterns)
1817
1818def test_bmm_div_add_int8_fp32(self):
1819class M(nn.Module):
1820def __init__(self):
1821super(M, self).__init__()
1822self.num_attention_heads = 16
1823self.attention_head_size = 4
1824
1825def forward(self, x, y, z):
1826new_x_shape = x.size()[:-1] + (
1827self.num_attention_heads,
1828self.attention_head_size,
1829)
1830x = x.view(*new_x_shape)
1831z1 = x.permute(0, 2, 1, 3)
1832
1833new_y_shape2 = y.size()[:-1] + (
1834self.num_attention_heads,
1835self.attention_head_size,
1836)
1837y = y.view(*new_y_shape2)
1838z2 = y.permute(0, 2, 1, 3)
1839
1840# inputs to matmul has been permuted or transposed, thus are strided tensor
1841s = torch.matmul(z1, z2.transpose(-1, -2)) / 0.4
1842s = s + z
1843return s
1844
1845m = M()
1846x = torch.randn(2, 3, 64)
1847y = torch.randn(2, 3, 64)
1848z = torch.randn(2, 1, 1, 3)
1849
1850patterns = [
1851["aten::dequantize", "aten::matmul", "aten::div", "aten::add"],
1852]
1853
1854graph = self.checkQuantizeTrace(m, [x, y, z], atol=2e-1)
1855self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1856self.assertFused(
1857graph, ["aten::matmul", "aten::dequantize", "aten::div", "aten::add"]
1858)
1859self.checkPatterns(graph, patterns)
1860
1861@unittest.skip("Graph Compiler unit-test")
1862def test_mha_pattern_int8_fp32(self):
1863class M(torch.nn.Module):
1864def __init__(self):
1865super(M, self).__init__()
1866self.linear = nn.Linear(1024, 1024, False)
1867
1868def forward(self, x, y, z, a):
1869x = x.permute(0, 2, 1, 3)
1870
1871y = y.permute(0, 2, 1, 3)
1872y = y.transpose(-1, -2)
1873
1874z = z.permute(0, 2, 1, 3)
1875tmp = torch.matmul(x, y) / 8.0 + a
1876tmp = torch.softmax(tmp, -1)
1877tmp = tmp.matmul(z)
1878tmp = tmp.permute(0, 2, 1, 3)
1879tmp = tmp.contiguous()
1880tmp = tmp.view(1, 16, 1024)
1881tmp = self.linear(tmp)
1882return tmp
1883
1884x = torch.randn(1, 16, 16, 64)
1885y = torch.randn(1, 16, 16, 64)
1886z = torch.randn(1, 16, 16, 64)
1887m = M()
1888a = torch.randn(1, 1, 1, 16)
1889graph = self.checkQuantizeTrace(m, [x, y, z, a], atol=2e-1)
1890self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
1891self.assertFused(
1892graph,
1893[
1894"aten::matmul",
1895"aten::div",
1896"aten:add",
1897"aten:softmax",
1898"aten::contiguous",
1899"aten::dequantize",
1900],
1901)
1902
1903@unittest.skip("Graph Compiler unit-test")
1904def test_mha_pattern_int8_bf16(self):
1905class M(torch.nn.Module):
1906def __init__(self):
1907super(M, self).__init__()
1908self.linear = nn.Linear(1024, 1024, False)
1909
1910def forward(self, x, y, z, a):
1911x = x.permute(0, 2, 1, 3)
1912
1913y = y.permute(0, 2, 1, 3)
1914y = y.transpose(-1, -2)
1915
1916z = z.permute(0, 2, 1, 3)
1917tmp = torch.matmul(x, y) / 8.0 + a
1918tmp = torch.softmax(tmp, -1)
1919tmp = tmp.matmul(z)
1920tmp = tmp.permute(0, 2, 1, 3)
1921tmp = tmp.contiguous()
1922tmp = tmp.view(1, 16, 1024)
1923tmp = self.linear(tmp)
1924return tmp
1925
1926x = torch.randn(1, 16, 16, 64)
1927y = torch.randn(1, 16, 16, 64)
1928z = torch.randn(1, 16, 16, 64)
1929m = M()
1930a = torch.randn(1, 1, 1, 16, dtype=torch.bfloat16)
1931graph = self.checkQuantizeTrace(
1932m,
1933[x, y, z, a],
1934atol=2e-1,
1935config_name="mha_pattern",
1936qscheme=torch.per_tensor_affine,
1937int8_bf16=True,
1938)
1939self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 5)
1940self.assertFused(
1941graph,
1942[
1943"aten::matmul",
1944"aten::div",
1945"aten:add",
1946"aten:softmax",
1947"aten::contiguous",
1948"aten::dequantize",
1949"aten::quantize_per_tensor",
1950],
1951)
1952
1953def test_bmm_div_add_int8_bf16(self):
1954class M(nn.Module):
1955def __init__(self):
1956super(M, self).__init__()
1957self.num_attention_heads = 16
1958self.attention_head_size = 4
1959
1960def forward(self, x, y, z):
1961new_x_shape = x.size()[:-1] + (
1962self.num_attention_heads,
1963self.attention_head_size,
1964)
1965x = x.view(*new_x_shape)
1966z1 = x.permute(0, 2, 1, 3)
1967
1968new_y_shape2 = y.size()[:-1] + (
1969self.num_attention_heads,
1970self.attention_head_size,
1971)
1972y = y.view(*new_y_shape2)
1973z2 = y.permute(0, 2, 1, 3)
1974
1975# inputs to matmul has been permuted or transposed, thus are strided tensor
1976s = torch.matmul(z1, z2.transpose(-1, -2)) / 0.4
1977s = s + z.to(s.dtype)
1978return s
1979
1980m = M()
1981x = torch.randn(2, 3, 64)
1982y = torch.randn(2, 3, 64)
1983z = torch.randn(2, 1, 1, 3)
1984
1985patterns = [
1986["aten::dequantize", "aten::to", "aten::matmul", "aten::div", "aten::add"],
1987]
1988
1989graph = self.checkQuantizeTrace(m, [x, y, z], atol=2e-1, int8_bf16=True)
1990self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
1991self.assertFused(
1992graph, ["aten::matmul", "aten::dequantize", "aten::div", "aten::add"]
1993)
1994self.checkPatterns(graph, patterns)
1995
1996def test_split_dequant_to(self):
1997class M(nn.Module):
1998def __init__(self):
1999super(M, self).__init__()
2000self.linear1 = nn.Linear(2, 1, bias=True)
2001self.linear2 = nn.Linear(2, 1, bias=True)
2002self.linear3 = nn.Linear(2, 1, bias=True)
2003
2004def forward(self, x):
2005a = self.linear1(x)
2006b = self.linear2(x)
2007c = self.linear3(x)
2008return torch.cat([a, b, c])
2009
2010# The below pattern:
2011# quant
2012# |
2013# dequant
2014# |
2015# to
2016# / | \
2017# linear linear linear
2018# | | |
2019#
2020# should be transformed to:
2021# to
2022# |
2023# quant
2024# / | \
2025# dequant dequant dequant
2026# | | |
2027# to to to
2028# | | |
2029# linear linear linear
2030# | | |
2031
2032patterns = [
2033["aten::dequantize", "aten::to", "aten::linear"],
2034["aten::dequantize", "aten::to", "aten::linear"],
2035["aten::dequantize", "aten::to", "aten::linear"],
2036]
2037m = M()
2038x = torch.randn(2, 2)
2039graph = self.checkQuantizeTrace(m, [x], atol=2e-1, int8_bf16=True)
2040self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
2041# single aten::to won't be rewritten by llga backend
2042self.assertFused(graph, ["aten::dequantize", "aten::linear"])
2043self.checkPatterns(graph, patterns)
2044
2045def test_dequant_remove_attr(self):
2046class M(nn.Module):
2047def __init__(self):
2048super(M, self).__init__()
2049
2050def forward(self, x):
2051x = torch.quantize_per_channel(
2052x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8
2053)
2054x = torch.dequantize(x)
2055return x
2056
2057x = x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
2058m = M()
2059traced = torch.jit.trace(m, x)
2060traced(x)
2061graph = traced.graph_for(x)
2062self.checkAttr(graph, "aten::dequantize", "qtype")
2063
2064def test_fx_converted_model(self):
2065class M(nn.Module):
2066def __init__(self):
2067super(M, self).__init__()
2068self.linear = nn.Linear(15, 20)
2069
2070def forward(self, x):
2071x = self.linear(x)
2072return x
2073
2074x = x = torch.randn(2, 15)
2075m = M()
2076m.eval()
2077
2078qconfig_dict = {"": static_qconfig[0]}
2079
2080m = prepare_fx(m, qconfig_dict, x)
2081m = convert_fx(m)
2082graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
2083self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)
2084
2085def test_fx_ao_qat_converted_model(self):
2086class M(nn.Module):
2087def __init__(self):
2088super(M, self).__init__()
2089self.linear = nn.Linear(15, 20)
2090
2091def forward(self, x):
2092x = self.linear(x)
2093return x
2094
2095x = x = torch.randn(2, 15)
2096m = M()
2097m.eval()
2098
2099qconfig_dict = {"": static_qconfig[0]}
2100
2101m = prepare_qat_fx(m, qconfig_dict, x)
2102m = convert_to_reference_fx(m)
2103graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
2104# dequant -> linear should be mapped to LLGA
2105self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
2106
2107@unittest.skipIf(True, "Poor accuracy")
2108@skipIfNoTorchVision
2109def test_fx_ao_qat_model(self):
2110class M(nn.Module):
2111def __init__(self):
2112super(M, self).__init__()
2113self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
2114self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
2115self.eltwise = torch.nn.ReLU()
2116
2117def forward(self, x):
2118x = self.conv1(x)
2119x = self.eltwise(x)
2120x = self.conv2(x)
2121return x
2122
2123data = torch.randn(1, 32, 224, 224).to(memory_format=torch.channels_last)
2124m = M()
2125m.eval()
2126#
2127# quantization aware training for static quantization
2128#
2129qconfig_dict = {"": torch.quantization.get_default_qat_qconfig("fbgemm")}
2130m.train()
2131model_prepared = prepare_qat_fx(m, qconfig_dict, example_inputs=data)
2132model_quantized = convert_to_reference_fx(model_prepared)
2133model_quantized = model_quantized.eval()
2134model = model_quantized.to(memory_format=torch.channels_last)
2135graph = self.checkQuantizeTrace(model, [data], atol=2e-1)
2136self.checkPatterns(
2137graph,
2138[
2139[
2140"aten::dequantize",
2141"aten::quantize_per_channel",
2142"aten::_convolution",
2143"aten::relu",
2144"aten::quantize_per_tensor",
2145],
2146[
2147"aten::dequantize",
2148"aten::quantize_per_channel",
2149"aten::_convolution",
2150"aten::quantize_per_tensor",
2151],
2152],
2153)
2154
2155def test_ffn_residual(self):
2156class FFN_Residual(nn.Module):
2157def __init__(self, hidden_size, intermediate_size):
2158super(FFN_Residual, self).__init__()
2159self.linear1 = nn.Linear(hidden_size, intermediate_size)
2160self.linear2 = nn.Linear(intermediate_size, hidden_size)
2161self.LayerNorm1 = nn.LayerNorm(hidden_size)
2162self.LayerNorm2 = nn.LayerNorm(hidden_size)
2163self.intermediate_act_fn = nn.functional.gelu
2164
2165def forward(self, x):
2166x1 = self.LayerNorm1(x)
2167x2 = self.linear1(x1)
2168x3 = self.intermediate_act_fn(x2)
2169x4 = self.linear2(x3)
2170x5 = self.LayerNorm2(x4 + x)
2171return x5
2172
2173patterns = [
2174[
2175"aten::dequantize",
2176"aten::linear",
2177"aten::gelu",
2178"aten::quantize_per_tensor",
2179],
2180["aten::dequantize", "aten::linear", "aten::add"],
2181]
2182m = FFN_Residual(1024, 4096).eval()
2183x = torch.rand(128, 1024)
2184graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
2185self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2)
2186self.assertFused(graph, ["aten::linear", "aten::gelu"])
2187self.assertFused(graph, ["aten::linear", "aten::add"])
2188self.checkPatterns(graph, patterns)
2189
2190def test_inplace_computation_accuracy(self):
2191class LowRankCrossNet(nn.Module):
2192def __init__(
2193self, in_features: int, num_layers: int, low_rank: int
2194) -> None:
2195super().__init__()
2196assert low_rank >= 1, "Low rank must be larger or equal to 1"
2197self._num_layers = num_layers
2198self._low_rank = low_rank
2199W_kernels: nn.ParameterList = nn.ParameterList()
2200for i in range(self._num_layers):
2201Wp = nn.Parameter(torch.randn(in_features, self._low_rank))
2202W_kernels.append(Wp)
2203V_kernels: nn.ParameterList = nn.ParameterList()
2204for i in range(self._num_layers):
2205V_kernels.append(
2206nn.Parameter(torch.randn(self._low_rank, in_features))
2207)
2208bias: nn.ParameterList = nn.ParameterList(
2209[
2210nn.Parameter(nn.init.zeros_(torch.empty(in_features)))
2211for i in range(self._num_layers)
2212]
2213)
2214self.MLPs = nn.ModuleDict()
2215for i in range(num_layers):
2216self.MLPs[f"V{i}"] = nn.Linear(in_features, low_rank, bias=False)
2217self.MLPs[f"W{i}"] = nn.Linear(low_rank, in_features, bias=True)
2218self.MLPs[f"V{i}"].weight = V_kernels[i]
2219self.MLPs[f"W{i}"].weight = W_kernels[i]
2220self.MLPs[f"W{i}"].bias = bias[i]
2221
2222def forward(self, input: torch.Tensor) -> torch.Tensor:
2223x_0 = input
2224x_l = x_0 # .clone()
2225for layer in range(self._num_layers):
2226x_l_v = self.MLPs[f"V{layer}"](x_l)
2227x_l_w = self.MLPs[f"W{layer}"](x_l_v)
2228x_l = x_0 * x_l_w + x_l # (B, N)
2229return x_l, x_0
2230
2231class FakeQuant(nn.Module):
2232def __init__(self):
2233super().__init__()
2234
2235def forward(self, x):
2236x = torch.quantize_per_tensor(x, 0.1, 0, torch.qint8)
2237return x.dequantize()
2238
2239class TinyDLRM(nn.Module):
2240def __init__(self):
2241super().__init__()
2242self.pre_model = FakeQuant()
2243self.cross_net = LowRankCrossNet(2, 2, 5)
2244
2245def forward(self, x):
2246out = self.pre_model(x)
2247out = self.cross_net(out)
2248return out
2249
2250m = TinyDLRM().eval()
2251x = torch.rand(2048, 2)
2252graph = self.checkQuantizeTrace(m, [x], atol=2e-1)
2253print(graph)
2254self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
2255self.assertFused(graph, ["aten::linear", "aten::mul", "aten::add"])
2256
2257
2258class TestShapeFallback(JitLlgaTestCase):
2259@unittest.skipIf(True, "Size peephole optimization not enabled yet")
2260def test_view_permute(self):
2261class M(nn.Module):
2262def __init__(self):
2263super(M, self).__init__()
2264
2265def forward(self, x):
2266new_x_shape = x.size()[:-1] + (3, 5)
2267x = x.view(*new_x_shape)
2268return x.permute(0, 2, 1, 3)
2269
2270x = torch.randn(5, 10, 15)
2271m = M()
2272
2273for qconfig in static_qconfig:
2274graph = self.checkQuantizeTrace(m, [x], qconfig=qconfig)
2275self.assertGraphContainsExactly(graph, "aten::size", 0)
2276self.assertGraphContainsExactly(graph, "prim::ListConstruct", 0)
2277
2278# change the size of the input
2279x2 = torch.randn(6, 4, 15)
2280# Bailout get triggered here
2281y2 = m(x2)
2282
2283def test_conv_reshape(self):
2284class M(nn.Module):
2285def __init__(self):
2286super(M, self).__init__()
2287self.conv1 = nn.Conv2d(4, 4, 3, padding=1, bias=True)
2288self.conv2 = nn.Conv2d(4, 32, 3, padding=1, bias=True)
2289
2290def forward(self, x):
2291x = self.conv1(x)
2292x = self.conv2(x).reshape(x.size(0), 4, -1)
2293return x
2294
2295for memory_format in [torch.contiguous_format, torch.channels_last]:
2296x = torch.randn(15, 4, 28, 28).to(memory_format=memory_format)
2297# change the size of the input, check the fallback
2298x_var = torch.randn(7, 4, 16, 16).to(memory_format=memory_format)
2299m = M()
2300for qconfig in static_qconfig:
2301graph = self.checkQuantizeTrace(
2302m, [x], x_var=[x_var], atol=2e-1, qconfig=qconfig
2303)
2304
2305# TODO: enable this check when size peephole optimization is enabled
2306# self.assertGraphContainsExactly(graph, "aten::size", 0)
2307
2308def test_add_recipe(self):
2309class ConvAddRelu(nn.Module):
2310def __init__(self, in_channels, out_channels, kernel_size, image_size):
2311super(ConvAddRelu, self).__init__()
2312self.conv = torch.nn.Conv2d(
2313in_channels, out_channels, kernel_size, image_size
2314)
2315
2316def forward(self, x1, x2):
2317return torch.relu(torch.add(self.conv(x1), x2))
2318
2319class ConvAdd(nn.Module):
2320def __init__(self, in_channels, out_channels, kernel_size, image_size):
2321super(ConvAdd, self).__init__()
2322self.conv = torch.nn.Conv2d(
2323in_channels, out_channels, kernel_size, image_size
2324)
2325
2326def forward(self, x1, x2):
2327return torch.add(self.conv(x1), x2)
2328
2329for memory_format in [torch.contiguous_format, torch.channels_last]:
2330conv_add_relu = ConvAddRelu(3, 16, 3, 2)
2331conv_add = ConvAdd(3, 16, 3, 2)
2332x1 = torch.rand(1, 3, 224, 224, requires_grad=False).to(
2333memory_format=memory_format
2334)
2335x2 = torch.rand(1, 16, 111, 111, requires_grad=False).to(
2336memory_format=memory_format
2337)
2338input = [x1, x2]
2339graph1 = self.checkQuantizeTrace(conv_add_relu, input, atol=1e-2)
2340self.assertGraphContainsExactly(graph1, "aten::quantize_per_tensor", 2)
2341graph2 = self.checkQuantizeTrace(conv_add, input, atol=1e-2)
2342self.assertGraphContainsExactly(graph2, "aten::quantize_per_tensor", 1)
2343
2344
2345class TestModel(JitLlgaTestCase):
2346@skipIfNoTorchVision
2347def _test_vision(self, model_name):
2348for memory_format in [torch.contiguous_format, torch.channels_last]:
2349m = getattr(torchvision.models, model_name)().eval()
2350x = (torch.rand(1, 3, 224, 224) / 10).to(memory_format=memory_format)
2351
2352for qconfig in static_qconfig:
2353graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=qconfig)
2354
2355# TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
2356self.assertFused(
2357graph,
2358[
2359"aten::_convolution",
2360"aten::relu",
2361"aten::max_pool2d",
2362"aten::linear",
2363"aten::quantize_per_channel",
2364],
2365)
2366# large partition: 7 fusion group in total
2367self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 7)
2368
2369
2370for model_name, enabled in [
2371["resnet50", True],
2372]:
2373
2374def wrapper(mname):
2375@unittest.skipIf(not enabled, "Disabled")
2376def test(self):
2377return self._test_vision(mname)
2378
2379return test
2380
2381setattr(TestModel, "test_vision_%s" % model_name, wrapper(model_name))
2382
2383if __name__ == "__main__":
2384run_tests()
2385