pytorch

Форк
0
/
test_nnapi.py 
712 строк · 25.1 Кб
1
#!/usr/bin/env python3
2
# Owner(s): ["oncall: mobile"]
3

4
import ctypes
5
import os
6
import unittest
7
from typing import Tuple
8

9
import torch
10
from torch.backends._nnapi.prepare import convert_model_to_nnapi
11
from torch.testing._internal.common_quantized import supported_qengines
12
from torch.testing._internal.common_utils import run_tests, TestCase
13

14

15
def qpt(t, scale, zero_point, dtype=torch.quint8):
16
    t = torch.tensor(t)
17
    return torch.quantize_per_tensor(t, scale, zero_point, dtype)
18

19

20
def nhwc(t):
21
    t = t.clone().contiguous(memory_format=torch.channels_last)
22
    t.nnapi_nhwc = True
23
    return t
24

25

26
@unittest.skipUnless(
27
    "qnnpack" in supported_qengines,
28
    "This Pytorch Build has not been built with or does not support QNNPACK",
29
)
30
class TestNNAPI(TestCase):
31
    def setUp(self):
32
        # Avoid saturation in fbgemm
33
        torch.backends.quantized.engine = "qnnpack"
34

35
        libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH")
36
        if libneuralnetworks_path:
37
            ctypes.cdll.LoadLibrary(libneuralnetworks_path)
38
            print("Will attempt to run NNAPI models.")
39
            self.can_run_nnapi = True
40
        else:
41
            self.can_run_nnapi = False
42

43
    # Created for easy override by subclasses (eg TestNnapiBackend)
44
    def call_lowering_to_nnapi(self, traced_module, args):
45
        return convert_model_to_nnapi(traced_module, args)
46

47
    # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend)
48
    def set_can_run_nnapi(self, can_run):
49
        self.can_run_nnapi = can_run
50

51
    def check(
52
        self,
53
        module,
54
        arg_or_args,
55
        *,
56
        trace_args=None,
57
        convert_args=None,
58
        atol_rtol=None,
59
        limit=None,
60
        expected_memory_format=None,
61
    ):
62
        with torch.no_grad():
63
            if isinstance(arg_or_args, torch.Tensor):
64
                args = [arg_or_args]
65
            else:
66
                args = arg_or_args
67
            module.eval()
68
            traced = torch.jit.trace(module, trace_args or args)
69
            nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args)
70
            if not self.can_run_nnapi:
71
                # Only test that the model was converted successfully.
72
                return
73
            eager_output = module(*args)
74
            nnapi_output = nnapi_module(*args)
75
            kwargs = {}
76
            if atol_rtol is not None:
77
                kwargs["atol"] = atol_rtol[0]
78
                kwargs["rtol"] = atol_rtol[1]
79
            self.assertEqual(eager_output, nnapi_output, **kwargs)
80
            if limit is not None:
81
                mismatches = eager_output.int_repr().to(
82
                    torch.int32
83
                ) - nnapi_output.int_repr().to(torch.int32)
84
                if mismatches.count_nonzero() > limit:
85
                    # Too many mismatches.  Re-run the check with no tolerance
86
                    # to get a nice message.
87
                    self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
88
            if expected_memory_format:
89
                self.assertTrue(
90
                    nnapi_output.is_contiguous(memory_format=expected_memory_format)
91
                )
92

93
    def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
94
        torch.manual_seed(29)
95
        inp_quant = qpt(inp_float, 0.03, 128)
96
        return [
97
            ("float", inp_float),
98
            ("float-nhwc", nhwc(inp_float)),
99
            ("quant", inp_quant),
100
            ("quant-nhwc", nhwc(inp_quant)),
101
        ]
102

103
    def test_prelu(self):
104
        arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
105
        single_a = torch.nn.PReLU()
106
        self.check(single_a, arg)
107
        multi_a = torch.nn.PReLU(4)
108
        with torch.no_grad():
109
            multi_a.weight.copy_(torch.tensor([0.1, 0.2, 0.3, 0.4]))
110
        self.check(multi_a, nhwc(arg))
111

112
        # Test flexible size
113
        self.check(
114
            multi_a,
115
            arg,
116
            trace_args=[torch.zeros(1, 4, 3, 3)],
117
            convert_args=[nhwc(torch.zeros(1, 4, 0, 0))],
118
        )
119

120
    def test_quantize(self):
121
        self.check(
122
            torch.ao.nn.quantized.Quantize(0.25, 2, torch.quint8),
123
            nhwc(torch.tensor([[[[1.0]], [[2.0]]]])),
124
        )
125

126
    def test_dequantize(self):
127
        self.check(
128
            torch.ao.nn.quantized.DeQuantize(), nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2))
129
        )
130

131
    def test_unsqueeze(self):
132
        class UnsqueezeModule(torch.nn.Module):
133
            def __init__(self, dim):
134
                super().__init__()
135
                self.dim = dim
136

137
            def forward(self, arg):
138
                return arg.unsqueeze(self.dim)
139

140
        self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
141
        self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
142
        self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
143
        self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
144
        self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))
145

146
    def test_reshape(self):
147
        class ReshapeModule(torch.nn.Module):
148
            def __init__(self, shape):
149
                super().__init__()
150
                self.shape = shape
151

152
            def forward(self, arg):
153
                return arg.reshape(self.shape)
154

155
        self.check(ReshapeModule((2, 4)), torch.randn(4, 2, 1, 1))
156

157
        self.check(ReshapeModule((8, -1)), nhwc(torch.randn(4, 2, 1, 1)))
158

159
        with self.assertRaisesRegex(Exception, "target size"):
160
            self.check(ReshapeModule((2, 4)), nhwc(torch.randn(4, 2, 1, 1)))
161

162
    def test_flatten(self):
163
        for mod in [
164
            torch.nn.Flatten(),
165
            torch.nn.Flatten(start_dim=2, end_dim=3),
166
            torch.nn.Flatten(start_dim=2, end_dim=4),
167
            torch.nn.Flatten(start_dim=0, end_dim=-2),
168
            torch.nn.Flatten(start_dim=0, end_dim=4),
169
        ]:
170
            self.check(mod, torch.randn(4, 2, 1, 3, 7))
171

172
        # flex inputs
173
        self.check(
174
            torch.nn.Flatten(),
175
            torch.randn(4, 2, 1, 3, 7),
176
            convert_args=[torch.zeros(0, 2, 1, 3, 7)],
177
        )
178

179
        # channels last
180
        self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 1, 4, 7)))
181
        self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 3, 1, 1)))
182

183
        # Exceptions
184
        with self.assertRaisesRegex(Exception, "not supported on NHWC"):
185
            self.check(torch.nn.Flatten(), nhwc(torch.randn(1, 3, 4, 4)))
186
        with self.assertRaisesRegex(
187
            Exception, "Flattening flexible dims is not supported yet"
188
        ):
189
            self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
190
        with self.assertRaisesRegex(Exception, "Only 1 dim"):
191
            self.check(
192
                torch.nn.Flatten(start_dim=1, end_dim=-2), torch.randn(0, 2, 1, 3, 0)
193
            )
194

195
    def test_slice(self):
196
        class SliceModule(torch.nn.Module):
197
            def __init__(self, start, stop, step):
198
                super().__init__()
199
                self.start = start
200
                self.stop = stop
201
                self.step = step
202

203
            def forward(self, t):
204
                return t[1:, self.start : self.stop : self.step, :]
205

206
        class SliceModule2(torch.nn.Module):
207
            def forward(self, t):
208
                return t[3:]
209

210
        self.check(SliceModule(1, 5, 2), torch.randn(4, 6, 2))
211
        self.check(SliceModule2(), torch.randn(5))
212

213
        # flex inputs
214
        self.check(
215
            SliceModule(1, 5, 2),
216
            torch.randn(4, 6, 2),
217
            convert_args=[torch.zeros(4, 6, 0)],
218
        )
219
        with self.assertRaisesRegex(Exception, "slice with flexible shape"):
220
            self.check(
221
                SliceModule(1, 5, 2),
222
                torch.randn(4, 6, 2),
223
                convert_args=[torch.zeros(0, 0, 0)],
224
            )
225

226
    def test_cat(self):
227
        class CatModule(torch.nn.Module):
228
            def __init__(self, dim):
229
                super().__init__()
230
                self.dim = dim
231

232
            def forward(self, t1, t2):
233
                return torch.cat([t1, t2], self.dim)
234

235
        self.check(
236
            CatModule(0),
237
            [
238
                torch.randn(1, 2, 3, 3),
239
                torch.randn(2, 2, 3, 3),
240
            ],
241
        )
242

243
        self.check(
244
            CatModule(1),
245
            [
246
                torch.randn(1, 2, 3, 3),
247
                torch.randn(1, 4, 3, 3),
248
            ],
249
        )
250

251
        self.check(
252
            CatModule(1),
253
            [
254
                nhwc(torch.randn(1, 2, 3, 3)),
255
                nhwc(torch.randn(1, 4, 3, 3)),
256
            ],
257
        )
258

259
        self.check(
260
            CatModule(1),
261
            [
262
                torch.randn(1, 2, 3, 3),
263
                torch.randn(1, 4, 3, 3),
264
            ],
265
            convert_args=[torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0)],
266
        )
267

268
    def test_pointwise_unary(self):
269
        for op in ["relu", "sigmoid"]:
270
            with self.subTest(op):
271

272
                class UnaryModule(torch.nn.Module):
273
                    def forward(self, arg):
274
                        if op == "relu":
275
                            return torch.nn.functional.relu(arg)
276
                        if op == "sigmoid":
277
                            return torch.sigmoid(arg)
278
                        raise Exception("Bad op")  # noqa: TRY002
279

280
                self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
281
                self.check(
282
                    UnaryModule(),
283
                    qpt(torch.tensor([-1.0, 1.0]), 1.0 / 256, 0),
284
                )
285

286
    def test_pointwise_binary(self):
287
        for op in ["add", "sub", "mul", "div"]:
288
            with self.subTest(op):
289

290
                class BinaryModule(torch.nn.Module):
291
                    def forward(self, lhs, rhs):
292
                        if op == "add":
293
                            return lhs + rhs
294
                        if op == "sub":
295
                            return lhs - rhs
296
                        if op == "mul":
297
                            return lhs * rhs
298
                        if op == "div":
299
                            return lhs / rhs
300
                        raise Exception("Bad op")  # noqa: TRY002
301

302
                self.check(
303
                    BinaryModule(),
304
                    [
305
                        torch.tensor([1.0, 2.0]),
306
                        torch.tensor([3.0, 4.0]),
307
                    ],
308
                )
309

310
                self.check(
311
                    BinaryModule(),
312
                    [
313
                        torch.tensor([[1.0, 2.0]]),
314
                        torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
315
                    ],
316
                )
317

318
                with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"):
319
                    self.check(
320
                        BinaryModule(),
321
                        [
322
                            torch.tensor([1.0, 2.0]),
323
                            torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
324
                        ],
325
                    )
326

327
    def test_pointwise_binary_const(self):
328
        const = torch.randn(1, 4, 6, 6)
329

330
        class ArgPlusConst(torch.nn.Module):
331
            def forward(self, arg):
332
                return arg + const
333

334
        class ConstPlusArg(torch.nn.Module):
335
            def forward(self, arg):
336
                return const + arg
337

338
        arg_contig = torch.randn(2, 4, 6, 6)
339
        arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))
340

341
        for mod_class in [ArgPlusConst, ConstPlusArg]:
342
            for use_nhwc in [False, True]:
343
                with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
344
                    arg = arg_nhwc if use_nhwc else arg_contig
345
                    memory_format = (
346
                        torch.channels_last if use_nhwc else torch.contiguous_format
347
                    )
348
                    self.check(mod_class(), arg, expected_memory_format=memory_format)
349

350
    def test_hardtanh(self):
351
        inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
352
        self.check(torch.nn.Hardtanh(), inp)
353
        self.check(torch.nn.Hardtanh(0.0, 6.0), inp)
354
        with self.assertRaisesRegex(Exception, "hardtanh with args"):
355
            self.check(torch.nn.Hardtanh(0.0, 5.0), inp)
356

357
    def test_softmax(self):
358
        inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]])
359
        self.check(torch.nn.Softmax(), inp)
360
        self.check(torch.nn.Softmax(dim=0), inp)
361
        # Test flexible size
362
        self.check(
363
            torch.nn.Softmax(),
364
            inp,
365
            convert_args=[torch.zeros(0, 0)],
366
        )
367

368
    def test_to(self):
369
        class ToCPU(torch.nn.Module):
370
            def __init__(self) -> None:
371
                super().__init__()
372
                self.prelu = torch.nn.PReLU()
373

374
            def forward(self, x):
375
                y = x.to("cpu")
376
                # add prelu since input operand can't be output
377
                return self.prelu(y)
378

379
        arg = torch.randn(1, 2, 3, 3)
380
        self.check(ToCPU(), arg)
381
        # Test flexible size
382
        self.check(
383
            ToCPU(),
384
            arg,
385
            convert_args=[torch.zeros(1, 2, 0, 0)],
386
        )
387

388
    def test_detach(self):
389
        class DetachModule(torch.nn.Module):
390
            def forward(self, x):
391
                y = x.detach()
392
                return torch.nn.functional.relu(y)
393

394
        self.check(DetachModule(), torch.randn(1, 2, 3, 3))
395
        self.check(
396
            DetachModule(),
397
            torch.randn(1, 2, 3, 3),
398
            convert_args=[torch.zeros(1, 2, 0, 0)],
399
        )
400

401
    def test_log_softmax(self):
402
        inp = torch.randn(3, 10)
403
        self.check(torch.nn.LogSoftmax(), inp)
404
        self.check(torch.nn.LogSoftmax(0), inp)
405

406
    def test_mean(self):
407
        class MeanModule(torch.nn.Module):
408
            def __init__(self, dim, keep=False):
409
                super().__init__()
410
                self.dim = dim
411
                self.keep = keep
412

413
            def forward(self, t):
414
                return torch.mean(t, dim=self.dim, keepdim=self.keep)
415

416
        self.check(MeanModule(0), torch.randn(2, 3))
417
        self.check(MeanModule(1), torch.randn(2, 3))
418
        self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
419
        self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
420
        self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
421
        self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))
422

423
    def test_max_pool2d(self):
424
        for name, inp in self.float_and_quant_and_nhwc(
425
            torch.randn(2, 3, 12, 16), 0.3, 128
426
        ):
427
            with self.subTest(name):
428
                self.check(torch.nn.MaxPool2d(2), inp)
429
                self.check(torch.nn.MaxPool2d((3, 4)), inp)
430
                self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)
431

432
    def test_avg_pool2d(self):
433
        for name, inp in self.float_and_quant_and_nhwc(
434
            torch.randn(2, 3, 12, 16), 0.3, 128
435
        ):
436
            with self.subTest(name):
437
                atol_rtol = None
438
                limit = None
439
                convert_dims = (2, 3, 0, 0)
440
                convert_arg = torch.zeros(*convert_dims)
441

442
                for model in (
443
                    torch.nn.AvgPool2d(2),
444
                    torch.nn.AvgPool2d((3, 4)),
445
                    torch.nn.AvgPool2d((3, 4), (1, 2)),
446
                ):
447
                    if "quant" in name:
448
                        atol_rtol = (1, 0)
449
                        limit = model(inp).numel()
450
                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
451
                    if "nhwc" in name:
452
                        convert_arg = nhwc(convert_arg)
453

454
                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
455
                    self.check(
456
                        model,
457
                        inp,
458
                        convert_args=[convert_arg],
459
                        atol_rtol=atol_rtol,
460
                        limit=limit,
461
                    )
462

463
    def test_adaptive_avg_pool2d(self):
464
        for name, inp in self.float_and_quant_and_nhwc(
465
            torch.randn(2, 3, 12, 16), 0.3, 128
466
        ):
467
            with self.subTest(name):
468
                self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp)
469
                with self.assertRaisesRegex(Exception, "with output size"):
470
                    self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp)
471

472
    def test_upsample_nearest2d(self):
473
        convert_args = dict(
474
            self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128)
475
        )
476
        for name, inp in self.float_and_quant_and_nhwc(
477
            torch.randn(2, 3, 12, 16), 0.3, 128
478
        ):
479
            with self.subTest(name):
480
                self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp)
481
                self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp)
482
                self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp)
483
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp)
484
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp)
485
                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp)
486

487
                self.check(
488
                    torch.nn.UpsamplingNearest2d(size=(24, 32)),
489
                    inp,
490
                    convert_args=[convert_args[name]],
491
                )
492
                self.check(
493
                    torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)),
494
                    inp,
495
                    convert_args=[convert_args[name]],
496
                )
497

498
    def test_linear(self):
499
        torch.manual_seed(29)
500
        self.check(torch.nn.Linear(16, 32), torch.randn(2, 16))
501
        self.check(
502
            torch.nn.Linear(16, 32),
503
            torch.randn(2, 16),
504
            convert_args=[torch.zeros(0, 16)],
505
        )
506

507
    def test_conv2d(self):
508
        cases = [
509
            # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim,      name
510
            (4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"),  # noqa: E201,E241
511
            (4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"),  # noqa: E201,E241
512
            (4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"),  # noqa: E201,E241
513
            (8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"),  # noqa: E201,E241
514
            (4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"),  # noqa: E201,E241
515
            (4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"),  # noqa: E201,E241
516
            (8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"),  # noqa: E201,E241
517
        ]
518

519
        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
520
            for case in cases:
521
                (
522
                    in_ch,
523
                    out_ch,
524
                    kernel,
525
                    stride,
526
                    padding,
527
                    groups,
528
                    bias,
529
                    input_dim,
530
                    name,
531
                ) = case
532
                with self.subTest(f"{kind}-{name}"):
533
                    inp = torch.randn(input_dim)
534
                    model = torch.nn.Conv2d(
535
                        in_ch,
536
                        out_ch,
537
                        kernel,
538
                        stride,
539
                        padding,
540
                        groups=groups,
541
                        bias=bool(bias),
542
                    )
543
                    output_size = model(inp).numel()
544
                    atol_rtol = None
545
                    limit = None
546
                    convert_dims = (0, in_ch, 0, 0)
547
                    convert_arg = torch.zeros(*convert_dims)
548

549
                    if "quant" in kind:
550
                        model = torch.nn.Sequential(model)
551
                        model.eval()
552
                        model.qconfig = torch.ao.quantization.get_default_qconfig(
553
                            "qnnpack"
554
                        )
555
                        model = torch.ao.quantization.prepare(model)
556
                        model(inp)
557
                        model = torch.ao.quantization.convert(model)
558
                        inp = qpt(inp, 1.0 / 16, 128)
559
                        # I've seen numerical differences between QNNPACK and NNAPI,
560
                        # but never more than 1 quantum, and never more than ~1% of
561
                        # the output in this test.
562
                        atol_rtol = (1, 0)
563
                        limit = output_size * 0.03
564
                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
565

566
                    if "nhwc" in kind:
567
                        inp = nhwc(inp)
568
                        convert_arg = nhwc(convert_arg)
569

570
                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
571
                    self.check(
572
                        model,
573
                        inp,
574
                        convert_args=[convert_arg],
575
                        atol_rtol=atol_rtol,
576
                        limit=limit,
577
                    )
578

579
    def test_conv2d_transpose(self):
580
        torch.manual_seed(29)
581
        in_ch, out_ch, kernel = (5, 7, (2, 2))
582
        input_dim = (4, 5, 3, 3)
583
        convert_dims = input_dim[:2] + (0, 0)
584

585
        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
586
            with self.subTest(kind):
587
                inp = torch.randn(input_dim)
588
                model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel)
589
                output_size = model(inp).numel()
590
                atol_rtol = (0.0002, 0)
591
                limit = None
592
                convert_arg = torch.zeros(*convert_dims)
593

594
                if "quant" in kind:
595
                    model = torch.ao.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel)
596
                    model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
597
                    inp = qpt(inp, 1.0 / 16, 128)
598
                    # I've seen numerical differences between QNNPACK and NNAPI,
599
                    # but never more than 1 quantum, and never more than ~10% of
600
                    # the output in this test.
601
                    atol_rtol = (1, 0)
602
                    limit = output_size * 0.1
603
                    convert_arg = qpt(convert_arg, 1.0 / 16, 128)
604

605
                if "nhwc" in kind:
606
                    inp = nhwc(inp)
607
                    convert_arg = nhwc(convert_arg)
608

609
                self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
610
                self.check(
611
                    model,
612
                    inp,
613
                    convert_args=[convert_arg],
614
                    atol_rtol=atol_rtol,
615
                    limit=limit,
616
                )
617

618
    def test_qadd(self):
619
        func = torch.ao.nn.quantized.QFunctional()
620
        func.scale = 0.5
621
        func.zero_point = 120
622

623
        class AddMod(torch.nn.Module):
624
            def forward(self, lhs, rhs):
625
                return func.add(lhs, rhs)
626

627
        class AddReluMod(torch.nn.Module):
628
            def forward(self, lhs, rhs):
629
                return func.add_relu(lhs, rhs)
630

631
        class MulMod(torch.nn.Module):
632
            def forward(self, lhs, rhs):
633
                return func.mul(lhs, rhs)
634

635
        for name, mod in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]:
636
            with self.subTest(name):
637
                self.check(
638
                    mod(),
639
                    [
640
                        qpt([1.0, 2.0], 0.25, 128),
641
                        qpt([3.0, 4.0], 0.25, 128),
642
                    ],
643
                )
644
                self.check(
645
                    mod(),
646
                    [
647
                        qpt([[1.0, 2.0]], 0.25, 128),
648
                        qpt([[3.0, 4.0]], 0.25, 128),
649
                    ],
650
                    convert_args=[
651
                        qpt([[1.0, 2.0]], 0.25, 128),
652
                        qpt(torch.zeros((1, 2)), 0.25, 128),
653
                    ],
654
                )
655
                self.check(
656
                    mod(),
657
                    [
658
                        qpt([[1.0, 2.0]], 0.25, 128),
659
                        qpt([[3.0, 4.0]], 0.25, 128),
660
                    ],
661
                    convert_args=[
662
                        qpt(torch.zeros((1, 2)), 0.25, 128),
663
                        qpt([[3.0, 4.0]], 0.25, 128),
664
                    ],
665
                )
666
                self.check(
667
                    mod(),
668
                    [
669
                        qpt([[1.0, 2.0]], 0.25, 128),
670
                        qpt([[3.0, 4.0]], 0.25, 128),
671
                    ],
672
                    convert_args=[
673
                        qpt(torch.zeros((1, 2)), 0.25, 128),
674
                        qpt(torch.zeros((1, 2)), 0.25, 128),
675
                    ],
676
                )
677
                # NOTE: NNAPI qadd supports broadcast, but PT does not.
678

679
    def test_qlinear(self):
680
        torch.manual_seed(29)
681
        weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8)
682
        bias = torch.randn(16)
683
        mod = torch.ao.nn.quantized.Linear(32, 16)
684
        mod.set_weight_bias(weight, bias)
685
        inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
686
        self.check(mod, inp)
687

688
    def test_seblock_mul(self):
689
        class MulModel(torch.nn.Module):
690
            def forward(self, lhs, rhs):
691
                return lhs * rhs
692

693
        self.check(
694
            MulModel(),
695
            [
696
                nhwc(torch.randn(2, 3, 4, 4)),
697
                torch.randn(1, 3, 1, 1),
698
            ],
699
        )
700

701
    def test_multi_output(self):
702
        class MultiModel(torch.nn.Module):
703
            def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
704
                the_sum = lhs + rhs
705
                the_diff = lhs - rhs
706
                return the_sum, the_diff
707

708
        self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])])
709

710

711
if __name__ == "__main__":
712
    run_tests()
713

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.