pytorch

Форк
0
/
test_xnnpack_integration.py 
1115 строк · 52.3 Кб
1
# Owner(s): ["oncall: mobile"]
2

3
import unittest
4

5
import torch
6
import torch.backends.xnnpack
7
from torch.nn import functional as F
8
from torch.utils.mobile_optimizer import optimize_for_mobile
9
from torch.testing import FileCheck
10
import torch.testing._internal.hypothesis_utils as hu
11
from torch.testing._internal.common_utils import TestCase, run_tests, slowTest
12
from hypothesis import given, assume
13
from hypothesis import strategies as st
14
import io
15
import itertools
16

17
from torch.testing._internal.common_utils import IS_FBCODE, TEST_WITH_TSAN
18

19
@unittest.skipUnless(torch.backends.xnnpack.enabled,
20
                     " XNNPACK must be enabled for these tests."
21
                     " Please build with USE_XNNPACK=1.")
22
@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.")
23
class TestXNNPACKOps(TestCase):
24
    @unittest.skip("Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488")
25
    @given(batch_size=st.integers(0, 3),
26
           data_shape=hu.array_shapes(1, 3, 2, 64),
27
           weight_output_dim=st.integers(2, 64),
28
           use_bias=st.booleans())
29
    def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
30
        data_shape = [batch_size] + list(data_shape)
31
        input_data = torch.rand(data_shape)
32
        weight = torch.rand((weight_output_dim, data_shape[-1]))
33
        if use_bias:
34
            bias = torch.rand(weight_output_dim)
35
        else:
36
            bias = None
37
        ref_result = F.linear(input_data, weight, bias)
38
        packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
39
        output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
40
        torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
41

42
    @given(input_size=st.integers(2, 32),
43
           weight_output_dim=st.integers(2, 64),
44
           use_bias=st.booleans())
45
    def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
46
        input_data = torch.rand(input_size)
47
        weight = torch.rand((weight_output_dim, input_data.shape[-1]))
48
        if use_bias:
49
            bias = torch.rand(weight_output_dim)
50
        else:
51
            bias = None
52
        ref_result = F.linear(input_data, weight, bias)
53
        packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
54
        output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
55
        torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
56

57
    @given(batch_size=st.integers(0, 3),
58
           input_channels_per_group=st.integers(1, 32),
59
           height=st.integers(5, 64),
60
           width=st.integers(5, 64),
61
           output_channels_per_group=st.integers(1, 32),
62
           groups=st.integers(1, 16),
63
           kernel_h=st.integers(1, 7),
64
           kernel_w=st.integers(1, 7),
65
           stride_h=st.integers(1, 2),
66
           stride_w=st.integers(1, 2),
67
           pad_h=st.integers(0, 2),
68
           pad_w=st.integers(0, 2),
69
           dilation=st.integers(1, 2),
70
           use_bias=st.booleans(),
71
           format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
72
    def test_conv2d(self,
73
                    batch_size,
74
                    input_channels_per_group,
75
                    height,
76
                    width,
77
                    output_channels_per_group,
78
                    groups,
79
                    kernel_h,
80
                    kernel_w,
81
                    stride_h,
82
                    stride_w,
83
                    pad_h,
84
                    pad_w,
85
                    dilation,
86
                    use_bias,
87
                    format):
88
        input_channels = input_channels_per_group * groups
89
        output_channels = output_channels_per_group * groups
90
        kernels = (kernel_h, kernel_w)
91
        strides = (stride_h, stride_w)
92
        paddings = (pad_h, pad_w)
93
        dilations = (dilation, dilation)
94
        assume(height + 2 * paddings[0]
95
               >= dilations[0] * (kernels[0] - 1) + 1)
96
        assume(width + 2 * paddings[1]
97
               >= dilations[1] * (kernels[1] - 1) + 1)
98

99
        input_data = torch.rand((batch_size, input_channels, height, width))
100
        if (format is not None):
101
            input_data = input_data.contiguous(memory_format=format)
102
        weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
103
        bias = None
104
        if use_bias:
105
            bias = torch.rand(output_channels)
106

107
        ref_result = F.conv2d(input_data, weight, bias,
108
                              strides, paddings, dilations, groups)
109
        packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(weight, bias,
110
                                                                      strides, paddings, dilations, groups)
111
        xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(input_data, packed_weight_bias)
112
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
113

114
    @given(batch_size=st.integers(1, 3),
115
           input_channels_per_group=st.integers(1, 32),
116
           height=st.integers(5, 64),
117
           width=st.integers(5, 64),
118
           output_channels_per_group=st.integers(1, 32),
119
           groups=st.integers(1, 16),
120
           kernel_h=st.integers(1, 7),
121
           kernel_w=st.integers(1, 7),
122
           stride_h=st.integers(1, 2),
123
           stride_w=st.integers(1, 2),
124
           pad_h=st.integers(0, 2),
125
           pad_w=st.integers(0, 2),
126
           output_pad_h=st.integers(0, 2),
127
           output_pad_w=st.integers(0, 2),
128
           dilation=st.integers(1, 2),
129
           use_bias=st.booleans(),
130
           format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
131
    def test_conv2d_transpose(self,
132
                              batch_size,
133
                              input_channels_per_group,
134
                              height,
135
                              width,
136
                              output_channels_per_group,
137
                              groups,
138
                              kernel_h,
139
                              kernel_w,
140
                              stride_h,
141
                              stride_w,
142
                              pad_h,
143
                              pad_w,
144
                              output_pad_h,
145
                              output_pad_w,
146
                              dilation,
147
                              use_bias,
148
                              format):
149
        input_channels = input_channels_per_group * groups
150
        output_channels = output_channels_per_group * groups
151
        kernels = (kernel_h, kernel_w)
152
        strides = (stride_h, stride_w)
153
        paddings = (pad_h, pad_w)
154
        output_paddings = (output_pad_h, output_pad_w)
155
        dilations = (dilation, dilation)
156
        assume(height + 2 * paddings[0]
157
               >= dilations[0] * (kernels[0] - 1) + 1)
158
        assume(width + 2 * paddings[1]
159
               >= dilations[1] * (kernels[1] - 1) + 1)
160
        assume((output_pad_h < stride_h) and (output_pad_h < dilation))
161
        assume((output_pad_w < stride_w) and (output_pad_w < dilation))
162

163
        input_data = torch.rand((batch_size, input_channels, height, width))
164
        if (format is not None):
165
            input_data = input_data.contiguous(memory_format=format)
166
        weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w))
167
        bias = None
168
        if use_bias:
169
            bias = torch.rand(output_channels)
170

171
        # Note that groups/dilation is in reverse order from conv2d
172
        ref_result = F.conv_transpose2d(input_data, weight, bias,
173
                                        strides, paddings, output_paddings, groups, dilation)
174
        packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias,
175
                                                                                strides, paddings,
176
                                                                                output_paddings, dilations,
177
                                                                                groups)
178
        xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(input_data, packed_weight_bias)
179
        torch.testing.assert_close(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3)
180

181
@unittest.skipUnless(torch.backends.xnnpack.enabled,
182
                     " XNNPACK must be enabled for these tests."
183
                     " Please build with USE_XNNPACK=1.")
184
@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.")
185
class TestXNNPACKSerDes(TestCase):
186
    @unittest.skip("Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488")
187
    @given(batch_size=st.integers(0, 3),
188
           data_shape=hu.array_shapes(1, 3, 2, 64),
189
           weight_output_dim=st.integers(2, 64),
190
           use_bias=st.booleans())
191
    def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
192
        class Linear(torch.nn.Module):
193
            def __init__(self, weight, bias=None):
194
                super().__init__()
195
                self.weight = weight
196
                self.bias = bias
197

198
            def forward(self, x):
199
                return F.linear(x, self.weight, self.bias)
200

201
        class LinearPrePacked(torch.nn.Module):
202
            def __init__(self, weight, bias=None):
203
                super().__init__()
204
                self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
205

206
            def forward(self, x):
207
                return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
208

209
        data_shape = [batch_size] + list(data_shape)
210
        weight = torch.rand((weight_output_dim, data_shape[-1]))
211
        if use_bias:
212
            bias = torch.rand(weight_output_dim)
213
        else:
214
            bias = None
215
        scripted_linear = torch.jit.script(Linear(weight, bias))
216
        scripted_linear_clamp_prepacked = torch.jit.script(LinearPrePacked(weight, bias))
217
        input_data = torch.rand(data_shape)
218
        ref_result = scripted_linear(input_data)
219
        output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
220
        torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
221

222
        # Serialize the modules and then deserialize
223
        input_data = torch.rand(data_shape)
224
        buffer = io.BytesIO()
225
        torch.jit.save(scripted_linear, buffer)
226
        buffer.seek(0)
227
        deserialized_linear = torch.jit.load(buffer)
228
        buffer = io.BytesIO()
229
        torch.jit.save(scripted_linear_clamp_prepacked, buffer)
230
        buffer.seek(0)
231
        deserialized_linear_clamp_prepacked = torch.jit.load(buffer)
232
        ref_result = deserialized_linear(input_data)
233
        output_linearprepacked = deserialized_linear_clamp_prepacked(input_data)
234
        torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
235

236
    @given(batch_size=st.integers(0, 3),
237
           input_channels_per_group=st.integers(1, 32),
238
           height=st.integers(5, 64),
239
           width=st.integers(5, 64),
240
           output_channels_per_group=st.integers(1, 32),
241
           groups=st.integers(1, 16),
242
           kernel_h=st.integers(1, 7),
243
           kernel_w=st.integers(1, 7),
244
           stride_h=st.integers(1, 2),
245
           stride_w=st.integers(1, 2),
246
           pad_h=st.integers(0, 2),
247
           pad_w=st.integers(0, 2),
248
           dilation=st.integers(1, 2),
249
           use_bias=st.booleans(),
250
           format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
251
    def test_conv2d(self,
252
                    batch_size,
253
                    input_channels_per_group,
254
                    height,
255
                    width,
256
                    output_channels_per_group,
257
                    groups,
258
                    kernel_h,
259
                    kernel_w,
260
                    stride_h,
261
                    stride_w,
262
                    pad_h,
263
                    pad_w,
264
                    dilation,
265
                    use_bias,
266
                    format):
267
        class Conv2D(torch.nn.Module):
268
            def __init__(self, weight, bias, strides, paddings, dilations, groups):
269
                super().__init__()
270
                self.weight = weight
271
                self.bias = bias
272
                self.strides = strides
273
                self.paddings = paddings
274
                self.dilations = dilations
275
                self.groups = groups
276

277
            def forward(self, x):
278
                return F.conv2d(x, self.weight, self.bias,
279
                                self.strides, self.paddings, self.dilations, self.groups)
280

281
        class Conv2DPrePacked(torch.nn.Module):
282
            def __init__(self, weight, bias, strides, paddings, dilations, groups):
283
                super().__init__()
284
                self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(weight, bias,
285
                                                                                   strides, paddings, dilations, groups)
286

287
            def forward(self, x):
288
                return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias)
289

290
        input_channels = input_channels_per_group * groups
291
        output_channels = output_channels_per_group * groups
292
        kernels = (kernel_h, kernel_w)
293
        strides = (stride_h, stride_w)
294
        paddings = (pad_h, pad_w)
295
        dilations = (dilation, dilation)
296
        assume(height + 2 * paddings[0] >=
297
               dilations[0] * (kernels[0] - 1) + 1)
298
        assume(width + 2 * paddings[1] >=
299
               dilations[1] * (kernels[1] - 1) + 1)
300

301
        input_data = torch.rand((batch_size, input_channels, height, width))
302
        if (format is not None):
303
            input_data = input_data.contiguous(memory_format=format)
304
        weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
305
        bias = None
306
        if use_bias:
307
            bias = torch.rand(output_channels)
308

309
        scripted_conv2d = torch.jit.script(Conv2D(weight, bias,
310
                                                  strides, paddings, dilations, groups))
311
        scripted_conv2d_clamp_prepacked = torch.jit.script(Conv2DPrePacked(
312
            weight, bias, strides, paddings, dilations, groups))
313
        ref_result = scripted_conv2d(input_data)
314
        xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
315
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
316

317
        # Serialize the modules and then deserialize
318
        input_data = torch.rand((batch_size, input_channels, height, width))
319
        if (format is not None):
320
            input_data = input_data.contiguous(memory_format=format)
321
        buffer = io.BytesIO()
322
        torch.jit.save(scripted_conv2d, buffer)
323
        buffer.seek(0)
324
        deserialized_conv2d = torch.jit.load(buffer)
325
        buffer = io.BytesIO()
326
        torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
327
        buffer.seek(0)
328
        deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
329
        ref_result = deserialized_conv2d(input_data)
330
        xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
331
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
332

333
    @given(batch_size=st.integers(0, 3),
334
           input_channels_per_group=st.integers(1, 32),
335
           height=st.integers(5, 64),
336
           width=st.integers(5, 64),
337
           output_channels_per_group=st.integers(1, 32),
338
           groups=st.integers(1, 16),
339
           kernel_h=st.integers(1, 7),
340
           kernel_w=st.integers(1, 7),
341
           stride_h=st.integers(1, 2),
342
           stride_w=st.integers(1, 2),
343
           pad_h=st.integers(0, 2),
344
           pad_w=st.integers(0, 2),
345
           output_pad_h=st.integers(0, 2),
346
           output_pad_w=st.integers(0, 2),
347
           dilation=st.integers(1, 2),
348
           use_bias=st.booleans(),
349
           format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
350
    def test_conv2d_transpose(self,
351
                              batch_size,
352
                              input_channels_per_group,
353
                              height,
354
                              width,
355
                              output_channels_per_group,
356
                              groups,
357
                              kernel_h,
358
                              kernel_w,
359
                              stride_h,
360
                              stride_w,
361
                              pad_h,
362
                              pad_w,
363
                              output_pad_h,
364
                              output_pad_w,
365
                              dilation,
366
                              use_bias,
367
                              format):
368
        class Conv2DT(torch.nn.Module):
369
            def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups):
370
                super().__init__()
371
                self.weight = weight
372
                self.bias = bias
373
                self.strides = strides
374
                self.paddings = paddings
375
                self.output_paddings = output_paddings
376
                self.dilations = dilations
377
                self.groups = groups
378

379
            def forward(self, x):
380
                return F.conv_transpose2d(x, self.weight, self.bias,
381
                                          self.strides, self.paddings, self.output_paddings, self.groups, self.dilations)
382

383
        class Conv2DTPrePacked(torch.nn.Module):
384
            def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups):
385
                super().__init__()
386
                self.packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias,
387
                                                                                             strides, paddings,
388
                                                                                             output_paddings,
389
                                                                                             dilations, groups)
390

391
            def forward(self, x):
392
                return torch.ops.prepacked.conv2d_transpose_clamp_run(x, self.packed_weight_bias)
393

394
        input_channels = input_channels_per_group * groups
395
        output_channels = output_channels_per_group * groups
396
        kernels = (kernel_h, kernel_w)
397
        strides = (stride_h, stride_w)
398
        paddings = (pad_h, pad_w)
399
        output_paddings = (output_pad_h, output_pad_w)
400
        dilations = (dilation, dilation)
401
        assume(height + 2 * paddings[0] >=
402
               dilations[0] * (kernels[0] - 1) + 1)
403
        assume(width + 2 * paddings[1] >=
404
               dilations[1] * (kernels[1] - 1) + 1)
405
        assume((output_pad_h < stride_h) and (output_pad_h < dilation))
406
        assume((output_pad_w < stride_w) and (output_pad_w < dilation))
407

408
        input_data = torch.rand((batch_size, input_channels, height, width))
409
        if (format is not None):
410
            input_data = input_data.contiguous(memory_format=format)
411
        weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w))
412
        bias = None
413
        if use_bias:
414
            bias = torch.rand(output_channels)
415

416
        scripted_conv2d = torch.jit.script(Conv2DT(weight, bias,
417
                                                   strides, paddings,
418
                                                   output_paddings, dilations, groups))
419
        scripted_conv2d_clamp_prepacked = torch.jit.script(Conv2DTPrePacked(
420
            weight, bias, strides, paddings, output_paddings, dilations, groups))
421
        ref_result = scripted_conv2d(input_data)
422
        xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
423
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
424

425
        # Serialize the modules and then deserialize
426
        input_data = torch.rand((batch_size, input_channels, height, width))
427
        if (format is not None):
428
            input_data = input_data.contiguous(memory_format=format)
429
        buffer = io.BytesIO()
430
        torch.jit.save(scripted_conv2d, buffer)
431
        buffer.seek(0)
432
        deserialized_conv2d = torch.jit.load(buffer)
433
        buffer = io.BytesIO()
434
        torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
435
        buffer.seek(0)
436
        deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
437
        ref_result = deserialized_conv2d(input_data)
438
        xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
439
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
440

441
    @unittest.skip("Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488")
442
    @given(batch_size=st.integers(0, 3),
443
           input_channels_per_group=st.integers(1, 32),
444
           height=st.integers(5, 64),
445
           width=st.integers(5, 64),
446
           output_channels_per_group=st.integers(1, 32),
447
           groups=st.integers(1, 16),
448
           kernel_h=st.integers(1, 7),
449
           kernel_w=st.integers(1, 7),
450
           stride_h=st.integers(1, 2),
451
           stride_w=st.integers(1, 2),
452
           pad_h=st.integers(0, 2),
453
           pad_w=st.integers(0, 2),
454
           dilation=st.integers(1, 2),
455
           linear_weight_output_dim=st.integers(2, 64),
456
           use_bias=st.booleans(),
457
           format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
458
    def test_combined_model(self,
459
                            batch_size,
460
                            input_channels_per_group,
461
                            height,
462
                            width,
463
                            output_channels_per_group,
464
                            groups,
465
                            kernel_h,
466
                            kernel_w,
467
                            stride_h,
468
                            stride_w,
469
                            pad_h,
470
                            pad_w,
471
                            dilation,
472
                            linear_weight_output_dim,
473
                            use_bias,
474
                            format):
475
        class M(torch.nn.Module):
476
            def __init__(self, conv_weight, conv_bias, linear_weight, linear_bias,
477
                         strides, paddings, dilations, groups):
478
                super().__init__()
479
                self.conv_weight = conv_weight
480
                self.conv_bias = conv_bias
481
                self.linear_weight = linear_weight
482
                self.linear_bias = linear_bias
483
                self.strides = strides
484
                self.paddings = paddings
485
                self.dilations = dilations
486
                self.groups = groups
487

488
            def forward(self, x):
489
                o = F.conv2d(x, self.conv_weight, self.conv_bias,
490
                             self.strides, self.paddings, self.dilations, self.groups)
491
                o = o.permute([0, 2, 3, 1])
492
                o = F.linear(o, self.linear_weight, self.linear_bias)
493
                return F.relu(o)
494

495
        class MPrePacked(torch.nn.Module):
496
            def __init__(self, conv_weight, conv_bias, linear_weight, linear_bias,
497
                         strides, paddings, dilations, groups):
498
                super().__init__()
499
                self.conv2d_clamp_run_weight_bias = \
500
                    torch.ops.prepacked.conv2d_clamp_prepack(conv_weight, conv_bias,
501
                                                             strides, paddings, dilations, groups)
502
                self.linear_clamp_run_weight_bias = \
503
                    torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias)
504

505
            def forward(self, x):
506
                o = torch.ops.prepacked.conv2d_clamp_run(x, self.conv2d_clamp_run_weight_bias)
507
                o = o.permute([0, 2, 3, 1])
508
                o = torch.ops.prepacked.linear_clamp_run(o, self.linear_clamp_run_weight_bias)
509
                return F.relu(o)
510

511
        input_channels = input_channels_per_group * groups
512
        output_channels = output_channels_per_group * groups
513
        kernels = (kernel_h, kernel_w)
514
        strides = (stride_h, stride_w)
515
        paddings = (pad_h, pad_w)
516
        dilations = (dilation, dilation)
517
        assume(height + 2 * paddings[0]
518
               >= dilations[0] * (kernels[0] - 1) + 1)
519
        assume(width + 2 * paddings[1]
520
               >= dilations[1] * (kernels[1] - 1) + 1)
521

522
        input_data = torch.rand((batch_size, input_channels, height, width))
523
        if (format is not None):
524
            input_data = input_data.contiguous(memory_format=format)
525
        conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
526
        conv_bias = None
527
        if use_bias:
528
            conv_bias = torch.rand(output_channels)
529

530
        # This is done just to find the output shape of the result
531
        # so that the shape of weight for the following linear layer
532
        # can be determined.
533
        result = F.conv2d(input_data, conv_weight, conv_bias,
534
                          strides, paddings, dilations, groups)
535
        linear_input_shape = result.shape[1]
536

537
        linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
538
        linear_bias = None
539
        if use_bias:
540
            linear_bias = torch.rand(linear_weight_output_dim)
541

542
        scripted_m = torch.jit.script(M(conv_weight, conv_bias, linear_weight,
543
                                        linear_bias, strides, paddings, dilations, groups))
544
        scripted_m_prepacked = torch.jit.script(
545
            MPrePacked(
546
                conv_weight,
547
                conv_bias,
548
                linear_weight,
549
                linear_bias,
550
                strides,
551
                paddings,
552
                dilations,
553
                groups))
554
        ref_result = scripted_m(input_data)
555
        xnnpack_result = scripted_m_prepacked(input_data)
556
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
557

558
        # Serialize the modules and then deserialize
559
        input_data = torch.rand((batch_size, input_channels, height, width))
560
        input_data = input_data.contiguous(memory_format=torch.channels_last)
561
        buffer = io.BytesIO()
562
        torch.jit.save(scripted_m, buffer)
563
        buffer.seek(0)
564
        deserialized_m = torch.jit.load(buffer)
565
        buffer = io.BytesIO()
566
        torch.jit.save(scripted_m_prepacked, buffer)
567
        buffer.seek(0)
568
        deserialized_m_prepacked = torch.jit.load(buffer)
569
        ref_result = deserialized_m(input_data)
570
        xnnpack_result = deserialized_m_prepacked(input_data)
571
        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
572

573

574
@unittest.skipUnless(torch.backends.xnnpack.enabled,
575
                     " XNNPACK must be enabled for these tests."
576
                     " Please build with USE_XNNPACK=1.")
577
@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.")
578
class TestXNNPACKRewritePass(TestCase):
579
    @staticmethod
580
    def validate_transformed_module(
581
            # To please flake
582
            self,
583
            pattern_count_map,
584
            data_shape,
585
            prepack_removal=False,
586
            fuse_clamping_ops=False):
587
        input_data = torch.normal(1, 20, size=data_shape)
588

589
        for jit_method in ["script", "trace"]:
590
            module_instance = self
591
            if jit_method == "script":
592
                scripted_model = torch.jit.script(module_instance)
593
            else:
594
                scripted_model = torch.jit.trace(module_instance, input_data)
595
            scripted_model.eval()
596
            ref_result = scripted_model(input_data)
597
            torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
598
            if fuse_clamping_ops or prepack_removal:
599
                scripted_model._c = torch._C._freeze_module(scripted_model._c)
600
            if fuse_clamping_ops:
601
                torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
602
            if (prepack_removal):
603
                torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
604

605
            buffer = io.BytesIO()
606
            torch.jit.save(scripted_model, buffer)
607
            buffer.seek(0)
608
            deserialized_scripted_model = torch.jit.load(buffer)
609
            for pattern, v in pattern_count_map.items():
610
                if (v == 0):
611
                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
612
                elif (v == -1):
613
                    FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
614
                else:
615
                    FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
616
            xnnpack_result = deserialized_scripted_model(input_data)
617
            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
618

619
    def test_linear(self):
620
        data_shape = [2, 3, 32]
621
        weight_output_dim = 24
622
        weight_shape = (weight_output_dim, data_shape[-1])
623

624
        class Linear(torch.nn.Module):
625
            def __init__(self):
626
                super().__init__()
627
                self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
628
                self.bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
629

630
            def forward(self, x):
631
                return F.linear(x, self.weight, self.bias)
632

633
        class LinearNoBias(torch.nn.Module):
634
            def __init__(self):
635
                super().__init__()
636
                self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
637

638
            def forward(self, x):
639
                return F.linear(x, self.weight, None)
640

641
        # Linear with bias pattern.
642
        pattern_count_map = {"Tensor = prim::CallFunction": -1,
643
                             "prepacked::linear_clamp_prepack": 1,
644
                             "prepacked::linear_clamp_run": 1}
645
        TestXNNPACKRewritePass.validate_transformed_module(Linear(), pattern_count_map, data_shape)
646
        TestXNNPACKRewritePass.validate_transformed_module(LinearNoBias(), pattern_count_map, data_shape)
647

648
        # Conv params
649
        batch_size = 2
650
        input_channels_per_group = 6
651
        height = 16
652
        width = 16
653
        output_channels_per_group = 6
654
        groups = 4
655
        kernel_h = kernel_w = 3
656
        stride_h = stride_w = 1
657
        pad_h = pad_w = 1
658
        output_pad_h = output_pad_w = 0
659
        dilation = 1
660
        input_channels = input_channels_per_group * groups
661
        output_channels = output_channels_per_group * groups
662
        kernels = (kernel_h, kernel_w)
663
        strides = (stride_h, stride_w)
664
        paddings = (pad_h, pad_w)
665
        output_paddings = (output_pad_h, output_pad_w)
666
        dilations = (dilation, dilation)
667
        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
668
        conv_transpose_weight_shape = (input_channels, output_channels_per_group, kernel_h, kernel_w)
669
        conv_bias_shape = (output_channels)
670

671
        class Conv2D(torch.nn.Module):
672
            def __init__(self):
673
                super().__init__()
674
                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
675
                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
676
                self.strides = strides
677
                self.paddings = paddings
678
                self.dilations = dilations
679
                self.groups = groups
680

681
            def forward(self, x):
682
                return F.conv2d(x, self.weight, self.bias,
683
                                self.strides, self.paddings, self.dilations, self.groups)
684

685
        class Conv2DT(torch.nn.Module):
686
            def __init__(self):
687
                super().__init__()
688
                self.weight = torch.nn.Parameter(torch.rand(conv_transpose_weight_shape), requires_grad=False)
689
                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
690
                self.strides = strides
691
                self.paddings = paddings
692
                self.output_paddings = output_paddings
693
                self.dilations = dilations
694
                self.groups = groups
695

696
            def forward(self, x):
697
                return F.conv_transpose2d(x, self.weight, self.bias,
698
                                          self.strides, self.paddings, self.output_paddings, self.groups, self.dilations)
699

700

701
        data_shape = (batch_size, input_channels, height, width)
702
        pattern_count_map = {"Tensor = aten::conv2d": -1,
703
                             "prepacked::conv2d_clamp_prepack": 1,
704
                             "prepacked::conv2d_clamp_run": 1}
705
        TestXNNPACKRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
706

707
        transpose_data_shape = (batch_size, input_channels, height, width)
708
        transpose_pattern_count_map = {"Tensor = aten::conv_transpose2d": -1,
709
                                       "prepacked::conv2d_transpose_clamp_prepack": 1,
710
                                       "prepacked::conv2d_transpose_clamp_run": 1}
711
        TestXNNPACKRewritePass.validate_transformed_module(Conv2DT(), transpose_pattern_count_map, data_shape)
712

713
        input_data = torch.rand((batch_size, input_channels, height, width))
714
        conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
715
        conv_bias = torch.rand(output_channels)
716
        result = F.conv2d(input_data, conv_weight, conv_bias,
717
                          strides, paddings, dilations, groups)
718
        linear_input_shape = result.shape[1]
719
        linear_weight_shape = (weight_output_dim, linear_input_shape)
720

721
        class M(torch.nn.Module):
722
            def __init__(self, activation_fn=F.relu):
723
                super().__init__()
724
                self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
725
                self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
726
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape), requires_grad=False)
727
                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
728
                self.strides = strides
729
                self.paddings = paddings
730
                self.dilations = dilations
731
                self.groups = groups
732
                self.activation_fn = activation_fn
733

734
            def forward(self, x):
735
                o = F.conv2d(x, self.conv_weight, self.conv_bias,
736
                             self.strides, self.paddings, self.dilations, self.groups)
737
                o = self.activation_fn(o)
738
                o = o.permute([0, 2, 3, 1])
739
                o = F.linear(o, self.linear_weight, self.linear_bias)
740
                return self.activation_fn(o)
741

742
        pattern_count_map = {"Tensor = aten::conv2d": -1,
743
                             "prepacked::conv2d_clamp_prepack": 1,
744
                             "prepacked::conv2d_clamp_run": 1,
745
                             "prepacked::linear_clamp_prepack": 1,
746
                             "prepacked::linear_clamp_run": 1}
747
        TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape)
748
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
749
        pattern_count_map["Tensor = prim::CallFunction"] = -1
750
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
751
        TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
752

753
        # Not inplace relu fusion test.
754
        pattern_count_map = {"aten::relu": 2,
755
                             "prepacked::conv2d_clamp_prepack": -1,
756
                             "prepacked::conv2d_clamp_run": 1,
757
                             "prepacked::linear_clamp_prepack": -1,
758
                             "prepacked::linear_clamp_run": 1}
759
        TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
760
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
761
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
762
        pattern_count_map["aten::relu"] = -1
763
        TestXNNPACKRewritePass.validate_transformed_module(
764
            M(),
765
            pattern_count_map,
766
            data_shape,
767
            prepack_removal=True,
768
            fuse_clamping_ops=True)
769

770
        # Inplace relu fusion test.
771
        pattern_count_map = {"aten::relu": 2,
772
                             "prepacked::conv2d_clamp_prepack": -1,
773
                             "prepacked::conv2d_clamp_run": 1,
774
                             "prepacked::linear_clamp_prepack": -1,
775
                             "prepacked::linear_clamp_run": 1}
776
        TestXNNPACKRewritePass.validate_transformed_module(
777
            M(F.relu_),
778
            pattern_count_map,
779
            data_shape,
780
            prepack_removal=True)
781
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
782
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
783
        pattern_count_map["aten::relu"] = -1
784
        TestXNNPACKRewritePass.validate_transformed_module(
785
            M(F.relu_),
786
            pattern_count_map,
787
            data_shape,
788
            prepack_removal=True,
789
            fuse_clamping_ops=True)
790

791
        # Not inplace hardtanh fusion test.
792
        pattern_count_map = {"aten::hardtanh": 2,
793
                             "prepacked::conv2d_clamp_prepack": -1,
794
                             "prepacked::conv2d_clamp_run": 1,
795
                             "prepacked::linear_clamp_prepack": -1,
796
                             "prepacked::linear_clamp_run": 1}
797
        TestXNNPACKRewritePass.validate_transformed_module(
798
            M(F.hardtanh),
799
            pattern_count_map,
800
            data_shape,
801
            prepack_removal=True)
802
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
803
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
804
        pattern_count_map["aten::hardtanh"] = -1
805
        TestXNNPACKRewritePass.validate_transformed_module(
806
            M(F.hardtanh),
807
            pattern_count_map,
808
            data_shape,
809
            prepack_removal=True,
810
            fuse_clamping_ops=True)
811

812
        # Inplace hardtanh fusion test.
813
        pattern_count_map = {"aten::hardtanh_": 2,
814
                             "prepacked::conv2d_clamp_prepack": -1,
815
                             "prepacked::conv2d_clamp_run": 1,
816
                             "prepacked::linear_clamp_prepack": -1,
817
                             "prepacked::linear_clamp_run": 1}
818
        TestXNNPACKRewritePass.validate_transformed_module(
819
            M(F.hardtanh_),
820
            pattern_count_map,
821
            data_shape,
822
            prepack_removal=True)
823
        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
824
        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
825
        pattern_count_map["aten::hardtanh_"] = -1
826
        TestXNNPACKRewritePass.validate_transformed_module(
827
            M(F.hardtanh_),
828
            pattern_count_map,
829
            data_shape,
830
            prepack_removal=True,
831
            fuse_clamping_ops=True)
832

833
        class MFusionAntiPattern(torch.nn.Module):
834
            def __init__(self):
835
                super().__init__()
836
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape), requires_grad=False)
837
                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
838
                self.strides = strides
839
                self.paddings = paddings
840
                self.dilations = dilations
841
                self.groups = groups
842

843
            def forward(self, x):
844
                o = F.linear(x, self.linear_weight, self.linear_bias)
845
                o = F.relu(o)
846
                o = F.hardtanh(o)
847
                return o
848

849
        # Unfusable hardtanh.
850
        pattern_count_map = {"aten::hardtanh": 1,  # hardtanh cannot be.
851
                             "aten::relu": -1,  # relu is fused.
852
                             "prepacked::linear_clamp_prepack": -1,
853
                             "prepacked::linear_clamp_run": 1}
854
        TestXNNPACKRewritePass.validate_transformed_module(
855
            MFusionAntiPattern(),
856
            pattern_count_map,
857
            (16, linear_weight_shape[1]),
858
            prepack_removal=True,
859
            fuse_clamping_ops=True)
860

861
        class MFusionAntiPatternParamMinMax(torch.nn.Module):
862
            def __init__(self):
863
                super().__init__()
864
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape), requires_grad=False)
865
                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
866
                self.strides = strides
867
                self.paddings = paddings
868
                self.dilations = dilations
869
                self.groups = groups
870

871
            def forward(self, x):
872
                min = x[0, 0]
873
                max = min + 10
874
                o = F.linear(x, self.linear_weight, self.linear_bias)
875
                o = F.hardtanh(o, min, max)
876
                return o
877

878
        # Unfusable hardtanh.
879
        pattern_count_map = {"aten::hardtanh": 1,  # hardtanh cannot be.
880
                             "prepacked::linear_clamp_prepack": -1,
881
                             "prepacked::linear_clamp_run": 1}
882
        TestXNNPACKRewritePass.validate_transformed_module(
883
            MFusionAntiPatternParamMinMax(),
884
            pattern_count_map,
885
            (16, linear_weight_shape[1]),
886
            prepack_removal=True,
887
            fuse_clamping_ops=True)
888

889
    def test_decomposed_linear(self):
890
        data_shape = [2, 32]
891
        weight_output_dim = 24
892
        weight_shape = (weight_output_dim, data_shape[-1])
893

894
        class DecomposedLinearAddmm(torch.nn.Module):
895
            def __init__(self):
896
                super().__init__()
897
                self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
898
                self.bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
899

900
            def forward(self, x):
901
                weight_t = self.weight.t()
902
                return torch.addmm(self.bias, x, weight_t)
903

904
        class DecomposedLinearMatmulAdd(torch.nn.Module):
905
            def __init__(self):
906
                super().__init__()
907
                self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
908
                self.bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
909

910
            def forward(self, x):
911
                weight_t = self.weight.t()
912
                y = torch.matmul(x, weight_t)
913
                res = y.add_(self.bias)
914
                return res
915

916
        class DecomposedLinearMatmul(torch.nn.Module):
917
            def __init__(self):
918
                super().__init__()
919
                self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
920
                self.bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
921

922
            def forward(self, x):
923
                weight_t = self.weight.t()
924
                res = torch.matmul(x, weight_t)
925
                return res
926

927
        # Linear with bias pattern.
928
        pattern_count_map = {"Tensor = prim::CallFunction": -1,
929
                             "prepacked::linear_clamp_prepack": 1,
930
                             "prepacked::linear_clamp_run": 1}
931
        TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearAddmm(), pattern_count_map, data_shape)
932
        TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape)
933
        TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape)
934

935
@unittest.skipUnless(torch.backends.xnnpack.enabled,
936
                     " XNNPACK must be enabled for these tests."
937
                     " Please build with USE_XNNPACK=1.")
938
@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment")
939
class TestXNNPACKConv1dTransformPass(TestCase):
940
    @staticmethod
941
    def validate_transform_conv1d_to_conv2d(
942
            self,
943
            pattern_count_transformed_map,
944
            pattern_count_optimized_map,
945
            data_shape):
946
        input_data = torch.normal(1, 20, size=data_shape)
947

948
        for jit_method in ["script", "trace"]:
949
            module_instance = self
950
            if jit_method == "script":
951
                scripted_model = torch.jit.script(module_instance)
952
            else:
953
                scripted_model = torch.jit.trace(module_instance, input_data)
954
            scripted_model.eval()
955
            ref_result = scripted_model(input_data)
956
            torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
957
            optimized_scripted_model = optimize_for_mobile(scripted_model)
958

959
            buffer = io.BytesIO()
960
            torch.jit.save(scripted_model, buffer)
961
            buffer.seek(0)
962
            deserialized_scripted_model = torch.jit.load(buffer)
963

964
            for pattern, v in pattern_count_transformed_map.items():
965
                if (v == 0):
966
                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
967
                elif (v == -1):
968
                    FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
969
                else:
970
                    FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
971
            transformed_result = deserialized_scripted_model(input_data)
972
            torch.testing.assert_close(ref_result, transformed_result, rtol=1e-2, atol=1e-3)
973

974
            optimized_buffer = io.BytesIO()
975
            torch.jit.save(optimized_scripted_model, optimized_buffer)
976
            optimized_buffer.seek(0)
977
            deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
978

979
            for pattern, v in pattern_count_optimized_map.items():
980
                if (v == 0):
981
                    FileCheck().check(pattern).run(deserialized_optimized_scripted_model.graph)
982
                elif (v == -1):
983
                    FileCheck().check_not(pattern).run(deserialized_optimized_scripted_model.graph)
984
                else:
985
                    FileCheck().check_count(pattern, v, exactly=True).run(deserialized_optimized_scripted_model.graph)
986
            xnnpack_result = deserialized_optimized_scripted_model(input_data)
987
            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
988

989

990
    @unittest.skipIf(IS_FBCODE, "T137513244")
991
    def test_conv1d_basic(self):
992
        batch_size_list = range(1, 3)
993
        input_channels_per_group_list = range(10, 12)
994
        width_list = range(10, 12)
995
        output_channels_per_group_list = range(10, 12)
996
        groups_list = range(1, 3)
997
        kernel_list = range(1, 4)
998
        stride_list = range(1, 3)
999
        padding_list = range(0, 3)
1000
        dilation_list = range(1, 3)
1001

1002
        for hparams in itertools.product(batch_size_list,
1003
                                         input_channels_per_group_list,
1004
                                         width_list,
1005
                                         output_channels_per_group_list,
1006
                                         groups_list,
1007
                                         kernel_list,
1008
                                         stride_list,
1009
                                         padding_list,
1010
                                         dilation_list):
1011
            batch_size, input_channels_per_group, width, output_channels_per_group, \
1012
                groups, kernel, stride, padding, dilation = hparams
1013

1014
            input_channels = input_channels_per_group * groups
1015
            output_channels = output_channels_per_group * groups
1016
            conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1017
            conv_bias_shape = (output_channels)
1018

1019
            class Conv1D(torch.nn.Module):
1020
                def __init__(self):
1021
                    super().__init__()
1022
                    self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
1023
                    self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
1024
                    self.stride = stride
1025
                    self.padding = padding
1026
                    self.dilation = dilation
1027
                    self.groups = groups
1028

1029
                def forward(self, x):
1030
                    return F.conv1d(x, self.weight, self.bias,
1031
                                    self.stride, self.padding, self.dilation, self.groups)
1032

1033
            data_shape = (batch_size, input_channels, width)
1034
            pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
1035
                                             "Tensor = aten::conv2d": 1}
1036
            pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
1037
                                           "Tensor = aten::conv2d": -1,
1038
                                           "prepacked::conv2d_clamp_prepack" : -1,
1039
                                           "prepacked::conv2d_clamp_run": 1}
1040

1041
            TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Conv1D(),
1042
                                                                               pattern_count_transformed_map,
1043
                                                                               pattern_count_optimized_map,
1044
                                                                               data_shape)
1045

1046
    # See https://github.com/pytorch/pytorch/issues/46066
1047
    @slowTest
1048
    def test_conv1d_with_relu_fc(self):
1049
        batch_size_list = range(1, 3)
1050
        input_channels_per_group_list = range(10, 12)
1051
        width_list = range(10, 12)
1052
        output_channels_per_group_list = range(10, 12)
1053
        groups_list = range(1, 3)
1054
        kernel_list = range(1, 4)
1055
        stride_list = range(1, 3)
1056
        padding_list = range(0, 3)
1057
        dilation_list = range(1, 3)
1058
        output_features_list = range(1, 3)
1059

1060
        for hparams in itertools.product(batch_size_list,
1061
                                         input_channels_per_group_list,
1062
                                         width_list,
1063
                                         output_channels_per_group_list,
1064
                                         groups_list,
1065
                                         kernel_list,
1066
                                         stride_list,
1067
                                         padding_list,
1068
                                         dilation_list,
1069
                                         output_features_list):
1070
            batch_size, input_channels_per_group, width, output_channels_per_group, \
1071
                groups, kernel, stride, padding, dilation, output_features = hparams
1072

1073
            input_channels = input_channels_per_group * groups
1074
            output_channels = output_channels_per_group * groups
1075
            conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1076
            conv_bias_shape = (output_channels)
1077
            conv_output_width = int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
1078
            fc_weight_shape = (output_features, output_channels * conv_output_width)
1079
            fc_bias_shape = (output_features)
1080

1081
            class Net(torch.nn.Module):
1082
                def __init__(self):
1083
                    super().__init__()
1084
                    self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
1085
                    self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
1086
                    self.stride = stride
1087
                    self.padding = padding
1088
                    self.dilation = dilation
1089
                    self.groups = groups
1090

1091
                    self.fc_weight = torch.nn.Parameter(torch.rand(fc_weight_shape), requires_grad=False)
1092
                    self.fc_bias = torch.nn.Parameter(torch.rand(fc_bias_shape), requires_grad=False)
1093

1094
                def forward(self, x):
1095
                    x = F.conv1d(x, self.conv_weight, self.conv_bias,
1096
                                 self.stride, self.padding, self.dilation, self.groups)
1097
                    x = F.relu(x)
1098
                    x = x.view(x.size(0), -1)
1099
                    x = F.linear(x, self.fc_weight, self.fc_bias)
1100
                    return x
1101

1102
            data_shape = (batch_size, input_channels, width)
1103
            pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
1104
                                             "Tensor = aten::conv2d": 1}
1105
            pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
1106
                                           "Tensor = aten::conv2d": -1,
1107
                                           "prepacked::conv2d_clamp_prepack" : -1,
1108
                                           "prepacked::conv2d_clamp_run": 1}
1109
            TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Net(),
1110
                                                                               pattern_count_transformed_map,
1111
                                                                               pattern_count_optimized_map,
1112
                                                                               data_shape)
1113

1114
if __name__ == "__main__":
1115
    run_tests()
1116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.