pytorch

Форк
0
/
test_with.py 
648 строк · 19.4 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5
from typing import Any, List
6

7
import torch
8
from torch.testing._internal.common_utils import skipIfTorchDynamo
9
from torch.testing._internal.jit_utils import JitTestCase, make_global
10

11

12
# Make the helper files in test/ importable
13
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14
sys.path.append(pytorch_test_dir)
15

16
if __name__ == "__main__":
17
    raise RuntimeError(
18
        "This test file is not meant to be run directly, use:\n\n"
19
        "\tpython test/test_jit.py TESTNAME\n\n"
20
        "instead."
21
    )
22

23

24
class TestWith(JitTestCase):
25
    """
26
    A suite of tests for with statements.
27
    """
28

29
    def test_with_as(self):
30
        """
31
        Check that with statements that use the 'as' keyword to bind expressions
32
        to targets work as expected.
33
        """
34

35
        @torch.jit.script
36
        class Context:
37
            """
38
            This class implements a basic context manager interface for use in
39
            the unit tests. Unlike Context, the stateful part of this class
40
            is a Tensor that is mutated in-place so that modifications made in the
41
            JIT interpreter are visible outside of it.
42
            """
43

44
            def __init__(self, start: int):
45
                self.count = torch.tensor([start], dtype=torch.double)
46

47
            def __enter__(self):
48
                self.count.add_(0.3)
49
                return self.count
50

51
            def __exit__(self, type: Any, value: Any, tb: Any) -> bool:
52
                self.count.sub_(0.3)
53
                return True
54

55
        make_global(Context)
56

57
        def test_basic(x: torch.Tensor) -> torch.Tensor:
58
            """Basic test with one with-statement."""
59

60
            c = Context(1)
61

62
            with c as mult:
63
                y = x + mult
64

65
            y *= c.count
66
            return y
67

68
        def test_pass(x: torch.Tensor) -> torch.Tensor:
69
            """
70
            Test with a pass statement inside a with-statement. Although
71
            the body of the with is empty, __enter__ and __exit__ should
72
            still be called.
73
            """
74
            c = Context(1)
75

76
            with c as mult:
77
                pass
78

79
            x *= c.count
80
            return x
81

82
        def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
83
            """
84
            Test that returning early from inside a with-statement works
85
            as expected.
86
            """
87
            with c as mult:
88
                y = x + mult
89
                return y
90

91
            x = y + y
92
            return x
93

94
        def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
95
            """
96
            Test that conditionally returning early from inside a with-statement works
97
            as expected.
98
            """
99
            with c as mult:
100
                y = x + mult
101
                if mult > 0:
102
                    return y
103

104
            x = y + y
105
            return x
106

107
        def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
108
            """
109
            Test that breaking early from inside a with-statement works
110
            as expected.
111
            """
112
            with c as mult:
113
                for a in l:
114
                    if a == 0:
115
                        break
116
                    x += a * mult
117

118
            return x
119

120
        def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
121
            """
122
            Test that using continue inside a with-statement works
123
            as expected.
124
            """
125
            with c as mult:
126
                for a in l:
127
                    if a == 0:
128
                        continue
129
                    x += a * mult
130

131
            return x
132

133
        def test_serial(x: torch.Tensor) -> torch.Tensor:
134
            """
135
            Test two with-statements in a row.
136
            """
137
            c = Context(1)
138

139
            with c as mult:
140
                y = x + mult
141

142
            with c as mult:
143
                y *= mult
144

145
            return y
146

147
        def test_nested(x: torch.Tensor) -> torch.Tensor:
148
            """
149
            Test nested with-statements.
150
            """
151
            c = Context(1)
152

153
            with c as m:
154
                with c as n:
155
                    y = x + n
156

157
                y *= m
158

159
            return y
160

161
        def test_combined(x: torch.Tensor) -> torch.Tensor:
162
            """
163
            Test a with-statement with multiple with items.
164
            """
165
            c = Context(1)
166
            d = Context(2)
167

168
            with c as m, d as n:
169
                y = x + (m + n)
170

171
            return y
172

173
        test_input = torch.randn(2, 2)
174
        test_context = Context(2)
175
        test_list = [2, 0, 1, 3, 0, 2]
176

177
        self.checkScript(test_basic, (test_input,))
178
        self.checkScript(test_pass, (test_input,))
179
        self.checkScript(test_early_return, (test_input, test_context))
180
        self.checkScript(test_break, (test_input, test_context, test_list))
181
        self.checkScript(test_continue, (test_input, test_context, test_list))
182
        self.assertEqual(test_context.count, 2)
183
        self.checkScript(test_serial, (test_input,))
184
        self.checkScript(test_nested, (test_input,))
185
        self.checkScript(test_combined, (test_input,))
186

187
    def test_with_no_as(self):
188
        """
189
        Check that with statements that do not use the 'as' keyword to bind expressions
190
        to targets work as expected.
191
        """
192

193
        @torch.jit.script
194
        class Context:
195
            """
196
            This class implements a basic context manager interface for use in
197
            the unit tests. Unlike Context, the stateful part of this class
198
            is a Tensor that is mutated in-place so that modifications made in the
199
            JIT interpreter are visible outside of it.
200
            """
201

202
            def __init__(self, start: int):
203
                self.count = torch.tensor([start], dtype=torch.double)
204

205
            def __enter__(self):
206
                self.count.add_(0.3)
207
                return self.count
208

209
            def __exit__(self, type: Any, value: Any, tb: Any):
210
                self.count.sub_(0.3)
211

212
        make_global(Context)
213

214
        def test_basic(x: torch.Tensor) -> torch.Tensor:
215
            """Basic test with one with-statement."""
216

217
            c = Context(1)
218

219
            with c:
220
                y = x + c.count
221

222
            y *= c.count
223
            return y
224

225
        def test_pass(x: torch.Tensor) -> torch.Tensor:
226
            """
227
            Test with a pass statement inside a with-statement. Although
228
            the body of the with is empty, __enter__ and __exit__ should
229
            still be called.
230
            """
231
            c = Context(1)
232

233
            with c:
234
                pass
235

236
            x *= c.count
237
            return x
238

239
        def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
240
            """
241
            Test that returning early from inside a with-statement works
242
            as expected.
243
            """
244
            with c:
245
                y = x + c.count
246
                return y
247

248
            x = y + y
249
            return x
250

251
        def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
252
            """
253
            Test that conditionally returning early from inside a with-statement works
254
            as expected.
255
            """
256
            with c:
257
                y = x + c.count
258
                if c.count > 0:
259
                    return y
260

261
            x = y + y
262
            return x
263

264
        def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
265
            """
266
            Test that breaking early from inside a with-statement works
267
            as expected.
268
            """
269
            with c:
270
                for a in l:
271
                    if a == 0:
272
                        break
273
                    x += a * c.count
274

275
            return x
276

277
        def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
278
            """
279
            Test that using continue inside a with-statement works
280
            as expected.
281
            """
282
            with c:
283
                for a in l:
284
                    if a == 0:
285
                        continue
286
                    x += a * c.count
287

288
            return x
289

290
        def test_serial(x: torch.Tensor) -> torch.Tensor:
291
            """
292
            Test two with-statements in a row.
293
            """
294
            c = Context(1)
295

296
            with c:
297
                y = x + c.count
298

299
            with c:
300
                y *= c.count
301

302
            return y
303

304
        def test_nested(x: torch.Tensor) -> torch.Tensor:
305
            """
306
            Test nested with-statements.
307
            """
308
            c = Context(1)
309

310
            with c:
311
                with c:
312
                    y = x + c.count
313

314
                y *= c.count
315

316
            return y
317

318
        def test_combined(x: torch.Tensor) -> torch.Tensor:
319
            """
320
            Test a with-statement with multiple with items.
321
            """
322
            c = Context(1)
323
            d = Context(2)
324

325
            with c, d:
326
                y = x + (c.count + d.count)
327

328
            return y
329

330
        test_input = torch.randn(2, 2)
331
        test_context = Context(2)
332
        test_list = [2, 0, 1, 3, 0, 2]
333

334
        self.checkScript(test_basic, (test_input,))
335
        self.checkScript(test_pass, (test_input,))
336
        self.checkScript(test_early_return, (test_input, test_context))
337
        self.checkScript(test_break, (test_input, test_context, test_list))
338
        self.checkScript(test_continue, (test_input, test_context, test_list))
339
        self.assertEqual(test_context.count, 2)
340
        self.checkScript(test_serial, (test_input,))
341
        self.checkScript(test_nested, (test_input,))
342
        self.checkScript(test_combined, (test_input,))
343

344
    def test_with_exceptions(self):
345
        """
346
        Check that exceptions thrown in the bodies of with-statements are
347
        handled correctly.
348
        """
349

350
        @torch.jit.script
351
        class Context:
352
            """
353
            This class implements a basic context manager interface for use in
354
            the unit tests. Unlike Context, the stateful part of this class
355
            is a Tensor that is mutated in-place so that modifications made in the
356
            JIT interpreter are visible outside of it.
357
            """
358

359
            def __init__(self, start: int):
360
                self.count = torch.tensor([start], dtype=torch.double)
361

362
            def __enter__(self):
363
                self.count.add_(0.3)
364
                return self.count
365

366
            def __exit__(self, type: Any, value: Any, tb: Any):
367
                self.count.sub_(0.3)
368

369
        make_global(Context)
370

371
        @torch.jit.script
372
        def method_that_raises() -> torch.Tensor:
373
            raise Exception("raised exception")  # noqa: TRY002
374

375
        @torch.jit.script
376
        def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor:
377
            """
378
            Test the case in which an exception is thrown while executing the body of a with-statement.
379
            """
380
            with c as _:
381
                x += method_that_raises()
382

383
            return x
384

385
        @torch.jit.script
386
        def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor:
387
            """
388
            Test the case in which an exception is thrown while executing the body of a nested with-statement.
389
            """
390
            with c as _:
391
                with c as _:
392
                    x += method_that_raises()
393

394
            return x
395

396
        @torch.jit.script
397
        def with_that_raises(c: Context) -> torch.Tensor:
398
            a = torch.tensor([1])
399

400
            with c as _:
401
                a += method_that_raises()
402

403
            return a
404

405
        @torch.jit.script
406
        def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor:
407
            """
408
            Test the case in which an exception is thrown while there are active with-statements in two different
409
            frames.
410
            """
411
            with c as _:
412
                x += with_that_raises(c)
413

414
            return x
415

416
        c = Context(1)
417

418
        # checkScript and checkScriptRaisesRegex cannot be used because the string frontend will
419
        # not compile class types (of which Context, the context manager being used for this test
420
        # is one).
421
        with self.assertRaisesRegexWithHighlight(
422
            Exception, r"raised exception", 'raise Exception("raised exception'
423
        ):
424
            test_exception(torch.randn(2), c)
425
        self.assertEqual(c.count, 1)
426

427
        with self.assertRaisesRegexWithHighlight(
428
            Exception, r"raised exception", 'raise Exception("raised exception'
429
        ):
430
            test_exception_nested(torch.randn(2), c)
431
        self.assertEqual(c.count, 1)
432

433
        with self.assertRaisesRegexWithHighlight(
434
            Exception, r"raised exception", 'raise Exception("raised exception'
435
        ):
436
            test_exception_fn_call(torch.randn(2), c)
437
        self.assertEqual(c.count, 1)
438

439
    def test_with_errors(self):
440
        """
441
        Check that errors related to with-statements are detected and reported correctly.
442
        """
443

444
        @torch.jit.script
445
        class NoEnterNoExit:
446
            """
447
            This class is missing __enter__ and __exit__ methods.
448
            """
449

450
            def __init__(self) -> None:
451
                self.count = 1
452

453
        @torch.jit.script
454
        class BadEnter:
455
            """
456
            This class has an __enter__ method with an incorrect signature.
457
            """
458

459
            def __init__(self) -> None:
460
                self.count = 1
461

462
            def __enter__(self, incr: int):  # noqa: PLE0302
463
                self.count += incr
464

465
            def __exit__(self, type: Any, value: Any, tb: Any):
466
                pass
467

468
        @torch.jit.script
469
        class BadExit:
470
            """
471
            This class has an __exit__ method with an incorrect signature.
472
            """
473

474
            def __init__(self) -> None:
475
                self.count = 1
476

477
            def __enter__(self):
478
                self.count += 1
479

480
            def __exit__(self, type: Any, value: Any):  # noqa: PLE0302
481
                pass
482

483
        @torch.jit.script
484
        class ExitIncorrectTypes:
485
            """
486
            This class has an __exit__ method with unsupported argument types.
487
            """
488

489
            def __init__(self) -> None:
490
                self.count = 1
491

492
            def __enter__(self):
493
                self.count += 1
494

495
            def __exit__(self, type: Any, value: int, tb: int):
496
                pass
497

498
        def test_no_enter_no_exit(x: torch.Tensor, cm: NoEnterNoExit) -> torch.Tensor:
499
            with cm as _:
500
                pass
501

502
            return x
503

504
        def test_bad_enter(x: torch.Tensor, cm: BadEnter) -> torch.Tensor:
505
            with cm as _:
506
                pass
507

508
            return x
509

510
        def test_bad_exit(x: torch.Tensor, cm: BadExit) -> torch.Tensor:
511
            with cm as _:
512
                pass
513

514
            return x
515

516
        def test_exit_incorrect_types(
517
            x: torch.Tensor, cm: ExitIncorrectTypes
518
        ) -> torch.Tensor:
519
            with cm as _:
520
                pass
521

522
            return x
523

524
        def test_enter_without_object():
525
            with "not_object" as obj:
526
                pass
527

528
        test_tensor = torch.randn(5, dtype=torch.double)
529

530
        with self.assertRaisesRegexWithHighlight(
531
            RuntimeError, r"does not define __enter__ and __exit__ methods", "cm"
532
        ):
533
            self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit()))
534

535
        with self.assertRaisesRegexWithHighlight(
536
            RuntimeError,
537
            r"__enter__ must have only one argument and one return value",
538
            "cm",
539
        ):
540
            self.checkScript(test_bad_enter, (test_tensor, BadEnter()))
541

542
        with self.assertRaisesRegexWithHighlight(
543
            RuntimeError, r"__exit__ must have four arguments", "cm"
544
        ):
545
            self.checkScript(test_bad_exit, (test_tensor, BadExit()))
546

547
        with self.assertRaisesRegexWithHighlight(
548
            RuntimeError, r"argument 2 of __exit__ must have Any type", "cm"
549
        ):
550
            self.checkScript(
551
                test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes())
552
            )
553

554
        with self.assertRaisesRegexWithHighlight(
555
            RuntimeError, r"must return an object", '"not_object"'
556
        ):
557
            self.checkScript(test_enter_without_object, ())
558

559
    def test_with_no_grad(self):
560
        """
561
        Check that torch.no_grad() works. Most of these are adapted from
562
        corresponding tests for eager-mode no_grad.
563
        """
564

565
        # Basic no_grad test.
566
        def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
567
            with torch.no_grad():
568
                w = x + y
569

570
            return w
571

572
        s = torch.jit.script(test_no_grad)
573
        x = torch.ones(5, 5, requires_grad=True)
574
        y = torch.ones(5, 5) * 4
575
        w = s(x, y)
576

577
        self.assertFalse(w.requires_grad)
578
        self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
579
        self.assertIsNone(w.grad_fn)
580

581
        # Test assignment of a grad-less Tensor to a Tensor with gradients
582
        # in a no_grad block.
583
        def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
584
            with torch.no_grad():
585
                x[0] = y
586

587
            return x
588

589
        s = torch.jit.script(test_no_grad_assignment)
590
        z = torch.randn(5)
591
        w = s(x, z)
592
        self.assertTrue(w.requires_grad)
593
        self.assertIsNone(w.grad_fn)
594

595
        # Check that @torch.jit.ignored functions respect no_grad when it is
596
        # called in JIT mode.
597
        class NoGradModule(torch.nn.Module):
598
            @torch.jit.ignore
599
            def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
600
                w = x + y
601
                return w
602

603
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
604
                with torch.no_grad():
605
                    w = self.adder(x, y)
606

607
                return w
608

609
        s = torch.jit.script(NoGradModule())
610
        w = s(x, y)
611

612
        self.assertFalse(w.requires_grad)
613

614
    @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
615
    def test_with_record_function(self):
616
        """
617
        Check that torch.autograd.profiler.record_function context manager is
618
        torchscriptable.
619
        """
620

621
        def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
622
            with torch.autograd.profiler.record_function("foo"):
623
                # Nested record_function.
624
                with torch.autograd.profiler.record_function("nested"):
625
                    a = x + y
626
            return a
627

628
        scripted = torch.jit.script(with_rf)
629
        x, y = torch.ones(2), torch.ones(2)
630
        with torch.autograd.profiler.profile() as p:
631
            scripted(x, y)
632

633
        # Need to call below to populate CPU children.
634
        p.key_averages()
635
        function_events = p.function_events
636
        # Event with name "foo" should be recorded.
637
        rf_events = [evt for evt in function_events if evt.name == "foo"]
638
        self.assertEqual(len(rf_events), 1)
639
        rf_event = rf_events[0]
640
        child_events = rf_event.cpu_children
641
        # Ensure we find nested record_function event
642
        self.assertTrue("nested" in (child.name for child in child_events))
643
        nested_function_event = [
644
            evt for evt in function_events if evt.name == "nested"
645
        ][0]
646
        # Nested record function should have child "aten::add"
647
        nested_child_events = nested_function_event.cpu_children
648
        self.assertTrue("aten::add" in (child.name for child in nested_child_events))
649

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.