1
# Owner(s): ["oncall: jit"]
10
# Make the helper files in test/ importable
11
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12
sys.path.append(pytorch_test_dir)
13
from torch.testing._internal.jit_utils import JitTestCase
16
if __name__ == "__main__":
18
"This test file is not meant to be run directly, use:\n\n"
19
"\tpython test/test_jit.py TESTNAME\n\n"
24
# Tests that Python slice class is supported in TorchScript
25
class TestSlice(JitTestCase):
26
def test_slice_kwarg(self):
27
def slice_kwarg(x: List[int]):
28
return x[slice(1, stop=2)]
30
with self.assertRaisesRegex(
31
RuntimeError, "Slice does not accept any keyword arguments"
33
torch.jit.script(slice_kwarg)
35
def test_slice_three_nones(self):
36
def three_nones(x: List[int]):
37
return x[slice(None, None, None)]
39
self.checkScript(three_nones, (range(10),))
41
def test_slice_two_nones(self):
42
def two_nones(x: List[int]):
43
return x[slice(None, None)]
45
self.checkScript(two_nones, (range(10),))
47
def test_slice_one_none(self):
48
def one_none(x: List[int]):
51
self.checkScript(one_none, (range(10),))
53
def test_slice_stop_only(self):
57
self.checkScript(fn, (range(10),))
59
def test_slice_stop_only_with_nones(self):
61
return x[slice(None, 5, None)]
63
self.checkScript(fn, (range(10),))
65
def test_slice_start_stop(self):
69
self.checkScript(fn, (range(10),))
71
def test_slice_start_stop_with_none(self):
73
return x[slice(1, 5, None)]
75
self.checkScript(fn, (range(10),))
77
def test_slice_start_stop_step(self):
79
return x[slice(0, 6, 2)]
81
self.checkScript(fn, (range(10),))
83
def test_slice_string(self):
85
return x[slice(None, 3, 1)]
87
self.checkScript(fn, ("foo_bar",))
89
def test_slice_tensor(self):
90
def fn(x: torch.Tensor):
91
return x[slice(None, 3, 1)]
93
self.checkScript(fn, (torch.ones(10),))
95
def test_slice_tensor_multidim(self):
96
def fn(x: torch.Tensor):
97
return x[slice(None, 3, 1), 0]
99
self.checkScript(fn, (torch.ones((10, 10)),))
101
def test_slice_tensor_multidim_with_dots(self):
102
def fn(x: torch.Tensor):
103
return x[slice(None, 3, 1), ...]
105
self.checkScript(fn, (torch.ones((10, 10)),))
107
def test_slice_as_variable(self):
108
def fn(x: List[int]):
112
self.checkScript(fn, (range(10),))
114
def test_slice_stop_clipped(self):
115
def fn(x: List[int]):
116
return x[slice(1000)]
118
self.checkScript(fn, (range(10),))
120
def test_slice_dynamic_index(self):
126
return slice1 + slice2
128
self.checkScript(t, (torch.zeros(3, 2, 3),))
130
def test_tuple_slicing(self):
140
self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
141
scripted_fn = torch.jit.script(tuple_slice)
142
self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
143
tuple_graph = scripted_fn.graph
144
slices = tuple_graph.findAllNodes("prim::TupleConstruct")
145
num_outputs = {len(x.output().type().elements()) for x in slices}
146
# there should be only one tupleSlice with length of 2
147
self.assertTrue(num_outputs == {2})
148
self.run_pass("lower_all_tuples", tuple_graph)
149
self.assertTrue("Tuple" not in str(tuple_graph))
151
def test_module_list_slicing(self):
152
class Bar(torch.nn.Module):
153
def __init__(self, identifier: str):
155
self.identifier = identifier
160
class Foo(torch.nn.Module):
161
def __init__(self) -> None:
163
module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")]
164
self.test = torch.nn.ModuleList(module_list)
167
return self.test[::-2], self.test[1:4:]
169
scripted_foo = torch.jit.script(Foo())
170
result1, result2 = scripted_foo()
172
self.assertEqual(len(result1), 3)
173
self.assertEqual(result1[0].identifier, "E")
174
self.assertEqual(result1[1].identifier, "C")
175
self.assertEqual(result1[2].identifier, "A")
177
self.assertEqual(len(result2), 3)
178
self.assertEqual(result2[0].identifier, "B")
179
self.assertEqual(result2[1].identifier, "C")
180
self.assertEqual(result2[2].identifier, "D")