pytorch

Форк
0
/
test_misc.py 
505 строк · 15.8 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5
import unittest
6
from typing import Any, Dict, List, Optional, Tuple
7

8
import torch
9
import torch.nn as nn
10
import torch.testing._internal.jit_utils
11
from jit.test_module_interface import TestModuleInterface  # noqa: F401
12
from torch import jit
13
from torch.testing import FileCheck
14
from torch.testing._internal.common_utils import freeze_rng_state
15
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
16

17

18
# Make the helper files in test/ importable
19
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
20
sys.path.append(pytorch_test_dir)
21

22
if __name__ == "__main__":
23
    raise RuntimeError(
24
        "This test file is not meant to be run directly, use:\n\n"
25
        "\tpython test/test_jit.py TESTNAME\n\n"
26
        "instead."
27
    )
28

29

30
class TestMisc(JitTestCase):
31
    def test_joined_str(self):
32
        def func(x):
33
            hello, test = "Hello", "test"
34
            print(f"{hello + ' ' + test}, I'm a {test}")
35
            print("format blank")
36
            hi = "hi"
37
            print(f"stuff before {hi}")
38
            print(f"{hi} stuff after")
39
            return x + 1
40

41
        x = torch.arange(4.0, requires_grad=True)
42
        # TODO: Add support for f-strings in string parser frontend
43
        # self.checkScript(func, [x], optimize=True, capture_output=True)
44

45
        with self.capture_stdout() as captured:
46
            out = func(x)
47

48
        scripted = torch.jit.script(func)
49
        with self.capture_stdout() as captured_script:
50
            out_script = func(x)
51

52
        self.assertEqual(out, out_script)
53
        self.assertEqual(captured, captured_script)
54

55
    def test_kwarg_support(self):
56
        with self.assertRaisesRegex(
57
            torch.jit.frontend.NotSupportedError, "variable number of arguments"
58
        ):
59

60
            class M(torch.nn.Module):
61
                def forward(self, *, n_tokens: int, device_name: str = 2):
62
                    pass
63

64
            torch.jit.script(M())
65

66
        class M(torch.nn.Module):
67
            def forward(self, *, n_tokens: int, device_name: str):
68
                return n_tokens, device_name
69

70
        sm = torch.jit.script(M())
71

72
        with self.assertRaisesRegex(
73
            RuntimeError, "missing value for argument 'n_tokens'"
74
        ):
75
            sm()
76

77
        with self.assertRaisesRegex(RuntimeError, "positional arg"):
78
            sm(3, "hello")
79

80
        self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))
81

82
    def test_tuple_subscripted_assign(self):
83
        with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
84

85
            @torch.jit.script
86
            def foo(a: Tuple[int, int]) -> None:
87
                a[0] = a[1]
88

89
        with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
90

91
            @torch.jit.script
92
            def bar(a: Tuple[int, int]) -> None:
93
                a[0] += a[1]
94

95
    def test_subexpression_List_Future(self):
96
        @torch.jit.script
97
        def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
98
            return x[0]
99

100
        FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)
101

102
    def test_subexpression_Future_annotate(self):
103
        @torch.jit.script
104
        def fn() -> torch.jit.Future[int]:
105
            x: List[torch.jit.Future[int]] = []
106
            return x[0]
107

108
        FileCheck().check("Future[int][]").run(fn.graph)
109

110
    def test_future_isinstance(self):
111
        @torch.jit.script
112
        def fn(x: Any) -> torch.jit.Future[int]:
113
            assert isinstance(x, jit.Future[int])
114
            return x
115

116
        FileCheck().check("Future[int]").run(fn.graph)
117

118
    def test_str_refine_any(self):
119
        def forward(x: Any) -> str:
120
            if isinstance(x, str):
121
                return x
122
            return "foo"
123

124
        forward = torch.jit.script(forward)
125
        self.assertEqual(forward(1), "foo")
126
        self.assertEqual(forward("bar"), "bar")
127

128
    def test_subexpression_Tuple_int_int_Future(self):
129
        @torch.jit.script
130
        def fn(
131
            x: Tuple[int, int, torch.jit.Future[int]]
132
        ) -> Tuple[int, torch.jit.Future[int]]:
133
            return x[0], x[2]
134

135
        FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
136
            fn.graph
137
        )
138

139
    def test_subexpression_Dict_int_Future(self):
140
        @torch.jit.script
141
        def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
142
            return x[y]
143

144
        FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)
145

146
    def test_subexpression_Optional(self):
147
        @torch.jit.script
148
        def fn(
149
            x: Optional[Dict[int, torch.jit.Future[int]]]
150
        ) -> Optional[torch.jit.Future[int]]:
151
            if x is not None:
152
                return x[0]
153
            else:
154
                return None
155

156
        FileCheck().check("Dict(int, Future(int))?").run(fn.graph)
157

158
    def test_if_returning_any(self):
159
        """
160
        Check that an if statement can return different
161
        types early from each branch when the return
162
        type of the function is Any.
163
        """
164

165
        def if_function(inp: torch.Tensor) -> Any:
166
            if inp.shape[0] == 1:
167
                return inp * inp
168
            else:
169
                return "str"
170

171
        self.checkScript(if_function, (torch.randn(5),))
172

173
    def test_hacked_twin(self):
174
        def gen_data():
175
            with freeze_rng_state():
176
                return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
177

178
        (
179
            input,
180
            index,
181
            value,
182
        ) = gen_data()
183
        (
184
            input1,
185
            index1,
186
            value1,
187
        ) = gen_data()
188
        out1 = torch.ops.aten.index_put.hacked_twin(
189
            input, [index], value, accumulate=False
190
        )
191
        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
192
        self.assertEqual(out1, out2)
193

194
        torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False)
195
        torch.index_put_(input1, [index1], value1, accumulate=False)
196
        self.assertEqual(input, input1)
197

198
    def test_unsafe_hacked_twin(self):
199
        def gen_data():
200
            with freeze_rng_state():
201
                return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
202

203
        (
204
            input,
205
            index,
206
            value,
207
        ) = gen_data()
208
        (
209
            input1,
210
            index1,
211
            value1,
212
        ) = gen_data()
213
        out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
214
            input, [index], value, accumulate=False
215
        )
216
        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
217
        self.assertEqual(out1, out2)
218

219
        torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
220
        torch.index_put(input1, [index1], value1, accumulate=False)
221
        self.assertEqual(input, input1)
222

223
        def index_put_fn(input, index, value):
224
            return torch.ops.aten._unsafe_index_put(
225
                input, [index], value, accumulate=False
226
            )
227

228
        input2, index2, value2 = gen_data()
229
        script_index_put_fn = torch.jit.script(index_put_fn)
230
        expect = index_put_fn(input2.clone(), index2, value2)
231
        actual = script_index_put_fn(input2.clone(), index2, value2)
232
        self.assertEqual(expect, actual)
233

234
        def index_fn(input, index, value):
235
            return torch.ops.aten._unsafe_index_put(
236
                input, [index], value, accumulate=False
237
            )
238

239
        script_index_fn = torch.jit.script(index_fn)
240
        expect = index_fn(input2.clone(), index2, value2)
241
        actual = script_index_fn(input2.clone(), index2, value2)
242
        self.assertEqual(expect, actual)
243

244
    def test_export_opnames_interface(self):
245
        @torch.jit.interface
246
        class OneTwoModule(nn.Module):
247
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
248
                pass
249

250
            def two(self, x: torch.Tensor) -> torch.Tensor:
251
                pass
252

253
            def forward(self, x: torch.Tensor) -> torch.Tensor:
254
                pass
255

256
        class FooMod(nn.Module):
257
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
258
                return x + y
259

260
            def two(self, x: torch.Tensor) -> torch.Tensor:
261
                return 2 * x
262

263
            def forward(self, x: torch.Tensor) -> torch.Tensor:
264
                return self.one(self.two(x), x)
265

266
        class BarMod(nn.Module):
267
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
268
                return x * y
269

270
            def two(self, x: torch.Tensor) -> torch.Tensor:
271
                return 2 / x
272

273
            def forward(self, x: torch.Tensor) -> torch.Tensor:
274
                return self.two(self.one(x, x))
275

276
        make_global(OneTwoModule)
277

278
        class M(nn.Module):
279
            sub: OneTwoModule
280

281
            def __init__(self) -> None:
282
                super().__init__()
283
                self.sub = BarMod()
284

285
            def forward(self, x: torch.Tensor) -> torch.Tensor:
286
                return self.sub.forward(x)
287

288
        def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
289
            return mod_list[0].forward(x) + mod_list[1].forward(x)
290

291
        torch._C._enable_mobile_interface_call_export()
292
        scripted_M_mod = torch.jit.script(M())
293
        self.assertTrue(
294
            {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
295
                set(torch.jit.export_opnames(scripted_M_mod))
296
            )
297
        )
298

299
        scripted_M_mod.sub = torch.jit.script(FooMod())
300
        self.assertTrue(
301
            {"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
302
                set(torch.jit.export_opnames(scripted_M_mod))
303
            )
304
        )
305

306
    def test_math_inf(self):
307
        from math import inf
308

309
        def foo():
310
            return inf
311

312
        self.checkScript(foo, ())
313

314
    def test_list_literal_infer(self):
315
        def expects_intlist(x: List[int]):
316
            x.append(3)
317
            return x
318

319
        def foo():
320
            return expects_intlist([])
321

322
        self.checkScript(foo, ())
323

324
        def annotated_list_fail():
325
            return expects_intlist(torch.jit.annotate([], List[Tensor]))  # noqa: F821
326

327
        with self.assertRaises(RuntimeError):
328
            torch.jit.script(annotated_list_fail)
329

330
        def non_temporary_fail():
331
            a = []
332
            return expects_intlist(a)
333

334
        with self.assertRaises(RuntimeError):
335
            torch.jit.script(non_temporary_fail)
336

337
        @torch.jit.script
338
        def test_return():
339
            return []
340

341
        FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)
342

343
    def test_legacy_tensor_constructor(self):
344
        # testing PyObject overload
345
        def test_all_dtypes():
346
            return (
347
                torch.BoolTensor([2]),
348
                torch.LongTensor([3]),
349
                torch.ByteTensor([4]),
350
                torch.CharTensor([5]),
351
                torch.DoubleTensor([6]),
352
                torch.FloatTensor([7]),
353
                torch.IntTensor([8]),
354
                torch.ShortTensor([1]),
355
                torch.HalfTensor([1]),
356
            )
357

358
        self.checkScript(test_all_dtypes, ())
359

360
        # now test empty overload
361
        def empty_overload():
362
            return torch.LongTensor(2, 3, 4)
363

364
        eager = empty_overload()
365
        jit = torch.jit.script(empty_overload)()
366
        eager[:] = 1
367
        jit[:] = 1
368
        self.assertEqual(eager, jit)
369

370
        def no_inputs():
371
            return torch.DoubleTensor()
372

373
        self.checkScript(no_inputs, ())
374

375
        # bad schema
376
        def multiple_args():
377
            return torch.LongTensor(1, [2])
378

379
        with self.assertRaisesRegex(
380
            RuntimeError, "multiple positional arguments that were not all integers"
381
        ):
382
            torch.jit.script(multiple_args)
383

384
        # kwarg bad schema
385
        def bad_kwarg():
386
            return torch.LongTensor(hello="1")
387

388
        with self.assertRaisesRegex(RuntimeError, "hello"):
389
            torch.jit.script(bad_kwarg)
390

391
    def test_broadcasting_list(self):
392
        """
393
        Test BroadcastingList and torch.nn._size_N_t alias
394
        """
395
        from torch._jit_internal import BroadcastingList2
396
        from torch.nn.common_types import _size_2_t
397

398
        def sum_i(x: _size_2_t) -> int:
399
            return x[0] + x[1]
400

401
        def sum_f(x: BroadcastingList2[float]) -> float:
402
            return x[0] + x[1]
403

404
        self.assertTrue(torch.jit.script(sum_i)(4) == 8)
405
        self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)
406

407
    def test_parse_ir_annotate(self):
408
        ir = """
409
        graph():
410
          %3 : int[] = prim::Constant[value=annotate(List[int], [])]()
411
          return (%3)
412
        """
413
        graph = torch._C.parse_ir(ir, True)
414
        func = torch._C._create_function_from_graph("forward", graph)
415
        ret = func()
416
        self.assertTrue(ret == [])
417

418
    def test_parse_ir_single_element_tensor_positive(self):
419
        ir = """
420
        graph():
421
          %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]()
422
          return (%7)
423
        """
424
        graph = torch._C.parse_ir(ir, True)
425
        func = torch._C._create_function_from_graph("forward", graph)
426
        ret = func()
427
        self.assertTrue(ret.numel() == 1)
428
        self.assertTrue(len(ret.size()) == 1)
429

430
    def test_parse_ir_single_element_tensor_negative(self):
431
        ir = """
432
        graph():
433
          %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]()
434
          return (%7)
435
        """
436
        graph = torch._C.parse_ir(ir, True)
437
        func = torch._C._create_function_from_graph("forward", graph)
438
        ret = func()
439
        self.assertTrue(ret.numel() == 1)
440
        self.assertTrue(len(ret.size()) == 1)
441

442
    def test_script_many_decorators(self):
443
        def no_op_decorator(f):
444
            return f
445

446
        @no_op_decorator
447
        @no_op_decorator
448
        @no_op_decorator
449
        @no_op_decorator
450
        @no_op_decorator
451
        def foo(x, dim: int):
452
            return x.unsqueeze(dim)
453

454
        x = torch.randn(
455
            1,
456
        )
457
        expected = foo(x, 0)
458
        scripted = torch.jit.script(foo)
459
        actual = scripted(x, 0)
460
        torch.testing.assert_close(expected, actual)
461

462
    @unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support")
463
    def test_pow_multiple_dtype(self):
464
        # https://github.com/pytorch/pytorch/issues/75476
465
        def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
466
            p = torch.sigmoid(p)
467
            result = p**gamma
468
            return result
469

470
        x = torch.rand((2, 2), dtype=torch.half, device="cuda")
471

472
        ref = fn(x)
473

474
        script_fn = torch.jit.script(fn)
475
        for i in range(4):
476
            res = script_fn(x)
477

478
        self.assertEqual(ref, res)
479

480
    def test_jit_get_operation_order(self):
481
        # See https://github.com/pytorch/pytorch/pull/107138.
482
        # Depending on order of operator registration, you can get different
483
        # order of overloads in the JIT operator registry.
484
        # This is to verify that the order of operators returned by
485
        # _jit_get_operation always puts aten ops first (i.e. by sorting
486
        # to put them first)
487

488
        # Make sure that this chooses a "scalar" overload not a "complex" overload
489
        ret = torch.ops.aten.add(4, 3.3)
490
        self.assertFalse("complex" in str(ret.dtype))
491

492
        # "Scalar" overload is a normal aten op; "complex" is added by torchscript.
493
        # We want "Scalar" to come before "complex".
494
        op, override_names = torch._C._jit_get_operation("aten::add")
495
        print(override_names)
496
        complex_indices = [
497
            i for i, name in enumerate(override_names) if name == "complex"
498
        ]
499
        Scalar_indices = [
500
            i for i, name in enumerate(override_names) if name == "Scalar"
501
        ]
502

503
        self.assertTrue(len(complex_indices) > 0)
504
        self.assertTrue(len(Scalar_indices) > 0)
505
        self.assertTrue(complex_indices[0] > Scalar_indices[0])
506

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.