1
# Owner(s): ["oncall: jit"]
5
from collections import OrderedDict
6
from typing import Any, List, Tuple
10
from torch.testing._internal.jit_utils import JitTestCase
13
# Make the helper files in test/ importable
14
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15
sys.path.append(pytorch_test_dir)
17
if __name__ == "__main__":
19
"This test file is not meant to be run directly, use:\n\n"
20
"\tpython test/test_jit.py TESTNAME\n\n"
25
class TestModuleContainers(JitTestCase):
26
def test_sequential_intermediary_types(self):
27
class A(torch.nn.Module):
31
class B(torch.nn.Module):
35
class C(torch.nn.Module):
36
def __init__(self) -> None:
38
self.foo = torch.nn.Sequential(A(), B())
43
self.checkModule(C(), (torch.tensor(1),))
45
def test_moduledict(self):
46
class Inner(torch.nn.Module):
50
class Inner2(torch.nn.Module):
54
class Inner3(torch.nn.Module):
58
class M(torch.nn.Module):
59
def __init__(self) -> None:
61
modules = OrderedDict(
68
self.moduledict = nn.ModuleDict(modules)
70
def forward(self, x, skip_name):
72
names = torch.jit.annotate(List[str], [])
74
for name in self.moduledict:
77
for name, mod in self.moduledict.items():
83
for mod in self.moduledict.values():
87
for key in self.moduledict.keys():
93
def forward(self, x, skip_name):
95
names = torch.jit.annotate(List[str], [])
99
for name in self.moduledict:
102
for i, (name, mod) in enumerate(self.moduledict.items()):
104
if name != skip_name:
109
for i, mod in enumerate(self.moduledict.values()):
114
for i, key in enumerate(self.moduledict.keys()):
118
for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
122
return x, x2, names, iter
124
for name in ["", "one", "two", "three"]:
125
inp = torch.tensor(1)
126
self.checkModule(M(), (inp, name))
127
self.checkModule(M2(), (inp, name))
129
def test_custom_container_forward(self):
130
class Inner(torch.nn.Module):
131
def forward(self, x):
134
class CustomSequential(nn.Sequential):
135
def __init__(self) -> None:
136
super().__init__(nn.ReLU(), Inner())
138
def forward(self, x):
144
self.checkModule(CustomSequential(), (torch.tensor(0.5),))
146
class CustomModuleList(nn.ModuleList):
147
def __init__(self) -> None:
148
super().__init__([nn.ReLU(), Inner()])
150
def forward(self, x):
156
self.checkModule(CustomModuleList(), (torch.tensor(0.5),))
158
class CustomModuleDict(nn.ModuleDict):
159
def __init__(self) -> None:
170
def forward(self, x):
172
names = torch.jit.annotate(List[str], [])
173
for name, mod in self.items():
178
self.checkModule(CustomModuleDict(), (torch.tensor(0.5),))
180
def test_script_module_list_sequential(self):
181
class M(torch.jit.ScriptModule):
182
def __init__(self, mod_list):
186
@torch.jit.script_method
187
def forward(self, v):
192
with torch.jit.optimized_execution(False):
193
m = M(nn.Sequential(nn.ReLU()))
194
self.assertExportImportModule(m, (torch.randn(2, 2),))
196
def test_script_modulelist_index(self):
197
class Sub(torch.nn.Module):
198
def __init__(self, i):
202
def forward(self, thing):
203
return thing - self.i
205
class M(torch.nn.Module):
206
def __init__(self) -> None:
208
self.mods = nn.ModuleList([Sub(i) for i in range(10)])
210
def forward(self, v):
211
v = self.mods[4].forward(v)
212
v = self.mods[-1].forward(v)
213
v = self.mods[-9].forward(v)
217
self.checkModule(M(), (x,))
219
class MForward(torch.nn.Module):
220
def __init__(self) -> None:
222
self.mods = nn.ModuleList([Sub(i) for i in range(10)])
224
def forward(self, v):
230
self.checkModule(MForward(), (torch.tensor(1),))
233
def forward(self, v):
234
return self.mods[-11].forward(v)
236
with self.assertRaisesRegexWithHighlight(
237
Exception, "Index -11 out of range", "self.mods[-11]"
239
torch.jit.script(M2())
242
def forward(self, v):
244
return self.mods[i].forward(v)
246
with self.assertRaisesRegexWithHighlight(
247
Exception, "Enumeration is supported", "self.mods[i]"
249
torch.jit.script(M3())
252
def forward(self, v):
254
return self.mods[i].forward(v)
256
with self.assertRaisesRegex(Exception, "will fail because i is not a literal"):
257
torch.jit.script(M4())
259
def test_module_interface_special_methods(self):
260
class CustomModuleInterface(torch.nn.Module):
263
class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
264
def __init__(self, modules=None):
265
CustomModuleInterface.__init__(self)
266
torch.nn.ModuleList.__init__(self, modules)
268
class CustomSequential(CustomModuleInterface, torch.nn.Sequential):
269
def __init__(self, modules=None):
270
CustomModuleInterface.__init__(self)
271
torch.nn.Sequential.__init__(self, modules)
273
class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
274
def __init__(self, modules=None):
275
CustomModuleInterface.__init__(self)
276
torch.nn.ModuleDict.__init__(self, modules)
278
class MyModule(torch.nn.Module):
279
def __init__(self) -> None:
281
# work around aliasing issue for 'is' operator by scripting ReLU up front
282
self.submod = torch.jit.script(torch.nn.ReLU())
283
self.modulelist = CustomModuleList([self.submod])
284
self.sequential = CustomSequential(self.submod)
285
self.moduledict = CustomModuleDict({"submod": self.submod})
287
def forward(self, inputs):
289
self.modulelist[0] is self.submod
290
), "__getitem__ failing for ModuleList"
291
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
292
for module in self.modulelist:
293
assert module is self.submod, "__iter__ failing for ModuleList"
296
self.sequential[0] is self.submod
297
), "__getitem__ failing for Sequential"
298
assert len(self.sequential) == 1, "__len__ failing for Sequential"
299
for module in self.sequential:
300
assert module is self.submod, "__iter__ failing for Sequential"
303
self.moduledict["submod"] is self.submod
304
), "__getitem__ failing for ModuleDict"
305
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
307
# note: unable to index moduledict with a string variable currently
309
for key in self.moduledict:
311
assert i == len(self.moduledict), "iteration failing for ModuleDict"
313
assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
315
for key in self.moduledict.keys():
316
assert key == "submod", "keys() fails for ModuleDict"
318
for item in self.moduledict.items():
319
assert item[0] == "submod", "items() fails for ModuleDict"
320
assert item[1] is self.submod, "items() fails for ModuleDict"
322
for value in self.moduledict.values():
323
assert value is self.submod, "values() fails for ModuleDict"
328
self.checkModule(m, [torch.randn(2, 2)])
330
def test_special_method_with_override(self):
331
class CustomModuleInterface(torch.nn.Module):
334
class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
335
def __init__(self, modules=None):
336
CustomModuleInterface.__init__(self)
337
torch.nn.ModuleList.__init__(self, modules)
340
# this is arbitrary, just to check that the overridden py __len__ from
341
# CustomModuleList takes precedence over the automatically generated
342
# __len__ added by the jit compiler
345
class MyModule(torch.nn.Module):
346
def __init__(self) -> None:
348
# work around aliasing issue for 'is' operator by scripting ReLU up front
349
self.submod = torch.jit.script(torch.nn.ReLU())
350
self.modulelist = CustomModuleList([self.submod])
352
def forward(self, inputs):
353
assert len(self.modulelist) == 2, "__len__ failing for ModuleList"
357
self.checkModule(m, [torch.randn(2, 2)])
358
mm = torch.jit.script(m)
360
def test_moduledict_getitem(self):
361
class MyModule(torch.nn.Module):
362
def __init__(self) -> None:
364
self.relu = torch.jit.script(torch.nn.ReLU())
365
self.tanh = torch.jit.script(torch.nn.Tanh())
366
self.moduledict = torch.nn.ModuleDict(
367
{"relu": self.relu, "tanh": self.tanh}
370
def forward(self, input):
371
assert self.moduledict["relu"] is self.relu
372
assert self.moduledict["tanh"] is self.tanh
376
self.checkModule(m, [torch.randn(2, 2)])
378
def test_moduledict_keyerror(self):
379
class BadModule(torch.nn.Module):
380
def __init__(self) -> None:
382
self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
384
def forward(self, input):
385
assert self.moduledict["blah"] == "blah", "this is a keyerror"
387
with self.assertRaisesRegexWithHighlight(
388
RuntimeError, "Key Error, blah", 'self.moduledict["blah"'
393
class AnotherBadModule(torch.nn.Module):
394
def __init__(self) -> None:
396
self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
398
def forward(self, input):
400
assert self.moduledict[idx] == "blah", "this is a string literal error"
402
with self.assertRaisesRegexWithHighlight(
404
"Unable to extract string literal index. "
405
"ModuleDict indexing is only supported with string literals. "
406
"For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
407
"because i is not a literal.",
408
"self.moduledict[idx]",
410
b = AnotherBadModule()
413
def test_normal_list_attribute_with_modules_error(self):
415
Test that an attempt to script a module with a regular list attribute
416
containing other modules fails with a relevant error message.
419
class Mod(torch.nn.Module):
420
def __init__(self) -> None:
422
self.a = [torch.nn.ReLU(), torch.nn.ReLU()]
427
error_msg = "Could not infer type of list element: Cannot infer concrete type of torch.nn.Module"
428
with self.assertRaisesRegexWithHighlight(RuntimeError, error_msg, "self.a"):
429
torch.jit.script(Mod())
431
def test_empty_dict_override_contains(self):
432
class CustomModuleInterface(torch.nn.Module):
435
class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
436
def __init__(self, modules=None):
437
CustomModuleInterface.__init__(self)
438
torch.nn.ModuleDict.__init__(self, modules)
440
class MyModule(torch.nn.Module):
441
def __init__(self) -> None:
443
# work around aliasing issue for 'is' operator by scripting ReLU up front
444
self.submod = torch.jit.script(torch.nn.ReLU())
445
self.moduledict = CustomModuleDict()
447
def forward(self, inputs):
449
"submod" not in self.moduledict
450
), "__contains__ fails for ModuleDict"
454
self.checkModule(m, [torch.randn(2, 2)])
456
def test_typed_module_dict(self):
458
Test that a type annotation can be provided for a ModuleDict that allows
463
class ModuleInterface(torch.nn.Module):
464
def forward(self, inp: Any) -> Any:
467
class ImplementsInterface(torch.nn.Module):
468
def forward(self, inp: Any) -> Any:
469
if isinstance(inp, torch.Tensor):
470
return torch.max(inp, dim=0)
474
class DoesNotImplementInterface(torch.nn.Module):
475
def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
476
return torch.max(inp, dim=0)
478
# Test annotation of submodule.
479
class Mod(torch.nn.Module):
480
def __init__(self) -> None:
482
self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})
484
def forward(self, x: torch.Tensor, key: str) -> Any:
485
value: ModuleInterface = self.d[key]
486
return value.forward(x)
489
self.checkModule(m, (torch.randn(2, 2), "module"))
491
# Test annotation of self.
492
class ModDict(torch.nn.ModuleDict):
493
def __init__(self) -> None:
494
super().__init__({"module": ImplementsInterface()})
496
def forward(self, x: torch.Tensor, key: str) -> Any:
497
submodule: ModuleInterface = self[key]
498
return submodule.forward(x)
501
self.checkModule(m, (torch.randn(2, 2), "module"))
503
# Test error message thrown when annotated attribute does not comply with the
505
class ModWithWrongAnnotation(torch.nn.ModuleDict):
506
def __init__(self) -> None:
508
self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()})
510
def forward(self, x: torch.Tensor, key: str) -> Any:
511
submodule: ModuleInterface = self.d[key]
512
return submodule.forward(x)
514
with self.assertRaisesRegexWithHighlight(
515
RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"
517
torch.jit.script(ModWithWrongAnnotation())
519
def test_typed_module_list(self):
521
Test that a type annotation can be provided for a ModuleList that allows
526
class ModuleInterface(torch.nn.Module):
527
def forward(self, inp: Any) -> Any:
530
class ImplementsInterface(torch.nn.Module):
531
def forward(self, inp: Any) -> Any:
532
if isinstance(inp, torch.Tensor):
533
return torch.max(inp, dim=0)
537
class DoesNotImplementInterface(torch.nn.Module):
538
def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
539
return torch.max(inp, dim=0)
541
# Test annotation of submodule.
542
class Mod(torch.nn.Module):
543
def __init__(self) -> None:
545
self.l = torch.nn.ModuleList([ImplementsInterface()])
547
def forward(self, x: torch.Tensor, idx: int) -> Any:
548
value: ModuleInterface = self.l[idx]
549
return value.forward(x)
552
self.checkModule(m, (torch.randn(2, 2), 0))
554
# Test annotation of self.
555
class ModList(torch.nn.ModuleList):
556
def __init__(self) -> None:
557
super().__init__([ImplementsInterface()])
559
def forward(self, x: torch.Tensor, idx: int) -> Any:
560
submodule: ModuleInterface = self[idx]
561
return submodule.forward(x)
564
self.checkModule(m, (torch.randn(2, 2), 0))
566
# Test error message thrown when annotated attribute does not comply with the
568
class ModWithWrongAnnotation(torch.nn.ModuleList):
569
def __init__(self) -> None:
571
self.l = torch.nn.ModuleList([DoesNotImplementInterface()])
573
def forward(self, x: torch.Tensor, idx: int) -> Any:
574
submodule: ModuleInterface = self.l[idx]
575
return submodule.forward(x)
577
with self.assertRaisesRegexWithHighlight(
578
RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"
580
torch.jit.script(ModWithWrongAnnotation())
582
def test_module_properties(self):
583
class ModuleWithProperties(torch.nn.Module):
584
__jit_unused_properties__ = ["ignored_attr"]
586
def __init__(self, a: int):
590
def forward(self, a: int, b: int):
599
def ignored_attr(self):
604
def ignored_attr_2(self):
607
@ignored_attr_2.setter
608
def ignored_attr_2(self, value):
609
self.a = sum([self.a])
612
def attr(self, a: int):
618
class ModuleWithNoSetter(torch.nn.Module):
619
def __init__(self, a: int):
623
def forward(self, a: int, b: int):
631
ModuleWithProperties(5),
638
ModuleWithProperties(5),
645
ModuleWithNoSetter(5),
652
ModuleWithNoSetter(5),
659
mod = ModuleWithProperties(3)
660
scripted_mod = torch.jit.script(mod)
662
with self.assertRaisesRegex(AttributeError, "has no attribute"):
663
scripted_mod.ignored_attr
665
def test_module_inplace_construct(self):
667
def __init__(self, start: int):
669
self.linear = nn.Linear(3, 3)
670
self.attribute = start
671
self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float))
673
def method(self) -> int:
674
return self.attribute
677
def unused_method(self):
678
return self.attribute + self.attribute
680
def forward(self, x):
681
return self.linear(self.linear(x))
684
def __init__(self) -> None:
686
self.linear = nn.Linear(4, 4)
689
def ignored_method(self, x):
692
def forward(self, x):
693
return self.linear(x)
695
m = torch.jit.script(M(3))
696
n = torch.jit.script(N())
700
inp = torch.rand((3))
702
# Check that both modules produce the same output.
703
with torch.no_grad():
706
self.assertEqual(m_out, n_out)
708
# Check that ignored method is still intact.
709
self.assertEqual(inp, n.ignored_method(inp))
711
def test_parameterlist_script_getitem(self):
712
class MyModule(nn.Module):
713
def __init__(self) -> None:
715
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
716
self.parameter_list = nn.ParameterList(
717
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
720
def forward(self, x):
722
self.parameter_list[0]
725
self.checkModule(MyModule(), (torch.zeros(1)))
727
def test_parameterlist_script_iter(self):
728
class MyModule(nn.Module):
729
def __init__(self) -> None:
731
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
732
self.parameter_list = nn.ParameterList(
733
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
736
def forward(self, x):
738
for i, p in enumerate(self.parameter_list):
742
self.checkModule(MyModule(), (torch.zeros(1),))
744
def test_parameterdict_script_getitem(self):
745
class MyModule(nn.Module):
746
def __init__(self) -> None:
748
self.parameter_dict = nn.ParameterDict(
749
{k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]}
752
def forward(self, x):
754
self.parameter_dict["a"] * x
755
+ self.parameter_dict["b"] * self.parameter_dict["c"]
758
self.checkModule(MyModule(), (torch.ones(1),))