pytorch

Форк
0
/
test_slice.py 
180 строк · 5.2 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5
from typing import List
6

7
import torch
8

9

10
# Make the helper files in test/ importable
11
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12
sys.path.append(pytorch_test_dir)
13
from torch.testing._internal.jit_utils import JitTestCase
14

15

16
if __name__ == "__main__":
17
    raise RuntimeError(
18
        "This test file is not meant to be run directly, use:\n\n"
19
        "\tpython test/test_jit.py TESTNAME\n\n"
20
        "instead."
21
    )
22

23

24
# Tests that Python slice class is supported in TorchScript
25
class TestSlice(JitTestCase):
26
    def test_slice_kwarg(self):
27
        def slice_kwarg(x: List[int]):
28
            return x[slice(1, stop=2)]
29

30
        with self.assertRaisesRegex(
31
            RuntimeError, "Slice does not accept any keyword arguments"
32
        ):
33
            torch.jit.script(slice_kwarg)
34

35
    def test_slice_three_nones(self):
36
        def three_nones(x: List[int]):
37
            return x[slice(None, None, None)]
38

39
        self.checkScript(three_nones, (range(10),))
40

41
    def test_slice_two_nones(self):
42
        def two_nones(x: List[int]):
43
            return x[slice(None, None)]
44

45
        self.checkScript(two_nones, (range(10),))
46

47
    def test_slice_one_none(self):
48
        def one_none(x: List[int]):
49
            return x[slice(None)]
50

51
        self.checkScript(one_none, (range(10),))
52

53
    def test_slice_stop_only(self):
54
        def fn(x: List[int]):
55
            return x[slice(5)]
56

57
        self.checkScript(fn, (range(10),))
58

59
    def test_slice_stop_only_with_nones(self):
60
        def fn(x: List[int]):
61
            return x[slice(None, 5, None)]
62

63
        self.checkScript(fn, (range(10),))
64

65
    def test_slice_start_stop(self):
66
        def fn(x: List[int]):
67
            return x[slice(1, 5)]
68

69
        self.checkScript(fn, (range(10),))
70

71
    def test_slice_start_stop_with_none(self):
72
        def fn(x: List[int]):
73
            return x[slice(1, 5, None)]
74

75
        self.checkScript(fn, (range(10),))
76

77
    def test_slice_start_stop_step(self):
78
        def fn(x: List[int]):
79
            return x[slice(0, 6, 2)]
80

81
        self.checkScript(fn, (range(10),))
82

83
    def test_slice_string(self):
84
        def fn(x: str):
85
            return x[slice(None, 3, 1)]
86

87
        self.checkScript(fn, ("foo_bar",))
88

89
    def test_slice_tensor(self):
90
        def fn(x: torch.Tensor):
91
            return x[slice(None, 3, 1)]
92

93
        self.checkScript(fn, (torch.ones(10),))
94

95
    def test_slice_tensor_multidim(self):
96
        def fn(x: torch.Tensor):
97
            return x[slice(None, 3, 1), 0]
98

99
        self.checkScript(fn, (torch.ones((10, 10)),))
100

101
    def test_slice_tensor_multidim_with_dots(self):
102
        def fn(x: torch.Tensor):
103
            return x[slice(None, 3, 1), ...]
104

105
        self.checkScript(fn, (torch.ones((10, 10)),))
106

107
    def test_slice_as_variable(self):
108
        def fn(x: List[int]):
109
            a = slice(1)
110
            return x[a]
111

112
        self.checkScript(fn, (range(10),))
113

114
    def test_slice_stop_clipped(self):
115
        def fn(x: List[int]):
116
            return x[slice(1000)]
117

118
        self.checkScript(fn, (range(10),))
119

120
    def test_slice_dynamic_index(self):
121
        def t(x):
122
            slice1 = x[0:1]
123
            zero = 0
124
            one = zero + 1
125
            slice2 = x[zero:one]
126
            return slice1 + slice2
127

128
        self.checkScript(t, (torch.zeros(3, 2, 3),))
129

130
    def test_tuple_slicing(self):
131
        def tuple_slice(a):
132
            if bool(a):
133
                b = (1, 2, 3, 4)
134
            else:
135
                b = (4, 3, 2, 1)
136
            c = b[-4:4]
137
            e = c[1:-1]
138
            return e
139

140
        self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
141
        scripted_fn = torch.jit.script(tuple_slice)
142
        self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
143
        tuple_graph = scripted_fn.graph
144
        slices = tuple_graph.findAllNodes("prim::TupleConstruct")
145
        num_outputs = {len(x.output().type().elements()) for x in slices}
146
        # there should be only one tupleSlice with length of 2
147
        self.assertTrue(num_outputs == {2})
148
        self.run_pass("lower_all_tuples", tuple_graph)
149
        self.assertTrue("Tuple" not in str(tuple_graph))
150

151
    def test_module_list_slicing(self):
152
        class Bar(torch.nn.Module):
153
            def __init__(self, identifier: str):
154
                super().__init__()
155
                self.identifier = identifier
156

157
            def forward(self):
158
                return 0
159

160
        class Foo(torch.nn.Module):
161
            def __init__(self) -> None:
162
                super().__init__()
163
                module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")]
164
                self.test = torch.nn.ModuleList(module_list)
165

166
            def forward(self):
167
                return self.test[::-2], self.test[1:4:]
168

169
        scripted_foo = torch.jit.script(Foo())
170
        result1, result2 = scripted_foo()
171

172
        self.assertEqual(len(result1), 3)
173
        self.assertEqual(result1[0].identifier, "E")
174
        self.assertEqual(result1[1].identifier, "C")
175
        self.assertEqual(result1[2].identifier, "A")
176

177
        self.assertEqual(len(result2), 3)
178
        self.assertEqual(result2[0].identifier, "B")
179
        self.assertEqual(result2[1].identifier, "C")
180
        self.assertEqual(result2[2].identifier, "D")
181

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.