1
# Owner(s): ["oncall: mobile"]
6
import torch.backends.xnnpack
7
from torch.nn import functional as F
8
from torch.utils.mobile_optimizer import optimize_for_mobile
9
from torch.testing import FileCheck
10
import torch.testing._internal.hypothesis_utils as hu
11
from torch.testing._internal.common_utils import TestCase, run_tests, slowTest
12
from hypothesis import given, assume
13
from hypothesis import strategies as st
17
from torch.testing._internal.common_utils import IS_FBCODE, TEST_WITH_TSAN
19
@unittest.skipUnless(torch.backends.xnnpack.enabled,
20
" XNNPACK must be enabled for these tests."
21
" Please build with USE_XNNPACK=1.")
22
@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.")
23
class TestXNNPACKOps(TestCase):
24
@unittest.skip("Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488")
25
@given(batch_size=st.integers(0, 3),
26
data_shape=hu.array_shapes(1, 3, 2, 64),
27
weight_output_dim=st.integers(2, 64),
28
use_bias=st.booleans())
29
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
30
data_shape = [batch_size] + list(data_shape)
31
input_data = torch.rand(data_shape)
32
weight = torch.rand((weight_output_dim, data_shape[-1]))
34
bias = torch.rand(weight_output_dim)
37
ref_result = F.linear(input_data, weight, bias)
38
packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
39
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
40
torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
42
@given(input_size=st.integers(2, 32),
43
weight_output_dim=st.integers(2, 64),
44
use_bias=st.booleans())
45
def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
46
input_data = torch.rand(input_size)
47
weight = torch.rand((weight_output_dim, input_data.shape[-1]))
49
bias = torch.rand(weight_output_dim)
52
ref_result = F.linear(input_data, weight, bias)
53
packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
54
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias)
55
torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
57
@given(batch_size=st.integers(0, 3),
58
input_channels_per_group=st.integers(1, 32),
59
height=st.integers(5, 64),
60
width=st.integers(5, 64),
61
output_channels_per_group=st.integers(1, 32),
62
groups=st.integers(1, 16),
63
kernel_h=st.integers(1, 7),
64
kernel_w=st.integers(1, 7),
65
stride_h=st.integers(1, 2),
66
stride_w=st.integers(1, 2),
67
pad_h=st.integers(0, 2),
68
pad_w=st.integers(0, 2),
69
dilation=st.integers(1, 2),
70
use_bias=st.booleans(),
71
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
74
input_channels_per_group,
77
output_channels_per_group,
88
input_channels = input_channels_per_group * groups
89
output_channels = output_channels_per_group * groups
90
kernels = (kernel_h, kernel_w)
91
strides = (stride_h, stride_w)
92
paddings = (pad_h, pad_w)
93
dilations = (dilation, dilation)
94
assume(height + 2 * paddings[0]
95
>= dilations[0] * (kernels[0] - 1) + 1)
96
assume(width + 2 * paddings[1]
97
>= dilations[1] * (kernels[1] - 1) + 1)
99
input_data = torch.rand((batch_size, input_channels, height, width))
100
if (format is not None):
101
input_data = input_data.contiguous(memory_format=format)
102
weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
105
bias = torch.rand(output_channels)
107
ref_result = F.conv2d(input_data, weight, bias,
108
strides, paddings, dilations, groups)
109
packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(weight, bias,
110
strides, paddings, dilations, groups)
111
xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(input_data, packed_weight_bias)
112
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
114
@given(batch_size=st.integers(1, 3),
115
input_channels_per_group=st.integers(1, 32),
116
height=st.integers(5, 64),
117
width=st.integers(5, 64),
118
output_channels_per_group=st.integers(1, 32),
119
groups=st.integers(1, 16),
120
kernel_h=st.integers(1, 7),
121
kernel_w=st.integers(1, 7),
122
stride_h=st.integers(1, 2),
123
stride_w=st.integers(1, 2),
124
pad_h=st.integers(0, 2),
125
pad_w=st.integers(0, 2),
126
output_pad_h=st.integers(0, 2),
127
output_pad_w=st.integers(0, 2),
128
dilation=st.integers(1, 2),
129
use_bias=st.booleans(),
130
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
131
def test_conv2d_transpose(self,
133
input_channels_per_group,
136
output_channels_per_group,
149
input_channels = input_channels_per_group * groups
150
output_channels = output_channels_per_group * groups
151
kernels = (kernel_h, kernel_w)
152
strides = (stride_h, stride_w)
153
paddings = (pad_h, pad_w)
154
output_paddings = (output_pad_h, output_pad_w)
155
dilations = (dilation, dilation)
156
assume(height + 2 * paddings[0]
157
>= dilations[0] * (kernels[0] - 1) + 1)
158
assume(width + 2 * paddings[1]
159
>= dilations[1] * (kernels[1] - 1) + 1)
160
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
161
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
163
input_data = torch.rand((batch_size, input_channels, height, width))
164
if (format is not None):
165
input_data = input_data.contiguous(memory_format=format)
166
weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w))
169
bias = torch.rand(output_channels)
171
# Note that groups/dilation is in reverse order from conv2d
172
ref_result = F.conv_transpose2d(input_data, weight, bias,
173
strides, paddings, output_paddings, groups, dilation)
174
packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias,
176
output_paddings, dilations,
178
xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(input_data, packed_weight_bias)
179
torch.testing.assert_close(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3)
181
@unittest.skipUnless(torch.backends.xnnpack.enabled,
182
" XNNPACK must be enabled for these tests."
183
" Please build with USE_XNNPACK=1.")
184
@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.")
185
class TestXNNPACKSerDes(TestCase):
186
@unittest.skip("Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488")
187
@given(batch_size=st.integers(0, 3),
188
data_shape=hu.array_shapes(1, 3, 2, 64),
189
weight_output_dim=st.integers(2, 64),
190
use_bias=st.booleans())
191
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
192
class Linear(torch.nn.Module):
193
def __init__(self, weight, bias=None):
198
def forward(self, x):
199
return F.linear(x, self.weight, self.bias)
201
class LinearPrePacked(torch.nn.Module):
202
def __init__(self, weight, bias=None):
204
self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
206
def forward(self, x):
207
return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
209
data_shape = [batch_size] + list(data_shape)
210
weight = torch.rand((weight_output_dim, data_shape[-1]))
212
bias = torch.rand(weight_output_dim)
215
scripted_linear = torch.jit.script(Linear(weight, bias))
216
scripted_linear_clamp_prepacked = torch.jit.script(LinearPrePacked(weight, bias))
217
input_data = torch.rand(data_shape)
218
ref_result = scripted_linear(input_data)
219
output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
220
torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
222
# Serialize the modules and then deserialize
223
input_data = torch.rand(data_shape)
224
buffer = io.BytesIO()
225
torch.jit.save(scripted_linear, buffer)
227
deserialized_linear = torch.jit.load(buffer)
228
buffer = io.BytesIO()
229
torch.jit.save(scripted_linear_clamp_prepacked, buffer)
231
deserialized_linear_clamp_prepacked = torch.jit.load(buffer)
232
ref_result = deserialized_linear(input_data)
233
output_linearprepacked = deserialized_linear_clamp_prepacked(input_data)
234
torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3)
236
@given(batch_size=st.integers(0, 3),
237
input_channels_per_group=st.integers(1, 32),
238
height=st.integers(5, 64),
239
width=st.integers(5, 64),
240
output_channels_per_group=st.integers(1, 32),
241
groups=st.integers(1, 16),
242
kernel_h=st.integers(1, 7),
243
kernel_w=st.integers(1, 7),
244
stride_h=st.integers(1, 2),
245
stride_w=st.integers(1, 2),
246
pad_h=st.integers(0, 2),
247
pad_w=st.integers(0, 2),
248
dilation=st.integers(1, 2),
249
use_bias=st.booleans(),
250
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
251
def test_conv2d(self,
253
input_channels_per_group,
256
output_channels_per_group,
267
class Conv2D(torch.nn.Module):
268
def __init__(self, weight, bias, strides, paddings, dilations, groups):
272
self.strides = strides
273
self.paddings = paddings
274
self.dilations = dilations
277
def forward(self, x):
278
return F.conv2d(x, self.weight, self.bias,
279
self.strides, self.paddings, self.dilations, self.groups)
281
class Conv2DPrePacked(torch.nn.Module):
282
def __init__(self, weight, bias, strides, paddings, dilations, groups):
284
self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(weight, bias,
285
strides, paddings, dilations, groups)
287
def forward(self, x):
288
return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias)
290
input_channels = input_channels_per_group * groups
291
output_channels = output_channels_per_group * groups
292
kernels = (kernel_h, kernel_w)
293
strides = (stride_h, stride_w)
294
paddings = (pad_h, pad_w)
295
dilations = (dilation, dilation)
296
assume(height + 2 * paddings[0] >=
297
dilations[0] * (kernels[0] - 1) + 1)
298
assume(width + 2 * paddings[1] >=
299
dilations[1] * (kernels[1] - 1) + 1)
301
input_data = torch.rand((batch_size, input_channels, height, width))
302
if (format is not None):
303
input_data = input_data.contiguous(memory_format=format)
304
weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
307
bias = torch.rand(output_channels)
309
scripted_conv2d = torch.jit.script(Conv2D(weight, bias,
310
strides, paddings, dilations, groups))
311
scripted_conv2d_clamp_prepacked = torch.jit.script(Conv2DPrePacked(
312
weight, bias, strides, paddings, dilations, groups))
313
ref_result = scripted_conv2d(input_data)
314
xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
315
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
317
# Serialize the modules and then deserialize
318
input_data = torch.rand((batch_size, input_channels, height, width))
319
if (format is not None):
320
input_data = input_data.contiguous(memory_format=format)
321
buffer = io.BytesIO()
322
torch.jit.save(scripted_conv2d, buffer)
324
deserialized_conv2d = torch.jit.load(buffer)
325
buffer = io.BytesIO()
326
torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
328
deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
329
ref_result = deserialized_conv2d(input_data)
330
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
331
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
333
@given(batch_size=st.integers(0, 3),
334
input_channels_per_group=st.integers(1, 32),
335
height=st.integers(5, 64),
336
width=st.integers(5, 64),
337
output_channels_per_group=st.integers(1, 32),
338
groups=st.integers(1, 16),
339
kernel_h=st.integers(1, 7),
340
kernel_w=st.integers(1, 7),
341
stride_h=st.integers(1, 2),
342
stride_w=st.integers(1, 2),
343
pad_h=st.integers(0, 2),
344
pad_w=st.integers(0, 2),
345
output_pad_h=st.integers(0, 2),
346
output_pad_w=st.integers(0, 2),
347
dilation=st.integers(1, 2),
348
use_bias=st.booleans(),
349
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
350
def test_conv2d_transpose(self,
352
input_channels_per_group,
355
output_channels_per_group,
368
class Conv2DT(torch.nn.Module):
369
def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups):
373
self.strides = strides
374
self.paddings = paddings
375
self.output_paddings = output_paddings
376
self.dilations = dilations
379
def forward(self, x):
380
return F.conv_transpose2d(x, self.weight, self.bias,
381
self.strides, self.paddings, self.output_paddings, self.groups, self.dilations)
383
class Conv2DTPrePacked(torch.nn.Module):
384
def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups):
386
self.packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias,
391
def forward(self, x):
392
return torch.ops.prepacked.conv2d_transpose_clamp_run(x, self.packed_weight_bias)
394
input_channels = input_channels_per_group * groups
395
output_channels = output_channels_per_group * groups
396
kernels = (kernel_h, kernel_w)
397
strides = (stride_h, stride_w)
398
paddings = (pad_h, pad_w)
399
output_paddings = (output_pad_h, output_pad_w)
400
dilations = (dilation, dilation)
401
assume(height + 2 * paddings[0] >=
402
dilations[0] * (kernels[0] - 1) + 1)
403
assume(width + 2 * paddings[1] >=
404
dilations[1] * (kernels[1] - 1) + 1)
405
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
406
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
408
input_data = torch.rand((batch_size, input_channels, height, width))
409
if (format is not None):
410
input_data = input_data.contiguous(memory_format=format)
411
weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w))
414
bias = torch.rand(output_channels)
416
scripted_conv2d = torch.jit.script(Conv2DT(weight, bias,
418
output_paddings, dilations, groups))
419
scripted_conv2d_clamp_prepacked = torch.jit.script(Conv2DTPrePacked(
420
weight, bias, strides, paddings, output_paddings, dilations, groups))
421
ref_result = scripted_conv2d(input_data)
422
xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
423
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
425
# Serialize the modules and then deserialize
426
input_data = torch.rand((batch_size, input_channels, height, width))
427
if (format is not None):
428
input_data = input_data.contiguous(memory_format=format)
429
buffer = io.BytesIO()
430
torch.jit.save(scripted_conv2d, buffer)
432
deserialized_conv2d = torch.jit.load(buffer)
433
buffer = io.BytesIO()
434
torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
436
deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
437
ref_result = deserialized_conv2d(input_data)
438
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
439
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
441
@unittest.skip("Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488")
442
@given(batch_size=st.integers(0, 3),
443
input_channels_per_group=st.integers(1, 32),
444
height=st.integers(5, 64),
445
width=st.integers(5, 64),
446
output_channels_per_group=st.integers(1, 32),
447
groups=st.integers(1, 16),
448
kernel_h=st.integers(1, 7),
449
kernel_w=st.integers(1, 7),
450
stride_h=st.integers(1, 2),
451
stride_w=st.integers(1, 2),
452
pad_h=st.integers(0, 2),
453
pad_w=st.integers(0, 2),
454
dilation=st.integers(1, 2),
455
linear_weight_output_dim=st.integers(2, 64),
456
use_bias=st.booleans(),
457
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
458
def test_combined_model(self,
460
input_channels_per_group,
463
output_channels_per_group,
472
linear_weight_output_dim,
475
class M(torch.nn.Module):
476
def __init__(self, conv_weight, conv_bias, linear_weight, linear_bias,
477
strides, paddings, dilations, groups):
479
self.conv_weight = conv_weight
480
self.conv_bias = conv_bias
481
self.linear_weight = linear_weight
482
self.linear_bias = linear_bias
483
self.strides = strides
484
self.paddings = paddings
485
self.dilations = dilations
488
def forward(self, x):
489
o = F.conv2d(x, self.conv_weight, self.conv_bias,
490
self.strides, self.paddings, self.dilations, self.groups)
491
o = o.permute([0, 2, 3, 1])
492
o = F.linear(o, self.linear_weight, self.linear_bias)
495
class MPrePacked(torch.nn.Module):
496
def __init__(self, conv_weight, conv_bias, linear_weight, linear_bias,
497
strides, paddings, dilations, groups):
499
self.conv2d_clamp_run_weight_bias = \
500
torch.ops.prepacked.conv2d_clamp_prepack(conv_weight, conv_bias,
501
strides, paddings, dilations, groups)
502
self.linear_clamp_run_weight_bias = \
503
torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias)
505
def forward(self, x):
506
o = torch.ops.prepacked.conv2d_clamp_run(x, self.conv2d_clamp_run_weight_bias)
507
o = o.permute([0, 2, 3, 1])
508
o = torch.ops.prepacked.linear_clamp_run(o, self.linear_clamp_run_weight_bias)
511
input_channels = input_channels_per_group * groups
512
output_channels = output_channels_per_group * groups
513
kernels = (kernel_h, kernel_w)
514
strides = (stride_h, stride_w)
515
paddings = (pad_h, pad_w)
516
dilations = (dilation, dilation)
517
assume(height + 2 * paddings[0]
518
>= dilations[0] * (kernels[0] - 1) + 1)
519
assume(width + 2 * paddings[1]
520
>= dilations[1] * (kernels[1] - 1) + 1)
522
input_data = torch.rand((batch_size, input_channels, height, width))
523
if (format is not None):
524
input_data = input_data.contiguous(memory_format=format)
525
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
528
conv_bias = torch.rand(output_channels)
530
# This is done just to find the output shape of the result
531
# so that the shape of weight for the following linear layer
533
result = F.conv2d(input_data, conv_weight, conv_bias,
534
strides, paddings, dilations, groups)
535
linear_input_shape = result.shape[1]
537
linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
540
linear_bias = torch.rand(linear_weight_output_dim)
542
scripted_m = torch.jit.script(M(conv_weight, conv_bias, linear_weight,
543
linear_bias, strides, paddings, dilations, groups))
544
scripted_m_prepacked = torch.jit.script(
554
ref_result = scripted_m(input_data)
555
xnnpack_result = scripted_m_prepacked(input_data)
556
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
558
# Serialize the modules and then deserialize
559
input_data = torch.rand((batch_size, input_channels, height, width))
560
input_data = input_data.contiguous(memory_format=torch.channels_last)
561
buffer = io.BytesIO()
562
torch.jit.save(scripted_m, buffer)
564
deserialized_m = torch.jit.load(buffer)
565
buffer = io.BytesIO()
566
torch.jit.save(scripted_m_prepacked, buffer)
568
deserialized_m_prepacked = torch.jit.load(buffer)
569
ref_result = deserialized_m(input_data)
570
xnnpack_result = deserialized_m_prepacked(input_data)
571
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
574
@unittest.skipUnless(torch.backends.xnnpack.enabled,
575
" XNNPACK must be enabled for these tests."
576
" Please build with USE_XNNPACK=1.")
577
@unittest.skipIf(TEST_WITH_TSAN, "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.")
578
class TestXNNPACKRewritePass(TestCase):
580
def validate_transformed_module(
585
prepack_removal=False,
586
fuse_clamping_ops=False):
587
input_data = torch.normal(1, 20, size=data_shape)
589
for jit_method in ["script", "trace"]:
590
module_instance = self
591
if jit_method == "script":
592
scripted_model = torch.jit.script(module_instance)
594
scripted_model = torch.jit.trace(module_instance, input_data)
595
scripted_model.eval()
596
ref_result = scripted_model(input_data)
597
torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
598
if fuse_clamping_ops or prepack_removal:
599
scripted_model._c = torch._C._freeze_module(scripted_model._c)
600
if fuse_clamping_ops:
601
torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
602
if (prepack_removal):
603
torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
605
buffer = io.BytesIO()
606
torch.jit.save(scripted_model, buffer)
608
deserialized_scripted_model = torch.jit.load(buffer)
609
for pattern, v in pattern_count_map.items():
611
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
613
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
615
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
616
xnnpack_result = deserialized_scripted_model(input_data)
617
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
619
def test_linear(self):
620
data_shape = [2, 3, 32]
621
weight_output_dim = 24
622
weight_shape = (weight_output_dim, data_shape[-1])
624
class Linear(torch.nn.Module):
627
self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
628
self.bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
630
def forward(self, x):
631
return F.linear(x, self.weight, self.bias)
633
class LinearNoBias(torch.nn.Module):
636
self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
638
def forward(self, x):
639
return F.linear(x, self.weight, None)
641
# Linear with bias pattern.
642
pattern_count_map = {"Tensor = prim::CallFunction": -1,
643
"prepacked::linear_clamp_prepack": 1,
644
"prepacked::linear_clamp_run": 1}
645
TestXNNPACKRewritePass.validate_transformed_module(Linear(), pattern_count_map, data_shape)
646
TestXNNPACKRewritePass.validate_transformed_module(LinearNoBias(), pattern_count_map, data_shape)
650
input_channels_per_group = 6
653
output_channels_per_group = 6
655
kernel_h = kernel_w = 3
656
stride_h = stride_w = 1
658
output_pad_h = output_pad_w = 0
660
input_channels = input_channels_per_group * groups
661
output_channels = output_channels_per_group * groups
662
kernels = (kernel_h, kernel_w)
663
strides = (stride_h, stride_w)
664
paddings = (pad_h, pad_w)
665
output_paddings = (output_pad_h, output_pad_w)
666
dilations = (dilation, dilation)
667
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
668
conv_transpose_weight_shape = (input_channels, output_channels_per_group, kernel_h, kernel_w)
669
conv_bias_shape = (output_channels)
671
class Conv2D(torch.nn.Module):
674
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
675
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
676
self.strides = strides
677
self.paddings = paddings
678
self.dilations = dilations
681
def forward(self, x):
682
return F.conv2d(x, self.weight, self.bias,
683
self.strides, self.paddings, self.dilations, self.groups)
685
class Conv2DT(torch.nn.Module):
688
self.weight = torch.nn.Parameter(torch.rand(conv_transpose_weight_shape), requires_grad=False)
689
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
690
self.strides = strides
691
self.paddings = paddings
692
self.output_paddings = output_paddings
693
self.dilations = dilations
696
def forward(self, x):
697
return F.conv_transpose2d(x, self.weight, self.bias,
698
self.strides, self.paddings, self.output_paddings, self.groups, self.dilations)
701
data_shape = (batch_size, input_channels, height, width)
702
pattern_count_map = {"Tensor = aten::conv2d": -1,
703
"prepacked::conv2d_clamp_prepack": 1,
704
"prepacked::conv2d_clamp_run": 1}
705
TestXNNPACKRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
707
transpose_data_shape = (batch_size, input_channels, height, width)
708
transpose_pattern_count_map = {"Tensor = aten::conv_transpose2d": -1,
709
"prepacked::conv2d_transpose_clamp_prepack": 1,
710
"prepacked::conv2d_transpose_clamp_run": 1}
711
TestXNNPACKRewritePass.validate_transformed_module(Conv2DT(), transpose_pattern_count_map, data_shape)
713
input_data = torch.rand((batch_size, input_channels, height, width))
714
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
715
conv_bias = torch.rand(output_channels)
716
result = F.conv2d(input_data, conv_weight, conv_bias,
717
strides, paddings, dilations, groups)
718
linear_input_shape = result.shape[1]
719
linear_weight_shape = (weight_output_dim, linear_input_shape)
721
class M(torch.nn.Module):
722
def __init__(self, activation_fn=F.relu):
724
self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
725
self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
726
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape), requires_grad=False)
727
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
728
self.strides = strides
729
self.paddings = paddings
730
self.dilations = dilations
732
self.activation_fn = activation_fn
734
def forward(self, x):
735
o = F.conv2d(x, self.conv_weight, self.conv_bias,
736
self.strides, self.paddings, self.dilations, self.groups)
737
o = self.activation_fn(o)
738
o = o.permute([0, 2, 3, 1])
739
o = F.linear(o, self.linear_weight, self.linear_bias)
740
return self.activation_fn(o)
742
pattern_count_map = {"Tensor = aten::conv2d": -1,
743
"prepacked::conv2d_clamp_prepack": 1,
744
"prepacked::conv2d_clamp_run": 1,
745
"prepacked::linear_clamp_prepack": 1,
746
"prepacked::linear_clamp_run": 1}
747
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape)
748
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
749
pattern_count_map["Tensor = prim::CallFunction"] = -1
750
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
751
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
753
# Not inplace relu fusion test.
754
pattern_count_map = {"aten::relu": 2,
755
"prepacked::conv2d_clamp_prepack": -1,
756
"prepacked::conv2d_clamp_run": 1,
757
"prepacked::linear_clamp_prepack": -1,
758
"prepacked::linear_clamp_run": 1}
759
TestXNNPACKRewritePass.validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
760
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
761
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
762
pattern_count_map["aten::relu"] = -1
763
TestXNNPACKRewritePass.validate_transformed_module(
767
prepack_removal=True,
768
fuse_clamping_ops=True)
770
# Inplace relu fusion test.
771
pattern_count_map = {"aten::relu": 2,
772
"prepacked::conv2d_clamp_prepack": -1,
773
"prepacked::conv2d_clamp_run": 1,
774
"prepacked::linear_clamp_prepack": -1,
775
"prepacked::linear_clamp_run": 1}
776
TestXNNPACKRewritePass.validate_transformed_module(
780
prepack_removal=True)
781
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
782
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
783
pattern_count_map["aten::relu"] = -1
784
TestXNNPACKRewritePass.validate_transformed_module(
788
prepack_removal=True,
789
fuse_clamping_ops=True)
791
# Not inplace hardtanh fusion test.
792
pattern_count_map = {"aten::hardtanh": 2,
793
"prepacked::conv2d_clamp_prepack": -1,
794
"prepacked::conv2d_clamp_run": 1,
795
"prepacked::linear_clamp_prepack": -1,
796
"prepacked::linear_clamp_run": 1}
797
TestXNNPACKRewritePass.validate_transformed_module(
801
prepack_removal=True)
802
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
803
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
804
pattern_count_map["aten::hardtanh"] = -1
805
TestXNNPACKRewritePass.validate_transformed_module(
809
prepack_removal=True,
810
fuse_clamping_ops=True)
812
# Inplace hardtanh fusion test.
813
pattern_count_map = {"aten::hardtanh_": 2,
814
"prepacked::conv2d_clamp_prepack": -1,
815
"prepacked::conv2d_clamp_run": 1,
816
"prepacked::linear_clamp_prepack": -1,
817
"prepacked::linear_clamp_run": 1}
818
TestXNNPACKRewritePass.validate_transformed_module(
822
prepack_removal=True)
823
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
824
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
825
pattern_count_map["aten::hardtanh_"] = -1
826
TestXNNPACKRewritePass.validate_transformed_module(
830
prepack_removal=True,
831
fuse_clamping_ops=True)
833
class MFusionAntiPattern(torch.nn.Module):
836
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape), requires_grad=False)
837
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
838
self.strides = strides
839
self.paddings = paddings
840
self.dilations = dilations
843
def forward(self, x):
844
o = F.linear(x, self.linear_weight, self.linear_bias)
849
# Unfusable hardtanh.
850
pattern_count_map = {"aten::hardtanh": 1, # hardtanh cannot be.
851
"aten::relu": -1, # relu is fused.
852
"prepacked::linear_clamp_prepack": -1,
853
"prepacked::linear_clamp_run": 1}
854
TestXNNPACKRewritePass.validate_transformed_module(
855
MFusionAntiPattern(),
857
(16, linear_weight_shape[1]),
858
prepack_removal=True,
859
fuse_clamping_ops=True)
861
class MFusionAntiPatternParamMinMax(torch.nn.Module):
864
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape), requires_grad=False)
865
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
866
self.strides = strides
867
self.paddings = paddings
868
self.dilations = dilations
871
def forward(self, x):
874
o = F.linear(x, self.linear_weight, self.linear_bias)
875
o = F.hardtanh(o, min, max)
878
# Unfusable hardtanh.
879
pattern_count_map = {"aten::hardtanh": 1, # hardtanh cannot be.
880
"prepacked::linear_clamp_prepack": -1,
881
"prepacked::linear_clamp_run": 1}
882
TestXNNPACKRewritePass.validate_transformed_module(
883
MFusionAntiPatternParamMinMax(),
885
(16, linear_weight_shape[1]),
886
prepack_removal=True,
887
fuse_clamping_ops=True)
889
def test_decomposed_linear(self):
891
weight_output_dim = 24
892
weight_shape = (weight_output_dim, data_shape[-1])
894
class DecomposedLinearAddmm(torch.nn.Module):
897
self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
898
self.bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
900
def forward(self, x):
901
weight_t = self.weight.t()
902
return torch.addmm(self.bias, x, weight_t)
904
class DecomposedLinearMatmulAdd(torch.nn.Module):
907
self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
908
self.bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
910
def forward(self, x):
911
weight_t = self.weight.t()
912
y = torch.matmul(x, weight_t)
913
res = y.add_(self.bias)
916
class DecomposedLinearMatmul(torch.nn.Module):
919
self.weight = torch.nn.Parameter(torch.rand(weight_shape), requires_grad=False)
920
self.bias = torch.nn.Parameter(torch.rand(weight_output_dim), requires_grad=False)
922
def forward(self, x):
923
weight_t = self.weight.t()
924
res = torch.matmul(x, weight_t)
927
# Linear with bias pattern.
928
pattern_count_map = {"Tensor = prim::CallFunction": -1,
929
"prepacked::linear_clamp_prepack": 1,
930
"prepacked::linear_clamp_run": 1}
931
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearAddmm(), pattern_count_map, data_shape)
932
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape)
933
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape)
935
@unittest.skipUnless(torch.backends.xnnpack.enabled,
936
" XNNPACK must be enabled for these tests."
937
" Please build with USE_XNNPACK=1.")
938
@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment")
939
class TestXNNPACKConv1dTransformPass(TestCase):
941
def validate_transform_conv1d_to_conv2d(
943
pattern_count_transformed_map,
944
pattern_count_optimized_map,
946
input_data = torch.normal(1, 20, size=data_shape)
948
for jit_method in ["script", "trace"]:
949
module_instance = self
950
if jit_method == "script":
951
scripted_model = torch.jit.script(module_instance)
953
scripted_model = torch.jit.trace(module_instance, input_data)
954
scripted_model.eval()
955
ref_result = scripted_model(input_data)
956
torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
957
optimized_scripted_model = optimize_for_mobile(scripted_model)
959
buffer = io.BytesIO()
960
torch.jit.save(scripted_model, buffer)
962
deserialized_scripted_model = torch.jit.load(buffer)
964
for pattern, v in pattern_count_transformed_map.items():
966
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
968
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
970
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
971
transformed_result = deserialized_scripted_model(input_data)
972
torch.testing.assert_close(ref_result, transformed_result, rtol=1e-2, atol=1e-3)
974
optimized_buffer = io.BytesIO()
975
torch.jit.save(optimized_scripted_model, optimized_buffer)
976
optimized_buffer.seek(0)
977
deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
979
for pattern, v in pattern_count_optimized_map.items():
981
FileCheck().check(pattern).run(deserialized_optimized_scripted_model.graph)
983
FileCheck().check_not(pattern).run(deserialized_optimized_scripted_model.graph)
985
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_optimized_scripted_model.graph)
986
xnnpack_result = deserialized_optimized_scripted_model(input_data)
987
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
990
@unittest.skipIf(IS_FBCODE, "T137513244")
991
def test_conv1d_basic(self):
992
batch_size_list = range(1, 3)
993
input_channels_per_group_list = range(10, 12)
994
width_list = range(10, 12)
995
output_channels_per_group_list = range(10, 12)
996
groups_list = range(1, 3)
997
kernel_list = range(1, 4)
998
stride_list = range(1, 3)
999
padding_list = range(0, 3)
1000
dilation_list = range(1, 3)
1002
for hparams in itertools.product(batch_size_list,
1003
input_channels_per_group_list,
1005
output_channels_per_group_list,
1011
batch_size, input_channels_per_group, width, output_channels_per_group, \
1012
groups, kernel, stride, padding, dilation = hparams
1014
input_channels = input_channels_per_group * groups
1015
output_channels = output_channels_per_group * groups
1016
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1017
conv_bias_shape = (output_channels)
1019
class Conv1D(torch.nn.Module):
1022
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
1023
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
1024
self.stride = stride
1025
self.padding = padding
1026
self.dilation = dilation
1027
self.groups = groups
1029
def forward(self, x):
1030
return F.conv1d(x, self.weight, self.bias,
1031
self.stride, self.padding, self.dilation, self.groups)
1033
data_shape = (batch_size, input_channels, width)
1034
pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
1035
"Tensor = aten::conv2d": 1}
1036
pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
1037
"Tensor = aten::conv2d": -1,
1038
"prepacked::conv2d_clamp_prepack" : -1,
1039
"prepacked::conv2d_clamp_run": 1}
1041
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Conv1D(),
1042
pattern_count_transformed_map,
1043
pattern_count_optimized_map,
1046
# See https://github.com/pytorch/pytorch/issues/46066
1048
def test_conv1d_with_relu_fc(self):
1049
batch_size_list = range(1, 3)
1050
input_channels_per_group_list = range(10, 12)
1051
width_list = range(10, 12)
1052
output_channels_per_group_list = range(10, 12)
1053
groups_list = range(1, 3)
1054
kernel_list = range(1, 4)
1055
stride_list = range(1, 3)
1056
padding_list = range(0, 3)
1057
dilation_list = range(1, 3)
1058
output_features_list = range(1, 3)
1060
for hparams in itertools.product(batch_size_list,
1061
input_channels_per_group_list,
1063
output_channels_per_group_list,
1069
output_features_list):
1070
batch_size, input_channels_per_group, width, output_channels_per_group, \
1071
groups, kernel, stride, padding, dilation, output_features = hparams
1073
input_channels = input_channels_per_group * groups
1074
output_channels = output_channels_per_group * groups
1075
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1076
conv_bias_shape = (output_channels)
1077
conv_output_width = int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
1078
fc_weight_shape = (output_features, output_channels * conv_output_width)
1079
fc_bias_shape = (output_features)
1081
class Net(torch.nn.Module):
1084
self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
1085
self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
1086
self.stride = stride
1087
self.padding = padding
1088
self.dilation = dilation
1089
self.groups = groups
1091
self.fc_weight = torch.nn.Parameter(torch.rand(fc_weight_shape), requires_grad=False)
1092
self.fc_bias = torch.nn.Parameter(torch.rand(fc_bias_shape), requires_grad=False)
1094
def forward(self, x):
1095
x = F.conv1d(x, self.conv_weight, self.conv_bias,
1096
self.stride, self.padding, self.dilation, self.groups)
1098
x = x.view(x.size(0), -1)
1099
x = F.linear(x, self.fc_weight, self.fc_bias)
1102
data_shape = (batch_size, input_channels, width)
1103
pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
1104
"Tensor = aten::conv2d": 1}
1105
pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
1106
"Tensor = aten::conv2d": -1,
1107
"prepacked::conv2d_clamp_prepack" : -1,
1108
"prepacked::conv2d_clamp_run": 1}
1109
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Net(),
1110
pattern_count_transformed_map,
1111
pattern_count_optimized_map,
1114
if __name__ == "__main__":