1
# Owner(s): ["oncall: jit"]
6
from typing import Any, Dict, List, Optional, Tuple
10
import torch.testing._internal.jit_utils
11
from jit.test_module_interface import TestModuleInterface # noqa: F401
13
from torch.testing import FileCheck
14
from torch.testing._internal.common_utils import freeze_rng_state
15
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
18
# Make the helper files in test/ importable
19
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
20
sys.path.append(pytorch_test_dir)
22
if __name__ == "__main__":
24
"This test file is not meant to be run directly, use:\n\n"
25
"\tpython test/test_jit.py TESTNAME\n\n"
30
class TestMisc(JitTestCase):
31
def test_joined_str(self):
33
hello, test = "Hello", "test"
34
print(f"{hello + ' ' + test}, I'm a {test}")
37
print(f"stuff before {hi}")
38
print(f"{hi} stuff after")
41
x = torch.arange(4.0, requires_grad=True)
42
# TODO: Add support for f-strings in string parser frontend
43
# self.checkScript(func, [x], optimize=True, capture_output=True)
45
with self.capture_stdout() as captured:
48
scripted = torch.jit.script(func)
49
with self.capture_stdout() as captured_script:
52
self.assertEqual(out, out_script)
53
self.assertEqual(captured, captured_script)
55
def test_kwarg_support(self):
56
with self.assertRaisesRegex(
57
torch.jit.frontend.NotSupportedError, "variable number of arguments"
60
class M(torch.nn.Module):
61
def forward(self, *, n_tokens: int, device_name: str = 2):
66
class M(torch.nn.Module):
67
def forward(self, *, n_tokens: int, device_name: str):
68
return n_tokens, device_name
70
sm = torch.jit.script(M())
72
with self.assertRaisesRegex(
73
RuntimeError, "missing value for argument 'n_tokens'"
77
with self.assertRaisesRegex(RuntimeError, "positional arg"):
80
self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))
82
def test_tuple_subscripted_assign(self):
83
with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
86
def foo(a: Tuple[int, int]) -> None:
89
with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
92
def bar(a: Tuple[int, int]) -> None:
95
def test_subexpression_List_Future(self):
97
def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
100
FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)
102
def test_subexpression_Future_annotate(self):
104
def fn() -> torch.jit.Future[int]:
105
x: List[torch.jit.Future[int]] = []
108
FileCheck().check("Future[int][]").run(fn.graph)
110
def test_future_isinstance(self):
112
def fn(x: Any) -> torch.jit.Future[int]:
113
assert isinstance(x, jit.Future[int])
116
FileCheck().check("Future[int]").run(fn.graph)
118
def test_str_refine_any(self):
119
def forward(x: Any) -> str:
120
if isinstance(x, str):
124
forward = torch.jit.script(forward)
125
self.assertEqual(forward(1), "foo")
126
self.assertEqual(forward("bar"), "bar")
128
def test_subexpression_Tuple_int_int_Future(self):
131
x: Tuple[int, int, torch.jit.Future[int]]
132
) -> Tuple[int, torch.jit.Future[int]]:
135
FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
139
def test_subexpression_Dict_int_Future(self):
141
def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
144
FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)
146
def test_subexpression_Optional(self):
149
x: Optional[Dict[int, torch.jit.Future[int]]]
150
) -> Optional[torch.jit.Future[int]]:
156
FileCheck().check("Dict(int, Future(int))?").run(fn.graph)
158
def test_if_returning_any(self):
160
Check that an if statement can return different
161
types early from each branch when the return
162
type of the function is Any.
165
def if_function(inp: torch.Tensor) -> Any:
166
if inp.shape[0] == 1:
171
self.checkScript(if_function, (torch.randn(5),))
173
def test_hacked_twin(self):
175
with freeze_rng_state():
176
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
188
out1 = torch.ops.aten.index_put.hacked_twin(
189
input, [index], value, accumulate=False
191
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
192
self.assertEqual(out1, out2)
194
torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False)
195
torch.index_put_(input1, [index1], value1, accumulate=False)
196
self.assertEqual(input, input1)
198
def test_unsafe_hacked_twin(self):
200
with freeze_rng_state():
201
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
213
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
214
input, [index], value, accumulate=False
216
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
217
self.assertEqual(out1, out2)
219
torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
220
torch.index_put(input1, [index1], value1, accumulate=False)
221
self.assertEqual(input, input1)
223
def index_put_fn(input, index, value):
224
return torch.ops.aten._unsafe_index_put(
225
input, [index], value, accumulate=False
228
input2, index2, value2 = gen_data()
229
script_index_put_fn = torch.jit.script(index_put_fn)
230
expect = index_put_fn(input2.clone(), index2, value2)
231
actual = script_index_put_fn(input2.clone(), index2, value2)
232
self.assertEqual(expect, actual)
234
def index_fn(input, index, value):
235
return torch.ops.aten._unsafe_index_put(
236
input, [index], value, accumulate=False
239
script_index_fn = torch.jit.script(index_fn)
240
expect = index_fn(input2.clone(), index2, value2)
241
actual = script_index_fn(input2.clone(), index2, value2)
242
self.assertEqual(expect, actual)
244
def test_export_opnames_interface(self):
246
class OneTwoModule(nn.Module):
247
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
250
def two(self, x: torch.Tensor) -> torch.Tensor:
253
def forward(self, x: torch.Tensor) -> torch.Tensor:
256
class FooMod(nn.Module):
257
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
260
def two(self, x: torch.Tensor) -> torch.Tensor:
263
def forward(self, x: torch.Tensor) -> torch.Tensor:
264
return self.one(self.two(x), x)
266
class BarMod(nn.Module):
267
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
270
def two(self, x: torch.Tensor) -> torch.Tensor:
273
def forward(self, x: torch.Tensor) -> torch.Tensor:
274
return self.two(self.one(x, x))
276
make_global(OneTwoModule)
281
def __init__(self) -> None:
285
def forward(self, x: torch.Tensor) -> torch.Tensor:
286
return self.sub.forward(x)
288
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
289
return mod_list[0].forward(x) + mod_list[1].forward(x)
291
torch._C._enable_mobile_interface_call_export()
292
scripted_M_mod = torch.jit.script(M())
294
{"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
295
set(torch.jit.export_opnames(scripted_M_mod))
299
scripted_M_mod.sub = torch.jit.script(FooMod())
301
{"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
302
set(torch.jit.export_opnames(scripted_M_mod))
306
def test_math_inf(self):
312
self.checkScript(foo, ())
314
def test_list_literal_infer(self):
315
def expects_intlist(x: List[int]):
320
return expects_intlist([])
322
self.checkScript(foo, ())
324
def annotated_list_fail():
325
return expects_intlist(torch.jit.annotate([], List[Tensor])) # noqa: F821
327
with self.assertRaises(RuntimeError):
328
torch.jit.script(annotated_list_fail)
330
def non_temporary_fail():
332
return expects_intlist(a)
334
with self.assertRaises(RuntimeError):
335
torch.jit.script(non_temporary_fail)
341
FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)
343
def test_legacy_tensor_constructor(self):
344
# testing PyObject overload
345
def test_all_dtypes():
347
torch.BoolTensor([2]),
348
torch.LongTensor([3]),
349
torch.ByteTensor([4]),
350
torch.CharTensor([5]),
351
torch.DoubleTensor([6]),
352
torch.FloatTensor([7]),
353
torch.IntTensor([8]),
354
torch.ShortTensor([1]),
355
torch.HalfTensor([1]),
358
self.checkScript(test_all_dtypes, ())
360
# now test empty overload
361
def empty_overload():
362
return torch.LongTensor(2, 3, 4)
364
eager = empty_overload()
365
jit = torch.jit.script(empty_overload)()
368
self.assertEqual(eager, jit)
371
return torch.DoubleTensor()
373
self.checkScript(no_inputs, ())
377
return torch.LongTensor(1, [2])
379
with self.assertRaisesRegex(
380
RuntimeError, "multiple positional arguments that were not all integers"
382
torch.jit.script(multiple_args)
386
return torch.LongTensor(hello="1")
388
with self.assertRaisesRegex(RuntimeError, "hello"):
389
torch.jit.script(bad_kwarg)
391
def test_broadcasting_list(self):
393
Test BroadcastingList and torch.nn._size_N_t alias
395
from torch._jit_internal import BroadcastingList2
396
from torch.nn.common_types import _size_2_t
398
def sum_i(x: _size_2_t) -> int:
401
def sum_f(x: BroadcastingList2[float]) -> float:
404
self.assertTrue(torch.jit.script(sum_i)(4) == 8)
405
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)
407
def test_parse_ir_annotate(self):
410
%3 : int[] = prim::Constant[value=annotate(List[int], [])]()
413
graph = torch._C.parse_ir(ir, True)
414
func = torch._C._create_function_from_graph("forward", graph)
416
self.assertTrue(ret == [])
418
def test_parse_ir_single_element_tensor_positive(self):
421
%7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]()
424
graph = torch._C.parse_ir(ir, True)
425
func = torch._C._create_function_from_graph("forward", graph)
427
self.assertTrue(ret.numel() == 1)
428
self.assertTrue(len(ret.size()) == 1)
430
def test_parse_ir_single_element_tensor_negative(self):
433
%7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]()
436
graph = torch._C.parse_ir(ir, True)
437
func = torch._C._create_function_from_graph("forward", graph)
439
self.assertTrue(ret.numel() == 1)
440
self.assertTrue(len(ret.size()) == 1)
442
def test_script_many_decorators(self):
443
def no_op_decorator(f):
451
def foo(x, dim: int):
452
return x.unsqueeze(dim)
458
scripted = torch.jit.script(foo)
459
actual = scripted(x, 0)
460
torch.testing.assert_close(expected, actual)
462
@unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support")
463
def test_pow_multiple_dtype(self):
464
# https://github.com/pytorch/pytorch/issues/75476
465
def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
470
x = torch.rand((2, 2), dtype=torch.half, device="cuda")
474
script_fn = torch.jit.script(fn)
478
self.assertEqual(ref, res)
480
def test_jit_get_operation_order(self):
481
# See https://github.com/pytorch/pytorch/pull/107138.
482
# Depending on order of operator registration, you can get different
483
# order of overloads in the JIT operator registry.
484
# This is to verify that the order of operators returned by
485
# _jit_get_operation always puts aten ops first (i.e. by sorting
488
# Make sure that this chooses a "scalar" overload not a "complex" overload
489
ret = torch.ops.aten.add(4, 3.3)
490
self.assertFalse("complex" in str(ret.dtype))
492
# "Scalar" overload is a normal aten op; "complex" is added by torchscript.
493
# We want "Scalar" to come before "complex".
494
op, override_names = torch._C._jit_get_operation("aten::add")
495
print(override_names)
497
i for i, name in enumerate(override_names) if name == "complex"
500
i for i, name in enumerate(override_names) if name == "Scalar"
503
self.assertTrue(len(complex_indices) > 0)
504
self.assertTrue(len(Scalar_indices) > 0)
505
self.assertTrue(complex_indices[0] > Scalar_indices[0])