pytorch

Форк
0
/
test_xnnpack_integration.py 
1500 строк · 53.5 Кб
1
# Owner(s): ["oncall: mobile"]
2

3
import io
4
import itertools
5
import unittest
6

7
from hypothesis import assume, given, strategies as st
8

9
import torch
10
import torch.backends.xnnpack
11
import torch.testing._internal.hypothesis_utils as hu
12
from torch.nn import functional as F
13
from torch.testing import FileCheck
14
from torch.testing._internal.common_utils import (
15
    IS_FBCODE,
16
    run_tests,
17
    slowTest,
18
    TEST_WITH_TSAN,
19
    TestCase,
20
)
21
from torch.utils.mobile_optimizer import optimize_for_mobile
22

23

24
@unittest.skipUnless(
25
    torch.backends.xnnpack.enabled,
26
    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
27
)
28
@unittest.skipIf(
29
    TEST_WITH_TSAN,
30
    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
31
)
32
class TestXNNPACKOps(TestCase):
33
    @unittest.skip(
34
        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
35
    )
36
    @given(
37
        batch_size=st.integers(0, 3),
38
        data_shape=hu.array_shapes(1, 3, 2, 64),
39
        weight_output_dim=st.integers(2, 64),
40
        use_bias=st.booleans(),
41
    )
42
    def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
43
        data_shape = [batch_size] + list(data_shape)
44
        input_data = torch.rand(data_shape)
45
        weight = torch.rand((weight_output_dim, data_shape[-1]))
46
        if use_bias:
47
            bias = torch.rand(weight_output_dim)
48
        else:
49
            bias = None
50
        ref_result = F.linear(input_data, weight, bias)
51
        packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
52
        output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
53
            input_data, packed_weight_bias
54
        )
55
        torch.testing.assert_close(
56
            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
57
        )
58

59
    @given(
60
        input_size=st.integers(2, 32),
61
        weight_output_dim=st.integers(2, 64),
62
        use_bias=st.booleans(),
63
    )
64
    def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
65
        input_data = torch.rand(input_size)
66
        weight = torch.rand((weight_output_dim, input_data.shape[-1]))
67
        if use_bias:
68
            bias = torch.rand(weight_output_dim)
69
        else:
70
            bias = None
71
        ref_result = F.linear(input_data, weight, bias)
72
        packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
73
        output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
74
            input_data, packed_weight_bias
75
        )
76
        torch.testing.assert_close(
77
            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
78
        )
79

80
    @given(
81
        batch_size=st.integers(0, 3),
82
        input_channels_per_group=st.integers(1, 32),
83
        height=st.integers(5, 64),
84
        width=st.integers(5, 64),
85
        output_channels_per_group=st.integers(1, 32),
86
        groups=st.integers(1, 16),
87
        kernel_h=st.integers(1, 7),
88
        kernel_w=st.integers(1, 7),
89
        stride_h=st.integers(1, 2),
90
        stride_w=st.integers(1, 2),
91
        pad_h=st.integers(0, 2),
92
        pad_w=st.integers(0, 2),
93
        dilation=st.integers(1, 2),
94
        use_bias=st.booleans(),
95
        format=st.sampled_from(
96
            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
97
        ),
98
    )
99
    def test_conv2d(
100
        self,
101
        batch_size,
102
        input_channels_per_group,
103
        height,
104
        width,
105
        output_channels_per_group,
106
        groups,
107
        kernel_h,
108
        kernel_w,
109
        stride_h,
110
        stride_w,
111
        pad_h,
112
        pad_w,
113
        dilation,
114
        use_bias,
115
        format,
116
    ):
117
        input_channels = input_channels_per_group * groups
118
        output_channels = output_channels_per_group * groups
119
        kernels = (kernel_h, kernel_w)
120
        strides = (stride_h, stride_w)
121
        paddings = (pad_h, pad_w)
122
        dilations = (dilation, dilation)
123
        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
124
        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
125

126
        input_data = torch.rand((batch_size, input_channels, height, width))
127
        if format is not None:
128
            input_data = input_data.contiguous(memory_format=format)
129
        weight = torch.rand(
130
            (output_channels, input_channels_per_group, kernel_h, kernel_w)
131
        )
132
        bias = None
133
        if use_bias:
134
            bias = torch.rand(output_channels)
135

136
        ref_result = F.conv2d(
137
            input_data, weight, bias, strides, paddings, dilations, groups
138
        )
139
        packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
140
            weight, bias, strides, paddings, dilations, groups
141
        )
142
        xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(
143
            input_data, packed_weight_bias
144
        )
145
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
146

147
    @given(
148
        batch_size=st.integers(1, 3),
149
        input_channels_per_group=st.integers(1, 32),
150
        height=st.integers(5, 64),
151
        width=st.integers(5, 64),
152
        output_channels_per_group=st.integers(1, 32),
153
        groups=st.integers(1, 16),
154
        kernel_h=st.integers(1, 7),
155
        kernel_w=st.integers(1, 7),
156
        stride_h=st.integers(1, 2),
157
        stride_w=st.integers(1, 2),
158
        pad_h=st.integers(0, 2),
159
        pad_w=st.integers(0, 2),
160
        output_pad_h=st.integers(0, 2),
161
        output_pad_w=st.integers(0, 2),
162
        dilation=st.integers(1, 2),
163
        use_bias=st.booleans(),
164
        format=st.sampled_from(
165
            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
166
        ),
167
    )
168
    def test_conv2d_transpose(
169
        self,
170
        batch_size,
171
        input_channels_per_group,
172
        height,
173
        width,
174
        output_channels_per_group,
175
        groups,
176
        kernel_h,
177
        kernel_w,
178
        stride_h,
179
        stride_w,
180
        pad_h,
181
        pad_w,
182
        output_pad_h,
183
        output_pad_w,
184
        dilation,
185
        use_bias,
186
        format,
187
    ):
188
        input_channels = input_channels_per_group * groups
189
        output_channels = output_channels_per_group * groups
190
        kernels = (kernel_h, kernel_w)
191
        strides = (stride_h, stride_w)
192
        paddings = (pad_h, pad_w)
193
        output_paddings = (output_pad_h, output_pad_w)
194
        dilations = (dilation, dilation)
195
        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
196
        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
197
        assume((output_pad_h < stride_h) and (output_pad_h < dilation))
198
        assume((output_pad_w < stride_w) and (output_pad_w < dilation))
199

200
        input_data = torch.rand((batch_size, input_channels, height, width))
201
        if format is not None:
202
            input_data = input_data.contiguous(memory_format=format)
203
        weight = torch.rand(
204
            (input_channels, output_channels_per_group, kernel_h, kernel_w)
205
        )
206
        bias = None
207
        if use_bias:
208
            bias = torch.rand(output_channels)
209

210
        # Note that groups/dilation is in reverse order from conv2d
211
        ref_result = F.conv_transpose2d(
212
            input_data,
213
            weight,
214
            bias,
215
            strides,
216
            paddings,
217
            output_paddings,
218
            groups,
219
            dilation,
220
        )
221
        packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(
222
            weight, bias, strides, paddings, output_paddings, dilations, groups
223
        )
224
        xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(
225
            input_data, packed_weight_bias
226
        )
227
        torch.testing.assert_close(
228
            ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3
229
        )
230

231

232
@unittest.skipUnless(
233
    torch.backends.xnnpack.enabled,
234
    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
235
)
236
@unittest.skipIf(
237
    TEST_WITH_TSAN,
238
    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
239
)
240
class TestXNNPACKSerDes(TestCase):
241
    @unittest.skip(
242
        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
243
    )
244
    @given(
245
        batch_size=st.integers(0, 3),
246
        data_shape=hu.array_shapes(1, 3, 2, 64),
247
        weight_output_dim=st.integers(2, 64),
248
        use_bias=st.booleans(),
249
    )
250
    def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
251
        class Linear(torch.nn.Module):
252
            def __init__(self, weight, bias=None):
253
                super().__init__()
254
                self.weight = weight
255
                self.bias = bias
256

257
            def forward(self, x):
258
                return F.linear(x, self.weight, self.bias)
259

260
        class LinearPrePacked(torch.nn.Module):
261
            def __init__(self, weight, bias=None):
262
                super().__init__()
263
                self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(
264
                    weight, bias
265
                )
266

267
            def forward(self, x):
268
                return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
269

270
        data_shape = [batch_size] + list(data_shape)
271
        weight = torch.rand((weight_output_dim, data_shape[-1]))
272
        if use_bias:
273
            bias = torch.rand(weight_output_dim)
274
        else:
275
            bias = None
276
        scripted_linear = torch.jit.script(Linear(weight, bias))
277
        scripted_linear_clamp_prepacked = torch.jit.script(
278
            LinearPrePacked(weight, bias)
279
        )
280
        input_data = torch.rand(data_shape)
281
        ref_result = scripted_linear(input_data)
282
        output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
283
        torch.testing.assert_close(
284
            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
285
        )
286

287
        # Serialize the modules and then deserialize
288
        input_data = torch.rand(data_shape)
289
        buffer = io.BytesIO()
290
        torch.jit.save(scripted_linear, buffer)
291
        buffer.seek(0)
292
        deserialized_linear = torch.jit.load(buffer)
293
        buffer = io.BytesIO()
294
        torch.jit.save(scripted_linear_clamp_prepacked, buffer)
295
        buffer.seek(0)
296
        deserialized_linear_clamp_prepacked = torch.jit.load(buffer)
297
        ref_result = deserialized_linear(input_data)
298
        output_linearprepacked = deserialized_linear_clamp_prepacked(input_data)
299
        torch.testing.assert_close(
300
            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
301
        )
302

303
    @given(
304
        batch_size=st.integers(0, 3),
305
        input_channels_per_group=st.integers(1, 32),
306
        height=st.integers(5, 64),
307
        width=st.integers(5, 64),
308
        output_channels_per_group=st.integers(1, 32),
309
        groups=st.integers(1, 16),
310
        kernel_h=st.integers(1, 7),
311
        kernel_w=st.integers(1, 7),
312
        stride_h=st.integers(1, 2),
313
        stride_w=st.integers(1, 2),
314
        pad_h=st.integers(0, 2),
315
        pad_w=st.integers(0, 2),
316
        dilation=st.integers(1, 2),
317
        use_bias=st.booleans(),
318
        format=st.sampled_from(
319
            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
320
        ),
321
    )
322
    def test_conv2d(
323
        self,
324
        batch_size,
325
        input_channels_per_group,
326
        height,
327
        width,
328
        output_channels_per_group,
329
        groups,
330
        kernel_h,
331
        kernel_w,
332
        stride_h,
333
        stride_w,
334
        pad_h,
335
        pad_w,
336
        dilation,
337
        use_bias,
338
        format,
339
    ):
340
        class Conv2D(torch.nn.Module):
341
            def __init__(self, weight, bias, strides, paddings, dilations, groups):
342
                super().__init__()
343
                self.weight = weight
344
                self.bias = bias
345
                self.strides = strides
346
                self.paddings = paddings
347
                self.dilations = dilations
348
                self.groups = groups
349

350
            def forward(self, x):
351
                return F.conv2d(
352
                    x,
353
                    self.weight,
354
                    self.bias,
355
                    self.strides,
356
                    self.paddings,
357
                    self.dilations,
358
                    self.groups,
359
                )
360

361
        class Conv2DPrePacked(torch.nn.Module):
362
            def __init__(self, weight, bias, strides, paddings, dilations, groups):
363
                super().__init__()
364
                self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
365
                    weight, bias, strides, paddings, dilations, groups
366
                )
367

368
            def forward(self, x):
369
                return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias)
370

371
        input_channels = input_channels_per_group * groups
372
        output_channels = output_channels_per_group * groups
373
        kernels = (kernel_h, kernel_w)
374
        strides = (stride_h, stride_w)
375
        paddings = (pad_h, pad_w)
376
        dilations = (dilation, dilation)
377
        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
378
        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
379

380
        input_data = torch.rand((batch_size, input_channels, height, width))
381
        if format is not None:
382
            input_data = input_data.contiguous(memory_format=format)
383
        weight = torch.rand(
384
            (output_channels, input_channels_per_group, kernel_h, kernel_w)
385
        )
386
        bias = None
387
        if use_bias:
388
            bias = torch.rand(output_channels)
389

390
        scripted_conv2d = torch.jit.script(
391
            Conv2D(weight, bias, strides, paddings, dilations, groups)
392
        )
393
        scripted_conv2d_clamp_prepacked = torch.jit.script(
394
            Conv2DPrePacked(weight, bias, strides, paddings, dilations, groups)
395
        )
396
        ref_result = scripted_conv2d(input_data)
397
        xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
398
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
399

400
        # Serialize the modules and then deserialize
401
        input_data = torch.rand((batch_size, input_channels, height, width))
402
        if format is not None:
403
            input_data = input_data.contiguous(memory_format=format)
404
        buffer = io.BytesIO()
405
        torch.jit.save(scripted_conv2d, buffer)
406
        buffer.seek(0)
407
        deserialized_conv2d = torch.jit.load(buffer)
408
        buffer = io.BytesIO()
409
        torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
410
        buffer.seek(0)
411
        deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
412
        ref_result = deserialized_conv2d(input_data)
413
        xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
414
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
415

416
    @given(
417
        batch_size=st.integers(0, 3),
418
        input_channels_per_group=st.integers(1, 32),
419
        height=st.integers(5, 64),
420
        width=st.integers(5, 64),
421
        output_channels_per_group=st.integers(1, 32),
422
        groups=st.integers(1, 16),
423
        kernel_h=st.integers(1, 7),
424
        kernel_w=st.integers(1, 7),
425
        stride_h=st.integers(1, 2),
426
        stride_w=st.integers(1, 2),
427
        pad_h=st.integers(0, 2),
428
        pad_w=st.integers(0, 2),
429
        output_pad_h=st.integers(0, 2),
430
        output_pad_w=st.integers(0, 2),
431
        dilation=st.integers(1, 2),
432
        use_bias=st.booleans(),
433
        format=st.sampled_from(
434
            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
435
        ),
436
    )
437
    def test_conv2d_transpose(
438
        self,
439
        batch_size,
440
        input_channels_per_group,
441
        height,
442
        width,
443
        output_channels_per_group,
444
        groups,
445
        kernel_h,
446
        kernel_w,
447
        stride_h,
448
        stride_w,
449
        pad_h,
450
        pad_w,
451
        output_pad_h,
452
        output_pad_w,
453
        dilation,
454
        use_bias,
455
        format,
456
    ):
457
        class Conv2DT(torch.nn.Module):
458
            def __init__(
459
                self,
460
                weight,
461
                bias,
462
                strides,
463
                paddings,
464
                output_paddings,
465
                dilations,
466
                groups,
467
            ):
468
                super().__init__()
469
                self.weight = weight
470
                self.bias = bias
471
                self.strides = strides
472
                self.paddings = paddings
473
                self.output_paddings = output_paddings
474
                self.dilations = dilations
475
                self.groups = groups
476

477
            def forward(self, x):
478
                return F.conv_transpose2d(
479
                    x,
480
                    self.weight,
481
                    self.bias,
482
                    self.strides,
483
                    self.paddings,
484
                    self.output_paddings,
485
                    self.groups,
486
                    self.dilations,
487
                )
488

489
        class Conv2DTPrePacked(torch.nn.Module):
490
            def __init__(
491
                self,
492
                weight,
493
                bias,
494
                strides,
495
                paddings,
496
                output_paddings,
497
                dilations,
498
                groups,
499
            ):
500
                super().__init__()
501
                self.packed_weight_bias = (
502
                    torch.ops.prepacked.conv2d_transpose_clamp_prepack(
503
                        weight,
504
                        bias,
505
                        strides,
506
                        paddings,
507
                        output_paddings,
508
                        dilations,
509
                        groups,
510
                    )
511
                )
512

513
            def forward(self, x):
514
                return torch.ops.prepacked.conv2d_transpose_clamp_run(
515
                    x, self.packed_weight_bias
516
                )
517

518
        input_channels = input_channels_per_group * groups
519
        output_channels = output_channels_per_group * groups
520
        kernels = (kernel_h, kernel_w)
521
        strides = (stride_h, stride_w)
522
        paddings = (pad_h, pad_w)
523
        output_paddings = (output_pad_h, output_pad_w)
524
        dilations = (dilation, dilation)
525
        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
526
        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
527
        assume((output_pad_h < stride_h) and (output_pad_h < dilation))
528
        assume((output_pad_w < stride_w) and (output_pad_w < dilation))
529

530
        input_data = torch.rand((batch_size, input_channels, height, width))
531
        if format is not None:
532
            input_data = input_data.contiguous(memory_format=format)
533
        weight = torch.rand(
534
            (input_channels, output_channels_per_group, kernel_h, kernel_w)
535
        )
536
        bias = None
537
        if use_bias:
538
            bias = torch.rand(output_channels)
539

540
        scripted_conv2d = torch.jit.script(
541
            Conv2DT(weight, bias, strides, paddings, output_paddings, dilations, groups)
542
        )
543
        scripted_conv2d_clamp_prepacked = torch.jit.script(
544
            Conv2DTPrePacked(
545
                weight, bias, strides, paddings, output_paddings, dilations, groups
546
            )
547
        )
548
        ref_result = scripted_conv2d(input_data)
549
        xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
550
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
551

552
        # Serialize the modules and then deserialize
553
        input_data = torch.rand((batch_size, input_channels, height, width))
554
        if format is not None:
555
            input_data = input_data.contiguous(memory_format=format)
556
        buffer = io.BytesIO()
557
        torch.jit.save(scripted_conv2d, buffer)
558
        buffer.seek(0)
559
        deserialized_conv2d = torch.jit.load(buffer)
560
        buffer = io.BytesIO()
561
        torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
562
        buffer.seek(0)
563
        deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
564
        ref_result = deserialized_conv2d(input_data)
565
        xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
566
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
567

568
    @unittest.skip(
569
        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
570
    )
571
    @given(
572
        batch_size=st.integers(0, 3),
573
        input_channels_per_group=st.integers(1, 32),
574
        height=st.integers(5, 64),
575
        width=st.integers(5, 64),
576
        output_channels_per_group=st.integers(1, 32),
577
        groups=st.integers(1, 16),
578
        kernel_h=st.integers(1, 7),
579
        kernel_w=st.integers(1, 7),
580
        stride_h=st.integers(1, 2),
581
        stride_w=st.integers(1, 2),
582
        pad_h=st.integers(0, 2),
583
        pad_w=st.integers(0, 2),
584
        dilation=st.integers(1, 2),
585
        linear_weight_output_dim=st.integers(2, 64),
586
        use_bias=st.booleans(),
587
        format=st.sampled_from(
588
            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
589
        ),
590
    )
591
    def test_combined_model(
592
        self,
593
        batch_size,
594
        input_channels_per_group,
595
        height,
596
        width,
597
        output_channels_per_group,
598
        groups,
599
        kernel_h,
600
        kernel_w,
601
        stride_h,
602
        stride_w,
603
        pad_h,
604
        pad_w,
605
        dilation,
606
        linear_weight_output_dim,
607
        use_bias,
608
        format,
609
    ):
610
        class M(torch.nn.Module):
611
            def __init__(
612
                self,
613
                conv_weight,
614
                conv_bias,
615
                linear_weight,
616
                linear_bias,
617
                strides,
618
                paddings,
619
                dilations,
620
                groups,
621
            ):
622
                super().__init__()
623
                self.conv_weight = conv_weight
624
                self.conv_bias = conv_bias
625
                self.linear_weight = linear_weight
626
                self.linear_bias = linear_bias
627
                self.strides = strides
628
                self.paddings = paddings
629
                self.dilations = dilations
630
                self.groups = groups
631

632
            def forward(self, x):
633
                o = F.conv2d(
634
                    x,
635
                    self.conv_weight,
636
                    self.conv_bias,
637
                    self.strides,
638
                    self.paddings,
639
                    self.dilations,
640
                    self.groups,
641
                )
642
                o = o.permute([0, 2, 3, 1])
643
                o = F.linear(o, self.linear_weight, self.linear_bias)
644
                return F.relu(o)
645

646
        class MPrePacked(torch.nn.Module):
647
            def __init__(
648
                self,
649
                conv_weight,
650
                conv_bias,
651
                linear_weight,
652
                linear_bias,
653
                strides,
654
                paddings,
655
                dilations,
656
                groups,
657
            ):
658
                super().__init__()
659
                self.conv2d_clamp_run_weight_bias = (
660
                    torch.ops.prepacked.conv2d_clamp_prepack(
661
                        conv_weight, conv_bias, strides, paddings, dilations, groups
662
                    )
663
                )
664
                self.linear_clamp_run_weight_bias = (
665
                    torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias)
666
                )
667

668
            def forward(self, x):
669
                o = torch.ops.prepacked.conv2d_clamp_run(
670
                    x, self.conv2d_clamp_run_weight_bias
671
                )
672
                o = o.permute([0, 2, 3, 1])
673
                o = torch.ops.prepacked.linear_clamp_run(
674
                    o, self.linear_clamp_run_weight_bias
675
                )
676
                return F.relu(o)
677

678
        input_channels = input_channels_per_group * groups
679
        output_channels = output_channels_per_group * groups
680
        kernels = (kernel_h, kernel_w)
681
        strides = (stride_h, stride_w)
682
        paddings = (pad_h, pad_w)
683
        dilations = (dilation, dilation)
684
        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
685
        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
686

687
        input_data = torch.rand((batch_size, input_channels, height, width))
688
        if format is not None:
689
            input_data = input_data.contiguous(memory_format=format)
690
        conv_weight = torch.rand(
691
            (output_channels, input_channels_per_group, kernel_h, kernel_w)
692
        )
693
        conv_bias = None
694
        if use_bias:
695
            conv_bias = torch.rand(output_channels)
696

697
        # This is done just to find the output shape of the result
698
        # so that the shape of weight for the following linear layer
699
        # can be determined.
700
        result = F.conv2d(
701
            input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
702
        )
703
        linear_input_shape = result.shape[1]
704

705
        linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
706
        linear_bias = None
707
        if use_bias:
708
            linear_bias = torch.rand(linear_weight_output_dim)
709

710
        scripted_m = torch.jit.script(
711
            M(
712
                conv_weight,
713
                conv_bias,
714
                linear_weight,
715
                linear_bias,
716
                strides,
717
                paddings,
718
                dilations,
719
                groups,
720
            )
721
        )
722
        scripted_m_prepacked = torch.jit.script(
723
            MPrePacked(
724
                conv_weight,
725
                conv_bias,
726
                linear_weight,
727
                linear_bias,
728
                strides,
729
                paddings,
730
                dilations,
731
                groups,
732
            )
733
        )
734
        ref_result = scripted_m(input_data)
735
        xnnpack_result = scripted_m_prepacked(input_data)
736
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
737

738
        # Serialize the modules and then deserialize
739
        input_data = torch.rand((batch_size, input_channels, height, width))
740
        input_data = input_data.contiguous(memory_format=torch.channels_last)
741
        buffer = io.BytesIO()
742
        torch.jit.save(scripted_m, buffer)
743
        buffer.seek(0)
744
        deserialized_m = torch.jit.load(buffer)
745
        buffer = io.BytesIO()
746
        torch.jit.save(scripted_m_prepacked, buffer)
747
        buffer.seek(0)
748
        deserialized_m_prepacked = torch.jit.load(buffer)
749
        ref_result = deserialized_m(input_data)
750
        xnnpack_result = deserialized_m_prepacked(input_data)
751
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
752

753

754
@unittest.skipUnless(
755
    torch.backends.xnnpack.enabled,
756
    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
757
)
758
@unittest.skipIf(
759
    TEST_WITH_TSAN,
760
    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
761
)
762
class TestXNNPACKRewritePass(TestCase):
763
    @staticmethod
764
    def validate_transformed_module(
765
        # To please flake
766
        self,
767
        pattern_count_map,
768
        data_shape,
769
        prepack_removal=False,
770
        fuse_clamping_ops=False,
771
    ):
772
        input_data = torch.normal(1, 20, size=data_shape)
773

774
        for jit_method in ["script", "trace"]:
775
            module_instance = self
776
            if jit_method == "script":
777
                scripted_model = torch.jit.script(module_instance)
778
            else:
779
                scripted_model = torch.jit.trace(module_instance, input_data)
780
            scripted_model.eval()
781
            ref_result = scripted_model(input_data)
782
            torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
783
            if fuse_clamping_ops or prepack_removal:
784
                scripted_model._c = torch._C._freeze_module(scripted_model._c)
785
            if fuse_clamping_ops:
786
                torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
787
            if prepack_removal:
788
                torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
789

790
            buffer = io.BytesIO()
791
            torch.jit.save(scripted_model, buffer)
792
            buffer.seek(0)
793
            deserialized_scripted_model = torch.jit.load(buffer)
794
            for pattern, v in pattern_count_map.items():
795
                if v == 0:
796
                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
797
                elif v == -1:
798
                    FileCheck().check_not(pattern).run(
799
                        deserialized_scripted_model.graph
800
                    )
801
                else:
802
                    FileCheck().check_count(pattern, v, exactly=True).run(
803
                        deserialized_scripted_model.graph
804
                    )
805
            xnnpack_result = deserialized_scripted_model(input_data)
806
            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
807

808
    def test_linear(self):
809
        data_shape = [2, 3, 32]
810
        weight_output_dim = 24
811
        weight_shape = (weight_output_dim, data_shape[-1])
812

813
        class Linear(torch.nn.Module):
814
            def __init__(self) -> None:
815
                super().__init__()
816
                self.weight = torch.nn.Parameter(
817
                    torch.rand(weight_shape), requires_grad=False
818
                )
819
                self.bias = torch.nn.Parameter(
820
                    torch.rand(weight_output_dim), requires_grad=False
821
                )
822

823
            def forward(self, x):
824
                return F.linear(x, self.weight, self.bias)
825

826
        class LinearNoBias(torch.nn.Module):
827
            def __init__(self) -> None:
828
                super().__init__()
829
                self.weight = torch.nn.Parameter(
830
                    torch.rand(weight_shape), requires_grad=False
831
                )
832

833
            def forward(self, x):
834
                return F.linear(x, self.weight, None)
835

836
        # Linear with bias pattern.
837
        pattern_count_map = {
838
            "Tensor = prim::CallFunction": -1,
839
            "prepacked::linear_clamp_prepack": 1,
840
            "prepacked::linear_clamp_run": 1,
841
        }
842
        TestXNNPACKRewritePass.validate_transformed_module(
843
            Linear(), pattern_count_map, data_shape
844
        )
845
        TestXNNPACKRewritePass.validate_transformed_module(
846
            LinearNoBias(), pattern_count_map, data_shape
847
        )
848

849
        # Conv params
850
        batch_size = 2
851
        input_channels_per_group = 6
852
        height = 16
853
        width = 16
854
        output_channels_per_group = 6
855
        groups = 4
856
        kernel_h = kernel_w = 3
857
        stride_h = stride_w = 1
858
        pad_h = pad_w = 1
859
        output_pad_h = output_pad_w = 0
860
        dilation = 1
861
        input_channels = input_channels_per_group * groups
862
        output_channels = output_channels_per_group * groups
863
        kernels = (kernel_h, kernel_w)
864
        strides = (stride_h, stride_w)
865
        paddings = (pad_h, pad_w)
866
        output_paddings = (output_pad_h, output_pad_w)
867
        dilations = (dilation, dilation)
868
        conv_weight_shape = (
869
            output_channels,
870
            input_channels_per_group,
871
            kernel_h,
872
            kernel_w,
873
        )
874
        conv_transpose_weight_shape = (
875
            input_channels,
876
            output_channels_per_group,
877
            kernel_h,
878
            kernel_w,
879
        )
880
        conv_bias_shape = output_channels
881

882
        class Conv2D(torch.nn.Module):
883
            def __init__(self) -> None:
884
                super().__init__()
885
                self.weight = torch.nn.Parameter(
886
                    torch.rand(conv_weight_shape), requires_grad=False
887
                )
888
                self.bias = torch.nn.Parameter(
889
                    torch.rand(conv_bias_shape), requires_grad=False
890
                )
891
                self.strides = strides
892
                self.paddings = paddings
893
                self.dilations = dilations
894
                self.groups = groups
895

896
            def forward(self, x):
897
                return F.conv2d(
898
                    x,
899
                    self.weight,
900
                    self.bias,
901
                    self.strides,
902
                    self.paddings,
903
                    self.dilations,
904
                    self.groups,
905
                )
906

907
        class Conv2DT(torch.nn.Module):
908
            def __init__(self) -> None:
909
                super().__init__()
910
                self.weight = torch.nn.Parameter(
911
                    torch.rand(conv_transpose_weight_shape), requires_grad=False
912
                )
913
                self.bias = torch.nn.Parameter(
914
                    torch.rand(conv_bias_shape), requires_grad=False
915
                )
916
                self.strides = strides
917
                self.paddings = paddings
918
                self.output_paddings = output_paddings
919
                self.dilations = dilations
920
                self.groups = groups
921

922
            def forward(self, x):
923
                return F.conv_transpose2d(
924
                    x,
925
                    self.weight,
926
                    self.bias,
927
                    self.strides,
928
                    self.paddings,
929
                    self.output_paddings,
930
                    self.groups,
931
                    self.dilations,
932
                )
933

934
        data_shape = (batch_size, input_channels, height, width)
935
        pattern_count_map = {
936
            "Tensor = aten::conv2d": -1,
937
            "prepacked::conv2d_clamp_prepack": 1,
938
            "prepacked::conv2d_clamp_run": 1,
939
        }
940
        TestXNNPACKRewritePass.validate_transformed_module(
941
            Conv2D(), pattern_count_map, data_shape
942
        )
943

944
        transpose_data_shape = (batch_size, input_channels, height, width)
945
        transpose_pattern_count_map = {
946
            "Tensor = aten::conv_transpose2d": -1,
947
            "prepacked::conv2d_transpose_clamp_prepack": 1,
948
            "prepacked::conv2d_transpose_clamp_run": 1,
949
        }
950
        TestXNNPACKRewritePass.validate_transformed_module(
951
            Conv2DT(), transpose_pattern_count_map, data_shape
952
        )
953

954
        input_data = torch.rand((batch_size, input_channels, height, width))
955
        conv_weight = torch.rand(
956
            (output_channels, input_channels_per_group, kernel_h, kernel_w)
957
        )
958
        conv_bias = torch.rand(output_channels)
959
        result = F.conv2d(
960
            input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
961
        )
962
        linear_input_shape = result.shape[1]
963
        linear_weight_shape = (weight_output_dim, linear_input_shape)
964

965
        class M(torch.nn.Module):
966
            def __init__(self, activation_fn=F.relu):
967
                super().__init__()
968
                self.conv_weight = torch.nn.Parameter(
969
                    torch.rand(conv_weight_shape), requires_grad=False
970
                )
971
                self.conv_bias = torch.nn.Parameter(
972
                    torch.rand(conv_bias_shape), requires_grad=False
973
                )
974
                self.linear_weight = torch.nn.Parameter(
975
                    torch.rand(linear_weight_shape), requires_grad=False
976
                )
977
                self.linear_bias = torch.nn.Parameter(
978
                    torch.rand(weight_output_dim), requires_grad=False
979
                )
980
                self.strides = strides
981
                self.paddings = paddings
982
                self.dilations = dilations
983
                self.groups = groups
984
                self.activation_fn = activation_fn
985

986
            def forward(self, x):
987
                o = F.conv2d(
988
                    x,
989
                    self.conv_weight,
990
                    self.conv_bias,
991
                    self.strides,
992
                    self.paddings,
993
                    self.dilations,
994
                    self.groups,
995
                )
996
                o = self.activation_fn(o)
997
                o = o.permute([0, 2, 3, 1])
998
                o = F.linear(o, self.linear_weight, self.linear_bias)
999
                return self.activation_fn(o)
1000

1001
        pattern_count_map = {
1002
            "Tensor = aten::conv2d": -1,
1003
            "prepacked::conv2d_clamp_prepack": 1,
1004
            "prepacked::conv2d_clamp_run": 1,
1005
            "prepacked::linear_clamp_prepack": 1,
1006
            "prepacked::linear_clamp_run": 1,
1007
        }
1008
        TestXNNPACKRewritePass.validate_transformed_module(
1009
            M(), pattern_count_map, data_shape
1010
        )
1011
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1012
        pattern_count_map["Tensor = prim::CallFunction"] = -1
1013
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1014
        TestXNNPACKRewritePass.validate_transformed_module(
1015
            M(), pattern_count_map, data_shape, prepack_removal=True
1016
        )
1017

1018
        # Not inplace relu fusion test.
1019
        pattern_count_map = {
1020
            "aten::relu": 2,
1021
            "prepacked::conv2d_clamp_prepack": -1,
1022
            "prepacked::conv2d_clamp_run": 1,
1023
            "prepacked::linear_clamp_prepack": -1,
1024
            "prepacked::linear_clamp_run": 1,
1025
        }
1026
        TestXNNPACKRewritePass.validate_transformed_module(
1027
            M(), pattern_count_map, data_shape, prepack_removal=True
1028
        )
1029
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1030
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1031
        pattern_count_map["aten::relu"] = -1
1032
        TestXNNPACKRewritePass.validate_transformed_module(
1033
            M(),
1034
            pattern_count_map,
1035
            data_shape,
1036
            prepack_removal=True,
1037
            fuse_clamping_ops=True,
1038
        )
1039

1040
        # Inplace relu fusion test.
1041
        pattern_count_map = {
1042
            "aten::relu": 2,
1043
            "prepacked::conv2d_clamp_prepack": -1,
1044
            "prepacked::conv2d_clamp_run": 1,
1045
            "prepacked::linear_clamp_prepack": -1,
1046
            "prepacked::linear_clamp_run": 1,
1047
        }
1048
        TestXNNPACKRewritePass.validate_transformed_module(
1049
            M(F.relu_), pattern_count_map, data_shape, prepack_removal=True
1050
        )
1051
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1052
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1053
        pattern_count_map["aten::relu"] = -1
1054
        TestXNNPACKRewritePass.validate_transformed_module(
1055
            M(F.relu_),
1056
            pattern_count_map,
1057
            data_shape,
1058
            prepack_removal=True,
1059
            fuse_clamping_ops=True,
1060
        )
1061

1062
        # Not inplace hardtanh fusion test.
1063
        pattern_count_map = {
1064
            "aten::hardtanh": 2,
1065
            "prepacked::conv2d_clamp_prepack": -1,
1066
            "prepacked::conv2d_clamp_run": 1,
1067
            "prepacked::linear_clamp_prepack": -1,
1068
            "prepacked::linear_clamp_run": 1,
1069
        }
1070
        TestXNNPACKRewritePass.validate_transformed_module(
1071
            M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True
1072
        )
1073
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1074
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1075
        pattern_count_map["aten::hardtanh"] = -1
1076
        TestXNNPACKRewritePass.validate_transformed_module(
1077
            M(F.hardtanh),
1078
            pattern_count_map,
1079
            data_shape,
1080
            prepack_removal=True,
1081
            fuse_clamping_ops=True,
1082
        )
1083

1084
        # Inplace hardtanh fusion test.
1085
        pattern_count_map = {
1086
            "aten::hardtanh_": 2,
1087
            "prepacked::conv2d_clamp_prepack": -1,
1088
            "prepacked::conv2d_clamp_run": 1,
1089
            "prepacked::linear_clamp_prepack": -1,
1090
            "prepacked::linear_clamp_run": 1,
1091
        }
1092
        TestXNNPACKRewritePass.validate_transformed_module(
1093
            M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True
1094
        )
1095
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1096
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1097
        pattern_count_map["aten::hardtanh_"] = -1
1098
        TestXNNPACKRewritePass.validate_transformed_module(
1099
            M(F.hardtanh_),
1100
            pattern_count_map,
1101
            data_shape,
1102
            prepack_removal=True,
1103
            fuse_clamping_ops=True,
1104
        )
1105

1106
        class MFusionAntiPattern(torch.nn.Module):
1107
            def __init__(self) -> None:
1108
                super().__init__()
1109
                self.linear_weight = torch.nn.Parameter(
1110
                    torch.rand(linear_weight_shape), requires_grad=False
1111
                )
1112
                self.linear_bias = torch.nn.Parameter(
1113
                    torch.rand(weight_output_dim), requires_grad=False
1114
                )
1115
                self.strides = strides
1116
                self.paddings = paddings
1117
                self.dilations = dilations
1118
                self.groups = groups
1119

1120
            def forward(self, x):
1121
                o = F.linear(x, self.linear_weight, self.linear_bias)
1122
                o = F.relu(o)
1123
                o = F.hardtanh(o)
1124
                return o
1125

1126
        # Unfusable hardtanh.
1127
        pattern_count_map = {
1128
            "aten::hardtanh": 1,  # hardtanh cannot be.
1129
            "aten::relu": -1,  # relu is fused.
1130
            "prepacked::linear_clamp_prepack": -1,
1131
            "prepacked::linear_clamp_run": 1,
1132
        }
1133
        TestXNNPACKRewritePass.validate_transformed_module(
1134
            MFusionAntiPattern(),
1135
            pattern_count_map,
1136
            (16, linear_weight_shape[1]),
1137
            prepack_removal=True,
1138
            fuse_clamping_ops=True,
1139
        )
1140

1141
        class MFusionAntiPatternParamMinMax(torch.nn.Module):
1142
            def __init__(self) -> None:
1143
                super().__init__()
1144
                self.linear_weight = torch.nn.Parameter(
1145
                    torch.rand(linear_weight_shape), requires_grad=False
1146
                )
1147
                self.linear_bias = torch.nn.Parameter(
1148
                    torch.rand(weight_output_dim), requires_grad=False
1149
                )
1150
                self.strides = strides
1151
                self.paddings = paddings
1152
                self.dilations = dilations
1153
                self.groups = groups
1154

1155
            def forward(self, x):
1156
                min = x[0, 0]
1157
                max = min + 10
1158
                o = F.linear(x, self.linear_weight, self.linear_bias)
1159
                o = F.hardtanh(o, min, max)
1160
                return o
1161

1162
        # Unfusable hardtanh.
1163
        pattern_count_map = {
1164
            "aten::hardtanh": 1,  # hardtanh cannot be.
1165
            "prepacked::linear_clamp_prepack": -1,
1166
            "prepacked::linear_clamp_run": 1,
1167
        }
1168
        TestXNNPACKRewritePass.validate_transformed_module(
1169
            MFusionAntiPatternParamMinMax(),
1170
            pattern_count_map,
1171
            (16, linear_weight_shape[1]),
1172
            prepack_removal=True,
1173
            fuse_clamping_ops=True,
1174
        )
1175

1176
    def test_decomposed_linear(self):
1177
        data_shape = [2, 32]
1178
        weight_output_dim = 24
1179
        weight_shape = (weight_output_dim, data_shape[-1])
1180

1181
        class DecomposedLinearAddmm(torch.nn.Module):
1182
            def __init__(self) -> None:
1183
                super().__init__()
1184
                self.weight = torch.nn.Parameter(
1185
                    torch.rand(weight_shape), requires_grad=False
1186
                )
1187
                self.bias = torch.nn.Parameter(
1188
                    torch.rand(weight_output_dim), requires_grad=False
1189
                )
1190

1191
            def forward(self, x):
1192
                weight_t = self.weight.t()
1193
                return torch.addmm(self.bias, x, weight_t)
1194

1195
        class DecomposedLinearMatmulAdd(torch.nn.Module):
1196
            def __init__(self) -> None:
1197
                super().__init__()
1198
                self.weight = torch.nn.Parameter(
1199
                    torch.rand(weight_shape), requires_grad=False
1200
                )
1201
                self.bias = torch.nn.Parameter(
1202
                    torch.rand(weight_output_dim), requires_grad=False
1203
                )
1204

1205
            def forward(self, x):
1206
                weight_t = self.weight.t()
1207
                y = torch.matmul(x, weight_t)
1208
                res = y.add_(self.bias)
1209
                return res
1210

1211
        class DecomposedLinearMatmul(torch.nn.Module):
1212
            def __init__(self) -> None:
1213
                super().__init__()
1214
                self.weight = torch.nn.Parameter(
1215
                    torch.rand(weight_shape), requires_grad=False
1216
                )
1217
                self.bias = torch.nn.Parameter(
1218
                    torch.rand(weight_output_dim), requires_grad=False
1219
                )
1220

1221
            def forward(self, x):
1222
                weight_t = self.weight.t()
1223
                res = torch.matmul(x, weight_t)
1224
                return res
1225

1226
        # Linear with bias pattern.
1227
        pattern_count_map = {
1228
            "Tensor = prim::CallFunction": -1,
1229
            "prepacked::linear_clamp_prepack": 1,
1230
            "prepacked::linear_clamp_run": 1,
1231
        }
1232
        TestXNNPACKRewritePass.validate_transformed_module(
1233
            DecomposedLinearAddmm(), pattern_count_map, data_shape
1234
        )
1235
        TestXNNPACKRewritePass.validate_transformed_module(
1236
            DecomposedLinearMatmulAdd(), pattern_count_map, data_shape
1237
        )
1238
        TestXNNPACKRewritePass.validate_transformed_module(
1239
            DecomposedLinearMatmul(), pattern_count_map, data_shape
1240
        )
1241

1242

1243
@unittest.skipUnless(
1244
    torch.backends.xnnpack.enabled,
1245
    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
1246
)
1247
@unittest.skipIf(
1248
    TEST_WITH_TSAN,
1249
    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
1250
)
1251
class TestXNNPACKConv1dTransformPass(TestCase):
1252
    @staticmethod
1253
    def validate_transform_conv1d_to_conv2d(
1254
        self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape
1255
    ):
1256
        input_data = torch.normal(1, 20, size=data_shape)
1257

1258
        for jit_method in ["script", "trace"]:
1259
            module_instance = self
1260
            if jit_method == "script":
1261
                scripted_model = torch.jit.script(module_instance)
1262
            else:
1263
                scripted_model = torch.jit.trace(module_instance, input_data)
1264
            scripted_model.eval()
1265
            ref_result = scripted_model(input_data)
1266
            torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
1267
            optimized_scripted_model = optimize_for_mobile(scripted_model)
1268

1269
            buffer = io.BytesIO()
1270
            torch.jit.save(scripted_model, buffer)
1271
            buffer.seek(0)
1272
            deserialized_scripted_model = torch.jit.load(buffer)
1273

1274
            for pattern, v in pattern_count_transformed_map.items():
1275
                if v == 0:
1276
                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
1277
                elif v == -1:
1278
                    FileCheck().check_not(pattern).run(
1279
                        deserialized_scripted_model.graph
1280
                    )
1281
                else:
1282
                    FileCheck().check_count(pattern, v, exactly=True).run(
1283
                        deserialized_scripted_model.graph
1284
                    )
1285
            transformed_result = deserialized_scripted_model(input_data)
1286
            torch.testing.assert_close(
1287
                ref_result, transformed_result, rtol=1e-2, atol=1e-3
1288
            )
1289

1290
            optimized_buffer = io.BytesIO()
1291
            torch.jit.save(optimized_scripted_model, optimized_buffer)
1292
            optimized_buffer.seek(0)
1293
            deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
1294

1295
            for pattern, v in pattern_count_optimized_map.items():
1296
                if v == 0:
1297
                    FileCheck().check(pattern).run(
1298
                        deserialized_optimized_scripted_model.graph
1299
                    )
1300
                elif v == -1:
1301
                    FileCheck().check_not(pattern).run(
1302
                        deserialized_optimized_scripted_model.graph
1303
                    )
1304
                else:
1305
                    FileCheck().check_count(pattern, v, exactly=True).run(
1306
                        deserialized_optimized_scripted_model.graph
1307
                    )
1308
            xnnpack_result = deserialized_optimized_scripted_model(input_data)
1309
            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
1310

1311
    @unittest.skipIf(IS_FBCODE, "T137513244")
1312
    def test_conv1d_basic(self):
1313
        batch_size_list = range(1, 3)
1314
        input_channels_per_group_list = range(10, 12)
1315
        width_list = range(10, 12)
1316
        output_channels_per_group_list = range(10, 12)
1317
        groups_list = range(1, 3)
1318
        kernel_list = range(1, 4)
1319
        stride_list = range(1, 3)
1320
        padding_list = range(0, 3)
1321
        dilation_list = range(1, 3)
1322

1323
        for hparams in itertools.product(
1324
            batch_size_list,
1325
            input_channels_per_group_list,
1326
            width_list,
1327
            output_channels_per_group_list,
1328
            groups_list,
1329
            kernel_list,
1330
            stride_list,
1331
            padding_list,
1332
            dilation_list,
1333
        ):
1334
            (
1335
                batch_size,
1336
                input_channels_per_group,
1337
                width,
1338
                output_channels_per_group,
1339
                groups,
1340
                kernel,
1341
                stride,
1342
                padding,
1343
                dilation,
1344
            ) = hparams
1345

1346
            input_channels = input_channels_per_group * groups
1347
            output_channels = output_channels_per_group * groups
1348
            conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1349
            conv_bias_shape = output_channels
1350

1351
            class Conv1D(torch.nn.Module):
1352
                def __init__(self) -> None:
1353
                    super().__init__()
1354
                    self.weight = torch.nn.Parameter(
1355
                        torch.rand(conv_weight_shape), requires_grad=False
1356
                    )
1357
                    self.bias = torch.nn.Parameter(
1358
                        torch.rand(conv_bias_shape), requires_grad=False
1359
                    )
1360
                    self.stride = stride
1361
                    self.padding = padding
1362
                    self.dilation = dilation
1363
                    self.groups = groups
1364

1365
                def forward(self, x):
1366
                    return F.conv1d(
1367
                        x,
1368
                        self.weight,
1369
                        self.bias,
1370
                        self.stride,
1371
                        self.padding,
1372
                        self.dilation,
1373
                        self.groups,
1374
                    )
1375

1376
            data_shape = (batch_size, input_channels, width)
1377
            pattern_count_transformed_map = {
1378
                "Tensor = aten::conv1d": -1,
1379
                "Tensor = aten::conv2d": 1,
1380
            }
1381
            pattern_count_optimized_map = {
1382
                "Tensor = aten::conv1d": -1,
1383
                "Tensor = aten::conv2d": -1,
1384
                "prepacked::conv2d_clamp_prepack": -1,
1385
                "prepacked::conv2d_clamp_run": 1,
1386
            }
1387

1388
            TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
1389
                Conv1D(),
1390
                pattern_count_transformed_map,
1391
                pattern_count_optimized_map,
1392
                data_shape,
1393
            )
1394

1395
    # See https://github.com/pytorch/pytorch/issues/46066
1396
    @slowTest
1397
    def test_conv1d_with_relu_fc(self):
1398
        batch_size_list = range(1, 3)
1399
        input_channels_per_group_list = range(10, 12)
1400
        width_list = range(10, 12)
1401
        output_channels_per_group_list = range(10, 12)
1402
        groups_list = range(1, 3)
1403
        kernel_list = range(1, 4)
1404
        stride_list = range(1, 3)
1405
        padding_list = range(0, 3)
1406
        dilation_list = range(1, 3)
1407
        output_features_list = range(1, 3)
1408

1409
        for hparams in itertools.product(
1410
            batch_size_list,
1411
            input_channels_per_group_list,
1412
            width_list,
1413
            output_channels_per_group_list,
1414
            groups_list,
1415
            kernel_list,
1416
            stride_list,
1417
            padding_list,
1418
            dilation_list,
1419
            output_features_list,
1420
        ):
1421
            (
1422
                batch_size,
1423
                input_channels_per_group,
1424
                width,
1425
                output_channels_per_group,
1426
                groups,
1427
                kernel,
1428
                stride,
1429
                padding,
1430
                dilation,
1431
                output_features,
1432
            ) = hparams
1433

1434
            input_channels = input_channels_per_group * groups
1435
            output_channels = output_channels_per_group * groups
1436
            conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1437
            conv_bias_shape = output_channels
1438
            conv_output_width = (
1439
                int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
1440
            )
1441
            fc_weight_shape = (output_features, output_channels * conv_output_width)
1442
            fc_bias_shape = output_features
1443

1444
            class Net(torch.nn.Module):
1445
                def __init__(self) -> None:
1446
                    super().__init__()
1447
                    self.conv_weight = torch.nn.Parameter(
1448
                        torch.rand(conv_weight_shape), requires_grad=False
1449
                    )
1450
                    self.conv_bias = torch.nn.Parameter(
1451
                        torch.rand(conv_bias_shape), requires_grad=False
1452
                    )
1453
                    self.stride = stride
1454
                    self.padding = padding
1455
                    self.dilation = dilation
1456
                    self.groups = groups
1457

1458
                    self.fc_weight = torch.nn.Parameter(
1459
                        torch.rand(fc_weight_shape), requires_grad=False
1460
                    )
1461
                    self.fc_bias = torch.nn.Parameter(
1462
                        torch.rand(fc_bias_shape), requires_grad=False
1463
                    )
1464

1465
                def forward(self, x):
1466
                    x = F.conv1d(
1467
                        x,
1468
                        self.conv_weight,
1469
                        self.conv_bias,
1470
                        self.stride,
1471
                        self.padding,
1472
                        self.dilation,
1473
                        self.groups,
1474
                    )
1475
                    x = F.relu(x)
1476
                    x = x.view(x.size(0), -1)
1477
                    x = F.linear(x, self.fc_weight, self.fc_bias)
1478
                    return x
1479

1480
            data_shape = (batch_size, input_channels, width)
1481
            pattern_count_transformed_map = {
1482
                "Tensor = aten::conv1d": -1,
1483
                "Tensor = aten::conv2d": 1,
1484
            }
1485
            pattern_count_optimized_map = {
1486
                "Tensor = aten::conv1d": -1,
1487
                "Tensor = aten::conv2d": -1,
1488
                "prepacked::conv2d_clamp_prepack": -1,
1489
                "prepacked::conv2d_clamp_run": 1,
1490
            }
1491
            TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
1492
                Net(),
1493
                pattern_count_transformed_map,
1494
                pattern_count_optimized_map,
1495
                data_shape,
1496
            )
1497

1498

1499
if __name__ == "__main__":
1500
    run_tests()
1501

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.