pytorch

Форк
0
/
test_python_builtins.py 
475 строк · 15.5 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import random
5
import sys
6
import tempfile
7
from textwrap import dedent
8

9
import torch
10
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
11

12

13
# Make the helper files in test/ importable
14
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15
sys.path.append(pytorch_test_dir)
16

17
if __name__ == "__main__":
18
    raise RuntimeError(
19
        "This test file is not meant to be run directly, use:\n\n"
20
        "\tpython test/test_jit.py TESTNAME\n\n"
21
        "instead."
22
    )
23

24

25
def get_fn(file_name, script_path):
26
    import importlib.util
27

28
    spec = importlib.util.spec_from_file_location(file_name, script_path)
29
    module = importlib.util.module_from_spec(spec)
30
    spec.loader.exec_module(module)
31
    fn = module.fn
32
    return fn
33

34

35
class TestPythonBuiltinOP(JitTestCase):
36
    def test_add(self):
37
        def func(a, b):
38
            c = a + b
39
            c += a
40
            return c
41

42
        a = torch.rand(1, requires_grad=True)
43
        b = torch.rand(1, requires_grad=True)
44
        self.checkScript(func, (a, b), optimize=True)
45

46
    def test_mul(self):
47
        def func(a, b):
48
            return a * b
49

50
        a = torch.rand(1, requires_grad=True)
51
        b = torch.rand(1, requires_grad=True)
52
        self.checkScript(func, (a, b), optimize=True)
53

54
    def test_matmul_py3(self):
55
        code = dedent(
56
            """
57
        def fn(a, b):
58
            return a @ b
59
        """
60
        )
61

62
        with tempfile.TemporaryDirectory() as tmp_dir:
63
            script_path = os.path.join(tmp_dir, "script.py")
64
            with open(script_path, "w") as f:
65
                f.write(code)
66
            fn = get_fn("test_matmul_py3", script_path)
67

68
            a = torch.rand(4, 3, requires_grad=True)
69
            b = torch.rand(3, 2, requires_grad=True)
70
            self.checkScript(fn, (a, b), optimize=True)
71

72
    def test_pow(self):
73
        def func(a, b):
74
            return a**b
75

76
        def func2(a, b, c, d):
77
            return c + a**b**d
78

79
        def func3(a, b):
80
            # type: (int, float) -> float
81
            return a**b
82

83
        def func4():
84
            # type: () -> float
85
            return 2**-2
86

87
        def func5(x, y):
88
            return x.item() ** y.item()
89

90
        a = torch.rand(1, requires_grad=True)
91
        b = torch.rand(1, requires_grad=True)
92
        c = torch.rand(1, requires_grad=True)
93
        d = torch.rand(1, requires_grad=True)
94
        self.checkScript(func, (a, b), optimize=True)
95
        self.checkScript(func2, (a, b, c, d), optimize=True)
96
        self.checkScript(func3, (4, -0.5), optimize=True)
97
        self.checkScript(func4, ())
98

99
        inputs = [
100
            torch.tensor(2),
101
            torch.tensor(-2),
102
            torch.tensor(0.5),
103
            torch.tensor(0.2),
104
        ]
105
        for x in inputs:
106
            for y in inputs:
107
                if x < 0:
108
                    continue
109
                else:
110
                    self.checkScript(func5, (x, y))
111

112
    def test_triple(self):
113
        def func(x):
114
            return 3.0 * x
115

116
        x = torch.rand(1, dtype=torch.float, requires_grad=True)
117
        self.checkScript(func, [x], optimize=True)
118

119
    def test_slice(self):
120
        def func(x):
121
            return x[:5]
122

123
        x = torch.rand(10, dtype=torch.float, requires_grad=True)
124
        self.checkScript(func, [x], optimize=True)
125

126
        def func2(x):
127
            return x[5:]
128

129
        self.checkScript(func2, [x], optimize=True)
130

131
        def func3(x):
132
            return x[:8:2]
133

134
        self.checkScript(func3, [x], optimize=True)
135

136
        def func4(x):
137
            return x[1::4]
138

139
        self.checkScript(func4, [x], optimize=True)
140

141
    def test_gather(self):
142
        def func(x):
143
            return x[0]
144

145
        x = torch.rand(10, dtype=torch.float, requires_grad=True)
146
        self.checkScript(func, [x], optimize=True)
147

148
    def test_random(self):
149
        @torch.jit.script
150
        def f(mean, std):
151
            return torch.normal(mean, std)
152

153
        mean, std = torch.zeros(5, 5), torch.ones(5, 5)
154
        with torch.random.fork_rng(devices=[]):
155
            output = torch.normal(mean, std)
156
        with torch.random.fork_rng(devices=[]):
157
            script_output = f(mean, std)
158
        self.assertEqual(output, script_output)
159

160
    def _check_code(self, code_str, fn_name, inputs):
161
        scope = {}
162
        exec(code_str, globals(), scope)
163
        cu = torch.jit.CompilationUnit(code_str)
164
        self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
165

166
    def test_stepped_tuple_slicing(self):
167
        def check_slicing_tuple(slicing, tuple_type, tuple):
168
            template = dedent(
169
                """
170
            def func(x):
171
                # type: ({}) -> Any
172
                return x{}
173
            """
174
            )
175
            self._check_code(template.format(tuple_type, slicing), "func", [tuple])
176

177
        check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2))
178
        check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
179
        check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
180
        check_slicing_tuple(
181
            "[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
182
        )
183
        check_slicing_tuple(
184
            "[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
185
        )
186
        check_slicing_tuple(
187
            "[5:7:-2]",
188
            "Tuple[int, int, int, int, int, int, int]",
189
            (0, 1, 2, 3, 4, 5, 6),
190
        )
191
        check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
192
        check_slicing_tuple(
193
            "[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)
194
        )
195
        check_slicing_tuple(
196
            "[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)
197
        )
198

199
    def test_index(self):
200
        def consec(size, start=0):
201
            numel = torch.tensor(size).prod().item()
202
            return torch.arange(numel).view(size)
203

204
        def check_indexing(indexing, tensor):
205
            template = dedent(
206
                """
207
            def func(x):
208
                return x{}
209
            """
210
            )
211

212
            self._check_code(template.format(indexing), "func", [tensor])
213

214
        def check_dynamic_indexing(indexing, tensor, value1, value2):
215
            value1 = torch.tensor(value1)
216
            value2 = torch.tensor(value2)
217

218
            template = dedent(
219
                """
220
            def func(x, value1, value2):
221
                i = int(value1)
222
                j = int(value2)
223
                return x{}
224
            """
225
            )
226

227
            self._check_code(
228
                template.format(indexing), "func", [tensor, value1, value2]
229
            )
230

231
        # basic slices
232
        check_indexing("[0]", consec((3, 3)))
233
        check_indexing("[1]", consec((3, 3), 10))
234
        check_indexing("[2]", consec((3, 3), 19))
235
        check_indexing("[2]", consec((3,)))
236
        check_indexing("[-1]", consec((3, 3), 19))
237
        check_indexing("[0:2]", consec((3, 3, 3)))
238
        check_indexing("[1:-1]", consec((3, 3, 3)))
239
        check_indexing("[-3:-1]", consec((6, 3)))
240
        check_indexing("[1:]", consec((3, 3)))
241
        check_indexing("[:1]", consec((3, 3)))
242
        check_indexing("[:]", consec((3, 2)))
243

244
        # multi-dim: indexes
245
        check_indexing("[0, 1]", consec((3, 3)))
246
        check_indexing("[0, 1]", consec((3, 3, 2)))
247
        check_indexing("[1, 0, 2]", consec((3, 3, 3)))
248
        check_indexing("[2, -1]", consec((3, 3)))
249

250
        # multi-dim: mixed slicing and indexing
251
        check_indexing("[0, 1:2]", consec((3, 3)))
252
        check_indexing("[0, :1]", consec((3, 3, 2)))
253
        check_indexing("[1, 2:]", consec((3, 3, 3)))
254
        check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
255
        check_indexing("[1:, -1, 0]", consec((3, 3, 3, 3)))
256
        check_indexing("[-1, 2:, 1:2]", consec((3, 3, 3, 3)))
257
        check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
258
        check_indexing("[-1, :, 0, 2]", consec((3, 3, 3, 3)))
259

260
        # zero-sized slices
261
        check_indexing("[0:0]", consec((2, 2)))
262
        check_indexing("[0:0, 1]", consec((3, 3)))
263

264
        # trivial expression usage
265
        check_indexing("[1+1]", consec((3, 3)))
266
        check_indexing("[1:(0 + 2)]", consec((3, 3, 3)))
267

268
        # None for new dimensions
269
        check_indexing("[None, 0]", consec((3, 3)))
270
        check_indexing("[1, None]", consec((3, 3), 10))
271
        check_indexing("[None, None, 2]", consec((3, 3), 19))
272
        check_indexing("[None, 2, None]", consec((3,)))
273
        check_indexing("[0:2, None]", consec((3, 3, 3)))
274
        check_indexing("[None, 1:-1]", consec((3, 3, 3)))
275
        check_indexing("[None, -3:-1, None]", consec((6, 3)))
276
        check_indexing("[-1, None, 2:, None, 1:2]", consec((3, 3, 3, 3)))
277
        check_indexing("[None, -1, None, 2:, None, 1:2, None]", consec((3, 3, 3, 3)))
278

279
        # dynamic expression usage
280
        check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
281
        check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
282

283
    def test_advancedindex(self):
284
        def consec(size, start=0):
285
            numel = torch.tensor(size).prod().item()
286
            return torch.arange(numel).view(size)
287

288
        def check_indexing(indexing, tensor, **kwargs):
289
            indices_dict = kwargs
290

291
            template = dedent(
292
                """
293
            def func(x{formals}):
294
                return x{expr}
295
            """
296
            )
297

298
            formals = []
299
            values = []
300
            for formal, value in indices_dict.items():
301
                formals.append(formal)
302
                values.append(value)
303

304
            formals = "".join(map(", {}".format, formals))
305
            inputs = [tensor] + values
306
            self._check_code(
307
                template.format(formals=formals, expr=indexing), "func", inputs
308
            )
309

310
        # Indexing with tensor (basic)
311
        check_indexing("[i]", consec((3, 3)), i=torch.tensor([0]))
312
        check_indexing("[i]", consec((3, 3)), i=torch.tensor(1))
313
        check_indexing("[i]", consec((3, 3)), i=torch.tensor([-2]))
314
        check_indexing("[i]", consec((3, 3), 2), i=torch.tensor([0, 0]))
315
        check_indexing("[i]", consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
316

317
        # NB: indexing with tensors and indexing with sequences can be implemented
318
        # in a very similar way (sequences are converted to tensors), so only one
319
        # case needs to be tested extensively.
320
        # XXX: When we can index with sequences, replace these cases with
321
        # sequence indexing expressions; those are much easier to read.
322

323
        # Misc sequence advanced indexing
324
        inp = consec((4, 8, 5))
325
        to_check = [
326
            # [[0, 1, 3]]
327
            ["[i]", {"i": [0, 1, 3]}],
328
            # [[0, 2], [1, 3]]
329
            ["[i, j]", {"i": [0, 2], "j": [1, 3]}],
330
            # [[[0, 1], [0, 1]], [[0, 1], [0, 1]]]
331
            ["[i, j]", {"i": [[0, 1], [0, 1]], "j": [[0, 1], [0, 1]]}],
332
            # [[0, 2], [1, 3], [1, 1]]
333
            ["[i, j, k]", {"i": [0, 2], "j": [1, 3], "k": [1, 1]}],
334
            # [[0, 2], 1, [1, 1]]
335
            ["[i, j, k]", {"i": [0, 2], "j": 1, "k": [1, 1]}],
336
            # [:, :, [0, 3, 4]]
337
            ["[:, :, i]", {"i": [0, 3, 4]}],
338
            # [:, [2, 4, 5, 7], 2:4]
339
            ["[:, i, 2:4]", {"i": [0, 2, 3]}],
340
            # [[2, 3], :, :]
341
            ["[i, :, :]", {"i": [2, 3]}],
342
            # [:, [0, 2, 3], [1, 3, 4]]
343
            ["[:, i, j]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
344
            # [:, [0], [1, 2, 4]]
345
            ["[:, i, j]", {"i": [0], "j": [1, 2, 4]}],
346
            # [:, [0, 1, 3], [4]]
347
            ["[:, i, j]", {"i": [0, 1, 3], "j": [4]}],
348
            # [:, [[0, 1], [1, 0]], [[2, 3]]]
349
            ["[:, i, j]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
350
            # [:, [[0, 1], [2, 3]], [[0]]]
351
            ["[:, i, j]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
352
            # [:, [[5, 6]], [[0, 3], [4, 4]]]
353
            ["[:, i, j]", {"i": [[5, 6]], "j": [[0, 3], [4, 4]]}],
354
            # [[0, 2, 3], [1, 3, 4], :]
355
            ["[i, j, :]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
356
            # [0, [1, 2, 4], :]
357
            ["[i, j, :]", {"i": 0, "j": [1, 2, 4]}],
358
            # [[0, 1, 3], 4, :]
359
            ["[i, j, :]", {"i": [0, 1, 3], "j": 4}],
360
            # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :]
361
            ["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 1], [3, 5]]}],
362
            # [[[0, 1], [1, 0]], [[2, 3]], :]
363
            ["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
364
            # [[[0, 1], [2, 3]], [[0]], :]
365
            ["[i, j, :]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
366
            # [[[2, 1]], [[0, 3], [4, 4]], :]
367
            ["[i, j, :]", {"i": [[2, 1]], "j": [[0, 3], [4, 4]]}],
368
            # [[[2]], [[0, 3], [4, 1]], 0:2]
369
            ["[i, j, 0:2]", {"i": [[2]], "j": [[0, 3], [4, 1]]}],
370
        ]
371

372
        for expr, argdict in to_check:
373
            tensordict = {k: torch.tensor(v) for (k, v) in argdict.items()}
374
            check_indexing(expr, inp, **tensordict)
375

376
    def test_adv_indexing_list(self):
377
        # indexing with list is equivalent to indexing with tensor
378
        def func1(x):
379
            return x[[0, 1, 5]]
380

381
        def func2(x):
382
            return x[[0, 1], [0, 1]]
383

384
        def func3(x):
385
            return x[[[0, 1], [0, 1]], [[0, 1], [0, 1]]]
386

387
        def func4(x):
388
            ls = [0]
389
            ls.append(1)
390
            ls.append(2)
391
            return x[ls]
392

393
        def func5(x):
394
            ls = [0.1, 1.2, 2.3]
395
            return x[ls]
396

397
        input = torch.rand((6, 2))
398
        self.checkScript(func1, (input,))
399
        self.checkScript(func2, (input,))
400
        self.checkScript(func3, (input,))
401
        self.checkScript(func4, (input,))
402
        self.checkScript(func5, (input,))
403

404
    def test_index_ellipses(self):
405
        vals = [":", 1, None]
406
        for _ in range(100):
407
            indices = [random.choice(vals) for _ in range(4)]
408
            indices[random.randint(0, len(indices) - 1)] = "..."
409
            test_str = dedent(
410
                """
411
            def f():
412
                x = torch.ones(10, 9, 8, 7, 6)
413
                return x{indices}.shape
414
            """.format(
415
                    indices=indices
416
                )
417
            )
418
            test_str = test_str.replace(r"'", r"")
419
            scope = {}
420
            execWrapper(test_str, globals(), scope)
421
            cu = torch.jit.CompilationUnit(test_str)
422
            res1 = cu.f()
423
            res2 = scope["f"]()
424
            self.assertEqual(res1, res2)
425

426
    def test_inf(self):
427
        @torch.jit.script
428
        def foo(a):
429
            return a < float("inf")
430

431
        s = torch.rand(1)
432
        self.assertTrue(foo(s))
433

434
        @torch.jit.script
435
        def bar(a):
436
            return a > float("-inf")
437

438
        s = torch.rand(1)
439
        self.assertTrue(foo(s))
440

441
        # test re-assignment on imported source
442
        str = """
443
        def foo(x):
444
            # type: (bool)
445
            a = float("-inf")
446
            if not x:
447
                a = float(torch.tensor([5]))
448
            return a < 4
449
        """
450
        cu = torch.jit.CompilationUnit(str)
451
        self.assertTrue(cu.foo(True))
452
        self.assertFalse(cu.foo(False))
453

454
    def test_str_to_float(self):
455
        @torch.jit.script
456
        def foo(a):
457
            return 0.5 == float("0.5 hello")
458

459
        s = torch.rand(1)
460
        with self.assertRaisesRegex(RuntimeError, "could not convert string to float"):
461
            self.assertTrue(foo(s))
462

463
        @torch.jit.script
464
        def foo(a):
465
            return 0.5 == float("0.5")
466

467
        s = torch.rand(1)
468
        self.assertTrue(foo(s))
469

470
        @torch.jit.script
471
        def foo(a):
472
            return 0.0 == float("0")
473

474
        s = torch.rand(1)
475
        self.assertTrue(foo(s))
476

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.