7
from hypothesis import assume, given, strategies as st
10
import torch.backends.xnnpack
11
import torch.testing._internal.hypothesis_utils as hu
12
from torch.nn import functional as F
13
from torch.testing import FileCheck
14
from torch.testing._internal.common_utils import (
21
from torch.utils.mobile_optimizer import optimize_for_mobile
25
torch.backends.xnnpack.enabled,
26
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
30
"TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
32
class TestXNNPACKOps(TestCase):
34
"Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
37
batch_size=st.integers(0, 3),
38
data_shape=hu.array_shapes(1, 3, 2, 64),
39
weight_output_dim=st.integers(2, 64),
40
use_bias=st.booleans(),
42
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
43
data_shape = [batch_size] + list(data_shape)
44
input_data = torch.rand(data_shape)
45
weight = torch.rand((weight_output_dim, data_shape[-1]))
47
bias = torch.rand(weight_output_dim)
50
ref_result = F.linear(input_data, weight, bias)
51
packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
52
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
53
input_data, packed_weight_bias
55
torch.testing.assert_close(
56
ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
60
input_size=st.integers(2, 32),
61
weight_output_dim=st.integers(2, 64),
62
use_bias=st.booleans(),
64
def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
65
input_data = torch.rand(input_size)
66
weight = torch.rand((weight_output_dim, input_data.shape[-1]))
68
bias = torch.rand(weight_output_dim)
71
ref_result = F.linear(input_data, weight, bias)
72
packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
73
output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
74
input_data, packed_weight_bias
76
torch.testing.assert_close(
77
ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
81
batch_size=st.integers(0, 3),
82
input_channels_per_group=st.integers(1, 32),
83
height=st.integers(5, 64),
84
width=st.integers(5, 64),
85
output_channels_per_group=st.integers(1, 32),
86
groups=st.integers(1, 16),
87
kernel_h=st.integers(1, 7),
88
kernel_w=st.integers(1, 7),
89
stride_h=st.integers(1, 2),
90
stride_w=st.integers(1, 2),
91
pad_h=st.integers(0, 2),
92
pad_w=st.integers(0, 2),
93
dilation=st.integers(1, 2),
94
use_bias=st.booleans(),
95
format=st.sampled_from(
96
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
102
input_channels_per_group,
105
output_channels_per_group,
117
input_channels = input_channels_per_group * groups
118
output_channels = output_channels_per_group * groups
119
kernels = (kernel_h, kernel_w)
120
strides = (stride_h, stride_w)
121
paddings = (pad_h, pad_w)
122
dilations = (dilation, dilation)
123
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
124
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
126
input_data = torch.rand((batch_size, input_channels, height, width))
127
if format is not None:
128
input_data = input_data.contiguous(memory_format=format)
130
(output_channels, input_channels_per_group, kernel_h, kernel_w)
134
bias = torch.rand(output_channels)
136
ref_result = F.conv2d(
137
input_data, weight, bias, strides, paddings, dilations, groups
139
packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
140
weight, bias, strides, paddings, dilations, groups
142
xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(
143
input_data, packed_weight_bias
145
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
148
batch_size=st.integers(1, 3),
149
input_channels_per_group=st.integers(1, 32),
150
height=st.integers(5, 64),
151
width=st.integers(5, 64),
152
output_channels_per_group=st.integers(1, 32),
153
groups=st.integers(1, 16),
154
kernel_h=st.integers(1, 7),
155
kernel_w=st.integers(1, 7),
156
stride_h=st.integers(1, 2),
157
stride_w=st.integers(1, 2),
158
pad_h=st.integers(0, 2),
159
pad_w=st.integers(0, 2),
160
output_pad_h=st.integers(0, 2),
161
output_pad_w=st.integers(0, 2),
162
dilation=st.integers(1, 2),
163
use_bias=st.booleans(),
164
format=st.sampled_from(
165
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
168
def test_conv2d_transpose(
171
input_channels_per_group,
174
output_channels_per_group,
188
input_channels = input_channels_per_group * groups
189
output_channels = output_channels_per_group * groups
190
kernels = (kernel_h, kernel_w)
191
strides = (stride_h, stride_w)
192
paddings = (pad_h, pad_w)
193
output_paddings = (output_pad_h, output_pad_w)
194
dilations = (dilation, dilation)
195
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
196
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
197
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
198
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
200
input_data = torch.rand((batch_size, input_channels, height, width))
201
if format is not None:
202
input_data = input_data.contiguous(memory_format=format)
204
(input_channels, output_channels_per_group, kernel_h, kernel_w)
208
bias = torch.rand(output_channels)
211
ref_result = F.conv_transpose2d(
221
packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(
222
weight, bias, strides, paddings, output_paddings, dilations, groups
224
xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(
225
input_data, packed_weight_bias
227
torch.testing.assert_close(
228
ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3
233
torch.backends.xnnpack.enabled,
234
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
238
"TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
240
class TestXNNPACKSerDes(TestCase):
242
"Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
245
batch_size=st.integers(0, 3),
246
data_shape=hu.array_shapes(1, 3, 2, 64),
247
weight_output_dim=st.integers(2, 64),
248
use_bias=st.booleans(),
250
def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
251
class Linear(torch.nn.Module):
252
def __init__(self, weight, bias=None):
257
def forward(self, x):
258
return F.linear(x, self.weight, self.bias)
260
class LinearPrePacked(torch.nn.Module):
261
def __init__(self, weight, bias=None):
263
self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(
267
def forward(self, x):
268
return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
270
data_shape = [batch_size] + list(data_shape)
271
weight = torch.rand((weight_output_dim, data_shape[-1]))
273
bias = torch.rand(weight_output_dim)
276
scripted_linear = torch.jit.script(Linear(weight, bias))
277
scripted_linear_clamp_prepacked = torch.jit.script(
278
LinearPrePacked(weight, bias)
280
input_data = torch.rand(data_shape)
281
ref_result = scripted_linear(input_data)
282
output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
283
torch.testing.assert_close(
284
ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
288
input_data = torch.rand(data_shape)
289
buffer = io.BytesIO()
290
torch.jit.save(scripted_linear, buffer)
292
deserialized_linear = torch.jit.load(buffer)
293
buffer = io.BytesIO()
294
torch.jit.save(scripted_linear_clamp_prepacked, buffer)
296
deserialized_linear_clamp_prepacked = torch.jit.load(buffer)
297
ref_result = deserialized_linear(input_data)
298
output_linearprepacked = deserialized_linear_clamp_prepacked(input_data)
299
torch.testing.assert_close(
300
ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
304
batch_size=st.integers(0, 3),
305
input_channels_per_group=st.integers(1, 32),
306
height=st.integers(5, 64),
307
width=st.integers(5, 64),
308
output_channels_per_group=st.integers(1, 32),
309
groups=st.integers(1, 16),
310
kernel_h=st.integers(1, 7),
311
kernel_w=st.integers(1, 7),
312
stride_h=st.integers(1, 2),
313
stride_w=st.integers(1, 2),
314
pad_h=st.integers(0, 2),
315
pad_w=st.integers(0, 2),
316
dilation=st.integers(1, 2),
317
use_bias=st.booleans(),
318
format=st.sampled_from(
319
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
325
input_channels_per_group,
328
output_channels_per_group,
340
class Conv2D(torch.nn.Module):
341
def __init__(self, weight, bias, strides, paddings, dilations, groups):
345
self.strides = strides
346
self.paddings = paddings
347
self.dilations = dilations
350
def forward(self, x):
361
class Conv2DPrePacked(torch.nn.Module):
362
def __init__(self, weight, bias, strides, paddings, dilations, groups):
364
self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
365
weight, bias, strides, paddings, dilations, groups
368
def forward(self, x):
369
return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias)
371
input_channels = input_channels_per_group * groups
372
output_channels = output_channels_per_group * groups
373
kernels = (kernel_h, kernel_w)
374
strides = (stride_h, stride_w)
375
paddings = (pad_h, pad_w)
376
dilations = (dilation, dilation)
377
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
378
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
380
input_data = torch.rand((batch_size, input_channels, height, width))
381
if format is not None:
382
input_data = input_data.contiguous(memory_format=format)
384
(output_channels, input_channels_per_group, kernel_h, kernel_w)
388
bias = torch.rand(output_channels)
390
scripted_conv2d = torch.jit.script(
391
Conv2D(weight, bias, strides, paddings, dilations, groups)
393
scripted_conv2d_clamp_prepacked = torch.jit.script(
394
Conv2DPrePacked(weight, bias, strides, paddings, dilations, groups)
396
ref_result = scripted_conv2d(input_data)
397
xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
398
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
401
input_data = torch.rand((batch_size, input_channels, height, width))
402
if format is not None:
403
input_data = input_data.contiguous(memory_format=format)
404
buffer = io.BytesIO()
405
torch.jit.save(scripted_conv2d, buffer)
407
deserialized_conv2d = torch.jit.load(buffer)
408
buffer = io.BytesIO()
409
torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
411
deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
412
ref_result = deserialized_conv2d(input_data)
413
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
414
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
417
batch_size=st.integers(0, 3),
418
input_channels_per_group=st.integers(1, 32),
419
height=st.integers(5, 64),
420
width=st.integers(5, 64),
421
output_channels_per_group=st.integers(1, 32),
422
groups=st.integers(1, 16),
423
kernel_h=st.integers(1, 7),
424
kernel_w=st.integers(1, 7),
425
stride_h=st.integers(1, 2),
426
stride_w=st.integers(1, 2),
427
pad_h=st.integers(0, 2),
428
pad_w=st.integers(0, 2),
429
output_pad_h=st.integers(0, 2),
430
output_pad_w=st.integers(0, 2),
431
dilation=st.integers(1, 2),
432
use_bias=st.booleans(),
433
format=st.sampled_from(
434
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
437
def test_conv2d_transpose(
440
input_channels_per_group,
443
output_channels_per_group,
457
class Conv2DT(torch.nn.Module):
471
self.strides = strides
472
self.paddings = paddings
473
self.output_paddings = output_paddings
474
self.dilations = dilations
477
def forward(self, x):
478
return F.conv_transpose2d(
484
self.output_paddings,
489
class Conv2DTPrePacked(torch.nn.Module):
501
self.packed_weight_bias = (
502
torch.ops.prepacked.conv2d_transpose_clamp_prepack(
513
def forward(self, x):
514
return torch.ops.prepacked.conv2d_transpose_clamp_run(
515
x, self.packed_weight_bias
518
input_channels = input_channels_per_group * groups
519
output_channels = output_channels_per_group * groups
520
kernels = (kernel_h, kernel_w)
521
strides = (stride_h, stride_w)
522
paddings = (pad_h, pad_w)
523
output_paddings = (output_pad_h, output_pad_w)
524
dilations = (dilation, dilation)
525
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
526
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
527
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
528
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
530
input_data = torch.rand((batch_size, input_channels, height, width))
531
if format is not None:
532
input_data = input_data.contiguous(memory_format=format)
534
(input_channels, output_channels_per_group, kernel_h, kernel_w)
538
bias = torch.rand(output_channels)
540
scripted_conv2d = torch.jit.script(
541
Conv2DT(weight, bias, strides, paddings, output_paddings, dilations, groups)
543
scripted_conv2d_clamp_prepacked = torch.jit.script(
545
weight, bias, strides, paddings, output_paddings, dilations, groups
548
ref_result = scripted_conv2d(input_data)
549
xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
550
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
553
input_data = torch.rand((batch_size, input_channels, height, width))
554
if format is not None:
555
input_data = input_data.contiguous(memory_format=format)
556
buffer = io.BytesIO()
557
torch.jit.save(scripted_conv2d, buffer)
559
deserialized_conv2d = torch.jit.load(buffer)
560
buffer = io.BytesIO()
561
torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
563
deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
564
ref_result = deserialized_conv2d(input_data)
565
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
566
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
569
"Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
572
batch_size=st.integers(0, 3),
573
input_channels_per_group=st.integers(1, 32),
574
height=st.integers(5, 64),
575
width=st.integers(5, 64),
576
output_channels_per_group=st.integers(1, 32),
577
groups=st.integers(1, 16),
578
kernel_h=st.integers(1, 7),
579
kernel_w=st.integers(1, 7),
580
stride_h=st.integers(1, 2),
581
stride_w=st.integers(1, 2),
582
pad_h=st.integers(0, 2),
583
pad_w=st.integers(0, 2),
584
dilation=st.integers(1, 2),
585
linear_weight_output_dim=st.integers(2, 64),
586
use_bias=st.booleans(),
587
format=st.sampled_from(
588
[None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
591
def test_combined_model(
594
input_channels_per_group,
597
output_channels_per_group,
606
linear_weight_output_dim,
610
class M(torch.nn.Module):
623
self.conv_weight = conv_weight
624
self.conv_bias = conv_bias
625
self.linear_weight = linear_weight
626
self.linear_bias = linear_bias
627
self.strides = strides
628
self.paddings = paddings
629
self.dilations = dilations
632
def forward(self, x):
642
o = o.permute([0, 2, 3, 1])
643
o = F.linear(o, self.linear_weight, self.linear_bias)
646
class MPrePacked(torch.nn.Module):
659
self.conv2d_clamp_run_weight_bias = (
660
torch.ops.prepacked.conv2d_clamp_prepack(
661
conv_weight, conv_bias, strides, paddings, dilations, groups
664
self.linear_clamp_run_weight_bias = (
665
torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias)
668
def forward(self, x):
669
o = torch.ops.prepacked.conv2d_clamp_run(
670
x, self.conv2d_clamp_run_weight_bias
672
o = o.permute([0, 2, 3, 1])
673
o = torch.ops.prepacked.linear_clamp_run(
674
o, self.linear_clamp_run_weight_bias
678
input_channels = input_channels_per_group * groups
679
output_channels = output_channels_per_group * groups
680
kernels = (kernel_h, kernel_w)
681
strides = (stride_h, stride_w)
682
paddings = (pad_h, pad_w)
683
dilations = (dilation, dilation)
684
assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
685
assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
687
input_data = torch.rand((batch_size, input_channels, height, width))
688
if format is not None:
689
input_data = input_data.contiguous(memory_format=format)
690
conv_weight = torch.rand(
691
(output_channels, input_channels_per_group, kernel_h, kernel_w)
695
conv_bias = torch.rand(output_channels)
701
input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
703
linear_input_shape = result.shape[1]
705
linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
708
linear_bias = torch.rand(linear_weight_output_dim)
710
scripted_m = torch.jit.script(
722
scripted_m_prepacked = torch.jit.script(
734
ref_result = scripted_m(input_data)
735
xnnpack_result = scripted_m_prepacked(input_data)
736
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
739
input_data = torch.rand((batch_size, input_channels, height, width))
740
input_data = input_data.contiguous(memory_format=torch.channels_last)
741
buffer = io.BytesIO()
742
torch.jit.save(scripted_m, buffer)
744
deserialized_m = torch.jit.load(buffer)
745
buffer = io.BytesIO()
746
torch.jit.save(scripted_m_prepacked, buffer)
748
deserialized_m_prepacked = torch.jit.load(buffer)
749
ref_result = deserialized_m(input_data)
750
xnnpack_result = deserialized_m_prepacked(input_data)
751
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
755
torch.backends.xnnpack.enabled,
756
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
760
"TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
762
class TestXNNPACKRewritePass(TestCase):
764
def validate_transformed_module(
769
prepack_removal=False,
770
fuse_clamping_ops=False,
772
input_data = torch.normal(1, 20, size=data_shape)
774
for jit_method in ["script", "trace"]:
775
module_instance = self
776
if jit_method == "script":
777
scripted_model = torch.jit.script(module_instance)
779
scripted_model = torch.jit.trace(module_instance, input_data)
780
scripted_model.eval()
781
ref_result = scripted_model(input_data)
782
torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
783
if fuse_clamping_ops or prepack_removal:
784
scripted_model._c = torch._C._freeze_module(scripted_model._c)
785
if fuse_clamping_ops:
786
torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
788
torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
790
buffer = io.BytesIO()
791
torch.jit.save(scripted_model, buffer)
793
deserialized_scripted_model = torch.jit.load(buffer)
794
for pattern, v in pattern_count_map.items():
796
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
798
FileCheck().check_not(pattern).run(
799
deserialized_scripted_model.graph
802
FileCheck().check_count(pattern, v, exactly=True).run(
803
deserialized_scripted_model.graph
805
xnnpack_result = deserialized_scripted_model(input_data)
806
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
808
def test_linear(self):
809
data_shape = [2, 3, 32]
810
weight_output_dim = 24
811
weight_shape = (weight_output_dim, data_shape[-1])
813
class Linear(torch.nn.Module):
814
def __init__(self) -> None:
816
self.weight = torch.nn.Parameter(
817
torch.rand(weight_shape), requires_grad=False
819
self.bias = torch.nn.Parameter(
820
torch.rand(weight_output_dim), requires_grad=False
823
def forward(self, x):
824
return F.linear(x, self.weight, self.bias)
826
class LinearNoBias(torch.nn.Module):
827
def __init__(self) -> None:
829
self.weight = torch.nn.Parameter(
830
torch.rand(weight_shape), requires_grad=False
833
def forward(self, x):
834
return F.linear(x, self.weight, None)
837
pattern_count_map = {
838
"Tensor = prim::CallFunction": -1,
839
"prepacked::linear_clamp_prepack": 1,
840
"prepacked::linear_clamp_run": 1,
842
TestXNNPACKRewritePass.validate_transformed_module(
843
Linear(), pattern_count_map, data_shape
845
TestXNNPACKRewritePass.validate_transformed_module(
846
LinearNoBias(), pattern_count_map, data_shape
851
input_channels_per_group = 6
854
output_channels_per_group = 6
856
kernel_h = kernel_w = 3
857
stride_h = stride_w = 1
859
output_pad_h = output_pad_w = 0
861
input_channels = input_channels_per_group * groups
862
output_channels = output_channels_per_group * groups
863
kernels = (kernel_h, kernel_w)
864
strides = (stride_h, stride_w)
865
paddings = (pad_h, pad_w)
866
output_paddings = (output_pad_h, output_pad_w)
867
dilations = (dilation, dilation)
868
conv_weight_shape = (
870
input_channels_per_group,
874
conv_transpose_weight_shape = (
876
output_channels_per_group,
880
conv_bias_shape = output_channels
882
class Conv2D(torch.nn.Module):
883
def __init__(self) -> None:
885
self.weight = torch.nn.Parameter(
886
torch.rand(conv_weight_shape), requires_grad=False
888
self.bias = torch.nn.Parameter(
889
torch.rand(conv_bias_shape), requires_grad=False
891
self.strides = strides
892
self.paddings = paddings
893
self.dilations = dilations
896
def forward(self, x):
907
class Conv2DT(torch.nn.Module):
908
def __init__(self) -> None:
910
self.weight = torch.nn.Parameter(
911
torch.rand(conv_transpose_weight_shape), requires_grad=False
913
self.bias = torch.nn.Parameter(
914
torch.rand(conv_bias_shape), requires_grad=False
916
self.strides = strides
917
self.paddings = paddings
918
self.output_paddings = output_paddings
919
self.dilations = dilations
922
def forward(self, x):
923
return F.conv_transpose2d(
929
self.output_paddings,
934
data_shape = (batch_size, input_channels, height, width)
935
pattern_count_map = {
936
"Tensor = aten::conv2d": -1,
937
"prepacked::conv2d_clamp_prepack": 1,
938
"prepacked::conv2d_clamp_run": 1,
940
TestXNNPACKRewritePass.validate_transformed_module(
941
Conv2D(), pattern_count_map, data_shape
944
transpose_data_shape = (batch_size, input_channels, height, width)
945
transpose_pattern_count_map = {
946
"Tensor = aten::conv_transpose2d": -1,
947
"prepacked::conv2d_transpose_clamp_prepack": 1,
948
"prepacked::conv2d_transpose_clamp_run": 1,
950
TestXNNPACKRewritePass.validate_transformed_module(
951
Conv2DT(), transpose_pattern_count_map, data_shape
954
input_data = torch.rand((batch_size, input_channels, height, width))
955
conv_weight = torch.rand(
956
(output_channels, input_channels_per_group, kernel_h, kernel_w)
958
conv_bias = torch.rand(output_channels)
960
input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
962
linear_input_shape = result.shape[1]
963
linear_weight_shape = (weight_output_dim, linear_input_shape)
965
class M(torch.nn.Module):
966
def __init__(self, activation_fn=F.relu):
968
self.conv_weight = torch.nn.Parameter(
969
torch.rand(conv_weight_shape), requires_grad=False
971
self.conv_bias = torch.nn.Parameter(
972
torch.rand(conv_bias_shape), requires_grad=False
974
self.linear_weight = torch.nn.Parameter(
975
torch.rand(linear_weight_shape), requires_grad=False
977
self.linear_bias = torch.nn.Parameter(
978
torch.rand(weight_output_dim), requires_grad=False
980
self.strides = strides
981
self.paddings = paddings
982
self.dilations = dilations
984
self.activation_fn = activation_fn
986
def forward(self, x):
996
o = self.activation_fn(o)
997
o = o.permute([0, 2, 3, 1])
998
o = F.linear(o, self.linear_weight, self.linear_bias)
999
return self.activation_fn(o)
1001
pattern_count_map = {
1002
"Tensor = aten::conv2d": -1,
1003
"prepacked::conv2d_clamp_prepack": 1,
1004
"prepacked::conv2d_clamp_run": 1,
1005
"prepacked::linear_clamp_prepack": 1,
1006
"prepacked::linear_clamp_run": 1,
1008
TestXNNPACKRewritePass.validate_transformed_module(
1009
M(), pattern_count_map, data_shape
1011
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1012
pattern_count_map["Tensor = prim::CallFunction"] = -1
1013
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1014
TestXNNPACKRewritePass.validate_transformed_module(
1015
M(), pattern_count_map, data_shape, prepack_removal=True
1019
pattern_count_map = {
1021
"prepacked::conv2d_clamp_prepack": -1,
1022
"prepacked::conv2d_clamp_run": 1,
1023
"prepacked::linear_clamp_prepack": -1,
1024
"prepacked::linear_clamp_run": 1,
1026
TestXNNPACKRewritePass.validate_transformed_module(
1027
M(), pattern_count_map, data_shape, prepack_removal=True
1029
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1030
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1031
pattern_count_map["aten::relu"] = -1
1032
TestXNNPACKRewritePass.validate_transformed_module(
1036
prepack_removal=True,
1037
fuse_clamping_ops=True,
1041
pattern_count_map = {
1043
"prepacked::conv2d_clamp_prepack": -1,
1044
"prepacked::conv2d_clamp_run": 1,
1045
"prepacked::linear_clamp_prepack": -1,
1046
"prepacked::linear_clamp_run": 1,
1048
TestXNNPACKRewritePass.validate_transformed_module(
1049
M(F.relu_), pattern_count_map, data_shape, prepack_removal=True
1051
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1052
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1053
pattern_count_map["aten::relu"] = -1
1054
TestXNNPACKRewritePass.validate_transformed_module(
1058
prepack_removal=True,
1059
fuse_clamping_ops=True,
1063
pattern_count_map = {
1064
"aten::hardtanh": 2,
1065
"prepacked::conv2d_clamp_prepack": -1,
1066
"prepacked::conv2d_clamp_run": 1,
1067
"prepacked::linear_clamp_prepack": -1,
1068
"prepacked::linear_clamp_run": 1,
1070
TestXNNPACKRewritePass.validate_transformed_module(
1071
M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True
1073
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1074
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1075
pattern_count_map["aten::hardtanh"] = -1
1076
TestXNNPACKRewritePass.validate_transformed_module(
1080
prepack_removal=True,
1081
fuse_clamping_ops=True,
1085
pattern_count_map = {
1086
"aten::hardtanh_": 2,
1087
"prepacked::conv2d_clamp_prepack": -1,
1088
"prepacked::conv2d_clamp_run": 1,
1089
"prepacked::linear_clamp_prepack": -1,
1090
"prepacked::linear_clamp_run": 1,
1092
TestXNNPACKRewritePass.validate_transformed_module(
1093
M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True
1095
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1096
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1097
pattern_count_map["aten::hardtanh_"] = -1
1098
TestXNNPACKRewritePass.validate_transformed_module(
1102
prepack_removal=True,
1103
fuse_clamping_ops=True,
1106
class MFusionAntiPattern(torch.nn.Module):
1107
def __init__(self) -> None:
1109
self.linear_weight = torch.nn.Parameter(
1110
torch.rand(linear_weight_shape), requires_grad=False
1112
self.linear_bias = torch.nn.Parameter(
1113
torch.rand(weight_output_dim), requires_grad=False
1115
self.strides = strides
1116
self.paddings = paddings
1117
self.dilations = dilations
1118
self.groups = groups
1120
def forward(self, x):
1121
o = F.linear(x, self.linear_weight, self.linear_bias)
1127
pattern_count_map = {
1128
"aten::hardtanh": 1,
1130
"prepacked::linear_clamp_prepack": -1,
1131
"prepacked::linear_clamp_run": 1,
1133
TestXNNPACKRewritePass.validate_transformed_module(
1134
MFusionAntiPattern(),
1136
(16, linear_weight_shape[1]),
1137
prepack_removal=True,
1138
fuse_clamping_ops=True,
1141
class MFusionAntiPatternParamMinMax(torch.nn.Module):
1142
def __init__(self) -> None:
1144
self.linear_weight = torch.nn.Parameter(
1145
torch.rand(linear_weight_shape), requires_grad=False
1147
self.linear_bias = torch.nn.Parameter(
1148
torch.rand(weight_output_dim), requires_grad=False
1150
self.strides = strides
1151
self.paddings = paddings
1152
self.dilations = dilations
1153
self.groups = groups
1155
def forward(self, x):
1158
o = F.linear(x, self.linear_weight, self.linear_bias)
1159
o = F.hardtanh(o, min, max)
1163
pattern_count_map = {
1164
"aten::hardtanh": 1,
1165
"prepacked::linear_clamp_prepack": -1,
1166
"prepacked::linear_clamp_run": 1,
1168
TestXNNPACKRewritePass.validate_transformed_module(
1169
MFusionAntiPatternParamMinMax(),
1171
(16, linear_weight_shape[1]),
1172
prepack_removal=True,
1173
fuse_clamping_ops=True,
1176
def test_decomposed_linear(self):
1177
data_shape = [2, 32]
1178
weight_output_dim = 24
1179
weight_shape = (weight_output_dim, data_shape[-1])
1181
class DecomposedLinearAddmm(torch.nn.Module):
1182
def __init__(self) -> None:
1184
self.weight = torch.nn.Parameter(
1185
torch.rand(weight_shape), requires_grad=False
1187
self.bias = torch.nn.Parameter(
1188
torch.rand(weight_output_dim), requires_grad=False
1191
def forward(self, x):
1192
weight_t = self.weight.t()
1193
return torch.addmm(self.bias, x, weight_t)
1195
class DecomposedLinearMatmulAdd(torch.nn.Module):
1196
def __init__(self) -> None:
1198
self.weight = torch.nn.Parameter(
1199
torch.rand(weight_shape), requires_grad=False
1201
self.bias = torch.nn.Parameter(
1202
torch.rand(weight_output_dim), requires_grad=False
1205
def forward(self, x):
1206
weight_t = self.weight.t()
1207
y = torch.matmul(x, weight_t)
1208
res = y.add_(self.bias)
1211
class DecomposedLinearMatmul(torch.nn.Module):
1212
def __init__(self) -> None:
1214
self.weight = torch.nn.Parameter(
1215
torch.rand(weight_shape), requires_grad=False
1217
self.bias = torch.nn.Parameter(
1218
torch.rand(weight_output_dim), requires_grad=False
1221
def forward(self, x):
1222
weight_t = self.weight.t()
1223
res = torch.matmul(x, weight_t)
1227
pattern_count_map = {
1228
"Tensor = prim::CallFunction": -1,
1229
"prepacked::linear_clamp_prepack": 1,
1230
"prepacked::linear_clamp_run": 1,
1232
TestXNNPACKRewritePass.validate_transformed_module(
1233
DecomposedLinearAddmm(), pattern_count_map, data_shape
1235
TestXNNPACKRewritePass.validate_transformed_module(
1236
DecomposedLinearMatmulAdd(), pattern_count_map, data_shape
1238
TestXNNPACKRewritePass.validate_transformed_module(
1239
DecomposedLinearMatmul(), pattern_count_map, data_shape
1243
@unittest.skipUnless(
1244
torch.backends.xnnpack.enabled,
1245
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
1249
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
1251
class TestXNNPACKConv1dTransformPass(TestCase):
1253
def validate_transform_conv1d_to_conv2d(
1254
self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape
1256
input_data = torch.normal(1, 20, size=data_shape)
1258
for jit_method in ["script", "trace"]:
1259
module_instance = self
1260
if jit_method == "script":
1261
scripted_model = torch.jit.script(module_instance)
1263
scripted_model = torch.jit.trace(module_instance, input_data)
1264
scripted_model.eval()
1265
ref_result = scripted_model(input_data)
1266
torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
1267
optimized_scripted_model = optimize_for_mobile(scripted_model)
1269
buffer = io.BytesIO()
1270
torch.jit.save(scripted_model, buffer)
1272
deserialized_scripted_model = torch.jit.load(buffer)
1274
for pattern, v in pattern_count_transformed_map.items():
1276
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
1278
FileCheck().check_not(pattern).run(
1279
deserialized_scripted_model.graph
1282
FileCheck().check_count(pattern, v, exactly=True).run(
1283
deserialized_scripted_model.graph
1285
transformed_result = deserialized_scripted_model(input_data)
1286
torch.testing.assert_close(
1287
ref_result, transformed_result, rtol=1e-2, atol=1e-3
1290
optimized_buffer = io.BytesIO()
1291
torch.jit.save(optimized_scripted_model, optimized_buffer)
1292
optimized_buffer.seek(0)
1293
deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
1295
for pattern, v in pattern_count_optimized_map.items():
1297
FileCheck().check(pattern).run(
1298
deserialized_optimized_scripted_model.graph
1301
FileCheck().check_not(pattern).run(
1302
deserialized_optimized_scripted_model.graph
1305
FileCheck().check_count(pattern, v, exactly=True).run(
1306
deserialized_optimized_scripted_model.graph
1308
xnnpack_result = deserialized_optimized_scripted_model(input_data)
1309
torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
1311
@unittest.skipIf(IS_FBCODE, "T137513244")
1312
def test_conv1d_basic(self):
1313
batch_size_list = range(1, 3)
1314
input_channels_per_group_list = range(10, 12)
1315
width_list = range(10, 12)
1316
output_channels_per_group_list = range(10, 12)
1317
groups_list = range(1, 3)
1318
kernel_list = range(1, 4)
1319
stride_list = range(1, 3)
1320
padding_list = range(0, 3)
1321
dilation_list = range(1, 3)
1323
for hparams in itertools.product(
1325
input_channels_per_group_list,
1327
output_channels_per_group_list,
1336
input_channels_per_group,
1338
output_channels_per_group,
1346
input_channels = input_channels_per_group * groups
1347
output_channels = output_channels_per_group * groups
1348
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1349
conv_bias_shape = output_channels
1351
class Conv1D(torch.nn.Module):
1352
def __init__(self) -> None:
1354
self.weight = torch.nn.Parameter(
1355
torch.rand(conv_weight_shape), requires_grad=False
1357
self.bias = torch.nn.Parameter(
1358
torch.rand(conv_bias_shape), requires_grad=False
1360
self.stride = stride
1361
self.padding = padding
1362
self.dilation = dilation
1363
self.groups = groups
1365
def forward(self, x):
1376
data_shape = (batch_size, input_channels, width)
1377
pattern_count_transformed_map = {
1378
"Tensor = aten::conv1d": -1,
1379
"Tensor = aten::conv2d": 1,
1381
pattern_count_optimized_map = {
1382
"Tensor = aten::conv1d": -1,
1383
"Tensor = aten::conv2d": -1,
1384
"prepacked::conv2d_clamp_prepack": -1,
1385
"prepacked::conv2d_clamp_run": 1,
1388
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
1390
pattern_count_transformed_map,
1391
pattern_count_optimized_map,
1397
def test_conv1d_with_relu_fc(self):
1398
batch_size_list = range(1, 3)
1399
input_channels_per_group_list = range(10, 12)
1400
width_list = range(10, 12)
1401
output_channels_per_group_list = range(10, 12)
1402
groups_list = range(1, 3)
1403
kernel_list = range(1, 4)
1404
stride_list = range(1, 3)
1405
padding_list = range(0, 3)
1406
dilation_list = range(1, 3)
1407
output_features_list = range(1, 3)
1409
for hparams in itertools.product(
1411
input_channels_per_group_list,
1413
output_channels_per_group_list,
1419
output_features_list,
1423
input_channels_per_group,
1425
output_channels_per_group,
1434
input_channels = input_channels_per_group * groups
1435
output_channels = output_channels_per_group * groups
1436
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1437
conv_bias_shape = output_channels
1438
conv_output_width = (
1439
int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
1441
fc_weight_shape = (output_features, output_channels * conv_output_width)
1442
fc_bias_shape = output_features
1444
class Net(torch.nn.Module):
1445
def __init__(self) -> None:
1447
self.conv_weight = torch.nn.Parameter(
1448
torch.rand(conv_weight_shape), requires_grad=False
1450
self.conv_bias = torch.nn.Parameter(
1451
torch.rand(conv_bias_shape), requires_grad=False
1453
self.stride = stride
1454
self.padding = padding
1455
self.dilation = dilation
1456
self.groups = groups
1458
self.fc_weight = torch.nn.Parameter(
1459
torch.rand(fc_weight_shape), requires_grad=False
1461
self.fc_bias = torch.nn.Parameter(
1462
torch.rand(fc_bias_shape), requires_grad=False
1465
def forward(self, x):
1476
x = x.view(x.size(0), -1)
1477
x = F.linear(x, self.fc_weight, self.fc_bias)
1480
data_shape = (batch_size, input_channels, width)
1481
pattern_count_transformed_map = {
1482
"Tensor = aten::conv1d": -1,
1483
"Tensor = aten::conv2d": 1,
1485
pattern_count_optimized_map = {
1486
"Tensor = aten::conv1d": -1,
1487
"Tensor = aten::conv2d": -1,
1488
"prepacked::conv2d_clamp_prepack": -1,
1489
"prepacked::conv2d_clamp_run": 1,
1491
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
1493
pattern_count_transformed_map,
1494
pattern_count_optimized_map,
1499
if __name__ == "__main__":