pytorch

Форк
0
/
test_mobile_optimizer.py 
619 строк · 26.2 Кб
1
# Owner(s): ["oncall: mobile"]
2

3
import unittest
4
import torch
5
import torch.nn as nn
6
import torch.utils.bundled_inputs
7
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK
8
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
9
from torch.utils.mobile_optimizer import (LintCode,
10
                                          generate_mobile_module_lints,
11
                                          optimize_for_mobile,
12
                                          MobileOptimizerType)
13
from torch.nn import functional as F
14
from torch.testing._internal.common_quantized import override_quantized_engine
15

16
try:
17
    import torchvision
18
    HAS_TORCHVISION = True
19
except ImportError:
20
    HAS_TORCHVISION = False
21

22
FileCheck = torch._C.FileCheck
23

24
class TestOptimizer(TestCase):
25

26
    @skipIfNoXNNPACK
27
    def test_optimize_for_mobile(self):
28
        batch_size = 2
29
        input_channels_per_group = 6
30
        height = 16
31
        width = 16
32
        output_channels_per_group = 6
33
        groups = 4
34
        kernel_h = kernel_w = 3
35
        stride_h = stride_w = 1
36
        pad_h = pad_w = 1
37
        dilation = 1
38
        input_channels = input_channels_per_group * groups
39
        output_channels = output_channels_per_group * groups
40
        kernels = (kernel_h, kernel_w)
41
        strides = (stride_h, stride_w)
42
        paddings = (pad_h, pad_w)
43
        dilations = (dilation, dilation)
44
        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
45
        conv_bias_shape = (output_channels)
46

47
        input_data = torch.rand((batch_size, input_channels, height, width))
48
        conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
49
        conv_bias = torch.rand(output_channels)
50
        result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
51
        weight_output_dim = 24
52
        linear_input_shape = result.shape[1]
53
        linear_weight_shape = (weight_output_dim, linear_input_shape)
54

55
        class MyTestModule(torch.nn.Module):
56
            def __init__(self) -> None:
57
                super().__init__()
58
                self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape))
59
                self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape))
60
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
61
                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
62
                self.strides = strides
63
                self.paddings = paddings
64
                self.dilations = dilations
65
                self.groups = groups
66

67
            def forward(self, x):
68
                o = F.conv2d(x, self.conv_weight, self.conv_bias,
69
                             self.strides, self.paddings, self.dilations, self.groups)
70
                o = F.relu(o)
71
                x = o.permute([0, 2, 3, 1])
72
                o = F.linear(x, self.linear_weight, self.linear_bias)
73
                o = o + x
74
                return F.relu(o)
75

76
            @torch.jit.export
77
            def foo(self, x):
78
                o = F.conv2d(x, self.conv_weight, self.conv_bias,
79
                             self.strides, self.paddings, self.dilations, self.groups)
80
                o = F.relu(o)
81
                x = o.permute([0, 2, 3, 1])
82
                o = F.linear(x, self.linear_weight, self.linear_bias)
83
                o = o + x
84
                return F.relu(o)
85

86

87
        class BNTestModule(torch.nn.Module):
88
            def __init__(self) -> None:
89
                super().__init__()
90
                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
91
                self.bn = torch.nn.BatchNorm2d(num_features=20)
92
                self.bn.eps = 0.0023
93

94
            def forward(self, x):
95
                x = self.conv(x)
96
                x = self.bn(x)
97
                return x
98

99
        data_shape = (batch_size, input_channels, height, width)
100
        input_data = torch.normal(1, 20, size=data_shape)
101

102
        scripted_model = torch.jit.script(MyTestModule())
103
        scripted_model.eval()
104
        initial_result = scripted_model(input_data)
105
        initial_foo_result = scripted_model.foo(input_data)
106

107
        optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo'])
108
        optimized_result = optimized_scripted_model(input_data)
109
        optimized_foo_result = optimized_scripted_model.foo(input_data)
110

111
        FileCheck().check_not("Tensor = aten::conv2d") \
112
                   .check_not("Tensor = prim::CallFunction") \
113
                   .check_not("prepacked::conv2d_clamp_prepack") \
114
                   .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
115
                   .check_not("prepacked::linear_clamp_prepack") \
116
                   .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
117
                   .check_not("aten::add(") \
118
                   .check_not("aten::relu(") \
119
                   .check_count("aten::_add_relu(", 1, exactly=True) \
120
                   .run(optimized_scripted_model.graph)
121
        torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
122

123
        FileCheck().check_not("Tensor = aten::conv2d") \
124
                   .check_not("Tensor = prim::CallFunction") \
125
                   .check_not("prepacked::conv2d_clamp_prepack") \
126
                   .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
127
                   .check_not("prepacked::linear_clamp_prepack") \
128
                   .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
129
                   .check_not("aten::add(") \
130
                   .check_not("aten::relu(") \
131
                   .check_count("aten::_add_relu(", 1, exactly=True) \
132
                   .run(optimized_scripted_model.foo.graph)
133
        torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
134

135

136
        optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
137
        optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
138
        optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
139

140
        FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
141
                   .check_not("prepacked::linear_clamp_run") \
142
                   .check_not("prepacked::conv2d_clamp_run") \
143
                   .run(optimized_scripted_model_no_prepack.graph)
144
        torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
145

146

147
        bn_test_module = BNTestModule()
148
        bn_scripted_module = torch.jit.script(bn_test_module)
149
        bn_scripted_module.eval()
150

151
        self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
152
        FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \
153
                   .run(str(get_forward(bn_scripted_module._c).graph))
154

155
        optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
156
        bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
157
        self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
158
        bn_input = torch.rand(1, 1, 6, 6)
159
        torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
160

161
        optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
162
        no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
163
        FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
164
                   .run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
165
        bn_input = torch.rand(1, 1, 6, 6)
166
        torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
167

168
        class MyMobileOptimizedTagTest(torch.nn.Module):
169
            def __init__(self) -> None:
170
                super().__init__()
171
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
172
                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
173

174
            def forward(self, x):
175
                o = F.linear(x, self.linear_weight, self.linear_bias)
176
                return F.relu(o)
177

178
        mobile_optimized_tag_module = MyMobileOptimizedTagTest()
179
        m = torch.jit.script(mobile_optimized_tag_module)
180
        m.eval()
181
        opt_m = optimize_for_mobile(m)
182
        tag = getattr(opt_m, "mobile_optimized", None)
183
        self.assertTrue(tag)
184

185
        class MyPreserveMethodsTest(torch.nn.Module):
186
            def __init__(self) -> None:
187
                super().__init__()
188
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
189
                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
190

191
            def forward(self, x):
192
                o = F.linear(x, self.linear_weight, self.linear_bias)
193
                return F.relu(o)
194

195
            @torch.jit.export
196
            def preserveThis(self):
197
                pass
198

199
        preserve_method_module = MyPreserveMethodsTest()
200
        m = torch.jit.script(preserve_method_module)
201
        m.eval()
202
        opt_m = optimize_for_mobile(m)
203
        no_preserveThis = getattr(opt_m, "preserveThis", None)
204
        self.assertEqual(no_preserveThis, None)
205
        opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
206
        preserveThis = getattr(opt_m, "preserveThis", None)
207
        self.assertNotEqual(preserveThis, None)
208

209
        class OptimizeNoForwardTest(torch.nn.Module):
210
            def __init__(self) -> None:
211
                super().__init__()
212
                self.l = nn.Linear(10, 100)
213
                self.l2 = nn.Linear(100, 1)
214
                self.d = nn.Dropout(p=0.2)
215

216
            @torch.jit.export
217
            def foo(self, x):
218
                x = self.d(F.relu(self.l(x)))
219
                x = self.l2(x)
220
                x = x + torch.ones(1, 100)
221
                return F.relu(x)
222
        input_data = torch.ones(1, 10)
223
        m = torch.jit.script(OptimizeNoForwardTest())
224
        m.eval()
225
        initial_result = m.foo(input_data)
226

227
        optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo'])
228
        optimized_result = optimized_scripted_model.foo(input_data)
229

230
        FileCheck().check_not("dropout.__") \
231
            .check_count("aten::_add_relu(", 1, exactly=True) \
232
            .run(optimized_scripted_model.foo.graph)
233
        torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
234

235
        class BNTestNoForwardModule(torch.nn.Module):
236
            def __init__(self) -> None:
237
                super().__init__()
238
                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
239
                self.bn = torch.nn.BatchNorm2d(num_features=20)
240
                self.bn.eps = 0.0023
241

242
            @torch.jit.export
243
            def foo(self, x):
244
                x = self.conv(x)
245
                x = self.bn(x)
246
                return x
247

248
        bn_test_no_forward_module = BNTestNoForwardModule()
249
        bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
250
        bn_no_forward_scripted_module.eval()
251

252
        self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
253
        FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \
254
                   .run(bn_no_forward_scripted_module.foo.graph)
255

256
        bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
257
        self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1)
258
        bn_input = torch.rand(1, 1, 6, 6)
259
        torch.testing.assert_close(
260
            bn_no_forward_scripted_module.foo(bn_input),
261
            bn_fold_no_forward_scripted_module.foo(bn_input),
262
            rtol=1e-2,
263
            atol=1e-3)
264

265
    @skipIfNoXNNPACK
266
    def test_quantized_conv_no_asan_failures(self):
267
        # There were ASAN failures when fold_conv_bn was run on
268
        # already quantized conv modules. Verifying that this does
269
        # not happen again.
270

271
        if 'qnnpack' not in torch.backends.quantized.supported_engines:
272
            return
273

274
        class Child(nn.Module):
275
            def __init__(self) -> None:
276
                super().__init__()
277
                self.conv2 = nn.Conv2d(1, 1, 1)
278

279
            def forward(self, x):
280
                x = self.conv2(x)
281
                return x
282

283
        class Parent(nn.Module):
284
            def __init__(self) -> None:
285
                super().__init__()
286
                self.quant = torch.ao.quantization.QuantStub()
287
                self.conv1 = nn.Conv2d(1, 1, 1)
288
                self.child = Child()
289
                self.dequant = torch.ao.quantization.DeQuantStub()
290

291
            def forward(self, x):
292
                x = self.quant(x)
293
                x = self.conv1(x)
294
                x = self.child(x)
295
                x = self.dequant(x)
296
                return x
297

298
        with override_quantized_engine('qnnpack'):
299
            model = Parent()
300
            model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
301
            torch.ao.quantization.prepare(model, inplace=True)
302
            model(torch.randn(4, 1, 4, 4))
303
            torch.ao.quantization.convert(model, inplace=True)
304
            model = torch.jit.script(model)
305
            # this line should not have ASAN failures
306
            model_optim = optimize_for_mobile(model)
307

308
    def test_generate_mobile_module_lints(self):
309
        class MyTestModule(torch.nn.Module):
310
            def __init__(self) -> None:
311
                super().__init__()
312
                self.fc = torch.nn.Linear(4, 4)
313
                self.dropout = torch.nn.Dropout(p=0.5)
314

315
            def forward(self, inputs):
316
                out = self.fc(inputs)
317
                out = self.dropout(out)
318
                return out
319

320
        class MyBNModule(torch.nn.Module):
321
            def __init__(self) -> None:
322
                super().__init__()
323
                self.bn = torch.nn.BatchNorm2d(4, affine=True)
324

325
            def forward(self, inputs):
326
                bn = self.bn(inputs)
327
                return bn
328

329
        class MyBundledInputModule(torch.nn.Module):
330
            def forward(self, inputs):
331
                return inputs
332

333
        def get_lint_count_by_type(lint_type, module_lint_List):
334
            return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
335

336
        test_module = torch.jit.script(MyTestModule())
337
        test_module_lint_list = generate_mobile_module_lints(test_module)
338
        self.assertEqual(len(test_module_lint_list), 4)
339
        self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
340
        self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
341
        self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
342

343
        bn_module = torch.jit.script(MyBNModule())
344
        bn_module_lint_list = generate_mobile_module_lints(bn_module)
345
        self.assertEqual(len(bn_module_lint_list), 4)
346
        self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
347
        self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
348
        self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
349

350
        bi_module = torch.jit.script(MyBundledInputModule())
351
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
352
            bi_module, [(torch.tensor([1]),)], [])
353
        bi_module_lint_list = generate_mobile_module_lints(bi_module)
354
        self.assertEqual(len(bi_module_lint_list), 0)
355

356
    @skipIfNoXNNPACK
357
    def test_preserve_bundled_inputs_methods(self):
358
        class MyBundledInputModule(torch.nn.Module):
359
            def forward(self, inputs):
360
                return inputs
361

362
        class MyIncompleteBundledInputModule(torch.nn.Module):
363
            def forward(self, inputs):
364
                return inputs
365

366
            @torch.jit.export
367
            def get_all_bundled_inputs(self):
368
                pass
369

370
        bi_module = torch.jit.script(MyBundledInputModule())
371
        module_optim_bi_not_preserved = optimize_for_mobile(bi_module)
372

373
        # Expected to be False since no bundled inputs methods were added
374
        self.assertFalse(
375
            hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
376
            hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs')
377
        )
378

379
        # Add bundled inputs methods to the module
380
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
381
            bi_module, [(torch.tensor([1]),)], [])
382
        # Now they should be preserved
383
        module_optim_bi_preserved = optimize_for_mobile(bi_module)
384

385
        # All of the bundled inputs methods were preserved
386
        self.assertTrue(
387
            hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and
388
            hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs')
389
        )
390

391
        bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
392
        module_optim_bi_preserved(*bundled_input)
393

394
        # If not all 3 bundled inputs methods are present in the module,
395
        # we will not try to preserve them unless specified by the user.
396
        incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule())
397
        incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
398
        self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
399

400
        # Specifically preserve get_all_bundled_inputs even if it's the only one
401
        # bundled inputs method available.
402
        incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
403
        self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
404

405
    @skipIfNoXNNPACK
406
    def test_hoist_conv_packed_params(self):
407

408
        if 'qnnpack' not in torch.backends.quantized.supported_engines:
409
            return
410

411
        class Standalone(nn.Module):
412
            def __init__(self) -> None:
413
                super().__init__()
414
                self.quant = torch.ao.quantization.QuantStub()
415
                self.conv1 = nn.Conv2d(1, 1, 1)
416
                self.conv2 = nn.Conv2d(1, 1, 1)
417
                self.relu = nn.ReLU()
418
                self.dequant = torch.ao.quantization.DeQuantStub()
419

420
            def forward(self, x):
421
                x = self.quant(x)
422
                x = self.conv1(x)
423
                x = self.conv2(x)
424
                x = self.relu(x)
425
                x = self.dequant(x)
426
                return x
427

428
            def fuse_model(self):
429
                torch.ao.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True)
430

431
        class Child(nn.Module):
432
            def __init__(self) -> None:
433
                super().__init__()
434
                self.conv1 = nn.Conv2d(1, 1, 1)
435

436
            def forward(self, x):
437
                x = self.conv1(x)
438
                return x
439

440
        class Parent(nn.Module):
441
            def __init__(self) -> None:
442
                super().__init__()
443
                self.quant = torch.ao.quantization.QuantStub()
444
                self.conv1 = nn.Conv2d(1, 1, 1)
445
                self.child = Child()
446
                # TODO: test nn.Sequential after #42039 is fixed
447
                self.dequant = torch.ao.quantization.DeQuantStub()
448

449
            def forward(self, x):
450
                x = self.quant(x)
451
                x = self.conv1(x)
452
                x = self.child(x)
453
                x = self.dequant(x)
454
                return x
455

456
            def fuse_model(self):
457
                pass
458

459
        with override_quantized_engine('qnnpack'):
460
            def _quant_script_and_optimize(model):
461
                model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
462
                model.fuse_model()
463
                torch.ao.quantization.prepare(model, inplace=True)
464
                model(torch.randn(4, 1, 4, 4))
465
                torch.ao.quantization.convert(model, inplace=True)
466
                model = torch.jit.script(model)
467
                model_optim = optimize_for_mobile(model)
468
                return model, model_optim
469

470
            # basic case
471

472
            m, m_optim = _quant_script_and_optimize(Standalone())
473
            FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \
474
                       .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
475
                       .run(m_optim.graph)
476
            self.assertFalse(hasattr(m_optim, "conv1"))
477
            self.assertFalse(hasattr(m_optim, "conv2"))
478

479
            data = torch.randn(4, 1, 4, 4)
480
            m_res = m(data)
481
            m_optim_res = m_optim(data)
482
            torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
483

484
            # generic case
485

486
            m, m_optim = _quant_script_and_optimize(Parent())
487
            FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \
488
                       .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
489
                       .run(m_optim.graph)
490
            self.assertFalse(hasattr(m_optim, "conv1"))
491
            self.assertFalse(hasattr(m_optim, "child"))
492

493
            data = torch.randn(4, 1, 4, 4)
494
            m_res = m(data)
495
            m_optim_res = m_optim(data)
496
            torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
497

498
    @skipIfNoXNNPACK
499
    @unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
500
    def test_mobilenet_optimize_for_mobile(self):
501
        m = torchvision.models.mobilenet_v3_small()
502
        m = torch.jit.script(m)
503
        m = optimize_for_mobile(m)
504

505
        # run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463
506
        x = torch.zeros(1, 3, 56, 56)
507
        self.assertEqual(m(x).numel(), 1000)
508
        self.assertEqual(m(x).numel(), 1000)
509
        self.assertEqual(m(x).numel(), 1000)
510

511
    def test_clone_module_with_class(self):
512
        class MyInnerTestModule(torch.nn.Module):
513
            def __init__(self) -> None:
514
                super().__init__()
515
                self.pqr = torch.Tensor([10., 20., 30.])
516

517
            def forward(self, inputs):
518
                return inputs
519

520
            @torch.jit.export
521
            def dummy_method_not_cloned(self):
522
                return 20
523

524
        class MyTestModule(torch.nn.Module):
525
            def __init__(self) -> None:
526
                super().__init__()
527
                self.abc = 23
528
                self.pqr = torch.Tensor([1., 2., 3.])
529
                self.inner = MyInnerTestModule()
530

531
            def forward(self, inputs):
532
                x = self.dummy_method_cloned()
533
                # The call to self.inner.dummy_method_not_cloned should not raise an error
534
                y = self.inner.dummy_method_not_cloned()
535
                # The call to self.inner.pqr should not raise an error
536
                z = self.inner.pqr
537
                return (inputs, x, y, z)
538

539
            @torch.jit.export
540
            def dummy_method_not_cloned2(self):
541
                # The call to self.inner.dummy_method_not_cloned should not raise an error
542
                y = self.inner.dummy_method_not_cloned()
543
                # The call to self.inner.pqr should not raise an error
544
                z = self.inner.pqr
545
                return self.pqr, self.dummy_method_not_cloned(), y, z
546

547
            @torch.jit.export
548
            def dummy_method_not_cloned(self):
549
                return None
550

551
            @torch.jit.export
552
            def dummy_method_cloned(self):
553
                return None
554

555
            @torch.jit.export
556
            def dummy_method_ref_attr_pqr(self):
557
                return self.pqr, self.inner.pqr
558

559
        m = torch.jit.script(MyTestModule())
560

561
        # Check that the methods exist on the original model.
562
        self.assertEqual(hasattr(m, "dummy_method_not_cloned"), True)
563
        self.assertEqual(hasattr(m, "dummy_method_cloned"), True)
564
        self.assertEqual(hasattr(m, "dummy_method_not_cloned2"), True)
565
        self.assertEqual(hasattr(m, "pqr"), True)
566

567
        # Case-1: Successfully clone, ignoring 2 methods, keeping all attributes.
568
        cloned = torch._C._hack_do_not_use_clone_module_with_class(
569
            m._c,
570
            ["dummy_method_not_cloned", "dummy_method_not_cloned2"],  # ignored_methods
571
            [],  # ignored_attributes
572
        )
573

574
        # Check that the ignored methods don't exist on the cloned model.
575
        self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
576
        self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
577
        self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
578
        self.assertEqual(hasattr(cloned, "pqr"), True)
579

580
        # Check that the cloned class has a classname that starts with __torch__.
581
        self.assertTrue(
582
            cloned.qualified_name.startswith('__torch__.'),
583
            ("Expected the cloned module's name to start with the string "
584
             f"'__torch__.', but got: {cloned.qualified_name}"),
585
        )
586

587

588
        # Case-2: Successfully clone the module, ignoring the attribute pqr, and the method that references it.
589
        cloned = torch._C._hack_do_not_use_clone_module_with_class(
590
            m._c,
591
            ["dummy_method_not_cloned", "dummy_method_not_cloned2", "dummy_method_ref_attr_pqr"],
592
            ["pqr"],
593
        )
594

595
        # Check that the ignored methods don't exist on the cloned model.
596
        self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
597
        self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
598
        self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
599
        self.assertEqual(hasattr(cloned, "dummy_method_ref_attr_pqr"), False)
600
        self.assertEqual(hasattr(cloned, "pqr"), False)
601

602

603
        # Case-3: The statement below will throw since dummy_method_cloned2 is preserved,
604
        # and references dummy_method_not_cloned, which is not cloned.
605
        with self.assertRaises(RuntimeError):
606
            cloned = torch._C._hack_do_not_use_clone_module_with_class(m._c, ["dummy_method_not_cloned"], [])
607

608
        # Case-4: The statement below will throw since dummy_method_ref_attr_pqr
609
        # is preserved, and references "pqr", which is not cloned.
610
        with self.assertRaises(RuntimeError):
611
            cloned = torch._C._hack_do_not_use_clone_module_with_class(
612
                m._c,
613
                ["dummy_method_not_cloned", "dummy_method_not_cloned2"],
614
                ["pqr"],
615
            )
616

617

618
if __name__ == '__main__':
619
    run_tests()
620

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.