1
# Owner(s): ["oncall: jit"]
5
from textwrap import dedent
8
from torch.testing._internal import jit_utils
11
# Make the helper files in test/ importable
12
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13
sys.path.append(pytorch_test_dir)
14
from torch.testing._internal.jit_utils import JitTestCase
17
if __name__ == "__main__":
19
"This test file is not meant to be run directly, use:\n\n"
20
"\tpython test/test_jit.py TESTNAME\n\n"
25
# Tests various JIT-related utility functions.
26
class TestJitUtils(JitTestCase):
27
# Tests that POSITIONAL_OR_KEYWORD arguments are captured.
28
def test_get_callable_argument_names_positional_or_keyword(self):
29
def fn_positional_or_keyword_args_only(x, y):
34
torch._jit_internal.get_callable_argument_names(
35
fn_positional_or_keyword_args_only
39
# Tests that POSITIONAL_ONLY arguments are ignored.
40
def test_get_callable_argument_names_positional_only(self):
43
def fn_positional_only_arg(x, /, y):
48
fn_positional_only_arg = jit_utils._get_py3_code(code, "fn_positional_only_arg")
51
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg),
54
# Tests that VAR_POSITIONAL arguments are ignored.
55
def test_get_callable_argument_names_var_positional(self):
56
# Tests that VAR_POSITIONAL arguments are ignored.
57
def fn_var_positional_arg(x, *arg):
62
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg),
65
# Tests that KEYWORD_ONLY arguments are ignored.
66
def test_get_callable_argument_names_keyword_only(self):
67
def fn_keyword_only_arg(x, *, y):
71
["x"], torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg)
74
# Tests that VAR_KEYWORD arguments are ignored.
75
def test_get_callable_argument_names_var_keyword(self):
76
def fn_var_keyword_arg(**args):
77
return args["x"] + args["y"]
80
[], torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg)
83
# Tests that a function signature containing various different types of
84
# arguments are ignored.
85
def test_get_callable_argument_names_hybrid(self):
88
def fn_hybrid_args(x, /, y, *args, **kwargs):
89
return x + y + args[0] + kwargs['z']
92
fn_hybrid_args = jit_utils._get_py3_code(code, "fn_hybrid_args")
94
["y"], torch._jit_internal.get_callable_argument_names(fn_hybrid_args)
97
def test_checkscriptassertraisesregex(self):
102
self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")
112
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
114
def test_no_tracer_warn_context_manager(self):
115
torch._C._jit_set_tracer_state_warn(True)
116
with jit_utils.NoTracerWarnContextManager() as no_warn:
117
self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
118
self.assertEqual(True, torch._C._jit_get_tracer_state_warn())