pytorch

Форк
0
/
test_futures.py 
341 строка · 10.3 Кб
1
# mypy: allow-untyped-defs
2
# Owner(s): ["module: unknown"]
3

4
import threading
5
import time
6
import torch
7
import unittest
8
from torch.futures import Future
9
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
10
from typing import TypeVar
11

12
T = TypeVar("T")
13

14

15
def add_one(fut):
16
    return fut.wait() + 1
17

18

19
class TestFuture(TestCase):
20
    def test_set_exception(self) -> None:
21
        # This test is to ensure errors can propagate across futures.
22
        error_msg = "Intentional Value Error"
23
        value_error = ValueError(error_msg)
24

25
        f = Future[T]()  # type: ignore[valid-type]
26
        # Set exception
27
        f.set_exception(value_error)
28
        # Exception should throw on wait
29
        with self.assertRaisesRegex(ValueError, "Intentional"):
30
            f.wait()
31

32
        # Exception should also throw on value
33
        f = Future[T]()  # type: ignore[valid-type]
34
        f.set_exception(value_error)
35
        with self.assertRaisesRegex(ValueError, "Intentional"):
36
            f.value()
37

38
        def cb(fut):
39
            fut.value()
40

41
        f = Future[T]()  # type: ignore[valid-type]
42
        f.set_exception(value_error)
43

44
        with self.assertRaisesRegex(RuntimeError, "Got the following error"):
45
            cb_fut = f.then(cb)
46
            cb_fut.wait()
47

48
    def test_set_exception_multithreading(self) -> None:
49
        # Ensure errors can propagate when one thread waits on future result
50
        # and the other sets it with an error.
51
        error_msg = "Intentional Value Error"
52
        value_error = ValueError(error_msg)
53

54
        def wait_future(f):
55
            with self.assertRaisesRegex(ValueError, "Intentional"):
56
                f.wait()
57

58
        f = Future[T]()  # type: ignore[valid-type]
59
        t = threading.Thread(target=wait_future, args=(f, ))
60
        t.start()
61
        f.set_exception(value_error)
62
        t.join()
63

64
        def cb(fut):
65
            fut.value()
66

67
        def then_future(f):
68
            fut = f.then(cb)
69
            with self.assertRaisesRegex(RuntimeError, "Got the following error"):
70
                fut.wait()
71

72
        f = Future[T]()  # type: ignore[valid-type]
73
        t = threading.Thread(target=then_future, args=(f, ))
74
        t.start()
75
        f.set_exception(value_error)
76
        t.join()
77

78
    def test_done(self) -> None:
79
        f = Future[torch.Tensor]()
80
        self.assertFalse(f.done())
81

82
        f.set_result(torch.ones(2, 2))
83
        self.assertTrue(f.done())
84

85
    def test_done_exception(self) -> None:
86
        err_msg = "Intentional Value Error"
87

88
        def raise_exception(unused_future):
89
            raise RuntimeError(err_msg)
90

91
        f1 = Future[torch.Tensor]()
92
        self.assertFalse(f1.done())
93
        f1.set_result(torch.ones(2, 2))
94
        self.assertTrue(f1.done())
95

96
        f2 = f1.then(raise_exception)
97
        self.assertTrue(f2.done())
98
        with self.assertRaisesRegex(RuntimeError, err_msg):
99
            f2.wait()
100

101
    def test_wait(self) -> None:
102
        f = Future[torch.Tensor]()
103
        f.set_result(torch.ones(2, 2))
104

105
        self.assertEqual(f.wait(), torch.ones(2, 2))
106

107
    def test_wait_multi_thread(self) -> None:
108

109
        def slow_set_future(fut, value):
110
            time.sleep(0.5)
111
            fut.set_result(value)
112

113
        f = Future[torch.Tensor]()
114

115
        t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
116
        t.start()
117

118
        self.assertEqual(f.wait(), torch.ones(2, 2))
119
        t.join()
120

121
    def test_mark_future_twice(self) -> None:
122
        fut = Future[int]()
123
        fut.set_result(1)
124
        with self.assertRaisesRegex(
125
            RuntimeError,
126
            "Future can only be marked completed once"
127
        ):
128
            fut.set_result(1)
129

130
    def test_pickle_future(self):
131
        fut = Future[int]()
132
        errMsg = "Can not pickle torch.futures.Future"
133
        with TemporaryFileName() as fname:
134
            with self.assertRaisesRegex(RuntimeError, errMsg):
135
                torch.save(fut, fname)
136

137
    def test_then(self):
138
        fut = Future[torch.Tensor]()
139
        then_fut = fut.then(lambda x: x.wait() + 1)
140

141
        fut.set_result(torch.ones(2, 2))
142
        self.assertEqual(fut.wait(), torch.ones(2, 2))
143
        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
144

145
    def test_chained_then(self):
146
        fut = Future[torch.Tensor]()
147
        futs = []
148
        last_fut = fut
149
        for _ in range(20):
150
            last_fut = last_fut.then(add_one)
151
            futs.append(last_fut)
152

153
        fut.set_result(torch.ones(2, 2))
154

155
        for i in range(len(futs)):
156
            self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
157

158
    def _test_then_error(self, cb, errMsg):
159
        fut = Future[int]()
160
        then_fut = fut.then(cb)
161

162
        fut.set_result(5)
163
        self.assertEqual(5, fut.wait())
164
        with self.assertRaisesRegex(RuntimeError, errMsg):
165
            then_fut.wait()
166

167
    def test_then_wrong_arg(self):
168

169
        def wrong_arg(tensor):
170
            return tensor + 1
171

172
        self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int")
173

174
    def test_then_no_arg(self):
175

176
        def no_arg():
177
            return True
178

179
        self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given")
180

181
    def test_then_raise(self):
182

183
        def raise_value_error(fut):
184
            raise ValueError("Expected error")
185

186
        self._test_then_error(raise_value_error, "Expected error")
187

188
    def test_add_done_callback_simple(self):
189
        callback_result = False
190

191
        def callback(fut):
192
            nonlocal callback_result
193
            fut.wait()
194
            callback_result = True
195

196
        fut = Future[torch.Tensor]()
197
        fut.add_done_callback(callback)
198

199
        self.assertFalse(callback_result)
200
        fut.set_result(torch.ones(2, 2))
201
        self.assertEqual(fut.wait(), torch.ones(2, 2))
202
        self.assertTrue(callback_result)
203

204
    def test_add_done_callback_maintains_callback_order(self):
205
        callback_result = 0
206

207
        def callback_set1(fut):
208
            nonlocal callback_result
209
            fut.wait()
210
            callback_result = 1
211

212
        def callback_set2(fut):
213
            nonlocal callback_result
214
            fut.wait()
215
            callback_result = 2
216

217
        fut = Future[torch.Tensor]()
218
        fut.add_done_callback(callback_set1)
219
        fut.add_done_callback(callback_set2)
220

221
        fut.set_result(torch.ones(2, 2))
222
        self.assertEqual(fut.wait(), torch.ones(2, 2))
223
        # set2 called last, callback_result = 2
224
        self.assertEqual(callback_result, 2)
225

226
    def _test_add_done_callback_error_ignored(self, cb):
227
        fut = Future[int]()
228
        fut.add_done_callback(cb)
229

230
        fut.set_result(5)
231
        # error msg logged to stdout
232
        self.assertEqual(5, fut.wait())
233

234
    def test_add_done_callback_error_is_ignored(self):
235

236
        def raise_value_error(fut):
237
            raise ValueError("Expected error")
238

239
        self._test_add_done_callback_error_ignored(raise_value_error)
240

241
    def test_add_done_callback_no_arg_error_is_ignored(self):
242

243
        def no_arg():
244
            return True
245

246
        # Adding another level of function indirection here on purpose.
247
        # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI
248
        self._test_add_done_callback_error_ignored(no_arg)
249

250
    def test_interleaving_then_and_add_done_callback_maintains_callback_order(self):
251
        callback_result = 0
252

253
        def callback_set1(fut):
254
            nonlocal callback_result
255
            fut.wait()
256
            callback_result = 1
257

258
        def callback_set2(fut):
259
            nonlocal callback_result
260
            fut.wait()
261
            callback_result = 2
262

263
        def callback_then(fut):
264
            nonlocal callback_result
265
            return fut.wait() + callback_result
266

267
        fut = Future[torch.Tensor]()
268
        fut.add_done_callback(callback_set1)
269
        then_fut = fut.then(callback_then)
270
        fut.add_done_callback(callback_set2)
271

272
        fut.set_result(torch.ones(2, 2))
273
        self.assertEqual(fut.wait(), torch.ones(2, 2))
274
        # then_fut's callback is called with callback_result = 1
275
        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
276
        # set2 called last, callback_result = 2
277
        self.assertEqual(callback_result, 2)
278

279
    def test_interleaving_then_and_add_done_callback_propagates_error(self):
280
        def raise_value_error(fut):
281
            raise ValueError("Expected error")
282

283
        fut = Future[torch.Tensor]()
284
        then_fut = fut.then(raise_value_error)
285
        fut.add_done_callback(raise_value_error)
286
        fut.set_result(torch.ones(2, 2))
287

288
        # error from add_done_callback's callback is swallowed
289
        # error from then's callback is not
290
        self.assertEqual(fut.wait(), torch.ones(2, 2))
291
        with self.assertRaisesRegex(RuntimeError, "Expected error"):
292
            then_fut.wait()
293

294
    def test_collect_all(self):
295
        fut1 = Future[int]()
296
        fut2 = Future[int]()
297
        fut_all = torch.futures.collect_all([fut1, fut2])
298

299
        def slow_in_thread(fut, value):
300
            time.sleep(0.1)
301
            fut.set_result(value)
302

303
        t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
304
        fut2.set_result(2)
305
        t.start()
306

307
        res = fut_all.wait()
308
        self.assertEqual(res[0].wait(), 1)
309
        self.assertEqual(res[1].wait(), 2)
310
        t.join()
311

312
    @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
313
    def test_wait_all(self):
314
        fut1 = Future[int]()
315
        fut2 = Future[int]()
316

317
        # No error version
318
        fut1.set_result(1)
319
        fut2.set_result(2)
320
        res = torch.futures.wait_all([fut1, fut2])
321
        print(res)
322
        self.assertEqual(res, [1, 2])
323

324
        # Version with an exception
325
        def raise_in_fut(fut):
326
            raise ValueError("Expected error")
327
        fut3 = fut1.then(raise_in_fut)
328
        with self.assertRaisesRegex(RuntimeError, "Expected error"):
329
            torch.futures.wait_all([fut3, fut2])
330

331
    def test_wait_none(self):
332
        fut1 = Future[int]()
333
        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
334
            torch.jit.wait(None)
335
        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
336
            torch.futures.wait_all((None,))  # type: ignore[arg-type]
337
        with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
338
            torch.futures.collect_all((fut1, None,))  # type: ignore[arg-type]
339

340
if __name__ == '__main__':
341
    run_tests()
342

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.