5
from typing import Any, List
8
from torch.testing._internal.common_utils import skipIfTorchDynamo
9
from torch.testing._internal.jit_utils import JitTestCase, make_global
13
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14
sys.path.append(pytorch_test_dir)
16
if __name__ == "__main__":
18
"This test file is not meant to be run directly, use:\n\n"
19
"\tpython test/test_jit.py TESTNAME\n\n"
24
class TestWith(JitTestCase):
26
A suite of tests for with statements.
29
def test_with_as(self):
31
Check that with statements that use the 'as' keyword to bind expressions
32
to targets work as expected.
38
This class implements a basic context manager interface for use in
39
the unit tests. Unlike Context, the stateful part of this class
40
is a Tensor that is mutated in-place so that modifications made in the
41
JIT interpreter are visible outside of it.
44
def __init__(self, start: int):
45
self.count = torch.tensor([start], dtype=torch.double)
51
def __exit__(self, type: Any, value: Any, tb: Any) -> bool:
57
def test_basic(x: torch.Tensor) -> torch.Tensor:
58
"""Basic test with one with-statement."""
68
def test_pass(x: torch.Tensor) -> torch.Tensor:
70
Test with a pass statement inside a with-statement. Although
71
the body of the with is empty, __enter__ and __exit__ should
82
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
84
Test that returning early from inside a with-statement works
94
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
96
Test that conditionally returning early from inside a with-statement works
107
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
109
Test that breaking early from inside a with-statement works
120
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
122
Test that using continue inside a with-statement works
133
def test_serial(x: torch.Tensor) -> torch.Tensor:
135
Test two with-statements in a row.
147
def test_nested(x: torch.Tensor) -> torch.Tensor:
149
Test nested with-statements.
161
def test_combined(x: torch.Tensor) -> torch.Tensor:
163
Test a with-statement with multiple with items.
173
test_input = torch.randn(2, 2)
174
test_context = Context(2)
175
test_list = [2, 0, 1, 3, 0, 2]
177
self.checkScript(test_basic, (test_input,))
178
self.checkScript(test_pass, (test_input,))
179
self.checkScript(test_early_return, (test_input, test_context))
180
self.checkScript(test_break, (test_input, test_context, test_list))
181
self.checkScript(test_continue, (test_input, test_context, test_list))
182
self.assertEqual(test_context.count, 2)
183
self.checkScript(test_serial, (test_input,))
184
self.checkScript(test_nested, (test_input,))
185
self.checkScript(test_combined, (test_input,))
187
def test_with_no_as(self):
189
Check that with statements that do not use the 'as' keyword to bind expressions
190
to targets work as expected.
196
This class implements a basic context manager interface for use in
197
the unit tests. Unlike Context, the stateful part of this class
198
is a Tensor that is mutated in-place so that modifications made in the
199
JIT interpreter are visible outside of it.
202
def __init__(self, start: int):
203
self.count = torch.tensor([start], dtype=torch.double)
209
def __exit__(self, type: Any, value: Any, tb: Any):
214
def test_basic(x: torch.Tensor) -> torch.Tensor:
215
"""Basic test with one with-statement."""
225
def test_pass(x: torch.Tensor) -> torch.Tensor:
227
Test with a pass statement inside a with-statement. Although
228
the body of the with is empty, __enter__ and __exit__ should
239
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
241
Test that returning early from inside a with-statement works
251
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
253
Test that conditionally returning early from inside a with-statement works
264
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
266
Test that breaking early from inside a with-statement works
277
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
279
Test that using continue inside a with-statement works
290
def test_serial(x: torch.Tensor) -> torch.Tensor:
292
Test two with-statements in a row.
304
def test_nested(x: torch.Tensor) -> torch.Tensor:
306
Test nested with-statements.
318
def test_combined(x: torch.Tensor) -> torch.Tensor:
320
Test a with-statement with multiple with items.
326
y = x + (c.count + d.count)
330
test_input = torch.randn(2, 2)
331
test_context = Context(2)
332
test_list = [2, 0, 1, 3, 0, 2]
334
self.checkScript(test_basic, (test_input,))
335
self.checkScript(test_pass, (test_input,))
336
self.checkScript(test_early_return, (test_input, test_context))
337
self.checkScript(test_break, (test_input, test_context, test_list))
338
self.checkScript(test_continue, (test_input, test_context, test_list))
339
self.assertEqual(test_context.count, 2)
340
self.checkScript(test_serial, (test_input,))
341
self.checkScript(test_nested, (test_input,))
342
self.checkScript(test_combined, (test_input,))
344
def test_with_exceptions(self):
346
Check that exceptions thrown in the bodies of with-statements are
353
This class implements a basic context manager interface for use in
354
the unit tests. Unlike Context, the stateful part of this class
355
is a Tensor that is mutated in-place so that modifications made in the
356
JIT interpreter are visible outside of it.
359
def __init__(self, start: int):
360
self.count = torch.tensor([start], dtype=torch.double)
366
def __exit__(self, type: Any, value: Any, tb: Any):
372
def method_that_raises() -> torch.Tensor:
373
raise Exception("raised exception")
376
def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor:
378
Test the case in which an exception is thrown while executing the body of a with-statement.
381
x += method_that_raises()
386
def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor:
388
Test the case in which an exception is thrown while executing the body of a nested with-statement.
392
x += method_that_raises()
397
def with_that_raises(c: Context) -> torch.Tensor:
398
a = torch.tensor([1])
401
a += method_that_raises()
406
def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor:
408
Test the case in which an exception is thrown while there are active with-statements in two different
412
x += with_that_raises(c)
421
with self.assertRaisesRegexWithHighlight(
422
Exception, r"raised exception", 'raise Exception("raised exception'
424
test_exception(torch.randn(2), c)
425
self.assertEqual(c.count, 1)
427
with self.assertRaisesRegexWithHighlight(
428
Exception, r"raised exception", 'raise Exception("raised exception'
430
test_exception_nested(torch.randn(2), c)
431
self.assertEqual(c.count, 1)
433
with self.assertRaisesRegexWithHighlight(
434
Exception, r"raised exception", 'raise Exception("raised exception'
436
test_exception_fn_call(torch.randn(2), c)
437
self.assertEqual(c.count, 1)
439
def test_with_errors(self):
441
Check that errors related to with-statements are detected and reported correctly.
447
This class is missing __enter__ and __exit__ methods.
450
def __init__(self) -> None:
456
This class has an __enter__ method with an incorrect signature.
459
def __init__(self) -> None:
462
def __enter__(self, incr: int):
465
def __exit__(self, type: Any, value: Any, tb: Any):
471
This class has an __exit__ method with an incorrect signature.
474
def __init__(self) -> None:
480
def __exit__(self, type: Any, value: Any):
484
class ExitIncorrectTypes:
486
This class has an __exit__ method with unsupported argument types.
489
def __init__(self) -> None:
495
def __exit__(self, type: Any, value: int, tb: int):
498
def test_no_enter_no_exit(x: torch.Tensor, cm: NoEnterNoExit) -> torch.Tensor:
504
def test_bad_enter(x: torch.Tensor, cm: BadEnter) -> torch.Tensor:
510
def test_bad_exit(x: torch.Tensor, cm: BadExit) -> torch.Tensor:
516
def test_exit_incorrect_types(
517
x: torch.Tensor, cm: ExitIncorrectTypes
524
def test_enter_without_object():
525
with "not_object" as obj:
528
test_tensor = torch.randn(5, dtype=torch.double)
530
with self.assertRaisesRegexWithHighlight(
531
RuntimeError, r"does not define __enter__ and __exit__ methods", "cm"
533
self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit()))
535
with self.assertRaisesRegexWithHighlight(
537
r"__enter__ must have only one argument and one return value",
540
self.checkScript(test_bad_enter, (test_tensor, BadEnter()))
542
with self.assertRaisesRegexWithHighlight(
543
RuntimeError, r"__exit__ must have four arguments", "cm"
545
self.checkScript(test_bad_exit, (test_tensor, BadExit()))
547
with self.assertRaisesRegexWithHighlight(
548
RuntimeError, r"argument 2 of __exit__ must have Any type", "cm"
551
test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes())
554
with self.assertRaisesRegexWithHighlight(
555
RuntimeError, r"must return an object", '"not_object"'
557
self.checkScript(test_enter_without_object, ())
559
def test_with_no_grad(self):
561
Check that torch.no_grad() works. Most of these are adapted from
562
corresponding tests for eager-mode no_grad.
566
def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
567
with torch.no_grad():
572
s = torch.jit.script(test_no_grad)
573
x = torch.ones(5, 5, requires_grad=True)
574
y = torch.ones(5, 5) * 4
577
self.assertFalse(w.requires_grad)
578
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
579
self.assertIsNone(w.grad_fn)
583
def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
584
with torch.no_grad():
589
s = torch.jit.script(test_no_grad_assignment)
592
self.assertTrue(w.requires_grad)
593
self.assertIsNone(w.grad_fn)
597
class NoGradModule(torch.nn.Module):
599
def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
603
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
604
with torch.no_grad():
609
s = torch.jit.script(NoGradModule())
612
self.assertFalse(w.requires_grad)
614
@skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
615
def test_with_record_function(self):
617
Check that torch.autograd.profiler.record_function context manager is
621
def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
622
with torch.autograd.profiler.record_function("foo"):
624
with torch.autograd.profiler.record_function("nested"):
628
scripted = torch.jit.script(with_rf)
629
x, y = torch.ones(2), torch.ones(2)
630
with torch.autograd.profiler.profile() as p:
635
function_events = p.function_events
637
rf_events = [evt for evt in function_events if evt.name == "foo"]
638
self.assertEqual(len(rf_events), 1)
639
rf_event = rf_events[0]
640
child_events = rf_event.cpu_children
642
self.assertTrue("nested" in (child.name for child in child_events))
643
nested_function_event = [
644
evt for evt in function_events if evt.name == "nested"
647
nested_child_events = nested_function_event.cpu_children
648
self.assertTrue("aten::add" in (child.name for child in nested_child_events))