4
from typing import List, Optional, Tuple
7
from torch import Tensor
8
from torch._awaits import _Await as Await
9
from torch.testing._internal.jit_utils import JitTestCase, make_global
12
class TestAwait(JitTestCase):
13
def test_await_python(self):
14
def foo(x: int) -> int:
17
aw: Await[int] = torch.jit._awaitable(foo, 13)
18
self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw))
19
nw = torch.jit._awaitable_nowait(33)
20
self.assertTrue(nw.is_nowait())
21
self.assertTrue(nw.args() == (33,))
23
def test_await_type_python(self):
27
awaits = torch.jit.annotate(List[Await[Tensor]], [])
28
awaits.append(torch.jit._awaitable(foo))
30
def test_script(self):
31
def delayed(z: int) -> int:
35
aw: Await[int] = torch.jit._awaitable(delayed, 99)
37
b = torch.jit._awaitable_wait(aw)
42
sm = torch.jit.script(fn)
45
self.assertTrue(torch.allclose(torch.eye(2) + 102, script_out))
46
self.assertTrue(torch.allclose(script_out, out))
48
def test_nowait(self):
50
aw = torch.jit._awaitable_nowait(13)
52
b = torch.jit._awaitable_wait(aw)
57
sm = torch.jit.script(fn)
60
self.assertTrue(torch.allclose(torch.eye(2) + 13, script_out))
61
self.assertTrue(torch.allclose(script_out, out))
63
def test_nowait_class(self):
65
def __init__(self, a: Tensor, b: Tensor):
69
def a(self) -> Tensor:
73
aw = torch.jit._awaitable_nowait(C(torch.zeros(2), torch.ones(2)))
75
c = torch.jit._awaitable_wait(aw)
81
sm = torch.jit.script(fn)
84
self.assertTrue(torch.allclose(torch.eye(2), script_out))
85
self.assertTrue(torch.allclose(script_out, out))
87
def test_await_class_arg(self):
89
def __init__(self, a: Tensor, b: Tensor):
93
def a(self) -> Tensor:
98
def delayed(c: C) -> Tensor:
102
c = C(torch.zeros(2), torch.ones(2))
103
aw = torch.jit._awaitable(delayed, c)
105
c2_t = torch.jit._awaitable_wait(aw)
110
sm = torch.jit.script(fn)
113
self.assertTrue(torch.allclose(torch.eye(2), script_out))
114
self.assertTrue(torch.allclose(script_out, out))
116
def test_awaitable_to_await(self):
118
__slots__ = ["_a", "_b"]
120
def __init__(self, a: Tensor, b: Tensor):
128
def C_wait_impl(self: C):
129
return self._a + self._b
132
aw = torch.jit._awaitable(C_wait_impl, C(torch.zeros(2), torch.ones(2)))
134
c_wait_impl_res = torch.jit._awaitable_wait(aw)
135
return _a + c_wait_impl_res + x
139
sm = torch.jit.script(fn)
142
self.assertTrue(torch.allclose(torch.eye(2) + 2 * torch.ones(2), script_out))
143
self.assertTrue(torch.allclose(script_out, out))
145
def test_await_class_return(self):
147
__slots__ = ["a", "b"]
149
def __init__(self, a: Tensor, b: Tensor):
157
def C_wait_impl(self: C) -> C:
158
return C(self.a * 2, self.b * 3)
160
def fn_arg_C(x: C) -> Tensor:
164
aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
166
y = fn_arg_C(torch.jit._awaitable_wait(aw))
171
sm = torch.jit.script(fn)
174
self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out))
175
self.assertTrue(torch.allclose(script_out, out))
176
self.assertGraphContainsExactly(
177
sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1
180
def test_await_getattr_implicit_convertion(self):
182
def __init__(self, a: Tensor, b: Tensor):
193
def C_wait_impl(self: C) -> C:
194
return C(self._a * 2, self._b * 3)
196
def fn_arg_C(x: C) -> Tensor:
200
aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
205
return _a + ai + x + c._a + c.b()
209
sm = torch.jit.script(fn)
212
self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out))
213
self.assertTrue(torch.allclose(script_out, out))
214
self.assertGraphContainsExactly(
215
sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2
218
def test_await_nested(self):
220
def __init__(self, a: Tensor, b: Tensor):
224
def a(self) -> Tensor:
229
def delayed(c: C) -> Await[Tensor]:
230
return torch.jit._awaitable_nowait(3 * c.a())
232
def fn(x: Tensor) -> Await[Await[Tensor]]:
233
return torch.jit._awaitable(delayed, C(2 * x, x))
235
def main(x: Tensor) -> Tensor:
237
return torch.jit._awaitable_wait(torch.jit._awaitable_wait(awaw))
241
sm = torch.jit.script(main)
244
self.assertTrue(torch.allclose(6 * torch.eye(2), script_out))
245
self.assertTrue(torch.allclose(script_out, out))
247
def test_eager_await_non_scriptable(self):
250
def __init__(self, v):
251
self.parent = torch.jit.annotate(Optional[Tree], None)
256
def delayed(t: Tree):
260
aw = torch.jit._awaitable(delayed, Tree(2))
261
t = torch.jit._awaitable_wait(aw)
262
self.assertTrue(t.v == 3)
264
def test_await_isinstance(self):
265
def delayed(x: Tensor) -> Tensor:
268
def main(x: Tensor) -> Tensor:
269
aw = torch.jit._awaitable(delayed, x)
270
if torch.jit.is_scripting():
271
assert isinstance(aw, torch.jit._Await)
272
return torch.jit._awaitable_wait(aw)
276
sm = torch.jit.script(main)
280
torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
282
self.assertTrue(torch.allclose(script_out, out))
284
def test_await_eager_lazy(self):
285
def delayed(x: Tensor) -> Tensor:
288
t = torch.ones(2, dtype=torch.int64)
289
aw = torch.jit._awaitable(delayed, t)
290
self.assertTrue(isinstance(aw, torch._C._Await))
291
self.assertTrue(t.dtype == aw.dtype)
293
def test_await_out_of_interpreter(self):
294
def delayed(x: Tensor) -> Tensor:
297
def main(x: Tensor) -> Await[Tensor]:
298
aw = torch.jit._awaitable(delayed, x)
303
sm = torch.jit.script(main)
305
out = torch.jit._awaitable_wait(out_aw)
307
script_out_aw = sm(inp)
308
script_out = torch.jit._awaitable_wait(script_out_aw)
310
torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
312
self.assertTrue(torch.allclose(script_out, out))
314
def test_jit_trace(self):
316
return torch.relu(x) + torch.sin(x)
318
def delayed(x: Tensor) -> Tensor:
319
return 2 * (torch.cos(x) + 1)
321
def main(x: Tensor, y: Tensor) -> Tensor:
322
aw = torch.jit._awaitable(delayed, x)
324
k = torch.jit._awaitable_wait(aw)
328
tm = torch.jit.trace(main, (inp, inp))
329
inp_check = torch.ones(2)
330
self.assertEqual(main(inp_check, inp_check), tm(inp_check, inp_check))
332
def test_await_multiout_save(self):
334
return torch.relu(x) + torch.sin(x)
336
def delayed(x: Tensor) -> Tuple[Tensor, List[Tensor]]:
337
l = [x * i for i in range(5)]
340
def main(x: Tensor) -> Tensor:
341
aw = torch.jit._awaitable(delayed, x)
343
(_, l) = torch.jit._awaitable_wait(aw)
348
sm = torch.jit.script(main)
351
expected = 4.8415 * torch.eye(2)
352
self.assertTrue(torch.allclose(expected, script_out))
353
self.assertTrue(torch.allclose(script_out, out))
355
iofile = io.BytesIO()
356
torch.jit.save(sm, iofile)
358
sm = torch.jit.load(iofile)
359
script_out_load = sm(inp)
360
self.assertTrue(torch.allclose(expected, script_out_load))
362
def test_await_func_arg(self):
364
return torch.relu(x) + torch.sin(x)
366
def delayed(x: Tensor) -> Tensor:
369
def fn(aw: Await[Tensor]) -> Tensor:
370
return 3 * torch.jit._awaitable_wait(aw)
372
def main(x: Tensor) -> Tensor:
373
aw = torch.jit._awaitable(delayed, x)
380
sm = torch.jit.script(main)
383
expected = -2 * torch.eye(2)
384
self.assertTrue(torch.allclose(expected, script_out))
385
self.assertTrue(torch.allclose(script_out, out))
387
iofile = io.BytesIO()
388
torch.jit.save(sm, iofile)
390
sm = torch.jit.load(iofile)
391
script_out_load = sm(inp)
392
self.assertTrue(torch.allclose(expected, script_out_load))