pytorch

Форк
0
/
test_async.py 
553 строки · 16.9 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5
from typing import Any, Tuple
6

7
import torch
8
import torch.nn as nn
9

10

11
# Make the helper files in test/ importable
12
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13
sys.path.append(pytorch_test_dir)
14
from typing import List
15

16
from torch import Tensor
17
from torch.jit import Future
18
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
19

20

21
class TestAsync(JitTestCase):
22
    def test_async_python(self):
23
        @torch.jit.script
24
        def foo(x):
25
            return torch.neg(x)
26

27
        x = torch.rand(3, 4)
28
        fut = torch.jit.fork(foo, x)
29
        y_hat = foo(x)
30
        y = torch.jit.wait(fut)
31
        # assert nothing; only to make sure the fake python path works
32

33
    def test_async_future_type_python(self):
34
        def foo(inp):
35
            futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], [])
36
            for i in range(5):
37
                futures.append(torch.jit.fork(lambda x: x, inp))
38
            all_outputs = []
39
            for future in futures:
40
                all_outputs.append(torch.jit.wait(future))
41
            return all_outputs
42

43
        # assert nothing, just to make sure python type parsing works
44
        foo(torch.randn(3, 4))
45

46
    def test_async_parsing(self):
47
        @torch.jit.script
48
        def foo(x: Tensor) -> List[Tensor]:
49
            return [torch.neg(x), x.t()]
50

51
        @torch.jit.script
52
        def bar(x):
53
            futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
54
            for _ in range(3):
55
                future = torch.jit.annotate(
56
                    Future[List[Tensor]], torch.jit.fork(foo, x)
57
                )
58
                futures.append(future)
59

60
            output = torch.jit.annotate(List[List[Tensor]], [])
61
            for i in range(3):
62
                output.append(torch.jit.wait(futures[i]))
63
            return output
64

65
        x = torch.rand(3, 3)
66
        result = bar(x)
67
        self.assertEqual(len(result), 3)
68

69
    def test_async_script(self):
70
        @torch.jit.script
71
        def foo(x):
72
            return torch.neg(x), x
73

74
        x = torch.rand(3, 4)
75

76
        @torch.jit.script
77
        def wait_script(x):
78
            fut = torch.jit.fork(foo, x)
79
            y_hat = foo(x)
80
            y = torch.jit.wait(fut)
81
            return y, y_hat
82

83
        y, y_hat = wait_script(x)
84

85
        self.assertEqual(y, y_hat)
86

87
    def test_async_script_capture(self):
88
        class Mod(torch.jit.ScriptModule):
89
            __constants__ = ["const"]
90

91
            def __init__(self) -> None:
92
                super().__init__()
93
                self.const = 42
94
                self.param = nn.Parameter(torch.randn(2, 2))
95

96
            @torch.jit.script_method
97
            def foo(self, x1, x2):
98
                return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param
99

100
            @torch.jit.script_method
101
            def forward(self, x1, x2):
102
                fut = torch.jit.fork(self.foo, x1, x2)
103
                y_hat = self.foo(x1, x2)
104
                y = torch.jit.wait(fut)
105
                return y, y_hat
106

107
        x1 = torch.rand(3, 4)
108
        x2 = torch.rand(5, 6)
109

110
        m = Mod()
111

112
        with torch.jit.optimized_execution(False):
113
            y, y_hat = m.forward(x1, x2)
114

115
        self.assertEqual(y, y_hat)
116

117
    def test_async_script_nested(self):
118
        @torch.jit.script
119
        def foo(x):
120
            return torch.neg(x), x
121

122
        x = torch.rand(3, 4)
123

124
        @torch.jit.script
125
        def wait_script(x):
126
            fut = torch.jit._fork(foo, x)
127
            y_hat = foo(x)
128
            y = torch.jit._wait(fut)
129
            return y, y_hat
130

131
        @torch.jit.script
132
        def wait_script_nest(x):
133
            fut = torch.jit._fork(wait_script, x)
134
            return torch.jit._wait(fut)
135

136
        y, y_hat = wait_script_nest(x)
137

138
        self.assertEqual(y, y_hat)
139

140
    def test_async_script_no_script_mod(self):
141
        x = torch.rand(3, 4)
142

143
        with self.assertRaisesRegexWithHighlight(
144
            RuntimeError, "cannot call a value", "torch.jit._fork(x"
145
        ):
146

147
            @torch.jit.script
148
            def wait_script(x):
149
                fut = torch.jit._fork(x)
150
                return fut
151

152
    def test_async_script_multi_waits(self):
153
        @torch.jit.script
154
        def foo(x):
155
            return torch.neg(x).t() + x
156

157
        @torch.jit.script
158
        def wait_script(x):
159
            fut = torch.jit._fork(foo, x)
160

161
            # wait twice on the same future
162
            y1 = torch.jit._wait(fut)
163
            y2 = torch.jit._wait(fut)
164
            return y1, y2
165

166
        x = torch.rand(2, 2)
167
        y1, y2 = wait_script(x)
168
        self.assertEqual(y1, y2)
169

170
    def test_async_script_multi_forks(self):
171
        @torch.jit.script
172
        def foo1(x):
173
            return torch.neg(x).t() + x
174

175
        @torch.jit.script
176
        def foo2(x, y):
177
            return torch.neg(x).t() + x + torch.neg(y).t()
178

179
        @torch.jit.script
180
        def foo3(x, y, z):
181
            return torch.neg(z).t() + y.t() + x
182

183
        x1 = torch.rand(10, 10)
184
        x2 = torch.rand(10, 10)
185
        x3 = torch.rand(10, 10)
186

187
        @torch.jit.script
188
        def wait_script(x1, x2, x3):
189
            f1 = torch.jit._fork(foo1, x1)
190
            f2 = torch.jit._fork(foo2, x1, x2)
191
            f3 = torch.jit._fork(foo3, x1, x2, x3)
192
            f4 = torch.jit._fork(foo1, x2)
193
            f5 = torch.jit._fork(foo2, x2, x3)
194

195
            # ignore some forks
196
            y1 = torch.jit._wait(f1)
197
            y2 = torch.jit._wait(f2)
198
            y3 = torch.jit._wait(f3)
199

200
            return y1, y2, y3
201

202
        y1, y2, y3 = wait_script(x1, x2, x3)
203
        self.assertEqual(y1, foo1(x1))
204
        self.assertEqual(y2, foo2(x1, x2))
205
        self.assertEqual(y3, foo3(x1, x2, x3))
206

207
    def test_async_kwargs(self):
208
        def foo(x1, x2):
209
            return 2 * x1 + x2
210

211
        x1 = torch.rand(3, 4)
212
        x2 = torch.rand(3, 4)
213
        y_hat = foo(x1, x2)
214

215
        # Cover tracing and bare functions with permutations of args, kwargs
216
        for func in [
217
            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)),
218
            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)),
219
            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)),
220
            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)),
221
        ]:
222
            for wrapper in [
223
                func,
224
                torch.jit.trace(func, (x1, x2)),
225
            ]:
226
                self.assertEqual(wrapper(x1, x2), y_hat)
227
                self.assertEqual(wrapper(x1, x2=x2), y_hat)
228
                self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
229
                self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
230

231
        # Cover scripting
232
        @torch.jit.script
233
        def foo_script_args(x1, x2):
234
            return torch.jit._wait(torch.jit._fork(foo, x1, x2))
235

236
        @torch.jit.script
237
        def foo_script_kwargs(x1, x2):
238
            return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2))
239

240
        for wrapper in [
241
            foo_script_args,
242
            foo_script_kwargs,
243
        ]:
244
            self.assertEqual(wrapper(x1, x2), y_hat)
245
            self.assertEqual(wrapper(x1, x2=x2), y_hat)
246
            self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
247
            self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
248

249
    @_inline_everything
250
    def test_async_script_trace(self):
251
        class Traced(nn.Module):
252
            def forward(self, x):
253
                return (torch.neg(x), x)
254

255
        class Mod(torch.jit.ScriptModule):
256
            def __init__(self) -> None:
257
                super().__init__()
258
                x = torch.rand(3, 3)
259
                self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
260

261
            @torch.jit.script_method
262
            def forward(
263
                self, x: Tensor
264
            ) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
265
                future1 = torch.jit._fork(self.traced, x)
266
                future2 = torch.jit._fork(torch.neg, x)
267

268
                tensor_tuple = torch.jit._wait(future1)
269
                tensor_single = torch.jit._wait(future2)
270

271
                tensor_list = []
272
                tensor_list.append(tensor_tuple[0])
273
                tensor_list.append(tensor_single)
274

275
                # return a nested structure of tensors
276
                return (tensor_list, tensor_tuple, tensor_tuple[1])
277

278
        class TupleCl(nn.Module):
279
            def __init__(self) -> None:
280
                super().__init__()
281
                self.module = Mod()
282

283
            def forward(self, x):
284
                z = torch.neg(x)
285
                y = self.module(x)
286
                list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
287
                return tuple(list)
288

289
        x = torch.rand(3, 3)
290
        module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
291

292
        # Make sure we have forks
293
        self.assertGraphContainsExactly(
294
            module.graph, kind="prim::fork", num_kind_nodes=2
295
        )
296
        # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
297
        self.assertGraphContainsExactly(
298
            module.graph, kind="aten::neg", num_kind_nodes=1
299
        )
300
        self.assertGraphContainsExactly(
301
            module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True
302
        )
303

304
        y = torch.neg(x)
305
        self.assertEqual(module(x), (y, y, y, y, x, x))
306

307
    def test_async_script_error(self):
308
        x = torch.rand(3, 4)
309

310
        @torch.jit.script
311
        def foo(x):
312
            # error here
313
            return x.t() + x
314

315
        @torch.jit.script
316
        def wait_script(x):
317
            fut = torch.jit._fork(foo, x)
318
            return torch.jit._wait(fut)
319

320
        @torch.jit.script
321
        def wait_script_nest(x):
322
            fut = torch.jit._fork(wait_script, x)
323
            return torch.jit._wait(fut)
324

325
        # no future
326
        error_msg = "The size.*must match the size of tensor"
327
        with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"):
328
            foo(x)
329

330
        # one future
331
        with self.assertRaisesRegexWithHighlight(
332
            Exception, error_msg, "torch.jit._fork(foo, x"
333
        ):
334
            wait_script(x)
335

336
        # two futures with a different error
337
        x = torch.rand(3, 4, 5)
338
        with self.assertRaisesRegexWithHighlight(
339
            Exception,
340
            "expects a tensor with <= 2 dimensions",
341
            "torch.jit._fork(wait_script, x",
342
        ):
343
            wait_script_nest(x)
344

345
    def test_async_grad_guard_with_grad(self):
346
        @torch.jit.script
347
        def foo(x):
348
            y = x * 2
349
            return y.requires_grad
350

351
        @torch.jit.script
352
        def bar(x):
353
            fut = torch.jit._fork(foo, x)
354
            requires_grad_in_fork = torch.jit._wait(fut)
355
            z = x * 2
356
            return (requires_grad_in_fork, z.requires_grad)
357

358
        x = torch.randn(3, requires_grad=True)
359

360
        with torch.enable_grad():
361
            (inside_fork, after_wait) = bar(x)
362

363
        self.assertEqual(inside_fork, True)
364
        self.assertEqual(after_wait, True)
365

366
    def test_async_grad_guard_no_grad(self):
367
        @torch.jit.script
368
        def foo(x):
369
            y = x * 2
370
            return y.requires_grad
371

372
        @torch.jit.script
373
        def bar(x):
374
            fut = torch.jit._fork(foo, x)
375
            requires_grad_in_fork = torch.jit._wait(fut)
376
            z = x * 2
377
            return (requires_grad_in_fork, z.requires_grad)
378

379
        x = torch.randn(3, requires_grad=True)
380

381
        with torch.no_grad():
382
            (inside_fork, after_wait) = bar(x)
383

384
        self.assertEqual(inside_fork, False)
385
        self.assertEqual(after_wait, False)
386

387
    def test_trace_fork_wait(self):
388
        def fork_body(x):
389
            return x.neg(), x.neg() + 1
390

391
        def fn(x):
392
            fut = torch.jit._fork(fork_body, x)
393
            vals = torch.jit._wait(fut)
394
            return vals[0], vals[1], x - 1
395

396
        traced = torch.jit.trace(fn, (torch.rand(3, 4),))
397
        x = torch.rand(3, 4)
398
        self.assertEqual(fn(x), traced(x))
399

400
        self.assertGraphContainsExactly(
401
            traced.graph, kind="prim::fork", num_kind_nodes=1
402
        )
403
        self.assertGraphContainsExactly(
404
            traced.graph, kind="aten::wait", num_kind_nodes=1
405
        )
406
        self.assertGraphContainsExactly(
407
            traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True
408
        )
409

410
    def test_trace_fork_wait_leaking(self):
411
        my_list = []
412

413
        def fork_body(x):
414
            my_list.append(x + 1)
415
            return x + 1
416

417
        def fn(x):
418
            fut = torch.jit._fork(fork_body, x)
419
            val = torch.jit._wait(fut)
420
            return my_list[0]
421

422
        with self.assertRaisesRegexWithHighlight(
423
            RuntimeError,
424
            "did not have observable data dependence with trace inputs; "
425
            "this probably indicates your program cannot be understood "
426
            "by the tracer.",
427
            "",
428
        ):
429
            traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
430

431
    def test_trace_fork_wait_inline(self):
432
        def fork_body(x):
433
            return x + 1, x + 2
434

435
        def fn(x):
436
            fut = torch.jit._fork(fork_body, x)
437
            val = torch.jit._wait(fut)
438
            return val[1]
439

440
        traced = torch.jit.trace(fn, (torch.rand(3, 4),))
441
        torch._C._jit_pass_inline_fork_wait(traced.graph)
442
        self.assertGraphContainsExactly(
443
            traced.graph, kind="prim::fork", num_kind_nodes=0
444
        )
445
        self.assertGraphContainsExactly(
446
            traced.graph, kind="aten::wait", num_kind_nodes=0
447
        )
448
        self.assertGraphContainsExactly(
449
            traced.graph, kind="aten::add", num_kind_nodes=2
450
        )
451

452
    def test_trace_fork_wait_list_modulecalls(self):
453
        def add_one(input):
454
            return input + torch.ones(input.size())
455

456
        class TestListFutureModule(nn.Module):
457
            def forward(self, input):
458
                input_list = []
459
                for i in range(3):
460
                    input_list.append(input)
461

462
                fut_list: List[Future[torch.Tensor]] = []
463
                for input_tensor in input_list:
464
                    fut_list.append(torch.jit._fork(add_one, input_tensor))
465
                # return list[future[tensor]] here to ensure tracing
466
                # module calls return the correct types
467
                return fut_list
468

469
        class TestModuleWrapper(nn.Module):
470
            def __init__(self) -> None:
471
                super().__init__()
472
                self.list_fut_mod = TestListFutureModule()
473

474
            def forward(self, input):
475
                fut_list = self.list_fut_mod(input)
476
                res = input
477
                for fut in fut_list:
478
                    res = res + fut.wait()
479
                return res
480

481
        self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),))
482

483
    def test_trace_modulecalls_with_different_output_types(self):
484
        def add_one(input):
485
            return input + torch.ones(input.size())
486

487
        class DifferentOutputModule(nn.Module):
488
            def forward(self, input):
489
                fut_res = torch.jit._fork(add_one, (input))
490

491
                # return different types from module call
492
                return input, fut_res
493

494
        class TestModule(nn.Module):
495
            def __init__(self) -> None:
496
                super().__init__()
497
                self.gen_output = DifferentOutputModule()
498

499
            def forward(self, input):
500
                res, fut_res = self.gen_output(input)
501
                res = res + fut_res.wait()
502
                return res
503

504
        self.checkTrace(TestModule(), (torch.randn(5, 5),))
505

506
    def test_no_future_subtype_message(self):
507
        with self.assertRaisesRegexWithHighlight(
508
            RuntimeError, "Future without a contained type", ""
509
        ):
510

511
            @torch.jit.script
512
            def forward(self, x):
513
                futs = torch.jit.annotate(List[torch.jit.Future], [])
514

515
    def test_future_subtyping(self):
516
        """
517
        Test that futures subtype each other properly.
518
        """
519

520
        # Successful subtyping.
521
        def returns_int(x: int) -> int:
522
            return x + x + 1
523

524
        def returns_future_any(x: int) -> torch.jit.Future[Any]:
525
            return torch.jit._fork(returns_int, (x))
526

527
        @torch.jit.script
528
        def fn_int(x: int) -> Any:
529
            fut = returns_future_any(x)
530
            return fut.wait()
531

532
        # Unsuccessful subtyping.
533
        with self.assertRaisesRegexWithHighlight(
534
            RuntimeError,
535
            r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
536
            "fut = returns_future_float(x",
537
        ):
538

539
            def returns_future_float(x: int) -> torch.jit.Future[float]:
540
                return torch.jit._fork(returns_int, (x))
541

542
            @torch.jit.script
543
            def fn_float(x: int) -> Any:
544
                fut = returns_future_float(x)
545
                return fut.wait()
546

547

548
if __name__ == "__main__":
549
    raise RuntimeError(
550
        "This test file is not meant to be run directly, use:\n\n"
551
        "\tpython test/test_jit.py TESTNAME\n\n"
552
        "instead."
553
    )
554

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.