pytorch

Форк
0
/
test_jit_utils.py 
118 строк · 3.7 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5
from textwrap import dedent
6

7
import torch
8
from torch.testing._internal import jit_utils
9

10

11
# Make the helper files in test/ importable
12
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13
sys.path.append(pytorch_test_dir)
14
from torch.testing._internal.jit_utils import JitTestCase
15

16

17
if __name__ == "__main__":
18
    raise RuntimeError(
19
        "This test file is not meant to be run directly, use:\n\n"
20
        "\tpython test/test_jit.py TESTNAME\n\n"
21
        "instead."
22
    )
23

24

25
# Tests various JIT-related utility functions.
26
class TestJitUtils(JitTestCase):
27
    # Tests that POSITIONAL_OR_KEYWORD arguments are captured.
28
    def test_get_callable_argument_names_positional_or_keyword(self):
29
        def fn_positional_or_keyword_args_only(x, y):
30
            return x + y
31

32
        self.assertEqual(
33
            ["x", "y"],
34
            torch._jit_internal.get_callable_argument_names(
35
                fn_positional_or_keyword_args_only
36
            ),
37
        )
38

39
    # Tests that POSITIONAL_ONLY arguments are ignored.
40
    def test_get_callable_argument_names_positional_only(self):
41
        code = dedent(
42
            """
43
            def fn_positional_only_arg(x, /, y):
44
                return x + y
45
        """
46
        )
47

48
        fn_positional_only_arg = jit_utils._get_py3_code(code, "fn_positional_only_arg")
49
        self.assertEqual(
50
            ["y"],
51
            torch._jit_internal.get_callable_argument_names(fn_positional_only_arg),
52
        )
53

54
    # Tests that VAR_POSITIONAL arguments are ignored.
55
    def test_get_callable_argument_names_var_positional(self):
56
        # Tests that VAR_POSITIONAL arguments are ignored.
57
        def fn_var_positional_arg(x, *arg):
58
            return x + arg[0]
59

60
        self.assertEqual(
61
            ["x"],
62
            torch._jit_internal.get_callable_argument_names(fn_var_positional_arg),
63
        )
64

65
    # Tests that KEYWORD_ONLY arguments are ignored.
66
    def test_get_callable_argument_names_keyword_only(self):
67
        def fn_keyword_only_arg(x, *, y):
68
            return x + y
69

70
        self.assertEqual(
71
            ["x"], torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg)
72
        )
73

74
    # Tests that VAR_KEYWORD arguments are ignored.
75
    def test_get_callable_argument_names_var_keyword(self):
76
        def fn_var_keyword_arg(**args):
77
            return args["x"] + args["y"]
78

79
        self.assertEqual(
80
            [], torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg)
81
        )
82

83
    # Tests that a function signature containing various different types of
84
    # arguments are ignored.
85
    def test_get_callable_argument_names_hybrid(self):
86
        code = dedent(
87
            """
88
            def fn_hybrid_args(x, /, y, *args, **kwargs):
89
                return x + y + args[0] + kwargs['z']
90
        """
91
        )
92
        fn_hybrid_args = jit_utils._get_py3_code(code, "fn_hybrid_args")
93
        self.assertEqual(
94
            ["y"], torch._jit_internal.get_callable_argument_names(fn_hybrid_args)
95
        )
96

97
    def test_checkscriptassertraisesregex(self):
98
        def fn():
99
            tup = (1, 2)
100
            return tup[2]
101

102
        self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")
103

104
        s = dedent(
105
            """
106
        def fn():
107
            tup = (1, 2)
108
            return tup[2]
109
        """
110
        )
111

112
        self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
113

114
    def test_no_tracer_warn_context_manager(self):
115
        torch._C._jit_set_tracer_state_warn(True)
116
        with jit_utils.NoTracerWarnContextManager() as no_warn:
117
            self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
118
        self.assertEqual(True, torch._C._jit_get_tracer_state_warn())
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.