8
import typing_extensions
9
from collections import OrderedDict
10
from typing import Dict, List, Optional, Tuple
13
import torch.jit.frontend
15
from torch import Tensor
16
from torch.testing import FileCheck
20
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
21
sys.path.append(pytorch_test_dir)
22
from torch.testing._internal.jit_utils import (
23
_tmp_donotuse_dont_inline_everything,
28
if __name__ == "__main__":
30
"This test file is not meant to be run directly, use:\n\n"
31
"\tpython test/test_jit.py TESTNAME\n\n"
36
class TestRecursiveScript(JitTestCase):
37
def test_inferred_nonetype(self):
39
def __init__(self) -> None:
46
m = torch.jit.script(M())
47
self.checkModule(M(), ())
49
def test_script_function_attribute(self):
58
class M(torch.nn.Module):
59
def __init__(self, fn):
69
self.checkModule(fn1_mod, (torch.randn(2, 2),))
70
self.checkModule(fn2_mod, (torch.randn(2, 2),))
72
def test_python_function_attribute(self):
73
class M(torch.nn.Module):
74
def __init__(self, fn):
81
mod = M(torch.sigmoid)
83
self.checkModule(mod, (torch.randn(2, 2),))
85
def test_failed_function_compilation(self):
89
class M(torch.nn.Module):
90
def __init__(self, fn):
98
with self.assertRaisesRegexWithHighlight(
99
RuntimeError, "failed to compile", "i_dont_exist"
103
def test_init_error(self):
105
def __init__(self) -> None:
111
with self.assertRaisesRegex(RuntimeError, "has not been initialized"):
112
torch.jit.script(M())
114
def test_script_after_eval(self):
123
sm1 = torch.jit.script(m)
125
sm2 = torch.jit.script(m)
128
self.assertFalse(m.training)
131
self.assertTrue(sm1.training)
132
self.assertEqual(sm1.training, sm1._c.getattr("training"))
133
self.assertEqual(sm1(), 2)
136
self.assertFalse(sm2.training)
137
self.assertEqual(sm2.training, sm2._c.getattr("training"))
138
self.assertEqual(sm2(), 0)
140
def test_module_name(self):
141
class MyModule(torch.nn.Module):
142
def __init__(self) -> None:
146
def forward(self, t):
149
m = torch.jit.script(MyModule())
150
FileCheck().check("MyModule").run(m.graph)
152
def test_repeated_error_stack(self):
167
except Exception as e:
168
FileCheck().check_count("is being compiled", 2).run(str(e))
172
except Exception as e:
174
FileCheck().check_count("is being compiled", 2).run(str(e))
176
def test_constants_with_final(self):
177
class M1(torch.nn.Module):
178
x: torch.jit.Final[int]
180
def __init__(self) -> None:
184
def forward(self, t):
187
self.checkModule(M1(), (torch.randn(2, 2),))
189
class M2(torch.nn.Module):
190
x: typing_extensions.Final[int]
192
def __init__(self) -> None:
196
def forward(self, t):
199
self.checkModule(M2(), (torch.randn(2, 2),))
201
class M3(torch.nn.Module):
204
def __init__(self) -> None:
208
def forward(self, t):
211
self.checkModule(M3(), (torch.randn(2, 2),))
213
def test_ignore_class(self):
216
def unscriptable(self):
219
class TestModule(torch.nn.Module):
220
def forward(self, x):
221
return MyScriptClass()
223
with self.assertRaisesRegexWithHighlight(
224
torch.jit.frontend.FrontendError,
225
"Cannot instantiate class",
228
t = torch.jit.script(TestModule())
230
def test_method_call(self):
235
def forward(self, z):
239
self.checkModule(M(), (torch.randn(2, 2),))
241
def test_module_repr(self):
242
class Submodule(nn.Module):
243
def forward(self, x):
246
class MyModule(nn.Module):
247
def __init__(self) -> None:
249
self.conv = nn.Conv2d(10, 10, 3)
250
self.lin = nn.Linear(10, 10)
251
self.sub = Submodule()
253
def forward(self, x):
254
return self.lin(x) + self.sub(x) + self.conv(x)
256
m = torch.jit.script(MyModule())
258
with self.capture_stdout() as out:
268
self.assertEqual(m.original_name, "MyModule")
271
def test_module_dir(mod):
273
scripted_mod = torch.jit.script(mod)
274
dir_scripted = set(dir(scripted_mod))
288
if attr in ignore_set:
290
self.assertTrue(attr in dir_scripted, attr)
292
class MyModule(nn.Module):
293
def __init__(self) -> None:
295
self.conv = nn.Conv2d(10, 10, 3)
296
self.lin = nn.Linear(10, 10)
298
def forward(self, x):
299
return self.lin(x) + self.conv(x)
301
test_module_dir(MyModule())
304
conv = nn.Conv2d(10, 10, 3)
305
linear = nn.Linear(10, 10)
307
test_module_dir(nn.Sequential(conv, linear))
309
nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))
312
def test_class_compile(self):
313
def other_fn(a: int, b: Tensor) -> Tensor:
317
def __init__(self, x):
321
return self.x + a + other_fn(self.x, a)
323
class N(torch.nn.Module):
324
def forward(self, x):
328
self.checkModule(N(), (torch.randn(2, 2),))
330
def test_error_stack(self):
331
def d(x: int) -> int:
335
return d("hello") + d(x)
344
scripted = torch.jit.script(a)
345
except RuntimeError as e:
346
checker = FileCheck()
347
checker.check("Expected a value of type 'int'")
348
checker.check("def c(x)")
349
checker.check("def b(x)")
350
checker.check("def a(x)")
353
def test_error_stack_module(self):
354
def d(x: int) -> int:
358
return d("hello") + d(x)
363
class Submodule(torch.nn.Module):
364
def forward(self, x):
367
class M(torch.nn.Module):
368
def __init__(self) -> None:
370
self.submodule = Submodule()
372
def some_method(self, y):
373
return y + self.submodule(y)
375
def forward(self, x):
376
return self.some_method(x)
379
scripted = torch.jit.script(M())
380
except RuntimeError as e:
381
checker = FileCheck()
382
checker.check("Expected a value of type 'int'")
383
checker.check("'c' is being compiled since it was called from 'b'")
384
checker.check("'b' is being compiled since it was called from")
387
@_tmp_donotuse_dont_inline_everything
388
def test_script_basic(self):
389
def a_python_fn(a, b, c):
393
def a_script_fn(d, e, f):
394
return a_python_fn(d, e, f)
396
graph = str(a_script_fn.graph)
397
FileCheck().check("prim::CallFunction").run(graph)
398
FileCheck().check_not("^a_python_fn").run(graph)
400
self.assertEqual(a_script_fn(t, t, t), t + t + t)
402
def test_error_stack_class(self):
412
except Exception as e:
413
checker = FileCheck()
414
checker.check("import statements")
415
checker.check("is being compiled since it was called from")
418
def test_error_stack_annotation(self):
428
except Exception as e:
429
checker = FileCheck()
430
checker.check("import statements")
431
checker.check("is being compiled since it was called from")
432
checker.check("-> X")
435
def test_module_basic(self):
436
class Other(torch.nn.Module):
437
__constants__ = ["x"]
439
def __init__(self, x):
442
self.param = torch.nn.Parameter(torch.ones(2, 2))
444
def some_unscriptable_method(self):
449
def forward(self, t):
450
return t + self.x + self.param
452
class M(torch.nn.Module):
453
def __init__(self) -> None:
455
self.other = Other(200)
457
def forward(self, t):
458
return self.other(t) * 2
460
self.checkModule(M(), (torch.ones(2, 2),))
462
def test_module_function_export(self):
463
class Other(torch.nn.Module):
464
__constants__ = ["x"]
466
def __init__(self, x):
469
self.param = torch.nn.Parameter(torch.ones(2, 2))
472
def some_entry_point(self, y):
475
def forward(self, t):
476
return t + self.x + self.param
478
class M(torch.nn.Module):
479
def __init__(self) -> None:
481
self.other = Other(200)
483
def forward(self, t):
484
return self.other(t) * 2
486
self.checkModule(M(), (torch.ones(2, 2),))
488
def test_iterable_modules(self):
489
class Inner(torch.nn.Module):
490
def forward(self, x):
493
class M(torch.nn.Module):
494
def __init__(self) -> None:
496
self.sequential = nn.Sequential(
497
Inner(), Inner(), nn.Sequential(Inner(), Inner())
499
self.module_list = nn.ModuleList([Inner(), Inner()])
501
def forward(self, x):
502
for mod in self.module_list:
504
x += self.sequential(x)
507
self.checkModule(M(), (torch.randn(5, 5),))
509
def test_prepare_scriptable_basic(self):
510
class SeluButReluWhenScripted(torch.nn.SELU):
511
def __prepare_scriptable__(self):
514
t = torch.randn(5, 5)
515
m = SeluButReluWhenScripted()
516
sm = torch.jit.script(m)
519
self.assertNotEqual(eager_out, script_out)
521
def test_prepare_scriptable_iterable_modules(self):
522
class SeluButReluWhenScripted(torch.nn.SELU):
523
def __prepare_scriptable__(self):
526
class M(torch.nn.Module):
527
def __init__(self) -> None:
529
shared = SeluButReluWhenScripted()
530
self.sequential = nn.Sequential(
531
SeluButReluWhenScripted(),
532
SeluButReluWhenScripted(),
534
SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()
538
self.module_list = nn.ModuleList(
539
[SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()]
542
def forward(self, x):
543
for mod in self.module_list:
545
x += self.sequential(x)
548
t = torch.randn(5, 5)
550
eager_out = m(t.clone())
551
sm = torch.jit.script(m)
552
script_out = sm(t.clone())
553
self.assertNotEqual(eager_out, script_out)
555
def test_prepare_scriptable_cycle(self):
556
t = torch.randn(5, 5)
557
c = torch.nn.Module()
558
p = torch.nn.Module()
562
sm = torch.jit.script(p)
564
def test_prepare_scriptable_escape_hatch(self):
565
class NonJitableClass:
566
def __call__(self, int1, int2, *args):
572
obj = NonJitableClass()
574
self.assertEqual(obj(1, 2), 3)
575
self.assertEqual(obj(1, 2, 3, 4), 10)
576
with self.assertRaisesRegex(
577
torch.jit.frontend.NotSupportedError,
578
expected_regex="can't take variable number of arguments",
580
torch.jit.script(obj)
582
def escape_hatch(int1: int, int2: int) -> int:
585
class NonJitableClassWithEscapeHatch(NonJitableClass):
586
def __prepare_scriptable__(self):
589
jit_obj = torch.jit.script(NonJitableClassWithEscapeHatch())
591
self.assertEqual(jit_obj(1, 2), 3)
592
with self.assertRaisesRegex(
594
expected_regex=re.escape(
595
"expected at most 2 argument(s) but received 4 argument(s)"
600
def test_attributes(self):
603
def __init__(self) -> None:
608
def __init__(self) -> None:
610
self.inner = Inner2()
614
def __init__(self) -> None:
616
self.inner = Inner2()
618
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
623
def __getstate__(self):
624
return (self.a, self.inner)
627
("my_dict", {"I": "am", "a test": "test"}),
631
("my_tuple", (1, 2, 3, 4)),
632
("my_list", [(1, 2), (3, 4)]),
634
("my_int_list", [1, 2, 3, 4]),
636
("my_bool_list", [True, True, False, True]),
637
("my_float_list", [1.0, 2.0, 3.0, 4.0]),
638
("my_str_list", ["hello", "bye"]),
641
("my_empty_list", []),
642
("my_empty_dict", {}),
644
("my_object", Foo()),
645
("my_object2", SFoo()),
648
class M(torch.nn.Module):
655
def forward(self, x):
671
self.my_object.inner.b,
673
self.my_object2.inner.b,
685
M.__annotations__ = {
686
"my_empty_list": List[int],
687
"my_empty_dict": Dict[str, int],
688
"my_none": Optional[int],
694
for name, value in untyped_values + typed_values:
695
setattr(m, name, value)
697
self.checkModule(m, (torch.randn(5, 5),))
699
def test_function_attribute_in_submodule(self):
701
def __init__(self, norm):
703
self.activation = torch.nn.functional.relu
706
def forward(self, src):
708
output = self.norm(output)
712
def __init__(self) -> None:
714
encoder_norm = nn.ReLU()
715
self.encoder = N(encoder_norm)
717
def forward(self, x):
718
return self.encoder(x)
721
self.checkModule(m, (torch.randn(5, 5),))
723
def test_inner_traced_module(self):
724
class Dummy(nn.Module):
725
def forward(self, x):
728
class Model(nn.Module):
729
def __init__(self, dummies):
731
self._dummies = dummies
733
def forward(self, x):
735
for dummy in self._dummies:
739
dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
740
dummies = nn.ModuleList([dummy])
741
model = Model(dummies)
742
self.checkModule(model, (torch.rand(5, 5),))
744
def test_script_loaded_module(self):
746
Test that we can hold a loaded ScriptModule as a submodule.
749
class Dummy(nn.Module):
750
def forward(self, x):
753
dummy = torch.jit.script(Dummy())
754
dummy = self.getExportImportCopy(dummy)
756
class ContainsLoaded(torch.nn.Module):
757
def __init__(self) -> None:
761
def forward(self, input):
762
return self.encoder(input)
764
self.checkModule(ContainsLoaded(), (torch.rand(2, 3),))
766
def test_optional_module(self):
767
class Dummy(nn.Module):
768
def __init__(self) -> None:
770
self.foo = nn.Linear(2, 2)
772
def forward(self, x):
773
if self.foo is not None:
778
self.checkModule(mod, (torch.rand(2, 2),))
780
self.checkModule(mod, (torch.rand(2, 2),))
782
def test_override_instance_method_ignore(self):
783
class M(torch.nn.Module):
785
def i_am_ignored(self):
792
def i_am_ignored(self):
795
m.i_am_ignored = types.MethodType(i_am_ignored, m)
796
self.assertEqual(m.i_am_ignored(), "new")
799
s = torch.jit.script(m)
800
self.assertEqual(s.i_am_ignored(), "new")