7
from textwrap import dedent
8
from typing import Dict, List, Optional, Tuple, Union
11
from torch.testing import FileCheck
15
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16
sys.path.append(pytorch_test_dir)
17
from torch.testing._internal.jit_utils import JitTestCase, make_global
20
if __name__ == "__main__":
22
"This test file is not meant to be run directly, use:\n\n"
23
"\tpython test/test_jit.py TESTNAME\n\n"
28
class TestUnion(JitTestCase):
30
This class tests the functionality of `Union`.
32
Note: It's important to be able to refine the type of a `Union` to
33
one of its internal types. Currently, there are differences in the
34
way Python expects `isinstance` checks and the way TorchScript
35
expects `isinstance` checks. This means that we can't use
36
`checkScript` in our test cases because either the eager mode or the
37
script mode wouldn't run! So, some test cases have separate but
38
equivalent functions to emulate `checkScript`.
41
def test_check_union_annotation(self):
42
def test_func(a: Union[int, float], b: Optional[int]):
45
scripted_func = torch.jit.script(test_func)
46
graph_rep = str(scripted_func.graph)
47
code_rep = str(scripted_func.code)
49
FileCheck().check("Union(").check("int?").run(graph_rep)
51
FileCheck().check("Union[").check("Optional[int]").run(code_rep)
52
self.checkScript(test_func, (5, 6))
54
torch._C.parse_ir(str(scripted_func.graph))
56
def test_union_with_scalar_values(self):
57
def fn(x: Union[int, float]) -> str:
60
self.checkScript(fn, (1,))
61
self.checkScript(fn, (1.0,))
63
scripted = torch.jit.script(fn)
65
with self.assertRaisesRegex(
67
"Expected a member of"
68
r" Union\[float, int\] but "
69
"instead found type str",
73
def test_union_with_collections(self):
74
def fn(x: Union[Dict[str, int], List[int]]) -> str:
77
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
78
self.checkScript(fn, ([1, 2, 3],))
80
scripted = torch.jit.script(fn)
82
with self.assertRaisesRegex(
84
"Expected a member of"
85
r" Union\[List\[int\], Dict\[str, "
86
r"int\]\] but instead found type "
89
scripted({"foo": "bar", "baz": "qux"})
91
with self.assertRaisesRegex(
93
"Expected a member of"
94
r" Union\[List\[int\], Dict\[str, "
95
r"int\]\] but instead found type "
98
scripted(["foo", "bar", "baz"])
100
with self.assertRaisesRegex(
102
"Expected a member of"
103
r" Union\[List\[int\], Dict\[str, "
104
r"int\]\] but instead found type "
109
def test_union_with_enum(self):
116
def fn(x: Union[str, Color]) -> str:
119
self.checkScript(fn, (Color.RED,))
120
self.checkScript(fn, ("red",))
122
scripted = torch.jit.script(fn)
124
with self.assertRaisesRegex(
126
"Expected a member of"
127
r" Union\[__torch__.jit.test_union."
128
r"Color, str\] but instead found "
133
def test_union_in_class_constructor(self):
136
def __init__(self, x: Union[int, str]) -> None:
139
def fn(x: Union[str, int]) -> A:
142
self.assertEqual(fn("foo").x, "foo")
143
self.assertEqual(fn(1).x, 1)
145
scripted = torch.jit.script(fn)
147
with self.assertRaisesRegex(
149
"Expected a member of"
150
r" Union\[int, str\] but instead "
151
r"found type List\[str\]",
153
scripted(["foo", "bar", "baz"])
155
def test_union_return_type(self):
156
def fn(x: int) -> Union[int, str]:
159
self.checkScript(fn, (1,))
161
def test_union_as_annotation(self):
162
def fn() -> Union[int, str]:
163
x: Union[int, str] = "foo"
166
self.checkScript(fn, ())
168
def test_union_as_annotation_in_typed_container(self):
170
l: List[Union[int, str]] = []
171
u1: Union[int, str] = "foo"
172
u2: Union[int, str] = 1
176
self.checkScript(fn, ())
178
def test_union_as_annotation_py2(self):
181
x: Union[int, str] = "foo"
184
self.checkScript(fn, ())
186
def test_union_as_internal_tuple_type(self):
188
t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
191
self.checkScript(fn, ())
193
def test_union_variable_can_be_reassigned(self):
202
def fn() -> Union[int, str]:
203
x: Union[int, str] = "foo"
207
z: str = aux2(str(y))
211
self.checkScript(fn, ())
213
def test_union_does_not_replace_existing_annotated_type(self):
215
x: List[int] = [1, 2, 3]
219
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
220
scripted = torch.jit.script(fn)
223
def test_union_does_not_replace_existing_annotated_type_union(self):
225
x: List[Union[int, str]] = [1, "foo", 3]
229
with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
230
scripted = torch.jit.script(fn)
233
def test_union_does_not_replace_existing_annotated_type_empty_container(self):
239
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
240
scripted = torch.jit.script(fn)
243
def test_unions_of_unions_are_flattened(self):
245
def fn(x: Union[Union[int, str], float]) -> str:
250
FileCheck().check("x : Union(float, int, str)").run(s)
252
def test_unions_of_a_single_argument_vanish(self):
254
def fn(x: Union[int]) -> str:
259
FileCheck().check("x : int").run(s)
261
def test_union_redundant_arguments_are_skipped(self):
263
def fn(x: Union[int, str, int]) -> str:
268
FileCheck().check("x : Union(int, str)").run(s)
270
def test_union_redundant_arguments_are_skipped_optional(self):
272
def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
277
FileCheck().check("x : Union(float, int, NoneType)").run(s)
279
def test_union_redundant_arguments_are_skipped_subtyping(self):
281
def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
286
FileCheck().check("x : Union((int?, int), str)").run(s)
288
def test_union_redundant_arguments_are_skipped_container(self):
290
def fn(x: Union[List[str], List[float], List[str]]) -> str:
295
FileCheck().check("x : Union(float[], str[])").run(s)
297
def test_union_argument_order_is_ignored(self):
299
def fn1(x: Union[int, str]) -> str:
303
def fn2(x: Union[str, int]) -> str:
306
for s in (fn1.graph, fn2.graph):
307
FileCheck().check("x : Union(int, str)").run(s)
309
def test_union_argument_order_is_ignored_container(self):
311
def fn1(x: Union[List[str], List[int]]) -> str:
315
def fn2(x: Union[List[int], List[str]]) -> str:
318
for s in (fn1.graph, fn2.graph):
319
FileCheck().check("x : Union(int[], str[])").run(s)
321
def test_union_T_None_is_equivalent_to_optional_T(self):
323
def inner(x: Union[int, None]) -> int:
332
b: Optional[int] = None
337
self.assertEqual(fn1(), 10)
340
def inner2(x: Optional[int]) -> int:
348
a: Union[int, None] = 5
349
b: Union[int, None] = None
354
self.assertEqual(fn2(), 10)
356
def test_union_optional_of_union_is_flattened(self):
358
def fn(flag: int) -> Union[str, int, None]:
359
y: Union[int, str, None] = "foo"
361
x: Optional[Union[int, str]] = y
363
x: Optional[Union[int, str]] = 1
365
x: Optional[Union[int, str]] = None
372
self.assertEqual(fn(0), "foo")
373
self.assertEqual(fn(1), 1)
374
self.assertEqual(fn(2), None)
376
buffer = io.BytesIO()
377
torch.jit.save(fn, buffer)
378
buffer = io.BytesIO(buffer.getvalue())
379
l = torch.jit.load(buffer)
383
FileCheck().check("Union[int, NoneType, str]").check(
384
"Union[int, NoneType, str]"
387
def test_union_subclasses_larger_union(self):
388
def fn() -> Union[int, str, torch.Tensor]:
389
x: Union[int, str] = "foo"
392
self.checkScript(fn, ())
396
def test_union_as_dict_key(self):
398
x: Dict[Union[int, str], str] = {}
403
with self.assertRaisesRegex(
406
"complex, Tensor, device and string keys "
411
def test_union_as_dict_value(self):
413
x: Dict[str, Union[int, str]] = {}
418
self.checkScript(fn, ())
420
def test_union_module_with_union_instance_variable(self):
421
class M(torch.nn.Module):
424
def __init__(self, x: Union[int, str]):
426
self.x: Union[int, str] = x
428
def forward(self, y: Union[int, str]):
438
self.checkModule(M("bar"), ("foo",))
440
def test_union_module_with_union_class_variable(self):
441
class M(torch.nn.Module):
442
x: Union[int, str] = "foo"
444
def __init__(self, y: int):
448
def forward(self, z: str):
452
self.checkModule(M(1), ("foo",))
454
def test_union_type_refinement(self):
455
def fn(x: Union[int, str]) -> str:
456
if isinstance(x, str):
462
self.checkScript(fn, ("foo",))
463
self.checkScript(fn, (1,))
465
def test_union_type_refinement_union_rhs(self):
466
def fn(x: int) -> str:
467
if torch.jit.isinstance(x, Union[int, str]):
472
self.checkScript(fn, (1,))
474
def test_union_type_refinement_tuple_rhs(self):
475
def fn(x: Union[int, float, List[str]]) -> str:
476
if isinstance(x, (int, float)):
477
if isinstance(x, int):
487
self.checkScript(fn, (1,))
488
self.checkScript(fn, (1.0,))
489
self.checkScript(fn, (["a", "b", "c"],))
491
def test_union_type_refinement_tuple_rhs_noncontained_type(self):
492
def fn(x: Union[int, List[str]]) -> str:
493
if isinstance(x, (int, float)):
502
self.checkScript(fn, (1,))
503
self.checkScript(fn, (["a", "b", "c"],))
505
def test_union_type_refinement_tuple_rhs_union(self):
507
def fn(x: int) -> str:
508
if torch.jit.isinstance(x, (Union[int, str], float)):
518
self.assertEqual(fn(1), "2")
520
def test_union_type_refinement_statically_false(self):
522
def fn(x: int) -> str:
523
if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
532
FileCheck().check_not("block0()").check_not("block1()").run(s)
534
def test_union_type_refinement_statically_true(self):
536
def fn(x: Union[List[int], int]) -> Union[List[int], int]:
537
if not torch.jit.isinstance(x, (int, List[int])):
541
y: Union[List[int], int] = l
547
FileCheck().check_not("block0()").check_not("block1()").run(s)
549
def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
550
def fn(x: Union[List[int], int]) -> int:
551
if torch.jit.isinstance(x, (int, float, str)):
558
self.checkScript(fn, ([1, 2, 3],))
559
self.checkScript(fn, (1,))
561
def test_union_type_refinement_partial_static_refinement_union_rhs(self):
562
def fn(x: Union[List[int], int]) -> int:
563
if torch.jit.isinstance(x, Union[int, float, str]):
570
self.checkScript(fn, ([1, 2, 3],))
571
self.checkScript(fn, (1,))
573
def test_union_type_refinement_internal_declaration(self):
574
def fn(flag: bool) -> str:
575
x: Union[int, str, None] = None
580
if isinstance(x, str):
585
self.checkScript(fn, (True,))
586
self.checkScript(fn, (False,))
588
def test_union_branching_with_union_return_and_homogenous_types(self):
589
def fn(x: int) -> Union[int, str]:
595
self.checkScript(fn, (1,))
596
self.checkScript(fn, (8,))
598
def test_union_branching_does_not_autoinfer_undeclared_union(self):
599
def fn(x: int) -> str:
604
if isinstance(y, str):
609
with self.assertRaisesRegex(
611
"y is set to type str"
612
" in the true branch and type int "
613
"in the false branch",
617
def test_union_branching_does_not_widen_existing_inferred_type(self):
618
def fn(x: int) -> str:
624
if isinstance(y, str):
629
with self.assertRaisesRegex(
631
"previously had type "
632
"str but is now being assigned to a"
633
" value of type int",
637
def test_union_schema_matching_on_internal_type(self):
638
def fn(x: Union[List[int], Dict[str, int]]) -> int:
639
if torch.jit.isinstance(x, List[int]):
642
return list(x.values())[0]
644
self.checkScript(fn, ([1, 2, 3],))
645
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
647
def test_union_subtractive_refinement(self):
648
def fn(x: Union[List[int], int]) -> int:
649
if not isinstance(x, int):
655
self.checkScript(fn, (1,))
656
self.checkScript(fn, ([1, 2, 3],))
658
def test_union_subtractive_refinement_with_container(self):
659
def fn(x: Union[List[int], int]) -> int:
660
if not torch.jit.isinstance(x, List[int]):
666
self.checkScript(fn, (1,))
667
self.checkScript(fn, ([1, 2, 3],))
669
def test_union_memory_aliasing(self):
671
x: List[torch.Tensor] = []
672
z: List[Optional[List[torch.Tensor]]] = []
675
if torch.jit.isinstance(x_alias, List[torch.Tensor]):
676
x_alias.append(torch.tensor(3))
679
self.checkScript(fn, ())
681
def test_union_serialization_preserves_type_annotations(self):
687
def fn(x: int) -> str:
689
y: Union[str, int] = "bar"
691
y: Union[str, int] = x
692
if isinstance(y, str):
697
self.checkScript(fn, (1,))
698
self.checkScript(fn, (8,))
700
def _assert_passes(self, template: str, ann: str, lhs: str):
701
code = template.format(ann=ann, lhs=lhs)
702
self.checkScript(code, (), name="fn")
704
def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
705
code = template.format(ann=ann, lhs=lhs)
706
with self.assertRaisesRegex(RuntimeError, msg):
707
cu = torch.jit.CompilationUnit(code, _frames_up=1)
708
string_frontend = getattr(cu, "fn")
710
def test_union_with_list_assignment(self):
715
if torch.jit.isinstance(x, List[torch.Tensor]):
716
x.append(torch.tensor(3))
722
"list_literal_empty": "[]",
723
"list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
724
"list_literal_of_str": '["foo", "bar", "baz"]',
725
"list_literal_of_mixed": "[torch.arange(5), 1]",
726
"list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
727
"list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
728
"list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
732
Union[List[str], List[torch.Tensor]]
736
"Union[List[str], List[torch.Tensor]]",
737
lhs["list_literal_empty"],
738
"there are multiple possible List type "
739
"candidates in the Union annotation",
744
"Union[List[str], List[torch.Tensor]]",
745
lhs["list_literal_of_tensor"],
749
template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
754
"Union[List[str], List[torch.Tensor]]",
755
lhs["list_literal_of_mixed"],
756
"none of those types match the types of the" " given list elements",
761
"Union[List[str], List[torch.Tensor]]",
762
lhs["list_comprehension_of_tensor"],
767
"Union[List[str], List[torch.Tensor]]",
768
lhs["list_comprehension_of_str"],
774
"Union[List[str], List[torch.Tensor]]",
775
lhs["list_comprehension_of_mixed"],
776
"Arguments for call are not valid",
780
Union[int, torch.Tensor]
784
"Union[int, torch.Tensor]",
785
lhs["list_literal_empty"],
786
"Expected an Union type annotation with an " "inner List type",
791
"Union[int, torch.Tensor]",
792
lhs["list_literal_of_tensor"],
793
"Expected an Union type annotation with an " "inner List type",
798
"Union[int, torch.Tensor]",
799
lhs["list_comprehension_of_tensor"],
800
"Expected an Union type annotation with an " "inner List type",
804
Union[List[torch.Tensor], int]
807
template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
811
template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
816
"Union[List[torch.Tensor], int]",
817
lhs["list_literal_of_str"],
818
r"List type annotation `List\[Tensor\]` did "
819
"not match the types of the given list "
825
"Union[List[torch.Tensor], int]",
826
lhs["list_literal_of_mixed"],
827
r"List type annotation `List\[Tensor\]` did "
828
"not match the types of the given list "
834
"Union[List[torch.Tensor], int]",
835
lhs["list_comprehension_of_tensor"],
840
"Union[List[torch.Tensor], int]",
841
lhs["list_comprehension_of_str"],
842
r"List type annotation `List\[Tensor\]` did "
843
"not match the types of the given list "
850
"Union[List[torch.Tensor], int]",
851
lhs["list_comprehension_of_mixed"],
852
"Arguments for call are not valid",
855
def test_union_with_dict_assignment(self):
860
if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
861
x["foo"] = torch.tensor(3)
867
"dict_literal_empty": "{}",
868
"dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
869
"dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
870
"dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
871
"dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
872
zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
873
"dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
874
zip(["foo", "bar"], [1, 2]}',
875
"dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
876
zip(["foo", "bar"], [torch.arange(3), 2])}',
877
"dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
878
"dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
879
"dict_keyword_with_empty_iterable": "dict([])",
880
"dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
881
"dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
882
"dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
886
Union[Dict[str, torch.Tensor], Dict[str, int]]
890
"Union[List[str], List[torch.Tensor]]",
891
lhs["dict_literal_empty"],
892
"Expected an Union type annotation with an " "inner Dict type",
897
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
898
lhs["dict_literal_of_str_tensor"],
903
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
904
lhs["dict_literal_of_str_int"],
909
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
910
lhs["dict_literal_of_mixed"],
911
"none of those dict types can hold the "
912
"types of the given keys and values",
936
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
938
"full type inference is not yet supported",
943
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
944
lhs["dict_keyword_with_iterable"],
945
"full type inference is not yet supported",
950
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
951
lhs["dict_keyword_with_empty_iterable"],
952
"full type inference is not yet supported",
957
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
958
lhs["dict_keyword_with_mapping"],
959
"full type inference is not yet supported",
964
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
965
lhs["dict_keyword_with_mapping_and_kwargs"],
966
"full type inference is not yet supported",
970
Union[int, torch.Tensor]
974
"Union[int, torch.Tensor]",
975
lhs["dict_literal_empty"],
976
"Expected an Union type annotation with " "an inner Dict type",
981
"Union[int, torch.Tensor]",
982
lhs["dict_literal_of_str_tensor"],
983
"Expected an Union type annotation with " "an inner Dict type",
992
Union[Dict[str, torch.Tensor], int]
995
template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
1000
"Union[Dict[str, torch.Tensor], int]",
1001
lhs["dict_literal_of_str_tensor"],
1004
self._assert_raises(
1006
"Union[Dict[str, torch.Tensor], int]",
1007
lhs["dict_literal_of_str_int"],
1008
"Type annotation was inferred to be "
1009
r"`Dict\[str, Tensor\]`, but the type of "
1010
"values given by the dict literal is",
1013
self._assert_raises(
1015
"Union[Dict[str, torch.Tensor], int]",
1016
lhs["dict_literal_of_mixed"],
1017
"Type annotation was inferred to be "
1018
r"`Dict\[str, Tensor\]`, but the type of "
1019
"values given by the dict literal is",
1022
self._assert_passes(
1023
template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
1026
self._assert_passes(
1028
"Union[Dict[str, torch.Tensor], int]",
1029
lhs["dict_keyword_with_iterable"],
1032
self._assert_passes(
1034
"Union[Dict[str, torch.Tensor], int]",
1035
lhs["dict_keyword_with_empty_iterable"],
1038
self._assert_passes(
1040
"Union[Dict[str, torch.Tensor], int]",
1041
lhs["dict_keyword_with_mapping"],
1044
self._assert_passes(
1046
"Union[Dict[str, torch.Tensor], int]",
1047
lhs["dict_keyword_with_mapping_and_kwargs"],