pytorch

Форк
0
/
test_recursive_script.py 
800 строк · 22.8 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import re
5
import sys
6
import types
7
import typing
8
import typing_extensions
9
from collections import OrderedDict
10
from typing import Dict, List, Optional, Tuple
11

12
import torch
13
import torch.jit.frontend
14
import torch.nn as nn
15
from torch import Tensor
16
from torch.testing import FileCheck
17

18

19
# Make the helper files in test/ importable
20
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
21
sys.path.append(pytorch_test_dir)
22
from torch.testing._internal.jit_utils import (
23
    _tmp_donotuse_dont_inline_everything,
24
    JitTestCase,
25
)
26

27

28
if __name__ == "__main__":
29
    raise RuntimeError(
30
        "This test file is not meant to be run directly, use:\n\n"
31
        "\tpython test/test_jit.py TESTNAME\n\n"
32
        "instead."
33
    )
34

35

36
class TestRecursiveScript(JitTestCase):
37
    def test_inferred_nonetype(self):
38
        class M(nn.Module):
39
            def __init__(self) -> None:
40
                super().__init__()
41
                self.x = None
42

43
            def forward(self):
44
                assert self.x is None
45

46
        m = torch.jit.script(M())
47
        self.checkModule(M(), ())
48

49
    def test_script_function_attribute(self):
50
        @torch.jit.script
51
        def fn1(x):
52
            return x + x
53

54
        @torch.jit.script
55
        def fn2(x):
56
            return x - x
57

58
        class M(torch.nn.Module):
59
            def __init__(self, fn):
60
                super().__init__()
61
                self.fn = fn
62

63
            def forward(self, x):
64
                return self.fn(x)
65

66
        fn1_mod = M(fn1)
67
        fn2_mod = M(fn2)
68

69
        self.checkModule(fn1_mod, (torch.randn(2, 2),))
70
        self.checkModule(fn2_mod, (torch.randn(2, 2),))
71

72
    def test_python_function_attribute(self):
73
        class M(torch.nn.Module):
74
            def __init__(self, fn):
75
                super().__init__()
76
                self.fn = fn
77

78
            def forward(self, x):
79
                return self.fn(x)
80

81
        mod = M(torch.sigmoid)
82

83
        self.checkModule(mod, (torch.randn(2, 2),))
84

85
    def test_failed_function_compilation(self):
86
        def fn(x):
87
            return i_dont_exist  # noqa: F821
88

89
        class M(torch.nn.Module):
90
            def __init__(self, fn):
91
                super().__init__()
92
                self.fn = fn
93

94
            def forward(self, x):
95
                return self.fn(x)
96

97
        m = M(fn)
98
        with self.assertRaisesRegexWithHighlight(
99
            RuntimeError, "failed to compile", "i_dont_exist"
100
        ):
101
            torch.jit.script(m)
102

103
    def test_init_error(self):
104
        class M(nn.Module):
105
            def __init__(self) -> None:
106
                self.x = 2
107

108
            def forward(self):
109
                pass
110

111
        with self.assertRaisesRegex(RuntimeError, "has not been initialized"):
112
            torch.jit.script(M())
113

114
    def test_script_after_eval(self):
115
        class M(nn.Module):
116
            def forward(self):
117
                if self.training:
118
                    return 2
119
                else:
120
                    return 0
121

122
        m = M()
123
        sm1 = torch.jit.script(m)
124
        m.eval()
125
        sm2 = torch.jit.script(m)
126

127
        # m is in eval mode, training should be False
128
        self.assertFalse(m.training)
129

130
        # sm1 was created while m had training = True
131
        self.assertTrue(sm1.training)
132
        self.assertEqual(sm1.training, sm1._c.getattr("training"))
133
        self.assertEqual(sm1(), 2)
134

135
        # sm2 was created after m was eval'ed
136
        self.assertFalse(sm2.training)
137
        self.assertEqual(sm2.training, sm2._c.getattr("training"))
138
        self.assertEqual(sm2(), 0)
139

140
    def test_module_name(self):
141
        class MyModule(torch.nn.Module):
142
            def __init__(self) -> None:
143
                super().__init__()
144
                self.x = 2
145

146
            def forward(self, t):
147
                return t + self.x
148

149
        m = torch.jit.script(MyModule())
150
        FileCheck().check("MyModule").run(m.graph)
151

152
    def test_repeated_error_stack(self):
153
        def d(x):
154
            return "a" - 2
155

156
        def c(x):
157
            return d(x)
158

159
        def b(x):
160
            return c(x)
161

162
        def a(x):
163
            return b(x)
164

165
        try:
166
            torch.jit.script(a)
167
        except Exception as e:
168
            FileCheck().check_count("is being compiled", 2).run(str(e))
169

170
        try:
171
            torch.jit.script(a)
172
        except Exception as e:
173
            # Make sure that no entries are left over from the previous failure
174
            FileCheck().check_count("is being compiled", 2).run(str(e))
175

176
    def test_constants_with_final(self):
177
        class M1(torch.nn.Module):
178
            x: torch.jit.Final[int]
179

180
            def __init__(self) -> None:
181
                super().__init__()
182
                self.x = 2
183

184
            def forward(self, t):
185
                return t + self.x
186

187
        self.checkModule(M1(), (torch.randn(2, 2),))
188

189
        class M2(torch.nn.Module):
190
            x: typing_extensions.Final[int]
191

192
            def __init__(self) -> None:
193
                super().__init__()
194
                self.x = 2
195

196
            def forward(self, t):
197
                return t + self.x
198

199
        self.checkModule(M2(), (torch.randn(2, 2),))
200

201
        class M3(torch.nn.Module):
202
            x: typing.Final[int]
203

204
            def __init__(self) -> None:
205
                super().__init__()
206
                self.x = 2
207

208
            def forward(self, t):
209
                return t + self.x
210

211
        self.checkModule(M3(), (torch.randn(2, 2),))
212

213
    def test_ignore_class(self):
214
        @torch.jit.ignore
215
        class MyScriptClass:
216
            def unscriptable(self):
217
                return "a" + 200
218

219
        class TestModule(torch.nn.Module):
220
            def forward(self, x):
221
                return MyScriptClass()
222

223
        with self.assertRaisesRegexWithHighlight(
224
            torch.jit.frontend.FrontendError,
225
            "Cannot instantiate class",
226
            "MyScriptClass",
227
        ):
228
            t = torch.jit.script(TestModule())
229

230
    def test_method_call(self):
231
        class M(nn.Module):
232
            def test(self, x):
233
                return x
234

235
            def forward(self, z):
236
                y = self.test(z)
237
                return z + 20 + y
238

239
        self.checkModule(M(), (torch.randn(2, 2),))
240

241
    def test_module_repr(self):
242
        class Submodule(nn.Module):
243
            def forward(self, x):
244
                return x
245

246
        class MyModule(nn.Module):
247
            def __init__(self) -> None:
248
                super().__init__()
249
                self.conv = nn.Conv2d(10, 10, 3)
250
                self.lin = nn.Linear(10, 10)
251
                self.sub = Submodule()
252

253
            def forward(self, x):
254
                return self.lin(x) + self.sub(x) + self.conv(x)
255

256
        m = torch.jit.script(MyModule())
257

258
        with self.capture_stdout() as out:
259
            print(m)
260

261
        f = FileCheck()
262
        f.check("MyModule")
263
        f.check("Conv2d")
264
        f.check("Linear")
265
        f.check("Submodule")
266
        f.run(out[0])
267

268
        self.assertEqual(m.original_name, "MyModule")
269

270
    def test_dir(self):
271
        def test_module_dir(mod):
272
            dir_set = dir(mod)
273
            scripted_mod = torch.jit.script(mod)
274
            dir_scripted = set(dir(scripted_mod))
275
            # set not currently copied over
276
            ignore_set = [
277
                "training",
278
                "__delitem__",
279
                "__setitem__",
280
                "clear",
281
                "items",
282
                "keys",
283
                "pop",
284
                "update",
285
                "values",
286
            ]
287
            for attr in dir_set:
288
                if attr in ignore_set:
289
                    continue
290
                self.assertTrue(attr in dir_scripted, attr)
291

292
        class MyModule(nn.Module):
293
            def __init__(self) -> None:
294
                super().__init__()
295
                self.conv = nn.Conv2d(10, 10, 3)
296
                self.lin = nn.Linear(10, 10)
297

298
            def forward(self, x):
299
                return self.lin(x) + self.conv(x)
300

301
        test_module_dir(MyModule())
302

303
        # test custom __dir__ for containers
304
        conv = nn.Conv2d(10, 10, 3)
305
        linear = nn.Linear(10, 10)
306

307
        test_module_dir(nn.Sequential(conv, linear))
308
        test_module_dir(
309
            nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))
310
        )
311

312
    def test_class_compile(self):
313
        def other_fn(a: int, b: Tensor) -> Tensor:
314
            return a * b
315

316
        class B:
317
            def __init__(self, x):
318
                self.x = 2
319

320
            def helper(self, a):
321
                return self.x + a + other_fn(self.x, a)
322

323
        class N(torch.nn.Module):
324
            def forward(self, x):
325
                b = B(x)
326
                return b.helper(x)
327

328
        self.checkModule(N(), (torch.randn(2, 2),))
329

330
    def test_error_stack(self):
331
        def d(x: int) -> int:
332
            return x + 10
333

334
        def c(x):
335
            return d("hello") + d(x)
336

337
        def b(x):
338
            return c(x)
339

340
        def a(x):
341
            return b(x)
342

343
        try:
344
            scripted = torch.jit.script(a)
345
        except RuntimeError as e:
346
            checker = FileCheck()
347
            checker.check("Expected a value of type 'int'")
348
            checker.check("def c(x)")
349
            checker.check("def b(x)")
350
            checker.check("def a(x)")
351
            checker.run(str(e))
352

353
    def test_error_stack_module(self):
354
        def d(x: int) -> int:
355
            return x + 10
356

357
        def c(x):
358
            return d("hello") + d(x)
359

360
        def b(x):
361
            return c(x)
362

363
        class Submodule(torch.nn.Module):
364
            def forward(self, x):
365
                return b(x)
366

367
        class M(torch.nn.Module):
368
            def __init__(self) -> None:
369
                super().__init__()
370
                self.submodule = Submodule()
371

372
            def some_method(self, y):
373
                return y + self.submodule(y)
374

375
            def forward(self, x):
376
                return self.some_method(x)
377

378
        try:
379
            scripted = torch.jit.script(M())
380
        except RuntimeError as e:
381
            checker = FileCheck()
382
            checker.check("Expected a value of type 'int'")
383
            checker.check("'c' is being compiled since it was called from 'b'")
384
            checker.check("'b' is being compiled since it was called from")
385
            checker.run(str(e))
386

387
    @_tmp_donotuse_dont_inline_everything
388
    def test_script_basic(self):
389
        def a_python_fn(a, b, c):
390
            return a + b + c
391

392
        @torch.jit.script
393
        def a_script_fn(d, e, f):
394
            return a_python_fn(d, e, f)
395

396
        graph = str(a_script_fn.graph)
397
        FileCheck().check("prim::CallFunction").run(graph)
398
        FileCheck().check_not("^a_python_fn").run(graph)
399
        t = torch.ones(2, 2)
400
        self.assertEqual(a_script_fn(t, t, t), t + t + t)
401

402
    def test_error_stack_class(self):
403
        class X:
404
            def bad_fn(self):
405
                import pdb  # noqa: F401
406

407
        def fn(x) -> X:
408
            return X(10)
409

410
        try:
411
            torch.jit.script(fn)
412
        except Exception as e:
413
            checker = FileCheck()
414
            checker.check("import statements")
415
            checker.check("is being compiled since it was called from")
416
            checker.run(str(e))
417

418
    def test_error_stack_annotation(self):
419
        class X:
420
            def bad_fn(self):
421
                import pdb  # noqa: F401
422

423
        def fn(x) -> X:
424
            return X(10)
425

426
        try:
427
            torch.jit.script(fn)
428
        except Exception as e:
429
            checker = FileCheck()
430
            checker.check("import statements")
431
            checker.check("is being compiled since it was called from")
432
            checker.check("-> X")
433
            checker.run(str(e))
434

435
    def test_module_basic(self):
436
        class Other(torch.nn.Module):
437
            __constants__ = ["x"]
438

439
            def __init__(self, x):
440
                super().__init__()
441
                self.x = x
442
                self.param = torch.nn.Parameter(torch.ones(2, 2))
443

444
            def some_unscriptable_method(self):
445
                a = 2
446
                a = [2]
447
                return a
448

449
            def forward(self, t):
450
                return t + self.x + self.param
451

452
        class M(torch.nn.Module):
453
            def __init__(self) -> None:
454
                super().__init__()
455
                self.other = Other(200)
456

457
            def forward(self, t):
458
                return self.other(t) * 2
459

460
        self.checkModule(M(), (torch.ones(2, 2),))
461

462
    def test_module_function_export(self):
463
        class Other(torch.nn.Module):
464
            __constants__ = ["x"]
465

466
            def __init__(self, x):
467
                super().__init__()
468
                self.x = x
469
                self.param = torch.nn.Parameter(torch.ones(2, 2))
470

471
            @torch.jit.export
472
            def some_entry_point(self, y):
473
                return y + 20
474

475
            def forward(self, t):
476
                return t + self.x + self.param
477

478
        class M(torch.nn.Module):
479
            def __init__(self) -> None:
480
                super().__init__()
481
                self.other = Other(200)
482

483
            def forward(self, t):
484
                return self.other(t) * 2
485

486
        self.checkModule(M(), (torch.ones(2, 2),))
487

488
    def test_iterable_modules(self):
489
        class Inner(torch.nn.Module):
490
            def forward(self, x):
491
                return x + 10
492

493
        class M(torch.nn.Module):
494
            def __init__(self) -> None:
495
                super().__init__()
496
                self.sequential = nn.Sequential(
497
                    Inner(), Inner(), nn.Sequential(Inner(), Inner())
498
                )
499
                self.module_list = nn.ModuleList([Inner(), Inner()])
500

501
            def forward(self, x):
502
                for mod in self.module_list:
503
                    x += mod(x)
504
                x += self.sequential(x)
505
                return x
506

507
        self.checkModule(M(), (torch.randn(5, 5),))
508

509
    def test_prepare_scriptable_basic(self):
510
        class SeluButReluWhenScripted(torch.nn.SELU):
511
            def __prepare_scriptable__(self):
512
                return nn.ReLU()
513

514
        t = torch.randn(5, 5)
515
        m = SeluButReluWhenScripted()
516
        sm = torch.jit.script(m)
517
        eager_out = m(t)
518
        script_out = sm(t)
519
        self.assertNotEqual(eager_out, script_out)
520

521
    def test_prepare_scriptable_iterable_modules(self):
522
        class SeluButReluWhenScripted(torch.nn.SELU):
523
            def __prepare_scriptable__(self):
524
                return nn.ReLU()
525

526
        class M(torch.nn.Module):
527
            def __init__(self) -> None:
528
                super().__init__()
529
                shared = SeluButReluWhenScripted()
530
                self.sequential = nn.Sequential(
531
                    SeluButReluWhenScripted(),
532
                    SeluButReluWhenScripted(),
533
                    nn.Sequential(
534
                        SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()
535
                    ),
536
                    shared,
537
                )
538
                self.module_list = nn.ModuleList(
539
                    [SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()]
540
                )
541

542
            def forward(self, x):
543
                for mod in self.module_list:
544
                    x += mod(x)
545
                x += self.sequential(x)
546
                return x
547

548
        t = torch.randn(5, 5)
549
        m = M()
550
        eager_out = m(t.clone())
551
        sm = torch.jit.script(m)
552
        script_out = sm(t.clone())
553
        self.assertNotEqual(eager_out, script_out)
554

555
    def test_prepare_scriptable_cycle(self):
556
        t = torch.randn(5, 5)
557
        c = torch.nn.Module()
558
        p = torch.nn.Module()
559
        c.__dict__["_p"] = p
560
        p.__dict__["_c"] = c
561

562
        sm = torch.jit.script(p)
563

564
    def test_prepare_scriptable_escape_hatch(self):
565
        class NonJitableClass:
566
            def __call__(self, int1, int2, *args):
567
                total = int1 + int2
568
                for arg in args:
569
                    total += arg
570
                return total
571

572
        obj = NonJitableClass()
573

574
        self.assertEqual(obj(1, 2), 3)
575
        self.assertEqual(obj(1, 2, 3, 4), 10)
576
        with self.assertRaisesRegex(
577
            torch.jit.frontend.NotSupportedError,
578
            expected_regex="can't take variable number of arguments",
579
        ):
580
            torch.jit.script(obj)
581

582
        def escape_hatch(int1: int, int2: int) -> int:
583
            return int1 + int2
584

585
        class NonJitableClassWithEscapeHatch(NonJitableClass):
586
            def __prepare_scriptable__(self):
587
                return escape_hatch
588

589
        jit_obj = torch.jit.script(NonJitableClassWithEscapeHatch())
590

591
        self.assertEqual(jit_obj(1, 2), 3)
592
        with self.assertRaisesRegex(
593
            RuntimeError,
594
            expected_regex=re.escape(
595
                "expected at most 2 argument(s) but received 4 argument(s)"
596
            ),
597
        ):
598
            jit_obj(1, 2, 3, 4)
599

600
    def test_attributes(self):
601
        @torch.jit.script
602
        class Inner2:
603
            def __init__(self) -> None:
604
                self.b = "a string"
605

606
        @torch.jit.script
607
        class Foo:
608
            def __init__(self) -> None:
609
                self.a = 4
610
                self.inner = Inner2()
611

612
        @torch.jit.script
613
        class SFoo:
614
            def __init__(self) -> None:
615
                self.a = 4
616
                self.inner = Inner2()
617

618
            def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
619
                a, inner = obj
620
                self.a = a
621
                self.inner = inner
622

623
            def __getstate__(self):
624
                return (self.a, self.inner)
625

626
        untyped_values = (
627
            ("my_dict", {"I": "am", "a test": "test"}),
628
            ("my_float", 2.3),
629
            ("my_int", 99),
630
            ("my_bool", False),
631
            ("my_tuple", (1, 2, 3, 4)),
632
            ("my_list", [(1, 2), (3, 4)]),
633
            # ('my_tensor', torch.randn(2, 2)),
634
            ("my_int_list", [1, 2, 3, 4]),
635
            # ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
636
            ("my_bool_list", [True, True, False, True]),
637
            ("my_float_list", [1.0, 2.0, 3.0, 4.0]),
638
            ("my_str_list", ["hello", "bye"]),
639
        )
640
        typed_values = (
641
            ("my_empty_list", []),
642
            ("my_empty_dict", {}),
643
            ("my_none", None),
644
            ("my_object", Foo()),
645
            ("my_object2", SFoo()),
646
        )
647

648
        class M(torch.nn.Module):
649
            # TODO: re-enable this once this test is in a Python 3-only syntax
650
            # file
651
            # my_empty_list : List[int]
652
            # my_empty_dict : Dict[str, int]
653
            # my_none : Optional[int]
654

655
            def forward(self, x):
656
                return (
657
                    self.my_dict,
658
                    self.my_float,
659
                    self.my_int,
660
                    self.my_bool,
661
                    # self.my_tensor,
662
                    self.my_int_list,
663
                    # self.my_tensor_list,
664
                    self.my_bool_list,
665
                    self.my_float_list,
666
                    self.my_str_list,
667
                    self.my_empty_list,
668
                    self.my_empty_dict,
669
                    self.my_none,
670
                    self.my_object.a,
671
                    self.my_object.inner.b,
672
                    self.my_object.a,
673
                    self.my_object2.inner.b,
674
                )
675

676
        # TODO: as a followup, fix this test
677
        # We can't define class attributes like we should be doing:
678
        #   class M(torch.nn.Module):
679
        #       my_empty_list : List[int]
680
        #       my_empty_dict : Dict[str, int]
681
        #       my_none : Optional[int]
682
        #       my_out_of_line_attribute: List[int] = [1, 2, 3]
683
        # since there's no string frontend for Python classes (so the `define`)
684
        # trick doesn't work.
685
        M.__annotations__ = {
686
            "my_empty_list": List[int],
687
            "my_empty_dict": Dict[str, int],
688
            "my_none": Optional[int],
689
            "my_object": Foo,
690
            "my_object2": SFoo,
691
        }
692

693
        m = M()
694
        for name, value in untyped_values + typed_values:
695
            setattr(m, name, value)
696

697
        self.checkModule(m, (torch.randn(5, 5),))
698

699
    def test_function_attribute_in_submodule(self):
700
        class N(nn.Module):
701
            def __init__(self, norm):
702
                super().__init__()
703
                self.activation = torch.nn.functional.relu
704
                self.norm = norm
705

706
            def forward(self, src):
707
                output = src
708
                output = self.norm(output)
709
                return output
710

711
        class M(nn.Module):
712
            def __init__(self) -> None:
713
                super().__init__()
714
                encoder_norm = nn.ReLU()
715
                self.encoder = N(encoder_norm)
716

717
            def forward(self, x):
718
                return self.encoder(x)
719

720
        m = M()
721
        self.checkModule(m, (torch.randn(5, 5),))
722

723
    def test_inner_traced_module(self):
724
        class Dummy(nn.Module):
725
            def forward(self, x):
726
                return x
727

728
        class Model(nn.Module):
729
            def __init__(self, dummies):
730
                super().__init__()
731
                self._dummies = dummies
732

733
            def forward(self, x):
734
                out = []
735
                for dummy in self._dummies:
736
                    out.append(dummy(x))
737
                return out
738

739
        dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
740
        dummies = nn.ModuleList([dummy])
741
        model = Model(dummies)
742
        self.checkModule(model, (torch.rand(5, 5),))
743

744
    def test_script_loaded_module(self):
745
        """
746
        Test that we can hold a loaded ScriptModule as a submodule.
747
        """
748

749
        class Dummy(nn.Module):
750
            def forward(self, x):
751
                return x
752

753
        dummy = torch.jit.script(Dummy())
754
        dummy = self.getExportImportCopy(dummy)
755

756
        class ContainsLoaded(torch.nn.Module):
757
            def __init__(self) -> None:
758
                super().__init__()
759
                self.encoder = dummy
760

761
            def forward(self, input):
762
                return self.encoder(input)
763

764
        self.checkModule(ContainsLoaded(), (torch.rand(2, 3),))
765

766
    def test_optional_module(self):
767
        class Dummy(nn.Module):
768
            def __init__(self) -> None:
769
                super().__init__()
770
                self.foo = nn.Linear(2, 2)
771

772
            def forward(self, x):
773
                if self.foo is not None:
774
                    return self.foo(x)
775
                return x
776

777
        mod = Dummy()
778
        self.checkModule(mod, (torch.rand(2, 2),))
779
        mod.foo = None
780
        self.checkModule(mod, (torch.rand(2, 2),))
781

782
    def test_override_instance_method_ignore(self):
783
        class M(torch.nn.Module):
784
            @torch.jit.ignore
785
            def i_am_ignored(self):
786
                return "old"
787

788
        m = M()
789

790
        # Override the ignored method by binding a new method to this instance.
791
        @torch.jit.ignore
792
        def i_am_ignored(self):
793
            return "new"
794

795
        m.i_am_ignored = types.MethodType(i_am_ignored, m)
796
        self.assertEqual(m.i_am_ignored(), "new")
797

798
        # ScriptModule should correctly reflect the override.
799
        s = torch.jit.script(m)
800
        self.assertEqual(s.i_am_ignored(), "new")
801

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.