1
# mypy: allow-untyped-defs
2
# Owner(s): ["module: unknown"]
8
from torch.futures import Future
9
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
10
from typing import TypeVar
19
class TestFuture(TestCase):
20
def test_set_exception(self) -> None:
21
# This test is to ensure errors can propagate across futures.
22
error_msg = "Intentional Value Error"
23
value_error = ValueError(error_msg)
25
f = Future[T]() # type: ignore[valid-type]
27
f.set_exception(value_error)
28
# Exception should throw on wait
29
with self.assertRaisesRegex(ValueError, "Intentional"):
32
# Exception should also throw on value
33
f = Future[T]() # type: ignore[valid-type]
34
f.set_exception(value_error)
35
with self.assertRaisesRegex(ValueError, "Intentional"):
41
f = Future[T]() # type: ignore[valid-type]
42
f.set_exception(value_error)
44
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
48
def test_set_exception_multithreading(self) -> None:
49
# Ensure errors can propagate when one thread waits on future result
50
# and the other sets it with an error.
51
error_msg = "Intentional Value Error"
52
value_error = ValueError(error_msg)
55
with self.assertRaisesRegex(ValueError, "Intentional"):
58
f = Future[T]() # type: ignore[valid-type]
59
t = threading.Thread(target=wait_future, args=(f, ))
61
f.set_exception(value_error)
69
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
72
f = Future[T]() # type: ignore[valid-type]
73
t = threading.Thread(target=then_future, args=(f, ))
75
f.set_exception(value_error)
78
def test_done(self) -> None:
79
f = Future[torch.Tensor]()
80
self.assertFalse(f.done())
82
f.set_result(torch.ones(2, 2))
83
self.assertTrue(f.done())
85
def test_done_exception(self) -> None:
86
err_msg = "Intentional Value Error"
88
def raise_exception(unused_future):
89
raise RuntimeError(err_msg)
91
f1 = Future[torch.Tensor]()
92
self.assertFalse(f1.done())
93
f1.set_result(torch.ones(2, 2))
94
self.assertTrue(f1.done())
96
f2 = f1.then(raise_exception)
97
self.assertTrue(f2.done())
98
with self.assertRaisesRegex(RuntimeError, err_msg):
101
def test_wait(self) -> None:
102
f = Future[torch.Tensor]()
103
f.set_result(torch.ones(2, 2))
105
self.assertEqual(f.wait(), torch.ones(2, 2))
107
def test_wait_multi_thread(self) -> None:
109
def slow_set_future(fut, value):
111
fut.set_result(value)
113
f = Future[torch.Tensor]()
115
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
118
self.assertEqual(f.wait(), torch.ones(2, 2))
121
def test_mark_future_twice(self) -> None:
124
with self.assertRaisesRegex(
126
"Future can only be marked completed once"
130
def test_pickle_future(self):
132
errMsg = "Can not pickle torch.futures.Future"
133
with TemporaryFileName() as fname:
134
with self.assertRaisesRegex(RuntimeError, errMsg):
135
torch.save(fut, fname)
138
fut = Future[torch.Tensor]()
139
then_fut = fut.then(lambda x: x.wait() + 1)
141
fut.set_result(torch.ones(2, 2))
142
self.assertEqual(fut.wait(), torch.ones(2, 2))
143
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
145
def test_chained_then(self):
146
fut = Future[torch.Tensor]()
150
last_fut = last_fut.then(add_one)
151
futs.append(last_fut)
153
fut.set_result(torch.ones(2, 2))
155
for i in range(len(futs)):
156
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
158
def _test_then_error(self, cb, errMsg):
160
then_fut = fut.then(cb)
163
self.assertEqual(5, fut.wait())
164
with self.assertRaisesRegex(RuntimeError, errMsg):
167
def test_then_wrong_arg(self):
169
def wrong_arg(tensor):
172
self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int")
174
def test_then_no_arg(self):
179
self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given")
181
def test_then_raise(self):
183
def raise_value_error(fut):
184
raise ValueError("Expected error")
186
self._test_then_error(raise_value_error, "Expected error")
188
def test_add_done_callback_simple(self):
189
callback_result = False
192
nonlocal callback_result
194
callback_result = True
196
fut = Future[torch.Tensor]()
197
fut.add_done_callback(callback)
199
self.assertFalse(callback_result)
200
fut.set_result(torch.ones(2, 2))
201
self.assertEqual(fut.wait(), torch.ones(2, 2))
202
self.assertTrue(callback_result)
204
def test_add_done_callback_maintains_callback_order(self):
207
def callback_set1(fut):
208
nonlocal callback_result
212
def callback_set2(fut):
213
nonlocal callback_result
217
fut = Future[torch.Tensor]()
218
fut.add_done_callback(callback_set1)
219
fut.add_done_callback(callback_set2)
221
fut.set_result(torch.ones(2, 2))
222
self.assertEqual(fut.wait(), torch.ones(2, 2))
223
# set2 called last, callback_result = 2
224
self.assertEqual(callback_result, 2)
226
def _test_add_done_callback_error_ignored(self, cb):
228
fut.add_done_callback(cb)
231
# error msg logged to stdout
232
self.assertEqual(5, fut.wait())
234
def test_add_done_callback_error_is_ignored(self):
236
def raise_value_error(fut):
237
raise ValueError("Expected error")
239
self._test_add_done_callback_error_ignored(raise_value_error)
241
def test_add_done_callback_no_arg_error_is_ignored(self):
246
# Adding another level of function indirection here on purpose.
247
# Otherwise mypy will pick up on no_arg having an incompatible type and fail CI
248
self._test_add_done_callback_error_ignored(no_arg)
250
def test_interleaving_then_and_add_done_callback_maintains_callback_order(self):
253
def callback_set1(fut):
254
nonlocal callback_result
258
def callback_set2(fut):
259
nonlocal callback_result
263
def callback_then(fut):
264
nonlocal callback_result
265
return fut.wait() + callback_result
267
fut = Future[torch.Tensor]()
268
fut.add_done_callback(callback_set1)
269
then_fut = fut.then(callback_then)
270
fut.add_done_callback(callback_set2)
272
fut.set_result(torch.ones(2, 2))
273
self.assertEqual(fut.wait(), torch.ones(2, 2))
274
# then_fut's callback is called with callback_result = 1
275
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
276
# set2 called last, callback_result = 2
277
self.assertEqual(callback_result, 2)
279
def test_interleaving_then_and_add_done_callback_propagates_error(self):
280
def raise_value_error(fut):
281
raise ValueError("Expected error")
283
fut = Future[torch.Tensor]()
284
then_fut = fut.then(raise_value_error)
285
fut.add_done_callback(raise_value_error)
286
fut.set_result(torch.ones(2, 2))
288
# error from add_done_callback's callback is swallowed
289
# error from then's callback is not
290
self.assertEqual(fut.wait(), torch.ones(2, 2))
291
with self.assertRaisesRegex(RuntimeError, "Expected error"):
294
def test_collect_all(self):
297
fut_all = torch.futures.collect_all([fut1, fut2])
299
def slow_in_thread(fut, value):
301
fut.set_result(value)
303
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
308
self.assertEqual(res[0].wait(), 1)
309
self.assertEqual(res[1].wait(), 2)
312
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
313
def test_wait_all(self):
320
res = torch.futures.wait_all([fut1, fut2])
322
self.assertEqual(res, [1, 2])
324
# Version with an exception
325
def raise_in_fut(fut):
326
raise ValueError("Expected error")
327
fut3 = fut1.then(raise_in_fut)
328
with self.assertRaisesRegex(RuntimeError, "Expected error"):
329
torch.futures.wait_all([fut3, fut2])
331
def test_wait_none(self):
333
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
335
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
336
torch.futures.wait_all((None,)) # type: ignore[arg-type]
337
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
338
torch.futures.collect_all((fut1, None,)) # type: ignore[arg-type]
340
if __name__ == '__main__':