pytorch

Форк
0
/
test_backends.py 
824 строки · 28.4 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import io
4
import os
5
import sys
6
import unittest
7

8
import torch
9
import torch._C
10
from torch.jit.mobile import _load_for_lite_interpreter
11
from torch.testing import FileCheck
12
from torch.testing._internal.common_utils import (
13
    find_library_location,
14
    IS_FBCODE,
15
    IS_MACOS,
16
    IS_SANDCASTLE,
17
    IS_WINDOWS,
18
    skipIfRocm,
19
    TEST_WITH_ROCM,
20
)
21
from torch.testing._internal.jit_utils import JitTestCase
22

23

24
# Make the helper files in test/ importable
25
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
26
sys.path.append(pytorch_test_dir)
27

28
if __name__ == "__main__":
29
    raise RuntimeError(
30
        "This test file is not meant to be run directly, use:\n\n"
31
        "\tpython test/test_jit.py TESTNAME\n\n"
32
        "instead."
33
    )
34

35

36
def to_test_backend(module, method_compile_spec):
37
    return torch._C._jit_to_backend(
38
        "test_backend", module, {"forward": method_compile_spec}
39
    )
40

41

42
def to_test_backend_multi(module, method_compile_spec):
43
    return torch._C._jit_to_backend("test_backend", module, method_compile_spec)
44

45

46
def to_test_backend_selective(module, method_compile_spec, submodules):
47
    def _to_test_backend(module):
48
        return to_test_backend(module, method_compile_spec)
49

50
    return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules)
51

52

53
class BasicModule(torch.nn.Module):
54
    """
55
    A simple Module used to test to_backend lowering machinery.
56
    """
57

58
    def forward(self, x, h):
59
        return self.accum(x, h), self.sub_accum(x, h)
60

61
    def accum(self, x, h):
62
        return x + h
63

64
    def sub_accum(self, x, h):
65
        return x - h
66

67

68
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
69
@unittest.skipIf(
70
    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
71
    "Non-portable load_library call used in test",
72
)
73
class JitBackendTestCase(JitTestCase):
74
    """
75
    A common base class for JIT backend tests that contains common utility
76
    functions for output comparison and serialization/deserialization.
77
    """
78

79
    def setUp(self):
80
        super().setUp()
81
        lib_file_path = find_library_location("libjitbackend_test.so")
82
        torch.ops.load_library(str(lib_file_path))
83
        # Subclasses are expected to set up three variables in their setUp methods:
84
        # module - a regular, Python version of the module being tested
85
        # scripted_module - a scripted version of module
86
        # lowered_module - a version of module lowered to a backend
87

88
    def check_function(self, function_name, input):
89
        """
90
        Check that the function named 'function_name' produces the same output using
91
        Python, regular JIT and the backend for the given 'input'.
92
        """
93
        # Get handles for Python, JIT and backend methods.
94
        python_method = self.module.__getattribute__(function_name)
95
        jit_method = self.scripted_module.__getattr__(function_name)
96
        backend_method = self.lowered_module.__getattr__(function_name)
97

98
        # Run methods.
99
        python_output = python_method(*input)
100
        jit_output = jit_method(*input)
101
        backend_output = backend_method(*input)
102

103
        # The answers returned by Python, JIT and to_backend should all match.
104
        self.assertEqual(python_output, backend_output)
105
        self.assertEqual(jit_output, backend_output)
106

107
    def save_load(self):
108
        """
109
        Save and load the lowered module.
110
        """
111
        self.lowered_module = self.getExportImportCopy(self.lowered_module)
112

113
    def test_execution(self):
114
        """
115
        Stub for correctness tests.
116
        """
117

118
    def test_save_load(self):
119
        """
120
        Stub for serialization tests.
121
        """
122

123
    def test_errors(self):
124
        """
125
        Stub for testing error checking.
126
        """
127

128

129
class BasicModuleTest(JitBackendTestCase):
130
    """
131
    Tests for BasicModule.
132
    """
133

134
    def setUp(self):
135
        super().setUp()
136
        # Create Python, JIT and backend versions of BasicModule.
137
        self.module = BasicModule()
138
        self.scripted_module = torch.jit.script(BasicModule())
139
        self.lowered_module = to_test_backend_multi(
140
            self.scripted_module,
141
            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
142
        )
143

144
    def test_execution(self):
145
        # Test execution with backend against Python and JIT.
146
        input = torch.randn(5)
147

148
        # Test all three module methods.
149
        self.check_function("accum", (input, input))
150
        self.check_function("sub_accum", (input, input))
151
        self.check_function("forward", (input, input))
152

153
    @skipIfRocm
154
    def test_save_load(self):
155
        # Lowered module should produce the same outputs.
156
        self.test_execution()
157

158
        # Save the compile spec to compare against the version retrieved after loading.
159
        pre_compile_spec = self.lowered_module.__getattr__(
160
            "__loweredModule__"
161
        ).__getattr__("__method_compile_spec")
162

163
        # Save and load the lowered module.
164
        self.save_load()
165

166
        # Get the compile spec after loading.
167
        post_compile_spec = self.lowered_module.__getattr__(
168
            "__loweredModule__"
169
        ).__getattr__("__method_compile_spec")
170

171
        # Compile specs should match.
172
        self.assertEqual(pre_compile_spec, post_compile_spec)
173

174
        # Loaded module should produce the same outputs.
175
        self.test_execution()
176

177

178
class BasicModuleUnavailableTest(JitBackendTestCase):
179
    """
180
    Tests for BasicModule with a backend that is not available.
181
    Fundamentally:
182
      * _jit_to_backend is successful.
183
      * Execution fails with an exception.
184
      * Saving is successful.
185
      * Loading fails with an exception.
186
    """
187

188
    def setUp(self):
189
        super().setUp()
190
        # Create Python, JIT and backend versions of BasicModule.
191
        self.module = BasicModule()
192
        self.scripted_module = torch.jit.script(BasicModule())
193
        self.lowered_module = torch._C._jit_to_backend(
194
            "test_backend_unavailable",
195
            self.scripted_module,
196
            {"forward": {"": ""}},
197
        )
198

199
    def test_execution(self):
200
        # Test execution with backend fails because the backend that is not available.
201
        input = torch.randn(5)
202

203
        # Test exception is thrown.
204
        with self.assertRaisesRegexWithHighlight(
205
            Exception,
206
            r"Backend is not available.",
207
            'raise Exception("Backend is not available."',
208
        ):
209
            backend_method = self.lowered_module.__getattr__("forward")
210
            backend_output = backend_method(*(input, input))
211

212
    @skipIfRocm
213
    def test_save_load(self):
214
        # Test that saving the lowered module is OK but loading fails because the backend is not available.
215
        buffer = io.BytesIO()
216
        torch.jit.save(self.lowered_module, buffer)
217
        buffer.seek(0)
218
        with self.assertRaisesRegexWithHighlight(
219
            Exception,
220
            r"Backend is not available.",
221
            'raise Exception("Backend is not available."',
222
        ):
223
            imported = torch.jit.load(buffer)
224

225

226
class NestedModuleTest(JitBackendTestCase):
227
    """
228
    Tests for NestedModule that check that a module lowered to a backend can be used
229
    as a submodule.
230
    """
231

232
    class NestedModule(torch.nn.Module):
233
        """
234
        A Module with one submodule that is used to test that lowered Modules
235
        can be used as submodules.
236
        """
237

238
        def __init__(self, submodule):
239
            super().__init__()
240
            self.submodule = submodule
241

242
        def forward(self, x, h):
243
            return self.submodule.forward(x, h)
244

245
    def setUp(self):
246
        super().setUp()
247
        # Create Python, JIT and backend versions of NestedModule.
248
        # Both modules in self.module are regular Python modules.
249
        self.module = NestedModuleTest.NestedModule(BasicModule())
250
        # Both modules in self.scripted_module are ScriptModules.
251
        self.scripted_module = torch.jit.script(
252
            NestedModuleTest.NestedModule(BasicModule())
253
        )
254

255
        # First, script another instance of NestedModule with share_types=False so that it can be
256
        # selectively lowered without modifying the type of self.scripted_module.
257
        lowered_module = to_test_backend_multi(
258
            torch.jit.script(BasicModule()),
259
            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
260
        )
261
        # self.lowered_module is a ScriptModule, but its submodule is a lowered module.
262
        self.lowered_module = torch.jit.script(
263
            NestedModuleTest.NestedModule(lowered_module)
264
        )
265

266
    def test_execution(self):
267
        # Test execution with backend against Python and JIT.
268
        input = torch.randn(5)
269

270
        # Test forward.
271
        self.check_function("forward", (input, input))
272

273
    def test_save_load(self):
274
        # Lowered module should produce the same outputs.
275
        self.test_execution()
276

277
        # Save and load the lowered module.
278
        self.save_load()
279

280
        # Loaded module should produce the same outputs.
281
        self.test_execution()
282

283

284
class SelectiveLoweringTest(JitBackendTestCase):
285
    """
286
    Tests for the selective lowering API.
287
    """
288

289
    class OuterModule(torch.nn.Module):
290
        def __init__(self, sub1, sub2, other):
291
            super().__init__()
292
            self.sub1 = sub1
293
            self.sub2 = sub2
294
            self.other = other
295

296
        def forward(self, x, y):
297
            # Call the module that will be lowered directly to test
298
            # type remapping in modules that are not its parent.
299
            a, b = self.sub1.submodule.forward(x, y)
300
            c, d = self.sub2.forward(x, y)
301
            e, f = self.other.forward(x, y)
302
            return a + c + e, b + d + f
303

304
    class MiddleModule(torch.nn.Module):
305
        def __init__(self, submodule):
306
            super().__init__()
307
            self.submodule = submodule
308

309
        def forward(self, x, y):
310
            return self.submodule.forward(x, y)
311

312
    def setUp(self):
313
        super().setUp()
314
        OuterModule = SelectiveLoweringTest.OuterModule
315
        MiddleModule = SelectiveLoweringTest.MiddleModule
316

317
        def script_without_type_sharing(mod):
318
            return torch.jit._recursive.create_script_module(
319
                mod, torch.jit._recursive.infer_methods_to_compile, share_types=False
320
            )
321

322
        # Create Python, JIT and backend versions of a hierarchy that looks like this:
323
        #                 --------- OuterModule --------
324
        #                 |              |              |
325
        #           MiddleModule    MiddleModule   MiddleModule
326
        #                |               |              |
327
        #           BasicModule     BasicModule    BasicModule
328
        #
329
        # Two BasicModules will be lowered and the third will not.
330
        self.module = OuterModule(
331
            MiddleModule(BasicModule()),
332
            MiddleModule(BasicModule()),
333
            MiddleModule(BasicModule()),
334
        )
335
        self.scripted_module = script_without_type_sharing(
336
            OuterModule(
337
                MiddleModule(BasicModule()),
338
                MiddleModule(BasicModule()),
339
                MiddleModule(BasicModule()),
340
            )
341
        )
342
        self.lowered_module = script_without_type_sharing(
343
            OuterModule(
344
                MiddleModule(BasicModule()),
345
                MiddleModule(BasicModule()),
346
                MiddleModule(BasicModule()),
347
            )
348
        )
349
        self.lowered_module = to_test_backend_selective(
350
            self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"]
351
        )
352

353
    def test_execution(self):
354
        input = torch.randn(5)
355
        self.check_function("forward", (input, input))
356

357
        self.test_selective_lowering_type_remap()
358

359
    def test_save_load(self):
360
        self.test_execution()
361
        self.save_load()
362
        self.test_execution()
363

364
        self.test_selective_lowering_type_remap()
365

366
    def test_selective_lowering_type_remap(self):
367
        """
368
        Check that type remapping and replacement occurred during selective lowering.
369
        """
370
        # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it
371
        # calling the lowered module directly.
372
        FileCheck().check("OuterModule").check("BasicModule").run(
373
            self.scripted_module.graph
374
        )
375
        FileCheck().check("OuterModule").check_not(
376
            "__torch__.torch.classes.__backends__.test_backend"
377
        ).check("LoweredWrapper.test_backend").run(self.lowered_module.graph)
378

379
        # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs.
380
        FileCheck().check("MiddleModule").check("BasicModule").check_not(
381
            "LoweredWrapper.test_backend"
382
        ).run(self.scripted_module.sub1.graph)
383
        FileCheck().check("MiddleModule").check_not(
384
            "__torch__.torch.classes.__backends__.test_backend"
385
        ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph)
386

387
        FileCheck().check("MiddleModule").check("BasicModule").check_not(
388
            "LoweredWrapper.test_backend"
389
        ).run(self.scripted_module.sub2.graph)
390
        FileCheck().check("MiddleModule").check_not(
391
            "__torch__.torch.classes.__backends__.test_backend"
392
        ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph)
393

394
        # Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute
395
        # __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend,
396
        # the TorchBind class for executing functions on the test JIT backend.
397
        FileCheck().check("LoweredModule.test_backend").check(
398
            "__torch__.torch.classes.__backends__.test_backend"
399
        ).run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
400

401
        FileCheck().check("LoweredModule.test_backend").check(
402
            "__torch__.torch.classes.__backends__.test_backend"
403
        ).run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
404

405
        # Check that self.other and self.other.submodule have been left untouched by the selective lowering process.
406
        FileCheck().check("MiddleModule").check("BasicModule").check_not(
407
            "__torch__.torch.classes.__backends__.test_backend"
408
        ).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph)
409
        FileCheck().check("BasicModule").check_not(
410
            "__torch__.torch.classes.__backends__.test_backend"
411
        ).check_not("LoweredModule.test_backend").run(
412
            self.scripted_module.other.submodule.graph
413
        )
414

415
    def test_errors(self):
416
        """
417
        Check errors associated with selective lowering.
418
        """
419
        # Check error messages thrown when attempting to lower something that is not a ScriptModule.
420
        with self.assertRaisesRegexWithHighlight(
421
            RuntimeError, r"Object .* is not a ScriptModule", ""
422
        ):
423
            to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"])
424

425
        MiddleModule = SelectiveLoweringTest.MiddleModule
426
        mod = MiddleModule(BasicModule())
427
        mod.new_attr = 3
428

429
        with self.assertRaisesRegexWithHighlight(
430
            RuntimeError, r"Attribute named new_attr is not a Module", ""
431
        ):
432
            to_test_backend_selective(
433
                torch.jit.script(mod), {"forward": ""}, ["new_attr"]
434
            )
435

436
        # Check error message thrown when module hierarchy doesn't have unique types.
437
        OuterModule = SelectiveLoweringTest.OuterModule
438
        mod = OuterModule(
439
            MiddleModule(BasicModule()),
440
            MiddleModule(BasicModule()),
441
            MiddleModule(BasicModule()),
442
        )
443

444
        with self.assertRaisesRegexWithHighlight(
445
            RuntimeError,
446
            r"Selective lowering is only supported for module hierarchies with unique types",
447
            "",
448
        ):
449
            to_test_backend_selective(
450
                torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]
451
            )
452

453

454
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
455
@unittest.skipIf(
456
    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
457
    "Non-portable load_library call used in test",
458
)
459
class TestBackends(JitTestCase):
460
    """
461
    This class wraps and invokes all subclasses of JitBackendTestCase so that each one
462
    does not have to be individually imported in test_jit.py.
463
    """
464

465
    def __init__(self, name):
466
        super().__init__(name)
467
        self.basic_module_test = BasicModuleTest(name)
468
        self.basic_module_unavailable_test = BasicModuleUnavailableTest(name)
469
        self.nested_module_test = NestedModuleTest(name)
470
        self.selective_lowering_test = SelectiveLoweringTest(name)
471

472
    def setUp(self):
473
        super().setUp()
474
        if not TEST_WITH_ROCM:
475
            self.basic_module_test.setUp()
476
            self.basic_module_unavailable_test.setUp()
477
            self.nested_module_test.setUp()
478
            self.selective_lowering_test.setUp()
479

480
    @skipIfRocm
481
    def test_execution(self):
482
        self.basic_module_test.test_execution()
483
        self.basic_module_unavailable_test.test_execution()
484
        self.nested_module_test.test_execution()
485
        self.selective_lowering_test.test_execution()
486

487
    @skipIfRocm
488
    def test_save_load(self):
489
        self.basic_module_test.test_save_load()
490
        self.basic_module_unavailable_test.test_save_load()
491
        self.nested_module_test.test_save_load()
492
        self.selective_lowering_test.test_save_load()
493

494
    @skipIfRocm
495
    def test_errors(self):
496
        self.selective_lowering_test.test_errors()
497

498

499
"""
500
Unit Tests for backend with compiler
501
This test case and the existing TestBackends are separate because they cover different aspects.
502
The actual backend implementation in this test is different.
503
It has a simple demo compiler to test the end-to-end flow in mobile.
504
However, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
505
"""
506

507

508
class BasicModuleAdd(torch.nn.Module):
509
    """
510
    A simple add Module used to test to_backend lowering machinery.
511
    """
512

513
    def forward(self, x, h):
514
        return x + h
515

516

517
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
518
@unittest.skipIf(
519
    TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
520
    "Non-portable load_library call used in test",
521
)
522
class JitBackendTestCaseWithCompiler(JitTestCase):
523
    """
524
    A common base class for JIT backend tests with compilers that contains common utility
525
    functions for output comparison.
526
    """
527

528
    def setUp(self):
529
        super().setUp()
530
        lib_file_path = find_library_location("libbackend_with_compiler.so")
531
        torch.ops.load_library(str(lib_file_path))
532
        # Subclasses are expected to set up four variables in their setUp methods:
533
        # module - a regular, Python version of the module being tested
534
        # scripted_module - a scripted version of module
535
        # lowered_module - a version of module lowered to a backend
536
        # mobile_module - a module with a format that Pytorch Mobile can execute
537

538
    def check_forward(self, input):
539
        """
540
        Check that the forward function produces the same output using
541
        Python, regular JIT, the backend, and mobile for the given 'input'.
542
        """
543

544
        # Get outputs from forward.
545
        python_output = self.module.forward(*input)
546
        jit_output = self.scripted_module.forward(*input)
547
        backend_output = self.lowered_module(*input)
548
        mobile_output = self.mobile_module(*input)
549

550
        # The answers returned by Python, JIT, to_backend, and mobile should all match.
551
        self.assertEqual(python_output, backend_output)
552
        self.assertEqual(jit_output, backend_output)
553
        self.assertEqual(mobile_output, backend_output)
554

555
    def test_execution(self):
556
        """
557
        Stub for correctness tests.
558
        """
559

560
    def test_errors(self):
561
        """
562
        Stub for testing error checking.
563
        """
564

565

566
class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
567
    """
568
    Tests for BasicModuleAdd.
569
    """
570

571
    def setUp(self):
572
        super().setUp()
573
        # Create Python, JIT and backend versions of BasicModuleAdd.
574
        self.module = BasicModuleAdd()
575
        self.scripted_module = torch.jit.script(BasicModuleAdd())
576
        compile_spec = {
577
            "forward": {
578
                "input_shapes": "((1, 1, 320, 240), (1, 3))",
579
                "some_other_option": "True",
580
            },
581
        }
582
        self.lowered_module = torch._C._jit_to_backend(
583
            "backend_with_compiler_demo", self.scripted_module, compile_spec
584
        )
585
        # Create mobile version of BasicModuleAdd
586
        buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
587
        buffer.seek(0)
588
        self.mobile_module = _load_for_lite_interpreter(buffer)
589

590
    def test_execution(self):
591
        # Test execution with backend against Python and JIT.
592
        input = torch.ones(1, dtype=torch.float)
593
        self.check_forward((input, input))
594

595

596
class ErrorMessagesWithCompiler(JitBackendTestCase):
597
    """
598
    Tests for errors that occur with compiler, specifically:
599
        * an operator is not supported by the backend
600
    """
601

602
    class ModuleNotSupported(torch.nn.Module):
603
        """
604
        A module with an operator that is not supported.
605
        """
606

607
        def forward(self, x, h):
608
            return x * h
609
            self._loweredmodule.forward()
610

611
    def test_errors(self):
612
        scripted_module_n = torch.jit.script(
613
            ErrorMessagesWithCompiler.ModuleNotSupported()
614
        )
615
        # Test exception is thrown when lowering a module with an unsupported operator
616
        with self.assertRaisesRegexWithHighlight(
617
            RuntimeError,
618
            # Special escape characters are replaced with '.'
619
            r"""The node of aten::mul is not supported in this compiler. .*
620
        def forward.self, x, h.:
621
            return x . h
622
                   ~~~~~ <--- HERE
623
            self._loweredmodule.forward..
624
""",
625
            "",
626
        ):
627
            lowered_module_n = torch._C._jit_to_backend(
628
                "backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}}
629
            )
630

631

632
class CompModuleTestWithCompiler(JitBackendTestCase):
633
    """
634
    Tests for CompModule, which is a module with two lowered submodules
635
    """
636

637
    class BasicModuleSub(torch.nn.Module):
638
        """
639
        A simple subtraction Module to be used in CompModule.
640
        """
641

642
        def forward(self, x, h):
643
            return x - h
644

645
    class CompModule(torch.nn.Module):
646
        """
647
        A module with two lowered submodules.
648
        """
649

650
        def __init__(self, addmodule, submodule):
651
            super().__init__()
652
            self.lowered_add = addmodule
653
            self.lowered_sub = submodule
654

655
        def forward(self, a, b, s):
656
            c = self.lowered_add.forward(a, b)
657
            d = self.lowered_sub.forward(a, b)
658
            y = s * (c * d)
659
            return y
660

661
    def setUp(self):
662
        super().setUp()
663
        # Create Python and JIT versions of CompModule with lowered submodules.
664
        compile_spec = {
665
            "forward": {
666
                "input_shapes": "((1, 1, 320, 240), (1, 3))",
667
                "some_other_option": "True",
668
            },
669
        }
670
        lowered_add = torch._C._jit_to_backend(
671
            "backend_with_compiler_demo",
672
            torch.jit.script(BasicModuleAdd()),
673
            compile_spec,
674
        )
675
        lowered_sub = torch._C._jit_to_backend(
676
            "backend_with_compiler_demo",
677
            torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()),
678
            {"forward": {"": ""}},
679
        )
680
        self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
681
        self.scripted_module = torch.jit.script(
682
            CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
683
        )
684
        # No backend version of CompModule currently, so this is filler.
685
        self.lowered_module = self.scripted_module
686
        # Create a mobile version of CompModule from JIT version
687
        buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
688
        buffer.seek(0)
689
        self.mobile_module = _load_for_lite_interpreter(buffer)
690

691
    def test_execution(self):
692
        # Test execution with backend against Python and JIT.
693
        input1 = torch.ones(1, dtype=torch.float)
694
        input2 = torch.ones(1, dtype=torch.float)
695

696
        # Test forward.
697
        self.check_function("forward", (input1, input2, input2))
698

699

700
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
701
@unittest.skipIf(
702
    IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
703
    "Non-portable load_library call used in test",
704
)
705
class TestBackendsWithCompiler(JitTestCase):
706
    """
707
    This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
708
    so that each one does not have to be individually imported in test_jit.py.
709
    """
710

711
    def __init__(self, name):
712
        super().__init__(name)
713
        self.basic_module_compiler_test = BasicModuleTestWithCompiler(name)
714
        self.error_module_compiler_test = ErrorMessagesWithCompiler(name)
715
        self.comp_module_compiler_test = CompModuleTestWithCompiler(name)
716

717
    def setUp(self):
718
        super().setUp()
719
        self.basic_module_compiler_test.setUp()
720
        self.error_module_compiler_test.setUp()
721
        self.comp_module_compiler_test.setUp()
722

723
    def test_execution(self):
724
        self.basic_module_compiler_test.test_execution()
725
        self.comp_module_compiler_test.test_execution()
726

727
    def test_errors(self):
728
        self.error_module_compiler_test.test_errors()
729

730

731
class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
732
    """
733
    Tests for CompModule, which is a module with two lowered submodules with same module name
734
    """
735

736
    class ModuleAdd(torch.nn.Module):
737
        """
738
        A simple Module used to test to_backend lowering machinery.
739
        """
740

741
        def forward(self, x, h):
742
            return x + h
743

744
    class CompModule(torch.nn.Module):
745
        """
746
        A module with two lowered submodules.
747
        """
748

749
        def __init__(self) -> None:
750
            super().__init__()
751
            compile_spec = {
752
                "forward": {
753
                    "some_other_option": "True",
754
                },
755
            }
756
            self.add = torch._C._jit_to_backend(
757
                "backend_with_compiler_demo",
758
                torch.jit.script(ModuleAdd()),  # noqa: F821
759
                compile_spec,
760
            )
761
            self.sub = torch._C._jit_to_backend(
762
                "backend_with_compiler_demo",
763
                torch.jit.script(ModuleAdd()),  # noqa: F821
764
                compile_spec,
765
            )
766

767
        def forward(self, a, b, s: int):
768
            c = self.add.forward(a, b)
769
            d = self.sub.forward(a, b)
770
            y = s * (c * d)
771
            return y
772

773
    def setUp(self):
774
        super().setUp()
775

776
        self.module = CompModule()  # noqa: F821
777
        self.scripted_module = torch.jit.script(self.module)
778
        buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
779
        buffer.seek(0)
780
        self.mobile_module = _load_for_lite_interpreter(buffer)
781

782
    def test_execution(self):
783
        a = torch.ones(1)
784
        b = 3 * torch.ones(1)
785
        s = 3
786
        # Test forward.
787
        self.check_function("forward", (a, b, s))
788

789

790
class AddedAttributesTest(JitBackendTestCase):
791
    """
792
    Tests for adding attributes to a model after lowering.
793
    """
794

795
    def setUp(self):
796
        super().setUp()
797
        # Create Python, JIT and backend versions of BasicModule.
798
        self.module = BasicModule()
799
        self.scripted_module = torch.jit.script(BasicModule())
800
        self.lowered_module = to_test_backend_multi(
801
            self.scripted_module,
802
            {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
803
        )
804

805
    def test_attribute(self):
806
        input = [(torch.ones(5),)]
807
        pre_bundled = self.lowered_module(*input[0])
808
        # Attach bundled inputs which adds several attributes and functions to the model
809
        self.lowered_module = (
810
            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
811
                lowered_module, input  # noqa: F821
812
            )
813
        )
814
        post_bundled = self.lowered_module(
815
            *self.lowered_module.get_all_bundled_inputs()[0]
816
        )
817
        # Save and load the lowered module.
818
        self.save_load()
819
        # Use bundled after save and load to prove its preserved
820
        post_load = self.lowered_module(
821
            *self.lowered_module.get_all_bundled_inputs()[0]
822
        )
823
        self.assertEqual(pre_bundled, post_bundled)
824
        self.assertEqual(post_bundled, post_load)
825

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.