10
from torch.jit.mobile import _load_for_lite_interpreter
11
from torch.testing import FileCheck
12
from torch.testing._internal.common_utils import (
13
find_library_location,
21
from torch.testing._internal.jit_utils import JitTestCase
25
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
26
sys.path.append(pytorch_test_dir)
28
if __name__ == "__main__":
30
"This test file is not meant to be run directly, use:\n\n"
31
"\tpython test/test_jit.py TESTNAME\n\n"
36
def to_test_backend(module, method_compile_spec):
37
return torch._C._jit_to_backend(
38
"test_backend", module, {"forward": method_compile_spec}
42
def to_test_backend_multi(module, method_compile_spec):
43
return torch._C._jit_to_backend("test_backend", module, method_compile_spec)
46
def to_test_backend_selective(module, method_compile_spec, submodules):
47
def _to_test_backend(module):
48
return to_test_backend(module, method_compile_spec)
50
return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules)
53
class BasicModule(torch.nn.Module):
55
A simple Module used to test to_backend lowering machinery.
58
def forward(self, x, h):
59
return self.accum(x, h), self.sub_accum(x, h)
61
def accum(self, x, h):
64
def sub_accum(self, x, h):
70
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
71
"Non-portable load_library call used in test",
73
class JitBackendTestCase(JitTestCase):
75
A common base class for JIT backend tests that contains common utility
76
functions for output comparison and serialization/deserialization.
81
lib_file_path = find_library_location("libjitbackend_test.so")
82
torch.ops.load_library(str(lib_file_path))
88
def check_function(self, function_name, input):
90
Check that the function named 'function_name' produces the same output using
91
Python, regular JIT and the backend for the given 'input'.
94
python_method = self.module.__getattribute__(function_name)
95
jit_method = self.scripted_module.__getattr__(function_name)
96
backend_method = self.lowered_module.__getattr__(function_name)
99
python_output = python_method(*input)
100
jit_output = jit_method(*input)
101
backend_output = backend_method(*input)
104
self.assertEqual(python_output, backend_output)
105
self.assertEqual(jit_output, backend_output)
109
Save and load the lowered module.
111
self.lowered_module = self.getExportImportCopy(self.lowered_module)
113
def test_execution(self):
115
Stub for correctness tests.
118
def test_save_load(self):
120
Stub for serialization tests.
123
def test_errors(self):
125
Stub for testing error checking.
129
class BasicModuleTest(JitBackendTestCase):
131
Tests for BasicModule.
137
self.module = BasicModule()
138
self.scripted_module = torch.jit.script(BasicModule())
139
self.lowered_module = to_test_backend_multi(
140
self.scripted_module,
141
{"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
144
def test_execution(self):
146
input = torch.randn(5)
149
self.check_function("accum", (input, input))
150
self.check_function("sub_accum", (input, input))
151
self.check_function("forward", (input, input))
154
def test_save_load(self):
156
self.test_execution()
159
pre_compile_spec = self.lowered_module.__getattr__(
161
).__getattr__("__method_compile_spec")
167
post_compile_spec = self.lowered_module.__getattr__(
169
).__getattr__("__method_compile_spec")
172
self.assertEqual(pre_compile_spec, post_compile_spec)
175
self.test_execution()
178
class BasicModuleUnavailableTest(JitBackendTestCase):
180
Tests for BasicModule with a backend that is not available.
182
* _jit_to_backend is successful.
183
* Execution fails with an exception.
184
* Saving is successful.
185
* Loading fails with an exception.
191
self.module = BasicModule()
192
self.scripted_module = torch.jit.script(BasicModule())
193
self.lowered_module = torch._C._jit_to_backend(
194
"test_backend_unavailable",
195
self.scripted_module,
196
{"forward": {"": ""}},
199
def test_execution(self):
201
input = torch.randn(5)
204
with self.assertRaisesRegexWithHighlight(
206
r"Backend is not available.",
207
'raise Exception("Backend is not available."',
209
backend_method = self.lowered_module.__getattr__("forward")
210
backend_output = backend_method(*(input, input))
213
def test_save_load(self):
215
buffer = io.BytesIO()
216
torch.jit.save(self.lowered_module, buffer)
218
with self.assertRaisesRegexWithHighlight(
220
r"Backend is not available.",
221
'raise Exception("Backend is not available."',
223
imported = torch.jit.load(buffer)
226
class NestedModuleTest(JitBackendTestCase):
228
Tests for NestedModule that check that a module lowered to a backend can be used
232
class NestedModule(torch.nn.Module):
234
A Module with one submodule that is used to test that lowered Modules
235
can be used as submodules.
238
def __init__(self, submodule):
240
self.submodule = submodule
242
def forward(self, x, h):
243
return self.submodule.forward(x, h)
249
self.module = NestedModuleTest.NestedModule(BasicModule())
251
self.scripted_module = torch.jit.script(
252
NestedModuleTest.NestedModule(BasicModule())
257
lowered_module = to_test_backend_multi(
258
torch.jit.script(BasicModule()),
259
{"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
262
self.lowered_module = torch.jit.script(
263
NestedModuleTest.NestedModule(lowered_module)
266
def test_execution(self):
268
input = torch.randn(5)
271
self.check_function("forward", (input, input))
273
def test_save_load(self):
275
self.test_execution()
281
self.test_execution()
284
class SelectiveLoweringTest(JitBackendTestCase):
286
Tests for the selective lowering API.
289
class OuterModule(torch.nn.Module):
290
def __init__(self, sub1, sub2, other):
296
def forward(self, x, y):
299
a, b = self.sub1.submodule.forward(x, y)
300
c, d = self.sub2.forward(x, y)
301
e, f = self.other.forward(x, y)
302
return a + c + e, b + d + f
304
class MiddleModule(torch.nn.Module):
305
def __init__(self, submodule):
307
self.submodule = submodule
309
def forward(self, x, y):
310
return self.submodule.forward(x, y)
314
OuterModule = SelectiveLoweringTest.OuterModule
315
MiddleModule = SelectiveLoweringTest.MiddleModule
317
def script_without_type_sharing(mod):
318
return torch.jit._recursive.create_script_module(
319
mod, torch.jit._recursive.infer_methods_to_compile, share_types=False
330
self.module = OuterModule(
331
MiddleModule(BasicModule()),
332
MiddleModule(BasicModule()),
333
MiddleModule(BasicModule()),
335
self.scripted_module = script_without_type_sharing(
337
MiddleModule(BasicModule()),
338
MiddleModule(BasicModule()),
339
MiddleModule(BasicModule()),
342
self.lowered_module = script_without_type_sharing(
344
MiddleModule(BasicModule()),
345
MiddleModule(BasicModule()),
346
MiddleModule(BasicModule()),
349
self.lowered_module = to_test_backend_selective(
350
self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"]
353
def test_execution(self):
354
input = torch.randn(5)
355
self.check_function("forward", (input, input))
357
self.test_selective_lowering_type_remap()
359
def test_save_load(self):
360
self.test_execution()
362
self.test_execution()
364
self.test_selective_lowering_type_remap()
366
def test_selective_lowering_type_remap(self):
368
Check that type remapping and replacement occurred during selective lowering.
372
FileCheck().check("OuterModule").check("BasicModule").run(
373
self.scripted_module.graph
375
FileCheck().check("OuterModule").check_not(
376
"__torch__.torch.classes.__backends__.test_backend"
377
).check("LoweredWrapper.test_backend").run(self.lowered_module.graph)
380
FileCheck().check("MiddleModule").check("BasicModule").check_not(
381
"LoweredWrapper.test_backend"
382
).run(self.scripted_module.sub1.graph)
383
FileCheck().check("MiddleModule").check_not(
384
"__torch__.torch.classes.__backends__.test_backend"
385
).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph)
387
FileCheck().check("MiddleModule").check("BasicModule").check_not(
388
"LoweredWrapper.test_backend"
389
).run(self.scripted_module.sub2.graph)
390
FileCheck().check("MiddleModule").check_not(
391
"__torch__.torch.classes.__backends__.test_backend"
392
).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph)
397
FileCheck().check("LoweredModule.test_backend").check(
398
"__torch__.torch.classes.__backends__.test_backend"
399
).run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
401
FileCheck().check("LoweredModule.test_backend").check(
402
"__torch__.torch.classes.__backends__.test_backend"
403
).run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
406
FileCheck().check("MiddleModule").check("BasicModule").check_not(
407
"__torch__.torch.classes.__backends__.test_backend"
408
).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph)
409
FileCheck().check("BasicModule").check_not(
410
"__torch__.torch.classes.__backends__.test_backend"
411
).check_not("LoweredModule.test_backend").run(
412
self.scripted_module.other.submodule.graph
415
def test_errors(self):
417
Check errors associated with selective lowering.
420
with self.assertRaisesRegexWithHighlight(
421
RuntimeError, r"Object .* is not a ScriptModule", ""
423
to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"])
425
MiddleModule = SelectiveLoweringTest.MiddleModule
426
mod = MiddleModule(BasicModule())
429
with self.assertRaisesRegexWithHighlight(
430
RuntimeError, r"Attribute named new_attr is not a Module", ""
432
to_test_backend_selective(
433
torch.jit.script(mod), {"forward": ""}, ["new_attr"]
437
OuterModule = SelectiveLoweringTest.OuterModule
439
MiddleModule(BasicModule()),
440
MiddleModule(BasicModule()),
441
MiddleModule(BasicModule()),
444
with self.assertRaisesRegexWithHighlight(
446
r"Selective lowering is only supported for module hierarchies with unique types",
449
to_test_backend_selective(
450
torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]
456
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
457
"Non-portable load_library call used in test",
459
class TestBackends(JitTestCase):
461
This class wraps and invokes all subclasses of JitBackendTestCase so that each one
462
does not have to be individually imported in test_jit.py.
465
def __init__(self, name):
466
super().__init__(name)
467
self.basic_module_test = BasicModuleTest(name)
468
self.basic_module_unavailable_test = BasicModuleUnavailableTest(name)
469
self.nested_module_test = NestedModuleTest(name)
470
self.selective_lowering_test = SelectiveLoweringTest(name)
474
if not TEST_WITH_ROCM:
475
self.basic_module_test.setUp()
476
self.basic_module_unavailable_test.setUp()
477
self.nested_module_test.setUp()
478
self.selective_lowering_test.setUp()
481
def test_execution(self):
482
self.basic_module_test.test_execution()
483
self.basic_module_unavailable_test.test_execution()
484
self.nested_module_test.test_execution()
485
self.selective_lowering_test.test_execution()
488
def test_save_load(self):
489
self.basic_module_test.test_save_load()
490
self.basic_module_unavailable_test.test_save_load()
491
self.nested_module_test.test_save_load()
492
self.selective_lowering_test.test_save_load()
495
def test_errors(self):
496
self.selective_lowering_test.test_errors()
500
Unit Tests for backend with compiler
501
This test case and the existing TestBackends are separate because they cover different aspects.
502
The actual backend implementation in this test is different.
503
It has a simple demo compiler to test the end-to-end flow in mobile.
504
However, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
508
class BasicModuleAdd(torch.nn.Module):
510
A simple add Module used to test to_backend lowering machinery.
513
def forward(self, x, h):
519
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
520
"Non-portable load_library call used in test",
522
class JitBackendTestCaseWithCompiler(JitTestCase):
524
A common base class for JIT backend tests with compilers that contains common utility
525
functions for output comparison.
530
lib_file_path = find_library_location("libbackend_with_compiler.so")
531
torch.ops.load_library(str(lib_file_path))
538
def check_forward(self, input):
540
Check that the forward function produces the same output using
541
Python, regular JIT, the backend, and mobile for the given 'input'.
545
python_output = self.module.forward(*input)
546
jit_output = self.scripted_module.forward(*input)
547
backend_output = self.lowered_module(*input)
548
mobile_output = self.mobile_module(*input)
551
self.assertEqual(python_output, backend_output)
552
self.assertEqual(jit_output, backend_output)
553
self.assertEqual(mobile_output, backend_output)
555
def test_execution(self):
557
Stub for correctness tests.
560
def test_errors(self):
562
Stub for testing error checking.
566
class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
568
Tests for BasicModuleAdd.
574
self.module = BasicModuleAdd()
575
self.scripted_module = torch.jit.script(BasicModuleAdd())
578
"input_shapes": "((1, 1, 320, 240), (1, 3))",
579
"some_other_option": "True",
582
self.lowered_module = torch._C._jit_to_backend(
583
"backend_with_compiler_demo", self.scripted_module, compile_spec
586
buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
588
self.mobile_module = _load_for_lite_interpreter(buffer)
590
def test_execution(self):
592
input = torch.ones(1, dtype=torch.float)
593
self.check_forward((input, input))
596
class ErrorMessagesWithCompiler(JitBackendTestCase):
598
Tests for errors that occur with compiler, specifically:
599
* an operator is not supported by the backend
602
class ModuleNotSupported(torch.nn.Module):
604
A module with an operator that is not supported.
607
def forward(self, x, h):
609
self._loweredmodule.forward()
611
def test_errors(self):
612
scripted_module_n = torch.jit.script(
613
ErrorMessagesWithCompiler.ModuleNotSupported()
616
with self.assertRaisesRegexWithHighlight(
619
r"""The node of aten::mul is not supported in this compiler. .*
620
def forward.self, x, h.:
623
self._loweredmodule.forward..
627
lowered_module_n = torch._C._jit_to_backend(
628
"backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}}
632
class CompModuleTestWithCompiler(JitBackendTestCase):
634
Tests for CompModule, which is a module with two lowered submodules
637
class BasicModuleSub(torch.nn.Module):
639
A simple subtraction Module to be used in CompModule.
642
def forward(self, x, h):
645
class CompModule(torch.nn.Module):
647
A module with two lowered submodules.
650
def __init__(self, addmodule, submodule):
652
self.lowered_add = addmodule
653
self.lowered_sub = submodule
655
def forward(self, a, b, s):
656
c = self.lowered_add.forward(a, b)
657
d = self.lowered_sub.forward(a, b)
666
"input_shapes": "((1, 1, 320, 240), (1, 3))",
667
"some_other_option": "True",
670
lowered_add = torch._C._jit_to_backend(
671
"backend_with_compiler_demo",
672
torch.jit.script(BasicModuleAdd()),
675
lowered_sub = torch._C._jit_to_backend(
676
"backend_with_compiler_demo",
677
torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()),
678
{"forward": {"": ""}},
680
self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
681
self.scripted_module = torch.jit.script(
682
CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
685
self.lowered_module = self.scripted_module
687
buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
689
self.mobile_module = _load_for_lite_interpreter(buffer)
691
def test_execution(self):
693
input1 = torch.ones(1, dtype=torch.float)
694
input2 = torch.ones(1, dtype=torch.float)
697
self.check_function("forward", (input1, input2, input2))
702
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
703
"Non-portable load_library call used in test",
705
class TestBackendsWithCompiler(JitTestCase):
707
This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
708
so that each one does not have to be individually imported in test_jit.py.
711
def __init__(self, name):
712
super().__init__(name)
713
self.basic_module_compiler_test = BasicModuleTestWithCompiler(name)
714
self.error_module_compiler_test = ErrorMessagesWithCompiler(name)
715
self.comp_module_compiler_test = CompModuleTestWithCompiler(name)
719
self.basic_module_compiler_test.setUp()
720
self.error_module_compiler_test.setUp()
721
self.comp_module_compiler_test.setUp()
723
def test_execution(self):
724
self.basic_module_compiler_test.test_execution()
725
self.comp_module_compiler_test.test_execution()
727
def test_errors(self):
728
self.error_module_compiler_test.test_errors()
731
class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
733
Tests for CompModule, which is a module with two lowered submodules with same module name
736
class ModuleAdd(torch.nn.Module):
738
A simple Module used to test to_backend lowering machinery.
741
def forward(self, x, h):
744
class CompModule(torch.nn.Module):
746
A module with two lowered submodules.
749
def __init__(self) -> None:
753
"some_other_option": "True",
756
self.add = torch._C._jit_to_backend(
757
"backend_with_compiler_demo",
758
torch.jit.script(ModuleAdd()),
761
self.sub = torch._C._jit_to_backend(
762
"backend_with_compiler_demo",
763
torch.jit.script(ModuleAdd()),
767
def forward(self, a, b, s: int):
768
c = self.add.forward(a, b)
769
d = self.sub.forward(a, b)
776
self.module = CompModule()
777
self.scripted_module = torch.jit.script(self.module)
778
buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter())
780
self.mobile_module = _load_for_lite_interpreter(buffer)
782
def test_execution(self):
784
b = 3 * torch.ones(1)
787
self.check_function("forward", (a, b, s))
790
class AddedAttributesTest(JitBackendTestCase):
792
Tests for adding attributes to a model after lowering.
798
self.module = BasicModule()
799
self.scripted_module = torch.jit.script(BasicModule())
800
self.lowered_module = to_test_backend_multi(
801
self.scripted_module,
802
{"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
805
def test_attribute(self):
806
input = [(torch.ones(5),)]
807
pre_bundled = self.lowered_module(*input[0])
809
self.lowered_module = (
810
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
811
lowered_module, input
814
post_bundled = self.lowered_module(
815
*self.lowered_module.get_all_bundled_inputs()[0]
820
post_load = self.lowered_module(
821
*self.lowered_module.get_all_bundled_inputs()[0]
823
self.assertEqual(pre_bundled, post_bundled)
824
self.assertEqual(post_bundled, post_load)