pytorch

Форк
0
/
test_union.py 
1067 строк · 33.3 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import io
4
import os
5
import sys
6
from enum import Enum
7
from textwrap import dedent
8
from typing import Dict, List, Optional, Tuple, Union
9

10
import torch
11
from torch.testing import FileCheck
12

13

14
# Make the helper files in test/ importable
15
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16
sys.path.append(pytorch_test_dir)
17
from torch.testing._internal.jit_utils import JitTestCase, make_global
18

19

20
if __name__ == "__main__":
21
    raise RuntimeError(
22
        "This test file is not meant to be run directly, use:\n\n"
23
        "\tpython test/test_jit.py TESTNAME\n\n"
24
        "instead."
25
    )
26

27

28
class TestUnion(JitTestCase):
29
    """
30
    This class tests the functionality of `Union`.
31

32
    Note: It's important to be able to refine the type of a `Union` to
33
    one of its internal types. Currently, there are differences in the
34
    way Python expects `isinstance` checks and the way TorchScript
35
    expects `isinstance` checks. This means that we can't use
36
    `checkScript` in our test cases because either the eager mode or the
37
    script mode wouldn't run! So, some test cases have separate but
38
    equivalent functions to emulate `checkScript`.
39
    """
40

41
    def test_check_union_annotation(self):
42
        def test_func(a: Union[int, float], b: Optional[int]):
43
            return 0
44

45
        scripted_func = torch.jit.script(test_func)
46
        graph_rep = str(scripted_func.graph)
47
        code_rep = str(scripted_func.code)
48
        # TS graph IR for Union should be annotated as Union()
49
        FileCheck().check("Union(").check("int?").run(graph_rep)
50
        # Serialized code for Union should be annotated as Union[]
51
        FileCheck().check("Union[").check("Optional[int]").run(code_rep)
52
        self.checkScript(test_func, (5, 6))
53
        # this shouldn't error out
54
        torch._C.parse_ir(str(scripted_func.graph))
55

56
    def test_union_with_scalar_values(self):
57
        def fn(x: Union[int, float]) -> str:
58
            return "foo"
59

60
        self.checkScript(fn, (1,))
61
        self.checkScript(fn, (1.0,))
62

63
        scripted = torch.jit.script(fn)
64

65
        with self.assertRaisesRegex(
66
            RuntimeError,
67
            "Expected a member of"
68
            r" Union\[float, int\] but "
69
            "instead found type str",
70
        ):
71
            scripted("1")
72

73
    def test_union_with_collections(self):
74
        def fn(x: Union[Dict[str, int], List[int]]) -> str:
75
            return "foo"
76

77
        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
78
        self.checkScript(fn, ([1, 2, 3],))
79

80
        scripted = torch.jit.script(fn)
81

82
        with self.assertRaisesRegex(
83
            RuntimeError,
84
            "Expected a member of"
85
            r" Union\[List\[int\], Dict\[str, "
86
            r"int\]\] but instead found type "
87
            r"Dict\[str, str\]",
88
        ):
89
            scripted({"foo": "bar", "baz": "qux"})
90

91
        with self.assertRaisesRegex(
92
            RuntimeError,
93
            "Expected a member of"
94
            r" Union\[List\[int\], Dict\[str, "
95
            r"int\]\] but instead found type "
96
            r"List\[str\]",
97
        ):
98
            scripted(["foo", "bar", "baz"])
99

100
        with self.assertRaisesRegex(
101
            RuntimeError,
102
            "Expected a member of"
103
            r" Union\[List\[int\], Dict\[str, "
104
            r"int\]\] but instead found type "
105
            "str",
106
        ):
107
            scripted("1")
108

109
    def test_union_with_enum(self):
110
        class Color(Enum):
111
            RED = 1
112
            GREEN = 2
113

114
        make_global(Color)
115

116
        def fn(x: Union[str, Color]) -> str:
117
            return "foo"
118

119
        self.checkScript(fn, (Color.RED,))
120
        self.checkScript(fn, ("red",))
121

122
        scripted = torch.jit.script(fn)
123

124
        with self.assertRaisesRegex(
125
            RuntimeError,
126
            "Expected a member of"
127
            r" Union\[__torch__.jit.test_union."
128
            r"Color, str\] but instead found "
129
            "type int",
130
        ):
131
            scripted(1)
132

133
    def test_union_in_class_constructor(self):
134
        @torch.jit.script  # noqa: B903
135
        class A:  # noqa: B903
136
            def __init__(self, x: Union[int, str]) -> None:
137
                self.x = x
138

139
        def fn(x: Union[str, int]) -> A:
140
            return A(x)
141

142
        self.assertEqual(fn("foo").x, "foo")
143
        self.assertEqual(fn(1).x, 1)
144

145
        scripted = torch.jit.script(fn)
146

147
        with self.assertRaisesRegex(
148
            RuntimeError,
149
            "Expected a member of"
150
            r" Union\[int, str\] but instead "
151
            r"found type List\[str\]",
152
        ):
153
            scripted(["foo", "bar", "baz"])
154

155
    def test_union_return_type(self):
156
        def fn(x: int) -> Union[int, str]:
157
            return "foo"
158

159
        self.checkScript(fn, (1,))
160

161
    def test_union_as_annotation(self):
162
        def fn() -> Union[int, str]:
163
            x: Union[int, str] = "foo"
164
            return x
165

166
        self.checkScript(fn, ())
167

168
    def test_union_as_annotation_in_typed_container(self):
169
        def fn() -> None:
170
            l: List[Union[int, str]] = []
171
            u1: Union[int, str] = "foo"
172
            u2: Union[int, str] = 1
173
            l.append(u1)
174
            l.append(u2)
175

176
        self.checkScript(fn, ())
177

178
    def test_union_as_annotation_py2(self):
179
        def fn():
180
            # type: () -> Union[int, str]
181
            x: Union[int, str] = "foo"
182
            return x
183

184
        self.checkScript(fn, ())
185

186
    def test_union_as_internal_tuple_type(self):
187
        def fn():
188
            t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
189
            return t
190

191
        self.checkScript(fn, ())
192

193
    def test_union_variable_can_be_reassigned(self):
194
        @torch.jit.script
195
        def aux1(i: int):
196
            return int(i**2)
197

198
        @torch.jit.script
199
        def aux2(s: str):
200
            return s + s
201

202
        def fn() -> Union[int, str]:
203
            x: Union[int, str] = "foo"
204
            i: int = 1
205
            x = i
206
            y: int = aux1(x)
207
            z: str = aux2(str(y))
208
            x = z
209
            return x
210

211
        self.checkScript(fn, ())
212

213
    def test_union_does_not_replace_existing_annotated_type(self):
214
        def fn():
215
            x: List[int] = [1, 2, 3]
216
            x.append("foo")
217
            return x
218

219
        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
220
            scripted = torch.jit.script(fn)
221
            scripted()
222

223
    def test_union_does_not_replace_existing_annotated_type_union(self):
224
        def fn():
225
            x: List[Union[int, str]] = [1, "foo", 3]
226
            x.append(2.0)
227
            return x
228

229
        with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
230
            scripted = torch.jit.script(fn)
231
            scripted()
232

233
    def test_union_does_not_replace_existing_annotated_type_empty_container(self):
234
        def fn():
235
            x: List[int] = []
236
            x.append("foo")
237
            return x
238

239
        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
240
            scripted = torch.jit.script(fn)
241
            scripted()
242

243
    def test_unions_of_unions_are_flattened(self):
244
        @torch.jit.script
245
        def fn(x: Union[Union[int, str], float]) -> str:
246
            return "foo"
247

248
        s = fn.graph
249

250
        FileCheck().check("x : Union(float, int, str)").run(s)
251

252
    def test_unions_of_a_single_argument_vanish(self):
253
        @torch.jit.script
254
        def fn(x: Union[int]) -> str:
255
            return "foo"
256

257
        s = fn.graph
258

259
        FileCheck().check("x : int").run(s)
260

261
    def test_union_redundant_arguments_are_skipped(self):
262
        @torch.jit.script
263
        def fn(x: Union[int, str, int]) -> str:
264
            return "foo"
265

266
        s = fn.graph
267

268
        FileCheck().check("x : Union(int, str)").run(s)
269

270
    def test_union_redundant_arguments_are_skipped_optional(self):
271
        @torch.jit.script
272
        def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
273
            return "foo"
274

275
        s = fn.graph
276

277
        FileCheck().check("x : Union(float, int, NoneType)").run(s)
278

279
    def test_union_redundant_arguments_are_skipped_subtyping(self):
280
        @torch.jit.script
281
        def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
282
            return "foo"
283

284
        s = fn.graph
285

286
        FileCheck().check("x : Union((int?, int), str)").run(s)
287

288
    def test_union_redundant_arguments_are_skipped_container(self):
289
        @torch.jit.script
290
        def fn(x: Union[List[str], List[float], List[str]]) -> str:
291
            return "foo"
292

293
        s = fn.graph
294

295
        FileCheck().check("x : Union(float[], str[])").run(s)
296

297
    def test_union_argument_order_is_ignored(self):
298
        @torch.jit.script
299
        def fn1(x: Union[int, str]) -> str:
300
            return "foo"
301

302
        @torch.jit.script
303
        def fn2(x: Union[str, int]) -> str:
304
            return "foo"
305

306
        for s in (fn1.graph, fn2.graph):
307
            FileCheck().check("x : Union(int, str)").run(s)
308

309
    def test_union_argument_order_is_ignored_container(self):
310
        @torch.jit.script
311
        def fn1(x: Union[List[str], List[int]]) -> str:
312
            return "foo"
313

314
        @torch.jit.script
315
        def fn2(x: Union[List[int], List[str]]) -> str:
316
            return "foo"
317

318
        for s in (fn1.graph, fn2.graph):
319
            FileCheck().check("x : Union(int[], str[])").run(s)
320

321
    def test_union_T_None_is_equivalent_to_optional_T(self):
322
        @torch.jit.script
323
        def inner(x: Union[int, None]) -> int:
324
            if x is not None:
325
                return x
326
            else:
327
                return 5
328

329
        @torch.jit.script
330
        def fn1() -> int:
331
            a: Optional[int] = 5
332
            b: Optional[int] = None
333
            a_ = inner(a)
334
            b_ = inner(b)
335
            return a_ + b_
336

337
        self.assertEqual(fn1(), 10)
338

339
        @torch.jit.script
340
        def inner2(x: Optional[int]) -> int:
341
            if x is not None:
342
                return x
343
            else:
344
                return 5
345

346
        @torch.jit.script
347
        def fn2() -> int:
348
            a: Union[int, None] = 5
349
            b: Union[int, None] = None
350
            a_ = inner(a)
351
            b_ = inner(b)
352
            return a_ + b_
353

354
        self.assertEqual(fn2(), 10)
355

356
    def test_union_optional_of_union_is_flattened(self):
357
        @torch.jit.script
358
        def fn(flag: int) -> Union[str, int, None]:
359
            y: Union[int, str, None] = "foo"
360
            if flag == 0:
361
                x: Optional[Union[int, str]] = y
362
            elif flag == 1:
363
                x: Optional[Union[int, str]] = 1
364
            else:
365
                x: Optional[Union[int, str]] = None
366
            return x
367

368
        # Can't use `checkScript` because it will flag the fact that
369
        # the original code has `Optional[Union[int, str]]` but the
370
        # saved/loaded code has `Union[int, NoneType, str]` (even
371
        # though this is exactly what we want)
372
        self.assertEqual(fn(0), "foo")
373
        self.assertEqual(fn(1), 1)
374
        self.assertEqual(fn(2), None)
375

376
        buffer = io.BytesIO()
377
        torch.jit.save(fn, buffer)
378
        buffer = io.BytesIO(buffer.getvalue())
379
        l = torch.jit.load(buffer)
380

381
        s = l.code
382

383
        FileCheck().check("Union[int, NoneType, str]").check(
384
            "Union[int, NoneType, str]"
385
        ).run(s)
386

387
    def test_union_subclasses_larger_union(self):
388
        def fn() -> Union[int, str, torch.Tensor]:
389
            x: Union[int, str] = "foo"
390
            return x
391

392
        self.checkScript(fn, ())
393

394
    # TODO: We would like to eventually support this. The issue is being
395
    # tracked at https://github.com/pytorch/pytorch/issues/58167
396
    def test_union_as_dict_key(self):
397
        def fn():
398
            x: Dict[Union[int, str], str] = {}
399
            x["foo"] = "bar"
400
            x[1] = 2
401
            return x[1]
402

403
        with self.assertRaisesRegex(
404
            RuntimeError,
405
            "only int, float, "
406
            "complex, Tensor, device and string keys "
407
            "are supported",
408
        ):
409
            torch.jit.script(fn)
410

411
    def test_union_as_dict_value(self):
412
        def fn():
413
            x: Dict[str, Union[int, str]] = {}
414
            x["foo"] = "bar"
415
            x["baz"] = 2
416
            return x["baz"]
417

418
        self.checkScript(fn, ())
419

420
    def test_union_module_with_union_instance_variable(self):
421
        class M(torch.nn.Module):
422
            x: Union[int, str]
423

424
            def __init__(self, x: Union[int, str]):
425
                super().__init__()
426
                self.x: Union[int, str] = x
427

428
            def forward(self, y: Union[int, str]):
429
                self.x = y
430
                return self.x
431

432
        self.checkModule(
433
            M(
434
                2,
435
            ),
436
            (1,),
437
        )
438
        self.checkModule(M("bar"), ("foo",))
439

440
    def test_union_module_with_union_class_variable(self):
441
        class M(torch.nn.Module):
442
            x: Union[int, str] = "foo"
443

444
            def __init__(self, y: int):
445
                super().__init__()
446
                x = y
447

448
            def forward(self, z: str):
449
                x = z
450
                return x
451

452
        self.checkModule(M(1), ("foo",))
453

454
    def test_union_type_refinement(self):
455
        def fn(x: Union[int, str]) -> str:
456
            if isinstance(x, str):
457
                z = x + "bar"
458
                return x
459
            else:
460
                return "baz"
461

462
        self.checkScript(fn, ("foo",))
463
        self.checkScript(fn, (1,))
464

465
    def test_union_type_refinement_union_rhs(self):
466
        def fn(x: int) -> str:
467
            if torch.jit.isinstance(x, Union[int, str]):
468
                return "bar"
469
            else:
470
                return "baz"
471

472
        self.checkScript(fn, (1,))
473

474
    def test_union_type_refinement_tuple_rhs(self):
475
        def fn(x: Union[int, float, List[str]]) -> str:
476
            if isinstance(x, (int, float)):
477
                if isinstance(x, int):
478
                    return str(x)
479
                else:
480
                    return "foo"
481
            else:
482
                if len(x):
483
                    return x[0]
484
                else:
485
                    return "bar"
486

487
        self.checkScript(fn, (1,))
488
        self.checkScript(fn, (1.0,))
489
        self.checkScript(fn, (["a", "b", "c"],))
490

491
    def test_union_type_refinement_tuple_rhs_noncontained_type(self):
492
        def fn(x: Union[int, List[str]]) -> str:
493
            if isinstance(x, (int, float)):
494
                y = x + x
495
                return str(y)
496
            else:
497
                if len(x):
498
                    return x[0]
499
                else:
500
                    return "bar"
501

502
        self.checkScript(fn, (1,))
503
        self.checkScript(fn, (["a", "b", "c"],))
504

505
    def test_union_type_refinement_tuple_rhs_union(self):
506
        @torch.jit.script
507
        def fn(x: int) -> str:
508
            if torch.jit.isinstance(x, (Union[int, str], float)):
509
                y = x + x
510
                return str(y)
511
            else:
512
                return "foo"
513

514
        # TODO: There's currently an unrelated bug in
515
        # `torch.jit.isinstance` that makes it fail for tuple literals.
516
        # Posted here: https://github.com/pytorch/pytorch/issues/60095
517
        # Change `assertEqual` to `checkScript` when the bug is fixed
518
        self.assertEqual(fn(1), "2")
519

520
    def test_union_type_refinement_statically_false(self):
521
        @torch.jit.script
522
        def fn(x: int) -> str:
523
            if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
524
                z = x + "foo"
525
                return z
526
            else:
527
                return "bar"
528

529
        s = fn.graph
530

531
        # Check that we don't have any branching statements
532
        FileCheck().check_not("block0()").check_not("block1()").run(s)
533

534
    def test_union_type_refinement_statically_true(self):
535
        @torch.jit.script
536
        def fn(x: Union[List[int], int]) -> Union[List[int], int]:
537
            if not torch.jit.isinstance(x, (int, List[int])):
538
                return x
539
            else:
540
                l = [1, 2, 3]
541
                y: Union[List[int], int] = l
542
                return y
543

544
        s = fn.graph
545

546
        # Check that we don't have any branching statements
547
        FileCheck().check_not("block0()").check_not("block1()").run(s)
548

549
    def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
550
        def fn(x: Union[List[int], int]) -> int:
551
            if torch.jit.isinstance(x, (int, float, str)):
552
                # We should know that `x` is an `int` here
553
                z = x + 1
554
                return z
555
            else:
556
                return 100
557

558
        self.checkScript(fn, ([1, 2, 3],))
559
        self.checkScript(fn, (1,))
560

561
    def test_union_type_refinement_partial_static_refinement_union_rhs(self):
562
        def fn(x: Union[List[int], int]) -> int:
563
            if torch.jit.isinstance(x, Union[int, float, str]):
564
                # We should know that `x` is an `int` here
565
                z = x + 1
566
                return z
567
            else:
568
                return 100
569

570
        self.checkScript(fn, ([1, 2, 3],))
571
        self.checkScript(fn, (1,))
572

573
    def test_union_type_refinement_internal_declaration(self):
574
        def fn(flag: bool) -> str:
575
            x: Union[int, str, None] = None
576
            if flag:
577
                y = "foo"
578
            else:
579
                y = 1
580
            if isinstance(x, str):
581
                return x
582
            else:
583
                return "bar"
584

585
        self.checkScript(fn, (True,))
586
        self.checkScript(fn, (False,))
587

588
    def test_union_branching_with_union_return_and_homogenous_types(self):
589
        def fn(x: int) -> Union[int, str]:
590
            if x % 2:
591
                return "foo"
592
            else:
593
                return "bar"
594

595
        self.checkScript(fn, (1,))
596
        self.checkScript(fn, (8,))
597

598
    def test_union_branching_does_not_autoinfer_undeclared_union(self):
599
        def fn(x: int) -> str:
600
            if x % 2:
601
                y = "foo"
602
            else:
603
                y = x
604
            if isinstance(y, str):
605
                return y
606
            else:
607
                return "bar"
608

609
        with self.assertRaisesRegex(
610
            RuntimeError,
611
            "y is set to type str"
612
            " in the true branch and type int "
613
            "in the false branch",
614
        ):
615
            torch.jit.script(fn)
616

617
    def test_union_branching_does_not_widen_existing_inferred_type(self):
618
        def fn(x: int) -> str:
619
            y = "foo"
620
            if x % 2:
621
                y = "bar"
622
            else:
623
                y = x
624
            if isinstance(y, str):
625
                return y
626
            else:
627
                return "baz"
628

629
        with self.assertRaisesRegex(
630
            RuntimeError,
631
            "previously had type "
632
            "str but is now being assigned to a"
633
            " value of type int",
634
        ):
635
            torch.jit.script(fn)
636

637
    def test_union_schema_matching_on_internal_type(self):
638
        def fn(x: Union[List[int], Dict[str, int]]) -> int:
639
            if torch.jit.isinstance(x, List[int]):
640
                return x[0]
641
            else:
642
                return list(x.values())[0]
643

644
        self.checkScript(fn, ([1, 2, 3],))
645
        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
646

647
    def test_union_subtractive_refinement(self):
648
        def fn(x: Union[List[int], int]) -> int:
649
            if not isinstance(x, int):
650
                x.append(1)
651
                return x[0]
652
            else:
653
                return x
654

655
        self.checkScript(fn, (1,))
656
        self.checkScript(fn, ([1, 2, 3],))
657

658
    def test_union_subtractive_refinement_with_container(self):
659
        def fn(x: Union[List[int], int]) -> int:
660
            if not torch.jit.isinstance(x, List[int]):
661
                return x
662
            else:
663
                x.append(1)
664
                return x[0]
665

666
        self.checkScript(fn, (1,))
667
        self.checkScript(fn, ([1, 2, 3],))
668

669
    def test_union_memory_aliasing(self):
670
        def fn():
671
            x: List[torch.Tensor] = []
672
            z: List[Optional[List[torch.Tensor]]] = []
673
            z.append(x)
674
            x_alias = z[0]
675
            if torch.jit.isinstance(x_alias, List[torch.Tensor]):
676
                x_alias.append(torch.tensor(3))
677
            return x
678

679
        self.checkScript(fn, ())
680

681
    def test_union_serialization_preserves_type_annotations(self):
682
        # This function will fail after being torch.jit.save'd and
683
        # torch.jit.load'd if the type annotations aren't preserved
684
        # for Union during serialization. We need the `Union[str, int]`
685
        # annotation to make sure that `y` is typed as a Union instead
686
        # of as a str in one branch and an int in the other
687
        def fn(x: int) -> str:
688
            if x % 2:
689
                y: Union[str, int] = "bar"
690
            else:
691
                y: Union[str, int] = x
692
            if isinstance(y, str):
693
                return y
694
            else:
695
                return "baz"
696

697
        self.checkScript(fn, (1,))
698
        self.checkScript(fn, (8,))
699

700
    def _assert_passes(self, template: str, ann: str, lhs: str):
701
        code = template.format(ann=ann, lhs=lhs)
702
        self.checkScript(code, (), name="fn")
703

704
    def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
705
        code = template.format(ann=ann, lhs=lhs)
706
        with self.assertRaisesRegex(RuntimeError, msg):
707
            cu = torch.jit.CompilationUnit(code, _frames_up=1)
708
            string_frontend = getattr(cu, "fn")  # noqa: B009
709

710
    def test_union_with_list_assignment(self):
711
        template = dedent(
712
            """
713
            def fn():
714
                x: {ann} = {lhs}
715
                if torch.jit.isinstance(x, List[torch.Tensor]):
716
                    x.append(torch.tensor(3))
717
                return x
718
        """
719
        )
720

721
        lhs = {
722
            "list_literal_empty": "[]",
723
            "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
724
            "list_literal_of_str": '["foo", "bar", "baz"]',
725
            "list_literal_of_mixed": "[torch.arange(5), 1]",
726
            "list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
727
            "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
728
            "list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
729
        }
730

731
        """
732
        Union[List[str], List[torch.Tensor]]
733
        """
734
        self._assert_raises(
735
            template,
736
            "Union[List[str], List[torch.Tensor]]",
737
            lhs["list_literal_empty"],
738
            "there are multiple possible List type "
739
            "candidates in the Union annotation",
740
        )
741

742
        self._assert_passes(
743
            template,
744
            "Union[List[str], List[torch.Tensor]]",
745
            lhs["list_literal_of_tensor"],
746
        )
747

748
        self._assert_passes(
749
            template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
750
        )
751

752
        self._assert_raises(
753
            template,
754
            "Union[List[str], List[torch.Tensor]]",
755
            lhs["list_literal_of_mixed"],
756
            "none of those types match the types of the" " given list elements",
757
        )
758

759
        self._assert_passes(
760
            template,
761
            "Union[List[str], List[torch.Tensor]]",
762
            lhs["list_comprehension_of_tensor"],
763
        )
764

765
        self._assert_passes(
766
            template,
767
            "Union[List[str], List[torch.Tensor]]",
768
            lhs["list_comprehension_of_str"],
769
        )
770

771
        # TODO: Support mixed list comprehensions
772
        self._assert_raises(
773
            template,
774
            "Union[List[str], List[torch.Tensor]]",
775
            lhs["list_comprehension_of_mixed"],
776
            "Arguments for call are not valid",
777
        )
778

779
        """
780
        Union[int, torch.Tensor]
781
        """
782
        self._assert_raises(
783
            template,
784
            "Union[int, torch.Tensor]",
785
            lhs["list_literal_empty"],
786
            "Expected an Union type annotation with an " "inner List type",
787
        )
788

789
        self._assert_raises(
790
            template,
791
            "Union[int, torch.Tensor]",
792
            lhs["list_literal_of_tensor"],
793
            "Expected an Union type annotation with an " "inner List type",
794
        )
795

796
        self._assert_raises(
797
            template,
798
            "Union[int, torch.Tensor]",
799
            lhs["list_comprehension_of_tensor"],
800
            "Expected an Union type annotation with an " "inner List type",
801
        )
802

803
        """
804
        Union[List[torch.Tensor], int]
805
        """
806
        self._assert_passes(
807
            template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
808
        )
809

810
        self._assert_passes(
811
            template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
812
        )
813

814
        self._assert_raises(
815
            template,
816
            "Union[List[torch.Tensor], int]",
817
            lhs["list_literal_of_str"],
818
            r"List type annotation `List\[Tensor\]` did "
819
            "not match the types of the given list "
820
            "elements",
821
        )
822

823
        self._assert_raises(
824
            template,
825
            "Union[List[torch.Tensor], int]",
826
            lhs["list_literal_of_mixed"],
827
            r"List type annotation `List\[Tensor\]` did "
828
            "not match the types of the given list "
829
            "elements",
830
        )
831

832
        self._assert_passes(
833
            template,
834
            "Union[List[torch.Tensor], int]",
835
            lhs["list_comprehension_of_tensor"],
836
        )
837

838
        self._assert_raises(
839
            template,
840
            "Union[List[torch.Tensor], int]",
841
            lhs["list_comprehension_of_str"],
842
            r"List type annotation `List\[Tensor\]` did "
843
            "not match the types of the given list "
844
            "elements",
845
        )
846

847
        # TODO(@ansley): Support mixed list comprehensions
848
        self._assert_raises(
849
            template,
850
            "Union[List[torch.Tensor], int]",
851
            lhs["list_comprehension_of_mixed"],
852
            "Arguments for call are not valid",
853
        )
854

855
    def test_union_with_dict_assignment(self):
856
        template = dedent(
857
            """
858
            def fn():
859
                x: {ann} = {lhs}
860
                if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
861
                    x["foo"] = torch.tensor(3)
862
                return x
863
        """
864
        )
865

866
        lhs = {
867
            "dict_literal_empty": "{}",
868
            "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
869
            "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
870
            "dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
871
            "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
872
                    zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
873
            "dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
874
                    zip(["foo", "bar"], [1, 2]}',
875
            "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
876
                    zip(["foo", "bar"], [torch.arange(3), 2])}',
877
            "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
878
            "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
879
            "dict_keyword_with_empty_iterable": "dict([])",
880
            "dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
881
            "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
882
            "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
883
        }
884

885
        """
886
        Union[Dict[str, torch.Tensor], Dict[str, int]]
887
        """
888
        self._assert_raises(
889
            template,
890
            "Union[List[str], List[torch.Tensor]]",
891
            lhs["dict_literal_empty"],
892
            "Expected an Union type annotation with an " "inner Dict type",
893
        )
894

895
        self._assert_passes(
896
            template,
897
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
898
            lhs["dict_literal_of_str_tensor"],
899
        )
900

901
        self._assert_passes(
902
            template,
903
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
904
            lhs["dict_literal_of_str_int"],
905
        )
906

907
        self._assert_raises(
908
            template,
909
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
910
            lhs["dict_literal_of_mixed"],
911
            "none of those dict types can hold the "
912
            "types of the given keys and values",
913
        )
914

915
        # TODO: String frontend does not support tuple unpacking
916
        # https://github.com/pytorch/pytorch/issues/64096
917
        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
918
        #              lhs["dict_comprehension_of_str_tensor"])
919

920
        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
921
        #              lhs["dict_comprehension_of_str_int"])
922

923
        # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
924
        #              lhs["dict_comprehension_of_mixed"],
925
        #              "foobar")
926

927
        # self._assert_passes(template,
928
        #                    "Union[Dict[str, torch.Tensor], Dict[str, int]]",
929
        #                    lhs["dict_keyword_with_internal_aggregate_function"])
930

931
        # TODO(@ansley): Follow-up project needed for full type
932
        # inference with dict keyword (supported for dict comprehension
933
        # and dict literal already; should not be a blocker for anyone)
934
        self._assert_raises(
935
            template,
936
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
937
            lhs["dict_keyword"],
938
            "full type inference is not yet supported",
939
        )
940

941
        self._assert_raises(
942
            template,
943
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
944
            lhs["dict_keyword_with_iterable"],
945
            "full type inference is not yet supported",
946
        )
947

948
        self._assert_raises(
949
            template,
950
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
951
            lhs["dict_keyword_with_empty_iterable"],
952
            "full type inference is not yet supported",
953
        )
954

955
        self._assert_raises(
956
            template,
957
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
958
            lhs["dict_keyword_with_mapping"],
959
            "full type inference is not yet supported",
960
        )
961

962
        self._assert_raises(
963
            template,
964
            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
965
            lhs["dict_keyword_with_mapping_and_kwargs"],
966
            "full type inference is not yet supported",
967
        )
968

969
        """
970
        Union[int, torch.Tensor]
971
        """
972
        self._assert_raises(
973
            template,
974
            "Union[int, torch.Tensor]",
975
            lhs["dict_literal_empty"],
976
            "Expected an Union type annotation with " "an inner Dict type",
977
        )
978

979
        self._assert_raises(
980
            template,
981
            "Union[int, torch.Tensor]",
982
            lhs["dict_literal_of_str_tensor"],
983
            "Expected an Union type annotation with " "an inner Dict type",
984
        )
985

986
        # See above--string frontend does not support tuple unpacking
987
        # self._assert_raises(template, "Union[int, torch.Tensor]",
988
        #              lhs["dict_comprehension_of_tensor"],
989
        #              "foobar")
990

991
        """
992
        Union[Dict[str, torch.Tensor], int]
993
        """
994
        self._assert_passes(
995
            template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
996
        )
997

998
        self._assert_passes(
999
            template,
1000
            "Union[Dict[str, torch.Tensor], int]",
1001
            lhs["dict_literal_of_str_tensor"],
1002
        )
1003

1004
        self._assert_raises(
1005
            template,
1006
            "Union[Dict[str, torch.Tensor], int]",
1007
            lhs["dict_literal_of_str_int"],
1008
            "Type annotation was inferred to be "
1009
            r"`Dict\[str, Tensor\]`, but the type of "
1010
            "values given by the dict literal is",
1011
        )
1012

1013
        self._assert_raises(
1014
            template,
1015
            "Union[Dict[str, torch.Tensor], int]",
1016
            lhs["dict_literal_of_mixed"],
1017
            "Type annotation was inferred to be "
1018
            r"`Dict\[str, Tensor\]`, but the type of "
1019
            "values given by the dict literal is",
1020
        )
1021

1022
        self._assert_passes(
1023
            template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
1024
        )
1025

1026
        self._assert_passes(
1027
            template,
1028
            "Union[Dict[str, torch.Tensor], int]",
1029
            lhs["dict_keyword_with_iterable"],
1030
        )
1031

1032
        self._assert_passes(
1033
            template,
1034
            "Union[Dict[str, torch.Tensor], int]",
1035
            lhs["dict_keyword_with_empty_iterable"],
1036
        )
1037

1038
        self._assert_passes(
1039
            template,
1040
            "Union[Dict[str, torch.Tensor], int]",
1041
            lhs["dict_keyword_with_mapping"],
1042
        )
1043

1044
        self._assert_passes(
1045
            template,
1046
            "Union[Dict[str, torch.Tensor], int]",
1047
            lhs["dict_keyword_with_mapping_and_kwargs"],
1048
        )
1049

1050
        # See above--string frontend does not support tuple unpacking
1051
        # self._assert_passes(template,
1052
        #                    "Union[Dict[str, torch.Tensor], int]",
1053
        #                    lhs["dict_keyword_with_internal_aggregate_function"])
1054
        #
1055
        # self._assert_passes(template,
1056
        #                    "Union[Dict[str, torch.Tensor], int]",
1057
        #                    lhs["dict_comprehension_of_str_tensor"])
1058

1059
        # self._assert_raises(template,
1060
        #                    "Union[Dict[str, torch.Tensor], int]",
1061
        #                    lhs["dict_comprehension_of_str_int"],
1062
        #                    "foobar")
1063

1064
        # self._assert_raises(template,
1065
        #                    "Union[Dict[str, torch.Tensor], int]",
1066
        #                    lhs["dict_comprehension_of_mixed"],
1067
        #                    "foobar")
1068

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.