1
# Owner(s): ["oncall: mobile"]
6
import torch.utils.bundled_inputs
7
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK
8
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
9
from torch.utils.mobile_optimizer import (LintCode,
10
generate_mobile_module_lints,
13
from torch.nn import functional as F
14
from torch.testing._internal.common_quantized import override_quantized_engine
18
HAS_TORCHVISION = True
20
HAS_TORCHVISION = False
22
FileCheck = torch._C.FileCheck
24
class TestOptimizer(TestCase):
27
def test_optimize_for_mobile(self):
29
input_channels_per_group = 6
32
output_channels_per_group = 6
34
kernel_h = kernel_w = 3
35
stride_h = stride_w = 1
38
input_channels = input_channels_per_group * groups
39
output_channels = output_channels_per_group * groups
40
kernels = (kernel_h, kernel_w)
41
strides = (stride_h, stride_w)
42
paddings = (pad_h, pad_w)
43
dilations = (dilation, dilation)
44
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
45
conv_bias_shape = (output_channels)
47
input_data = torch.rand((batch_size, input_channels, height, width))
48
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
49
conv_bias = torch.rand(output_channels)
50
result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
51
weight_output_dim = 24
52
linear_input_shape = result.shape[1]
53
linear_weight_shape = (weight_output_dim, linear_input_shape)
55
class MyTestModule(torch.nn.Module):
56
def __init__(self) -> None:
58
self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape))
59
self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape))
60
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
61
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
62
self.strides = strides
63
self.paddings = paddings
64
self.dilations = dilations
68
o = F.conv2d(x, self.conv_weight, self.conv_bias,
69
self.strides, self.paddings, self.dilations, self.groups)
71
x = o.permute([0, 2, 3, 1])
72
o = F.linear(x, self.linear_weight, self.linear_bias)
78
o = F.conv2d(x, self.conv_weight, self.conv_bias,
79
self.strides, self.paddings, self.dilations, self.groups)
81
x = o.permute([0, 2, 3, 1])
82
o = F.linear(x, self.linear_weight, self.linear_bias)
87
class BNTestModule(torch.nn.Module):
88
def __init__(self) -> None:
90
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
91
self.bn = torch.nn.BatchNorm2d(num_features=20)
99
data_shape = (batch_size, input_channels, height, width)
100
input_data = torch.normal(1, 20, size=data_shape)
102
scripted_model = torch.jit.script(MyTestModule())
103
scripted_model.eval()
104
initial_result = scripted_model(input_data)
105
initial_foo_result = scripted_model.foo(input_data)
107
optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo'])
108
optimized_result = optimized_scripted_model(input_data)
109
optimized_foo_result = optimized_scripted_model.foo(input_data)
111
FileCheck().check_not("Tensor = aten::conv2d") \
112
.check_not("Tensor = prim::CallFunction") \
113
.check_not("prepacked::conv2d_clamp_prepack") \
114
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
115
.check_not("prepacked::linear_clamp_prepack") \
116
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
117
.check_not("aten::add(") \
118
.check_not("aten::relu(") \
119
.check_count("aten::_add_relu(", 1, exactly=True) \
120
.run(optimized_scripted_model.graph)
121
torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
123
FileCheck().check_not("Tensor = aten::conv2d") \
124
.check_not("Tensor = prim::CallFunction") \
125
.check_not("prepacked::conv2d_clamp_prepack") \
126
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
127
.check_not("prepacked::linear_clamp_prepack") \
128
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
129
.check_not("aten::add(") \
130
.check_not("aten::relu(") \
131
.check_count("aten::_add_relu(", 1, exactly=True) \
132
.run(optimized_scripted_model.foo.graph)
133
torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
136
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
137
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
138
optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
140
FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
141
.check_not("prepacked::linear_clamp_run") \
142
.check_not("prepacked::conv2d_clamp_run") \
143
.run(optimized_scripted_model_no_prepack.graph)
144
torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
147
bn_test_module = BNTestModule()
148
bn_scripted_module = torch.jit.script(bn_test_module)
149
bn_scripted_module.eval()
151
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
152
FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \
153
.run(str(get_forward(bn_scripted_module._c).graph))
155
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
156
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
157
self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
158
bn_input = torch.rand(1, 1, 6, 6)
159
torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
161
optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
162
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
163
FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
164
.run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
165
bn_input = torch.rand(1, 1, 6, 6)
166
torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
168
class MyMobileOptimizedTagTest(torch.nn.Module):
169
def __init__(self) -> None:
171
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
172
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
174
def forward(self, x):
175
o = F.linear(x, self.linear_weight, self.linear_bias)
178
mobile_optimized_tag_module = MyMobileOptimizedTagTest()
179
m = torch.jit.script(mobile_optimized_tag_module)
181
opt_m = optimize_for_mobile(m)
182
tag = getattr(opt_m, "mobile_optimized", None)
185
class MyPreserveMethodsTest(torch.nn.Module):
186
def __init__(self) -> None:
188
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
189
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
191
def forward(self, x):
192
o = F.linear(x, self.linear_weight, self.linear_bias)
196
def preserveThis(self):
199
preserve_method_module = MyPreserveMethodsTest()
200
m = torch.jit.script(preserve_method_module)
202
opt_m = optimize_for_mobile(m)
203
no_preserveThis = getattr(opt_m, "preserveThis", None)
204
self.assertEqual(no_preserveThis, None)
205
opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
206
preserveThis = getattr(opt_m, "preserveThis", None)
207
self.assertNotEqual(preserveThis, None)
209
class OptimizeNoForwardTest(torch.nn.Module):
210
def __init__(self) -> None:
212
self.l = nn.Linear(10, 100)
213
self.l2 = nn.Linear(100, 1)
214
self.d = nn.Dropout(p=0.2)
218
x = self.d(F.relu(self.l(x)))
220
x = x + torch.ones(1, 100)
222
input_data = torch.ones(1, 10)
223
m = torch.jit.script(OptimizeNoForwardTest())
225
initial_result = m.foo(input_data)
227
optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo'])
228
optimized_result = optimized_scripted_model.foo(input_data)
230
FileCheck().check_not("dropout.__") \
231
.check_count("aten::_add_relu(", 1, exactly=True) \
232
.run(optimized_scripted_model.foo.graph)
233
torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
235
class BNTestNoForwardModule(torch.nn.Module):
236
def __init__(self) -> None:
238
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
239
self.bn = torch.nn.BatchNorm2d(num_features=20)
248
bn_test_no_forward_module = BNTestNoForwardModule()
249
bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
250
bn_no_forward_scripted_module.eval()
252
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
253
FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \
254
.run(bn_no_forward_scripted_module.foo.graph)
256
bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
257
self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1)
258
bn_input = torch.rand(1, 1, 6, 6)
259
torch.testing.assert_close(
260
bn_no_forward_scripted_module.foo(bn_input),
261
bn_fold_no_forward_scripted_module.foo(bn_input),
266
def test_quantized_conv_no_asan_failures(self):
267
# There were ASAN failures when fold_conv_bn was run on
268
# already quantized conv modules. Verifying that this does
271
if 'qnnpack' not in torch.backends.quantized.supported_engines:
274
class Child(nn.Module):
275
def __init__(self) -> None:
277
self.conv2 = nn.Conv2d(1, 1, 1)
279
def forward(self, x):
283
class Parent(nn.Module):
284
def __init__(self) -> None:
286
self.quant = torch.ao.quantization.QuantStub()
287
self.conv1 = nn.Conv2d(1, 1, 1)
289
self.dequant = torch.ao.quantization.DeQuantStub()
291
def forward(self, x):
298
with override_quantized_engine('qnnpack'):
300
model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
301
torch.ao.quantization.prepare(model, inplace=True)
302
model(torch.randn(4, 1, 4, 4))
303
torch.ao.quantization.convert(model, inplace=True)
304
model = torch.jit.script(model)
305
# this line should not have ASAN failures
306
model_optim = optimize_for_mobile(model)
308
def test_generate_mobile_module_lints(self):
309
class MyTestModule(torch.nn.Module):
310
def __init__(self) -> None:
312
self.fc = torch.nn.Linear(4, 4)
313
self.dropout = torch.nn.Dropout(p=0.5)
315
def forward(self, inputs):
316
out = self.fc(inputs)
317
out = self.dropout(out)
320
class MyBNModule(torch.nn.Module):
321
def __init__(self) -> None:
323
self.bn = torch.nn.BatchNorm2d(4, affine=True)
325
def forward(self, inputs):
329
class MyBundledInputModule(torch.nn.Module):
330
def forward(self, inputs):
333
def get_lint_count_by_type(lint_type, module_lint_List):
334
return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
336
test_module = torch.jit.script(MyTestModule())
337
test_module_lint_list = generate_mobile_module_lints(test_module)
338
self.assertEqual(len(test_module_lint_list), 4)
339
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
340
self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
341
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
343
bn_module = torch.jit.script(MyBNModule())
344
bn_module_lint_list = generate_mobile_module_lints(bn_module)
345
self.assertEqual(len(bn_module_lint_list), 4)
346
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
347
self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
348
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
350
bi_module = torch.jit.script(MyBundledInputModule())
351
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
352
bi_module, [(torch.tensor([1]),)], [])
353
bi_module_lint_list = generate_mobile_module_lints(bi_module)
354
self.assertEqual(len(bi_module_lint_list), 0)
357
def test_preserve_bundled_inputs_methods(self):
358
class MyBundledInputModule(torch.nn.Module):
359
def forward(self, inputs):
362
class MyIncompleteBundledInputModule(torch.nn.Module):
363
def forward(self, inputs):
367
def get_all_bundled_inputs(self):
370
bi_module = torch.jit.script(MyBundledInputModule())
371
module_optim_bi_not_preserved = optimize_for_mobile(bi_module)
373
# Expected to be False since no bundled inputs methods were added
375
hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
376
hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs')
379
# Add bundled inputs methods to the module
380
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
381
bi_module, [(torch.tensor([1]),)], [])
382
# Now they should be preserved
383
module_optim_bi_preserved = optimize_for_mobile(bi_module)
385
# All of the bundled inputs methods were preserved
387
hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and
388
hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs')
391
bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
392
module_optim_bi_preserved(*bundled_input)
394
# If not all 3 bundled inputs methods are present in the module,
395
# we will not try to preserve them unless specified by the user.
396
incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule())
397
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
398
self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
400
# Specifically preserve get_all_bundled_inputs even if it's the only one
401
# bundled inputs method available.
402
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
403
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
406
def test_hoist_conv_packed_params(self):
408
if 'qnnpack' not in torch.backends.quantized.supported_engines:
411
class Standalone(nn.Module):
412
def __init__(self) -> None:
414
self.quant = torch.ao.quantization.QuantStub()
415
self.conv1 = nn.Conv2d(1, 1, 1)
416
self.conv2 = nn.Conv2d(1, 1, 1)
417
self.relu = nn.ReLU()
418
self.dequant = torch.ao.quantization.DeQuantStub()
420
def forward(self, x):
428
def fuse_model(self):
429
torch.ao.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True)
431
class Child(nn.Module):
432
def __init__(self) -> None:
434
self.conv1 = nn.Conv2d(1, 1, 1)
436
def forward(self, x):
440
class Parent(nn.Module):
441
def __init__(self) -> None:
443
self.quant = torch.ao.quantization.QuantStub()
444
self.conv1 = nn.Conv2d(1, 1, 1)
446
# TODO: test nn.Sequential after #42039 is fixed
447
self.dequant = torch.ao.quantization.DeQuantStub()
449
def forward(self, x):
456
def fuse_model(self):
459
with override_quantized_engine('qnnpack'):
460
def _quant_script_and_optimize(model):
461
model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
463
torch.ao.quantization.prepare(model, inplace=True)
464
model(torch.randn(4, 1, 4, 4))
465
torch.ao.quantization.convert(model, inplace=True)
466
model = torch.jit.script(model)
467
model_optim = optimize_for_mobile(model)
468
return model, model_optim
472
m, m_optim = _quant_script_and_optimize(Standalone())
473
FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \
474
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
476
self.assertFalse(hasattr(m_optim, "conv1"))
477
self.assertFalse(hasattr(m_optim, "conv2"))
479
data = torch.randn(4, 1, 4, 4)
481
m_optim_res = m_optim(data)
482
torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
486
m, m_optim = _quant_script_and_optimize(Parent())
487
FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \
488
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
490
self.assertFalse(hasattr(m_optim, "conv1"))
491
self.assertFalse(hasattr(m_optim, "child"))
493
data = torch.randn(4, 1, 4, 4)
495
m_optim_res = m_optim(data)
496
torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
499
@unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
500
def test_mobilenet_optimize_for_mobile(self):
501
m = torchvision.models.mobilenet_v3_small()
502
m = torch.jit.script(m)
503
m = optimize_for_mobile(m)
505
# run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463
506
x = torch.zeros(1, 3, 56, 56)
507
self.assertEqual(m(x).numel(), 1000)
508
self.assertEqual(m(x).numel(), 1000)
509
self.assertEqual(m(x).numel(), 1000)
511
def test_clone_module_with_class(self):
512
class MyInnerTestModule(torch.nn.Module):
513
def __init__(self) -> None:
515
self.pqr = torch.Tensor([10., 20., 30.])
517
def forward(self, inputs):
521
def dummy_method_not_cloned(self):
524
class MyTestModule(torch.nn.Module):
525
def __init__(self) -> None:
528
self.pqr = torch.Tensor([1., 2., 3.])
529
self.inner = MyInnerTestModule()
531
def forward(self, inputs):
532
x = self.dummy_method_cloned()
533
# The call to self.inner.dummy_method_not_cloned should not raise an error
534
y = self.inner.dummy_method_not_cloned()
535
# The call to self.inner.pqr should not raise an error
537
return (inputs, x, y, z)
540
def dummy_method_not_cloned2(self):
541
# The call to self.inner.dummy_method_not_cloned should not raise an error
542
y = self.inner.dummy_method_not_cloned()
543
# The call to self.inner.pqr should not raise an error
545
return self.pqr, self.dummy_method_not_cloned(), y, z
548
def dummy_method_not_cloned(self):
552
def dummy_method_cloned(self):
556
def dummy_method_ref_attr_pqr(self):
557
return self.pqr, self.inner.pqr
559
m = torch.jit.script(MyTestModule())
561
# Check that the methods exist on the original model.
562
self.assertEqual(hasattr(m, "dummy_method_not_cloned"), True)
563
self.assertEqual(hasattr(m, "dummy_method_cloned"), True)
564
self.assertEqual(hasattr(m, "dummy_method_not_cloned2"), True)
565
self.assertEqual(hasattr(m, "pqr"), True)
567
# Case-1: Successfully clone, ignoring 2 methods, keeping all attributes.
568
cloned = torch._C._hack_do_not_use_clone_module_with_class(
570
["dummy_method_not_cloned", "dummy_method_not_cloned2"], # ignored_methods
571
[], # ignored_attributes
574
# Check that the ignored methods don't exist on the cloned model.
575
self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
576
self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
577
self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
578
self.assertEqual(hasattr(cloned, "pqr"), True)
580
# Check that the cloned class has a classname that starts with __torch__.
582
cloned.qualified_name.startswith('__torch__.'),
583
("Expected the cloned module's name to start with the string "
584
f"'__torch__.', but got: {cloned.qualified_name}"),
588
# Case-2: Successfully clone the module, ignoring the attribute pqr, and the method that references it.
589
cloned = torch._C._hack_do_not_use_clone_module_with_class(
591
["dummy_method_not_cloned", "dummy_method_not_cloned2", "dummy_method_ref_attr_pqr"],
595
# Check that the ignored methods don't exist on the cloned model.
596
self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
597
self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
598
self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
599
self.assertEqual(hasattr(cloned, "dummy_method_ref_attr_pqr"), False)
600
self.assertEqual(hasattr(cloned, "pqr"), False)
603
# Case-3: The statement below will throw since dummy_method_cloned2 is preserved,
604
# and references dummy_method_not_cloned, which is not cloned.
605
with self.assertRaises(RuntimeError):
606
cloned = torch._C._hack_do_not_use_clone_module_with_class(m._c, ["dummy_method_not_cloned"], [])
608
# Case-4: The statement below will throw since dummy_method_ref_attr_pqr
609
# is preserved, and references "pqr", which is not cloned.
610
with self.assertRaises(RuntimeError):
611
cloned = torch._C._hack_do_not_use_clone_module_with_class(
613
["dummy_method_not_cloned", "dummy_method_not_cloned2"],
618
if __name__ == '__main__':