pytorch

Форк
0
/
test_await.py 
392 строки · 11.8 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import io
4
from typing import List, Optional, Tuple
5

6
import torch
7
from torch import Tensor
8
from torch._awaits import _Await as Await
9
from torch.testing._internal.jit_utils import JitTestCase, make_global
10

11

12
class TestAwait(JitTestCase):
13
    def test_await_python(self):
14
        def foo(x: int) -> int:
15
            return x + 13
16

17
        aw: Await[int] = torch.jit._awaitable(foo, 13)
18
        self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw))
19
        nw = torch.jit._awaitable_nowait(33)
20
        self.assertTrue(nw.is_nowait())
21
        self.assertTrue(nw.args() == (33,))
22

23
    def test_await_type_python(self):
24
        def foo() -> Tensor:
25
            return torch.randn()
26

27
        awaits = torch.jit.annotate(List[Await[Tensor]], [])
28
        awaits.append(torch.jit._awaitable(foo))
29

30
    def test_script(self):
31
        def delayed(z: int) -> int:
32
            return z + 3
33

34
        def fn(x: Tensor):
35
            aw: Await[int] = torch.jit._awaitable(delayed, 99)
36
            a = torch.eye(2)
37
            b = torch.jit._awaitable_wait(aw)
38
            return a + b + x
39

40
        inp = torch.zeros(2)
41

42
        sm = torch.jit.script(fn)
43
        out = fn(inp)
44
        script_out = sm(inp)
45
        self.assertTrue(torch.allclose(torch.eye(2) + 102, script_out))
46
        self.assertTrue(torch.allclose(script_out, out))
47

48
    def test_nowait(self):
49
        def fn(x: Tensor):
50
            aw = torch.jit._awaitable_nowait(13)
51
            a = torch.eye(2)
52
            b = torch.jit._awaitable_wait(aw)
53
            return a + b + x
54

55
        inp = torch.zeros(2)
56

57
        sm = torch.jit.script(fn)
58
        out = fn(inp)
59
        script_out = sm(inp)
60
        self.assertTrue(torch.allclose(torch.eye(2) + 13, script_out))
61
        self.assertTrue(torch.allclose(script_out, out))
62

63
    def test_nowait_class(self):
64
        class C:
65
            def __init__(self, a: Tensor, b: Tensor):
66
                self._a = a
67
                self._b = b
68

69
            def a(self) -> Tensor:
70
                return self._a
71

72
        def fn(x: Tensor):
73
            aw = torch.jit._awaitable_nowait(C(torch.zeros(2), torch.ones(2)))
74
            _a = torch.eye(2)
75
            c = torch.jit._awaitable_wait(aw)
76
            return _a + c.a() + x
77

78
        make_global(C)
79
        inp = torch.zeros(2)
80

81
        sm = torch.jit.script(fn)
82
        out = fn(inp)
83
        script_out = sm(inp)
84
        self.assertTrue(torch.allclose(torch.eye(2), script_out))
85
        self.assertTrue(torch.allclose(script_out, out))
86

87
    def test_await_class_arg(self):
88
        class C:
89
            def __init__(self, a: Tensor, b: Tensor):
90
                self.__a = a
91
                self.__b = b
92

93
            def a(self) -> Tensor:
94
                return self.__a
95

96
        make_global(C)
97

98
        def delayed(c: C) -> Tensor:
99
            return c.a()
100

101
        def fn(x: Tensor):
102
            c = C(torch.zeros(2), torch.ones(2))
103
            aw = torch.jit._awaitable(delayed, c)
104
            _a = torch.eye(2)
105
            c2_t = torch.jit._awaitable_wait(aw)
106
            return _a + c2_t + x
107

108
        inp = torch.zeros(2)
109

110
        sm = torch.jit.script(fn)
111
        out = fn(inp)
112
        script_out = sm(inp)
113
        self.assertTrue(torch.allclose(torch.eye(2), script_out))
114
        self.assertTrue(torch.allclose(script_out, out))
115

116
    def test_awaitable_to_await(self):
117
        class C:
118
            __slots__ = ["_a", "_b"]
119

120
            def __init__(self, a: Tensor, b: Tensor):
121
                self._a = a
122
                self._b = b
123

124
        make_global(C)
125

126
        # Can not stay in the class as Jit does not support Recursive annotations
127
        # (self in wait_impl can not be annotated as C as C is not defined by this time)
128
        def C_wait_impl(self: C):
129
            return self._a + self._b
130

131
        def fn(x: Tensor):
132
            aw = torch.jit._awaitable(C_wait_impl, C(torch.zeros(2), torch.ones(2)))
133
            _a = torch.eye(2)
134
            c_wait_impl_res = torch.jit._awaitable_wait(aw)
135
            return _a + c_wait_impl_res + x
136

137
        inp = torch.ones(2)
138

139
        sm = torch.jit.script(fn)
140
        out = fn(inp)
141
        script_out = sm(inp)
142
        self.assertTrue(torch.allclose(torch.eye(2) + 2 * torch.ones(2), script_out))
143
        self.assertTrue(torch.allclose(script_out, out))
144

145
    def test_await_class_return(self):
146
        class C:
147
            __slots__ = ["a", "b"]
148

149
            def __init__(self, a: Tensor, b: Tensor):
150
                self.a = a
151
                self.b = b
152

153
        make_global(C)
154

155
        # Can not stay in the class as Jit does not support Recursive annotations
156
        # (self in wait_impl can not be annotated as C as C is not defined by this time)
157
        def C_wait_impl(self: C) -> C:
158
            return C(self.a * 2, self.b * 3)
159

160
        def fn_arg_C(x: C) -> Tensor:
161
            return x.a + x.b
162

163
        def fn(x: Tensor):
164
            aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
165
            _a = torch.eye(2)
166
            y = fn_arg_C(torch.jit._awaitable_wait(aw))
167
            return _a + y + x
168

169
        inp = torch.ones(2)
170

171
        sm = torch.jit.script(fn)
172
        out = fn(inp)
173
        script_out = sm(inp)
174
        self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out))
175
        self.assertTrue(torch.allclose(script_out, out))
176
        self.assertGraphContainsExactly(
177
            sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1
178
        )
179

180
    def test_await_getattr_implicit_convertion(self):
181
        class C:
182
            def __init__(self, a: Tensor, b: Tensor):
183
                self._a = a
184
                self._b = b
185

186
            def b(self):
187
                return self._b
188

189
        make_global(C)
190

191
        # Can not stay in the class as Jit does not support Recursive annotations
192
        # (self in wait_impl can not be annotated as C as C is not defined by this time)
193
        def C_wait_impl(self: C) -> C:
194
            return C(self._a * 2, self._b * 3)
195

196
        def fn_arg_C(x: C) -> Tensor:
197
            return x._a + x._b
198

199
        def fn(x: Tensor):
200
            aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
201
            _a = torch.eye(2)
202
            ai = aw._a
203
            awb = aw.b()
204
            c = C(2 * x, 2 * x)
205
            return _a + ai + x + c._a + c.b()
206

207
        inp = torch.ones(2)
208

209
        sm = torch.jit.script(fn)
210
        out = fn(inp)
211
        script_out = sm(inp)
212
        self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out))
213
        self.assertTrue(torch.allclose(script_out, out))
214
        self.assertGraphContainsExactly(
215
            sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2
216
        )
217

218
    def test_await_nested(self):
219
        class C:
220
            def __init__(self, a: Tensor, b: Tensor):
221
                self.__a = a
222
                self.__b = b
223

224
            def a(self) -> Tensor:
225
                return self.__a
226

227
        make_global(C)
228

229
        def delayed(c: C) -> Await[Tensor]:
230
            return torch.jit._awaitable_nowait(3 * c.a())
231

232
        def fn(x: Tensor) -> Await[Await[Tensor]]:
233
            return torch.jit._awaitable(delayed, C(2 * x, x))
234

235
        def main(x: Tensor) -> Tensor:
236
            awaw = fn(x)
237
            return torch.jit._awaitable_wait(torch.jit._awaitable_wait(awaw))
238

239
        inp = torch.eye(2)
240

241
        sm = torch.jit.script(main)
242
        out = main(inp)
243
        script_out = sm(inp)
244
        self.assertTrue(torch.allclose(6 * torch.eye(2), script_out))
245
        self.assertTrue(torch.allclose(script_out, out))
246

247
    def test_eager_await_non_scriptable(self):
248
        # Tree type can not be compiled (Recursive type)
249
        class Tree:
250
            def __init__(self, v):
251
                self.parent = torch.jit.annotate(Optional[Tree], None)
252
                self.v = v
253

254
        make_global(Tree)
255

256
        def delayed(t: Tree):
257
            t.v = t.v + 1
258
            return t
259

260
        aw = torch.jit._awaitable(delayed, Tree(2))
261
        t = torch.jit._awaitable_wait(aw)
262
        self.assertTrue(t.v == 3)
263

264
    def test_await_isinstance(self):
265
        def delayed(x: Tensor) -> Tensor:
266
            return 2 * (x + 1)
267

268
        def main(x: Tensor) -> Tensor:
269
            aw = torch.jit._awaitable(delayed, x)
270
            if torch.jit.is_scripting():
271
                assert isinstance(aw, torch.jit._Await)
272
            return torch.jit._awaitable_wait(aw)
273

274
        inp = torch.eye(2)
275

276
        sm = torch.jit.script(main)
277
        out = main(inp)
278
        script_out = sm(inp)
279
        self.assertTrue(
280
            torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
281
        )
282
        self.assertTrue(torch.allclose(script_out, out))
283

284
    def test_await_eager_lazy(self):
285
        def delayed(x: Tensor) -> Tensor:
286
            return 2 * (x + 1)
287

288
        t = torch.ones(2, dtype=torch.int64)
289
        aw = torch.jit._awaitable(delayed, t)
290
        self.assertTrue(isinstance(aw, torch._C._Await))
291
        self.assertTrue(t.dtype == aw.dtype)
292

293
    def test_await_out_of_interpreter(self):
294
        def delayed(x: Tensor) -> Tensor:
295
            return 2 * (x + 1)
296

297
        def main(x: Tensor) -> Await[Tensor]:
298
            aw = torch.jit._awaitable(delayed, x)
299
            return aw
300

301
        inp = torch.eye(2)
302

303
        sm = torch.jit.script(main)
304
        out_aw = main(inp)
305
        out = torch.jit._awaitable_wait(out_aw)
306

307
        script_out_aw = sm(inp)
308
        script_out = torch.jit._awaitable_wait(script_out_aw)
309
        self.assertTrue(
310
            torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
311
        )
312
        self.assertTrue(torch.allclose(script_out, out))
313

314
    def test_jit_trace(self):
315
        def gap(x: Tensor):
316
            return torch.relu(x) + torch.sin(x)
317

318
        def delayed(x: Tensor) -> Tensor:
319
            return 2 * (torch.cos(x) + 1)
320

321
        def main(x: Tensor, y: Tensor) -> Tensor:
322
            aw = torch.jit._awaitable(delayed, x)
323
            z = gap(y)
324
            k = torch.jit._awaitable_wait(aw)
325
            return y + k
326

327
        inp = torch.randn(2)
328
        tm = torch.jit.trace(main, (inp, inp))
329
        inp_check = torch.ones(2)
330
        self.assertEqual(main(inp_check, inp_check), tm(inp_check, inp_check))
331

332
    def test_await_multiout_save(self):
333
        def gap(x: Tensor):
334
            return torch.relu(x) + torch.sin(x)
335

336
        def delayed(x: Tensor) -> Tuple[Tensor, List[Tensor]]:
337
            l = [x * i for i in range(5)]
338
            return (100 * x, l)
339

340
        def main(x: Tensor) -> Tensor:
341
            aw = torch.jit._awaitable(delayed, x)
342
            z = gap(x)
343
            (_, l) = torch.jit._awaitable_wait(aw)
344
            return l[3] + z
345

346
        inp = torch.eye(2)
347

348
        sm = torch.jit.script(main)
349
        out = main(inp)
350
        script_out = sm(inp)
351
        expected = 4.8415 * torch.eye(2)
352
        self.assertTrue(torch.allclose(expected, script_out))
353
        self.assertTrue(torch.allclose(script_out, out))
354

355
        iofile = io.BytesIO()
356
        torch.jit.save(sm, iofile)
357
        iofile.seek(0)
358
        sm = torch.jit.load(iofile)
359
        script_out_load = sm(inp)
360
        self.assertTrue(torch.allclose(expected, script_out_load))
361

362
    def test_await_func_arg(self):
363
        def gap(x: Tensor):
364
            return torch.relu(x) + torch.sin(x)
365

366
        def delayed(x: Tensor) -> Tensor:
367
            return -1 * x
368

369
        def fn(aw: Await[Tensor]) -> Tensor:
370
            return 3 * torch.jit._awaitable_wait(aw)
371

372
        def main(x: Tensor) -> Tensor:
373
            aw = torch.jit._awaitable(delayed, x)
374
            z = gap(x)
375
            y = fn(aw)
376
            return y + x
377

378
        inp = torch.eye(2)
379

380
        sm = torch.jit.script(main)
381
        out = main(inp)
382
        script_out = sm(inp)
383
        expected = -2 * torch.eye(2)
384
        self.assertTrue(torch.allclose(expected, script_out))
385
        self.assertTrue(torch.allclose(script_out, out))
386

387
        iofile = io.BytesIO()
388
        torch.jit.save(sm, iofile)
389
        iofile.seek(0)
390
        sm = torch.jit.load(iofile)
391
        script_out_load = sm(inp)
392
        self.assertTrue(torch.allclose(expected, script_out_load))
393

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.