1
# Owner(s): ["oncall: jit"]
5
from typing import Any, Tuple
11
# Make the helper files in test/ importable
12
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13
sys.path.append(pytorch_test_dir)
14
from typing import List
16
from torch import Tensor
17
from torch.jit import Future
18
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
21
class TestAsync(JitTestCase):
22
def test_async_python(self):
28
fut = torch.jit.fork(foo, x)
30
y = torch.jit.wait(fut)
31
# assert nothing; only to make sure the fake python path works
33
def test_async_future_type_python(self):
35
futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], [])
37
futures.append(torch.jit.fork(lambda x: x, inp))
39
for future in futures:
40
all_outputs.append(torch.jit.wait(future))
43
# assert nothing, just to make sure python type parsing works
44
foo(torch.randn(3, 4))
46
def test_async_parsing(self):
48
def foo(x: Tensor) -> List[Tensor]:
49
return [torch.neg(x), x.t()]
53
futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
55
future = torch.jit.annotate(
56
Future[List[Tensor]], torch.jit.fork(foo, x)
58
futures.append(future)
60
output = torch.jit.annotate(List[List[Tensor]], [])
62
output.append(torch.jit.wait(futures[i]))
67
self.assertEqual(len(result), 3)
69
def test_async_script(self):
72
return torch.neg(x), x
78
fut = torch.jit.fork(foo, x)
80
y = torch.jit.wait(fut)
83
y, y_hat = wait_script(x)
85
self.assertEqual(y, y_hat)
87
def test_async_script_capture(self):
88
class Mod(torch.jit.ScriptModule):
89
__constants__ = ["const"]
91
def __init__(self) -> None:
94
self.param = nn.Parameter(torch.randn(2, 2))
96
@torch.jit.script_method
97
def foo(self, x1, x2):
98
return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param
100
@torch.jit.script_method
101
def forward(self, x1, x2):
102
fut = torch.jit.fork(self.foo, x1, x2)
103
y_hat = self.foo(x1, x2)
104
y = torch.jit.wait(fut)
107
x1 = torch.rand(3, 4)
108
x2 = torch.rand(5, 6)
112
with torch.jit.optimized_execution(False):
113
y, y_hat = m.forward(x1, x2)
115
self.assertEqual(y, y_hat)
117
def test_async_script_nested(self):
120
return torch.neg(x), x
126
fut = torch.jit._fork(foo, x)
128
y = torch.jit._wait(fut)
132
def wait_script_nest(x):
133
fut = torch.jit._fork(wait_script, x)
134
return torch.jit._wait(fut)
136
y, y_hat = wait_script_nest(x)
138
self.assertEqual(y, y_hat)
140
def test_async_script_no_script_mod(self):
143
with self.assertRaisesRegexWithHighlight(
144
RuntimeError, "cannot call a value", "torch.jit._fork(x"
149
fut = torch.jit._fork(x)
152
def test_async_script_multi_waits(self):
155
return torch.neg(x).t() + x
159
fut = torch.jit._fork(foo, x)
161
# wait twice on the same future
162
y1 = torch.jit._wait(fut)
163
y2 = torch.jit._wait(fut)
167
y1, y2 = wait_script(x)
168
self.assertEqual(y1, y2)
170
def test_async_script_multi_forks(self):
173
return torch.neg(x).t() + x
177
return torch.neg(x).t() + x + torch.neg(y).t()
181
return torch.neg(z).t() + y.t() + x
183
x1 = torch.rand(10, 10)
184
x2 = torch.rand(10, 10)
185
x3 = torch.rand(10, 10)
188
def wait_script(x1, x2, x3):
189
f1 = torch.jit._fork(foo1, x1)
190
f2 = torch.jit._fork(foo2, x1, x2)
191
f3 = torch.jit._fork(foo3, x1, x2, x3)
192
f4 = torch.jit._fork(foo1, x2)
193
f5 = torch.jit._fork(foo2, x2, x3)
196
y1 = torch.jit._wait(f1)
197
y2 = torch.jit._wait(f2)
198
y3 = torch.jit._wait(f3)
202
y1, y2, y3 = wait_script(x1, x2, x3)
203
self.assertEqual(y1, foo1(x1))
204
self.assertEqual(y2, foo2(x1, x2))
205
self.assertEqual(y3, foo3(x1, x2, x3))
207
def test_async_kwargs(self):
211
x1 = torch.rand(3, 4)
212
x2 = torch.rand(3, 4)
215
# Cover tracing and bare functions with permutations of args, kwargs
217
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)),
218
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)),
219
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)),
220
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)),
224
torch.jit.trace(func, (x1, x2)),
226
self.assertEqual(wrapper(x1, x2), y_hat)
227
self.assertEqual(wrapper(x1, x2=x2), y_hat)
228
self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
229
self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
233
def foo_script_args(x1, x2):
234
return torch.jit._wait(torch.jit._fork(foo, x1, x2))
237
def foo_script_kwargs(x1, x2):
238
return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2))
244
self.assertEqual(wrapper(x1, x2), y_hat)
245
self.assertEqual(wrapper(x1, x2=x2), y_hat)
246
self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
247
self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
250
def test_async_script_trace(self):
251
class Traced(nn.Module):
252
def forward(self, x):
253
return (torch.neg(x), x)
255
class Mod(torch.jit.ScriptModule):
256
def __init__(self) -> None:
259
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
261
@torch.jit.script_method
264
) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
265
future1 = torch.jit._fork(self.traced, x)
266
future2 = torch.jit._fork(torch.neg, x)
268
tensor_tuple = torch.jit._wait(future1)
269
tensor_single = torch.jit._wait(future2)
272
tensor_list.append(tensor_tuple[0])
273
tensor_list.append(tensor_single)
275
# return a nested structure of tensors
276
return (tensor_list, tensor_tuple, tensor_tuple[1])
278
class TupleCl(nn.Module):
279
def __init__(self) -> None:
283
def forward(self, x):
286
list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
290
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
292
# Make sure we have forks
293
self.assertGraphContainsExactly(
294
module.graph, kind="prim::fork", num_kind_nodes=2
296
# Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
297
self.assertGraphContainsExactly(
298
module.graph, kind="aten::neg", num_kind_nodes=1
300
self.assertGraphContainsExactly(
301
module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True
305
self.assertEqual(module(x), (y, y, y, y, x, x))
307
def test_async_script_error(self):
317
fut = torch.jit._fork(foo, x)
318
return torch.jit._wait(fut)
321
def wait_script_nest(x):
322
fut = torch.jit._fork(wait_script, x)
323
return torch.jit._wait(fut)
326
error_msg = "The size.*must match the size of tensor"
327
with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"):
331
with self.assertRaisesRegexWithHighlight(
332
Exception, error_msg, "torch.jit._fork(foo, x"
336
# two futures with a different error
337
x = torch.rand(3, 4, 5)
338
with self.assertRaisesRegexWithHighlight(
340
"expects a tensor with <= 2 dimensions",
341
"torch.jit._fork(wait_script, x",
345
def test_async_grad_guard_with_grad(self):
349
return y.requires_grad
353
fut = torch.jit._fork(foo, x)
354
requires_grad_in_fork = torch.jit._wait(fut)
356
return (requires_grad_in_fork, z.requires_grad)
358
x = torch.randn(3, requires_grad=True)
360
with torch.enable_grad():
361
(inside_fork, after_wait) = bar(x)
363
self.assertEqual(inside_fork, True)
364
self.assertEqual(after_wait, True)
366
def test_async_grad_guard_no_grad(self):
370
return y.requires_grad
374
fut = torch.jit._fork(foo, x)
375
requires_grad_in_fork = torch.jit._wait(fut)
377
return (requires_grad_in_fork, z.requires_grad)
379
x = torch.randn(3, requires_grad=True)
381
with torch.no_grad():
382
(inside_fork, after_wait) = bar(x)
384
self.assertEqual(inside_fork, False)
385
self.assertEqual(after_wait, False)
387
def test_trace_fork_wait(self):
389
return x.neg(), x.neg() + 1
392
fut = torch.jit._fork(fork_body, x)
393
vals = torch.jit._wait(fut)
394
return vals[0], vals[1], x - 1
396
traced = torch.jit.trace(fn, (torch.rand(3, 4),))
398
self.assertEqual(fn(x), traced(x))
400
self.assertGraphContainsExactly(
401
traced.graph, kind="prim::fork", num_kind_nodes=1
403
self.assertGraphContainsExactly(
404
traced.graph, kind="aten::wait", num_kind_nodes=1
406
self.assertGraphContainsExactly(
407
traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True
410
def test_trace_fork_wait_leaking(self):
414
my_list.append(x + 1)
418
fut = torch.jit._fork(fork_body, x)
419
val = torch.jit._wait(fut)
422
with self.assertRaisesRegexWithHighlight(
424
"did not have observable data dependence with trace inputs; "
425
"this probably indicates your program cannot be understood "
429
traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
431
def test_trace_fork_wait_inline(self):
436
fut = torch.jit._fork(fork_body, x)
437
val = torch.jit._wait(fut)
440
traced = torch.jit.trace(fn, (torch.rand(3, 4),))
441
torch._C._jit_pass_inline_fork_wait(traced.graph)
442
self.assertGraphContainsExactly(
443
traced.graph, kind="prim::fork", num_kind_nodes=0
445
self.assertGraphContainsExactly(
446
traced.graph, kind="aten::wait", num_kind_nodes=0
448
self.assertGraphContainsExactly(
449
traced.graph, kind="aten::add", num_kind_nodes=2
452
def test_trace_fork_wait_list_modulecalls(self):
454
return input + torch.ones(input.size())
456
class TestListFutureModule(nn.Module):
457
def forward(self, input):
460
input_list.append(input)
462
fut_list: List[Future[torch.Tensor]] = []
463
for input_tensor in input_list:
464
fut_list.append(torch.jit._fork(add_one, input_tensor))
465
# return list[future[tensor]] here to ensure tracing
466
# module calls return the correct types
469
class TestModuleWrapper(nn.Module):
470
def __init__(self) -> None:
472
self.list_fut_mod = TestListFutureModule()
474
def forward(self, input):
475
fut_list = self.list_fut_mod(input)
478
res = res + fut.wait()
481
self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),))
483
def test_trace_modulecalls_with_different_output_types(self):
485
return input + torch.ones(input.size())
487
class DifferentOutputModule(nn.Module):
488
def forward(self, input):
489
fut_res = torch.jit._fork(add_one, (input))
491
# return different types from module call
492
return input, fut_res
494
class TestModule(nn.Module):
495
def __init__(self) -> None:
497
self.gen_output = DifferentOutputModule()
499
def forward(self, input):
500
res, fut_res = self.gen_output(input)
501
res = res + fut_res.wait()
504
self.checkTrace(TestModule(), (torch.randn(5, 5),))
506
def test_no_future_subtype_message(self):
507
with self.assertRaisesRegexWithHighlight(
508
RuntimeError, "Future without a contained type", ""
512
def forward(self, x):
513
futs = torch.jit.annotate(List[torch.jit.Future], [])
515
def test_future_subtyping(self):
517
Test that futures subtype each other properly.
520
# Successful subtyping.
521
def returns_int(x: int) -> int:
524
def returns_future_any(x: int) -> torch.jit.Future[Any]:
525
return torch.jit._fork(returns_int, (x))
528
def fn_int(x: int) -> Any:
529
fut = returns_future_any(x)
532
# Unsuccessful subtyping.
533
with self.assertRaisesRegexWithHighlight(
535
r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
536
"fut = returns_future_float(x",
539
def returns_future_float(x: int) -> torch.jit.Future[float]:
540
return torch.jit._fork(returns_int, (x))
543
def fn_float(x: int) -> Any:
544
fut = returns_future_float(x)
548
if __name__ == "__main__":
550
"This test file is not meant to be run directly, use:\n\n"
551
"\tpython test/test_jit.py TESTNAME\n\n"