2
# Owner(s): ["oncall: mobile"]
7
from typing import Tuple
10
from torch.backends._nnapi.prepare import convert_model_to_nnapi
11
from torch.testing._internal.common_quantized import supported_qengines
12
from torch.testing._internal.common_utils import run_tests, TestCase
15
def qpt(t, scale, zero_point, dtype=torch.quint8):
17
return torch.quantize_per_tensor(t, scale, zero_point, dtype)
21
t = t.clone().contiguous(memory_format=torch.channels_last)
27
"qnnpack" in supported_qengines,
28
"This Pytorch Build has not been built with or does not support QNNPACK",
30
class TestNNAPI(TestCase):
32
# Avoid saturation in fbgemm
33
torch.backends.quantized.engine = "qnnpack"
35
libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH")
36
if libneuralnetworks_path:
37
ctypes.cdll.LoadLibrary(libneuralnetworks_path)
38
print("Will attempt to run NNAPI models.")
39
self.can_run_nnapi = True
41
self.can_run_nnapi = False
43
# Created for easy override by subclasses (eg TestNnapiBackend)
44
def call_lowering_to_nnapi(self, traced_module, args):
45
return convert_model_to_nnapi(traced_module, args)
47
# Created for subclasses to set can_run_nnapi (eg TestNnapiBackend)
48
def set_can_run_nnapi(self, can_run):
49
self.can_run_nnapi = can_run
60
expected_memory_format=None,
63
if isinstance(arg_or_args, torch.Tensor):
68
traced = torch.jit.trace(module, trace_args or args)
69
nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args)
70
if not self.can_run_nnapi:
71
# Only test that the model was converted successfully.
73
eager_output = module(*args)
74
nnapi_output = nnapi_module(*args)
76
if atol_rtol is not None:
77
kwargs["atol"] = atol_rtol[0]
78
kwargs["rtol"] = atol_rtol[1]
79
self.assertEqual(eager_output, nnapi_output, **kwargs)
81
mismatches = eager_output.int_repr().to(
83
) - nnapi_output.int_repr().to(torch.int32)
84
if mismatches.count_nonzero() > limit:
85
# Too many mismatches. Re-run the check with no tolerance
86
# to get a nice message.
87
self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
88
if expected_memory_format:
90
nnapi_output.is_contiguous(memory_format=expected_memory_format)
93
def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
95
inp_quant = qpt(inp_float, 0.03, 128)
98
("float-nhwc", nhwc(inp_float)),
100
("quant-nhwc", nhwc(inp_quant)),
103
def test_prelu(self):
104
arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
105
single_a = torch.nn.PReLU()
106
self.check(single_a, arg)
107
multi_a = torch.nn.PReLU(4)
108
with torch.no_grad():
109
multi_a.weight.copy_(torch.tensor([0.1, 0.2, 0.3, 0.4]))
110
self.check(multi_a, nhwc(arg))
116
trace_args=[torch.zeros(1, 4, 3, 3)],
117
convert_args=[nhwc(torch.zeros(1, 4, 0, 0))],
120
def test_quantize(self):
122
torch.ao.nn.quantized.Quantize(0.25, 2, torch.quint8),
123
nhwc(torch.tensor([[[[1.0]], [[2.0]]]])),
126
def test_dequantize(self):
128
torch.ao.nn.quantized.DeQuantize(), nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2))
131
def test_unsqueeze(self):
132
class UnsqueezeModule(torch.nn.Module):
133
def __init__(self, dim):
137
def forward(self, arg):
138
return arg.unsqueeze(self.dim)
140
self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
141
self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
142
self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
143
self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
144
self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))
146
def test_reshape(self):
147
class ReshapeModule(torch.nn.Module):
148
def __init__(self, shape):
152
def forward(self, arg):
153
return arg.reshape(self.shape)
155
self.check(ReshapeModule((2, 4)), torch.randn(4, 2, 1, 1))
157
self.check(ReshapeModule((8, -1)), nhwc(torch.randn(4, 2, 1, 1)))
159
with self.assertRaisesRegex(Exception, "target size"):
160
self.check(ReshapeModule((2, 4)), nhwc(torch.randn(4, 2, 1, 1)))
162
def test_flatten(self):
165
torch.nn.Flatten(start_dim=2, end_dim=3),
166
torch.nn.Flatten(start_dim=2, end_dim=4),
167
torch.nn.Flatten(start_dim=0, end_dim=-2),
168
torch.nn.Flatten(start_dim=0, end_dim=4),
170
self.check(mod, torch.randn(4, 2, 1, 3, 7))
175
torch.randn(4, 2, 1, 3, 7),
176
convert_args=[torch.zeros(0, 2, 1, 3, 7)],
180
self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 1, 4, 7)))
181
self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 3, 1, 1)))
184
with self.assertRaisesRegex(Exception, "not supported on NHWC"):
185
self.check(torch.nn.Flatten(), nhwc(torch.randn(1, 3, 4, 4)))
186
with self.assertRaisesRegex(
187
Exception, "Flattening flexible dims is not supported yet"
189
self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
190
with self.assertRaisesRegex(Exception, "Only 1 dim"):
192
torch.nn.Flatten(start_dim=1, end_dim=-2), torch.randn(0, 2, 1, 3, 0)
195
def test_slice(self):
196
class SliceModule(torch.nn.Module):
197
def __init__(self, start, stop, step):
203
def forward(self, t):
204
return t[1:, self.start : self.stop : self.step, :]
206
class SliceModule2(torch.nn.Module):
207
def forward(self, t):
210
self.check(SliceModule(1, 5, 2), torch.randn(4, 6, 2))
211
self.check(SliceModule2(), torch.randn(5))
215
SliceModule(1, 5, 2),
216
torch.randn(4, 6, 2),
217
convert_args=[torch.zeros(4, 6, 0)],
219
with self.assertRaisesRegex(Exception, "slice with flexible shape"):
221
SliceModule(1, 5, 2),
222
torch.randn(4, 6, 2),
223
convert_args=[torch.zeros(0, 0, 0)],
227
class CatModule(torch.nn.Module):
228
def __init__(self, dim):
232
def forward(self, t1, t2):
233
return torch.cat([t1, t2], self.dim)
238
torch.randn(1, 2, 3, 3),
239
torch.randn(2, 2, 3, 3),
246
torch.randn(1, 2, 3, 3),
247
torch.randn(1, 4, 3, 3),
254
nhwc(torch.randn(1, 2, 3, 3)),
255
nhwc(torch.randn(1, 4, 3, 3)),
262
torch.randn(1, 2, 3, 3),
263
torch.randn(1, 4, 3, 3),
265
convert_args=[torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0)],
268
def test_pointwise_unary(self):
269
for op in ["relu", "sigmoid"]:
270
with self.subTest(op):
272
class UnaryModule(torch.nn.Module):
273
def forward(self, arg):
275
return torch.nn.functional.relu(arg)
277
return torch.sigmoid(arg)
278
raise Exception("Bad op") # noqa: TRY002
280
self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
283
qpt(torch.tensor([-1.0, 1.0]), 1.0 / 256, 0),
286
def test_pointwise_binary(self):
287
for op in ["add", "sub", "mul", "div"]:
288
with self.subTest(op):
290
class BinaryModule(torch.nn.Module):
291
def forward(self, lhs, rhs):
300
raise Exception("Bad op") # noqa: TRY002
305
torch.tensor([1.0, 2.0]),
306
torch.tensor([3.0, 4.0]),
313
torch.tensor([[1.0, 2.0]]),
314
torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
318
with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"):
322
torch.tensor([1.0, 2.0]),
323
torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
327
def test_pointwise_binary_const(self):
328
const = torch.randn(1, 4, 6, 6)
330
class ArgPlusConst(torch.nn.Module):
331
def forward(self, arg):
334
class ConstPlusArg(torch.nn.Module):
335
def forward(self, arg):
338
arg_contig = torch.randn(2, 4, 6, 6)
339
arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))
341
for mod_class in [ArgPlusConst, ConstPlusArg]:
342
for use_nhwc in [False, True]:
343
with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
344
arg = arg_nhwc if use_nhwc else arg_contig
346
torch.channels_last if use_nhwc else torch.contiguous_format
348
self.check(mod_class(), arg, expected_memory_format=memory_format)
350
def test_hardtanh(self):
351
inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
352
self.check(torch.nn.Hardtanh(), inp)
353
self.check(torch.nn.Hardtanh(0.0, 6.0), inp)
354
with self.assertRaisesRegex(Exception, "hardtanh with args"):
355
self.check(torch.nn.Hardtanh(0.0, 5.0), inp)
357
def test_softmax(self):
358
inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]])
359
self.check(torch.nn.Softmax(), inp)
360
self.check(torch.nn.Softmax(dim=0), inp)
365
convert_args=[torch.zeros(0, 0)],
369
class ToCPU(torch.nn.Module):
370
def __init__(self) -> None:
372
self.prelu = torch.nn.PReLU()
374
def forward(self, x):
376
# add prelu since input operand can't be output
379
arg = torch.randn(1, 2, 3, 3)
380
self.check(ToCPU(), arg)
385
convert_args=[torch.zeros(1, 2, 0, 0)],
388
def test_detach(self):
389
class DetachModule(torch.nn.Module):
390
def forward(self, x):
392
return torch.nn.functional.relu(y)
394
self.check(DetachModule(), torch.randn(1, 2, 3, 3))
397
torch.randn(1, 2, 3, 3),
398
convert_args=[torch.zeros(1, 2, 0, 0)],
401
def test_log_softmax(self):
402
inp = torch.randn(3, 10)
403
self.check(torch.nn.LogSoftmax(), inp)
404
self.check(torch.nn.LogSoftmax(0), inp)
407
class MeanModule(torch.nn.Module):
408
def __init__(self, dim, keep=False):
413
def forward(self, t):
414
return torch.mean(t, dim=self.dim, keepdim=self.keep)
416
self.check(MeanModule(0), torch.randn(2, 3))
417
self.check(MeanModule(1), torch.randn(2, 3))
418
self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
419
self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
420
self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
421
self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))
423
def test_max_pool2d(self):
424
for name, inp in self.float_and_quant_and_nhwc(
425
torch.randn(2, 3, 12, 16), 0.3, 128
427
with self.subTest(name):
428
self.check(torch.nn.MaxPool2d(2), inp)
429
self.check(torch.nn.MaxPool2d((3, 4)), inp)
430
self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)
432
def test_avg_pool2d(self):
433
for name, inp in self.float_and_quant_and_nhwc(
434
torch.randn(2, 3, 12, 16), 0.3, 128
436
with self.subTest(name):
439
convert_dims = (2, 3, 0, 0)
440
convert_arg = torch.zeros(*convert_dims)
443
torch.nn.AvgPool2d(2),
444
torch.nn.AvgPool2d((3, 4)),
445
torch.nn.AvgPool2d((3, 4), (1, 2)),
449
limit = model(inp).numel()
450
convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
452
convert_arg = nhwc(convert_arg)
454
self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
458
convert_args=[convert_arg],
463
def test_adaptive_avg_pool2d(self):
464
for name, inp in self.float_and_quant_and_nhwc(
465
torch.randn(2, 3, 12, 16), 0.3, 128
467
with self.subTest(name):
468
self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp)
469
with self.assertRaisesRegex(Exception, "with output size"):
470
self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp)
472
def test_upsample_nearest2d(self):
474
self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128)
476
for name, inp in self.float_and_quant_and_nhwc(
477
torch.randn(2, 3, 12, 16), 0.3, 128
479
with self.subTest(name):
480
self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp)
481
self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp)
482
self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp)
483
self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp)
484
self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp)
485
self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp)
488
torch.nn.UpsamplingNearest2d(size=(24, 32)),
490
convert_args=[convert_args[name]],
493
torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)),
495
convert_args=[convert_args[name]],
498
def test_linear(self):
499
torch.manual_seed(29)
500
self.check(torch.nn.Linear(16, 32), torch.randn(2, 16))
502
torch.nn.Linear(16, 32),
504
convert_args=[torch.zeros(0, 16)],
507
def test_conv2d(self):
509
# in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name
510
(4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"), # noqa: E201,E241
511
(4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"), # noqa: E201,E241
512
(4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"), # noqa: E201,E241
513
(8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"), # noqa: E201,E241
514
(4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"), # noqa: E201,E241
515
(4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"), # noqa: E201,E241
516
(8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"), # noqa: E201,E241
519
for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
532
with self.subTest(f"{kind}-{name}"):
533
inp = torch.randn(input_dim)
534
model = torch.nn.Conv2d(
543
output_size = model(inp).numel()
546
convert_dims = (0, in_ch, 0, 0)
547
convert_arg = torch.zeros(*convert_dims)
550
model = torch.nn.Sequential(model)
552
model.qconfig = torch.ao.quantization.get_default_qconfig(
555
model = torch.ao.quantization.prepare(model)
557
model = torch.ao.quantization.convert(model)
558
inp = qpt(inp, 1.0 / 16, 128)
559
# I've seen numerical differences between QNNPACK and NNAPI,
560
# but never more than 1 quantum, and never more than ~1% of
561
# the output in this test.
563
limit = output_size * 0.03
564
convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
568
convert_arg = nhwc(convert_arg)
570
self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
574
convert_args=[convert_arg],
579
def test_conv2d_transpose(self):
580
torch.manual_seed(29)
581
in_ch, out_ch, kernel = (5, 7, (2, 2))
582
input_dim = (4, 5, 3, 3)
583
convert_dims = input_dim[:2] + (0, 0)
585
for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
586
with self.subTest(kind):
587
inp = torch.randn(input_dim)
588
model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel)
589
output_size = model(inp).numel()
590
atol_rtol = (0.0002, 0)
592
convert_arg = torch.zeros(*convert_dims)
595
model = torch.ao.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel)
596
model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
597
inp = qpt(inp, 1.0 / 16, 128)
598
# I've seen numerical differences between QNNPACK and NNAPI,
599
# but never more than 1 quantum, and never more than ~10% of
600
# the output in this test.
602
limit = output_size * 0.1
603
convert_arg = qpt(convert_arg, 1.0 / 16, 128)
607
convert_arg = nhwc(convert_arg)
609
self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
613
convert_args=[convert_arg],
619
func = torch.ao.nn.quantized.QFunctional()
621
func.zero_point = 120
623
class AddMod(torch.nn.Module):
624
def forward(self, lhs, rhs):
625
return func.add(lhs, rhs)
627
class AddReluMod(torch.nn.Module):
628
def forward(self, lhs, rhs):
629
return func.add_relu(lhs, rhs)
631
class MulMod(torch.nn.Module):
632
def forward(self, lhs, rhs):
633
return func.mul(lhs, rhs)
635
for name, mod in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]:
636
with self.subTest(name):
640
qpt([1.0, 2.0], 0.25, 128),
641
qpt([3.0, 4.0], 0.25, 128),
647
qpt([[1.0, 2.0]], 0.25, 128),
648
qpt([[3.0, 4.0]], 0.25, 128),
651
qpt([[1.0, 2.0]], 0.25, 128),
652
qpt(torch.zeros((1, 2)), 0.25, 128),
658
qpt([[1.0, 2.0]], 0.25, 128),
659
qpt([[3.0, 4.0]], 0.25, 128),
662
qpt(torch.zeros((1, 2)), 0.25, 128),
663
qpt([[3.0, 4.0]], 0.25, 128),
669
qpt([[1.0, 2.0]], 0.25, 128),
670
qpt([[3.0, 4.0]], 0.25, 128),
673
qpt(torch.zeros((1, 2)), 0.25, 128),
674
qpt(torch.zeros((1, 2)), 0.25, 128),
677
# NOTE: NNAPI qadd supports broadcast, but PT does not.
679
def test_qlinear(self):
680
torch.manual_seed(29)
681
weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8)
682
bias = torch.randn(16)
683
mod = torch.ao.nn.quantized.Linear(32, 16)
684
mod.set_weight_bias(weight, bias)
685
inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
688
def test_seblock_mul(self):
689
class MulModel(torch.nn.Module):
690
def forward(self, lhs, rhs):
696
nhwc(torch.randn(2, 3, 4, 4)),
697
torch.randn(1, 3, 1, 1),
701
def test_multi_output(self):
702
class MultiModel(torch.nn.Module):
703
def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
706
return the_sum, the_diff
708
self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])])
711
if __name__ == "__main__":