pytorch

Форк
0
/
test_batch_mm.py 
290 строк · 9.9 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import torch
4
from torch.testing import FileCheck
5
from torch.testing._internal.jit_utils import JitTestCase
6

7

8
if __name__ == "__main__":
9
    raise RuntimeError(
10
        "This test file is not meant to be run directly, use:\n\n"
11
        "\tpython test/test_jit.py TESTNAME\n\n"
12
        "instead."
13
    )
14

15

16
class TestBatchMM(JitTestCase):
17
    @staticmethod
18
    def _get_test_tensors(n: int):
19
        return [
20
            torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
21
            if x % 2 == 0
22
            else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
23
            for x in range(n)
24
        ]
25

26
    def test_batch_mm_no_mutation(self):
27
        def test_batch_mm(
28
            T1: torch.Tensor,
29
            T2: torch.Tensor,
30
            T3: torch.Tensor,
31
            T4: torch.Tensor,
32
            T5: torch.Tensor,
33
            T6: torch.Tensor,
34
            T7: torch.Tensor,
35
            T8: torch.Tensor,
36
        ):
37
            return (
38
                torch.mm(T1, T2)
39
                + torch.mm(T3, T4)
40
                + torch.mm(T5, T6)
41
                + torch.mm(T7, T8)
42
            )
43

44
        test_batch_mm_scripted = torch.jit.script(test_batch_mm)
45

46
        tensors = TestBatchMM._get_test_tensors(8)
47
        expected = test_batch_mm(*tensors)
48

49
        FileCheck().check_count("aten::mm", 4, exactly=True).run(
50
            test_batch_mm_scripted.graph
51
        )
52
        self.run_pass("batch_mm", test_batch_mm_scripted.graph)
53
        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
54
            test_batch_mm_scripted.graph
55
        )
56

57
        actual = test_batch_mm_scripted(*tensors)
58
        self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
59

60
    def test_batch_mm_permitted_mutation(self):
61
        def test_batch_mm(
62
            T1: torch.Tensor,
63
            T2: torch.Tensor,
64
            T3: torch.Tensor,
65
            T4: torch.Tensor,
66
            T5: torch.Tensor,
67
            T6: torch.Tensor,
68
            T7: torch.Tensor,
69
            T8: torch.Tensor,
70
        ):
71
            result = {}
72
            result["product"] = (
73
                torch.mm(T1, T2)
74
                + torch.mm(T3, T4)
75
                + torch.mm(T5, T6)
76
                + torch.mm(T7, T8)
77
            )
78
            result["constant"] = torch.tensor([42.0])
79
            return result
80

81
        test_batch_mm_scripted = torch.jit.script(test_batch_mm)
82

83
        tensors = TestBatchMM._get_test_tensors(8)
84
        expected = test_batch_mm(*tensors)
85

86
        FileCheck().check_count("aten::mm", 4, exactly=True).run(
87
            test_batch_mm_scripted.graph
88
        )
89
        self.run_pass("batch_mm", test_batch_mm_scripted.graph)
90
        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
91
            test_batch_mm_scripted.graph
92
        )
93

94
        actual = test_batch_mm_scripted(*tensors)
95
        self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
96

97
    def test_batch_mm_prohibited_mutation(self):
98
        @torch.jit.script
99
        def test_batch_mm(n: int):
100
            T1 = torch.zeros((n, n))
101
            T2 = torch.zeros((n, n))
102
            T3 = torch.zeros((n, n))
103
            T4 = torch.zeros((n, n))
104
            T5 = torch.zeros((n, n))
105
            T6 = torch.zeros((n, n))
106
            T7 = torch.zeros((n, n))
107
            T8 = torch.zeros((n, n))
108
            torch.relu_(T1)
109
            result = (
110
                torch.mm(T1, T2)
111
                + torch.mm(T3, T4)
112
                + torch.mm(T5, T6)
113
                + torch.mm(T7, T8)
114
            )
115
            return result
116

117
        FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph)
118
        self.run_pass("batch_mm", test_batch_mm.graph)
119
        FileCheck().check_count("aten::mm", 4, exactly=True).check_not(
120
            "prim::MMTreeReduce"
121
        ).run(test_batch_mm.graph)
122

123
    def test_batch_mm_prohibited_mutation_multiple_adds(self):
124
        @torch.jit.script
125
        def test_batch_mm(n: int):
126
            T1 = torch.zeros((n, n))
127
            T2 = torch.zeros((n, n))
128
            T3 = torch.zeros((n, n))
129
            T4 = torch.zeros((n, n))
130
            T5 = torch.zeros((n, n))
131
            T6 = torch.zeros((n, n))
132
            T7 = torch.zeros((n, n))
133
            T8 = torch.zeros((n, n))
134
            T9 = torch.zeros((n, n))
135
            T10 = torch.zeros((n, n))
136
            torch.relu_(T1)
137
            result = {}
138
            result["no_mutated_parameters"] = (
139
                torch.mm(T2, T3)
140
                + torch.mm(T4, T5)
141
                + torch.mm(T6, T7)
142
                + torch.mm(T8, T9)
143
            )
144
            result["all_parameters"] = (
145
                torch.mm(T1, T2)
146
                + torch.mm(T3, T4)
147
                + torch.mm(T5, T6)
148
                + torch.mm(T7, T8)
149
                + torch.mm(T9, T10)
150
            )
151
            return result
152

153
        self.run_pass("batch_mm", test_batch_mm.graph)
154
        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count(
155
            "aten::mm", 5, exactly=True
156
        ).run(test_batch_mm.graph)
157

158
    def test_batch_mm_prohibited_mutation_if_node(self):
159
        @torch.jit.script
160
        def test_batch_mm(n: int, use_t1: bool):
161
            T1 = torch.zeros((n, n))
162
            T2 = torch.zeros((n, n))
163
            T3 = torch.zeros((n, n))
164
            T4 = torch.zeros((n, n))
165
            T5 = torch.zeros((n, n))
166
            T6 = torch.zeros((n, n))
167
            T7 = torch.zeros((n, n))
168
            T8 = torch.zeros((n, n))
169
            T9 = torch.zeros((n, n))
170
            T10 = torch.zeros((n, n))
171
            if use_t1:
172
                torch.relu_(T1)
173
                return (
174
                    torch.mm(T1, T2)
175
                    + torch.mm(T3, T4)
176
                    + torch.mm(T5, T6)
177
                    + torch.mm(T7, T8)
178
                    + torch.mm(T9, T10)
179
                )
180
            else:
181
                return (
182
                    torch.mm(T2, T3)
183
                    + torch.mm(T4, T5)
184
                    + torch.mm(T6, T7)
185
                    + torch.mm(T8, T9)
186
                )
187

188
        self.run_pass("batch_mm", test_batch_mm.graph)
189
        FileCheck().check_count("aten::mm", 5, exactly=True).check_count(
190
            "prim::MMTreeReduce", 1, exactly=True
191
        ).run(test_batch_mm.graph)
192

193
    def test_batch_mm_side_permitted_mutation(self):
194
        @torch.jit.script
195
        def test_batch_mm(n: int):
196
            result = {}
197
            A = torch.zeros((n, n))
198
            T1 = torch.zeros((n, n))
199
            T2 = torch.zeros((n, n))
200
            T3 = torch.zeros((n, n))
201
            T4 = torch.zeros((n, n))
202
            T5 = torch.zeros((n, n))
203
            T6 = torch.zeros((n, n))
204
            T7 = torch.zeros((n, n))
205
            T8 = torch.zeros((n, n))
206
            result["T1"] = torch.mm(A, T1)
207
            result["T2"] = torch.mm(A, T2)
208
            result["T3"] = torch.mm(A, T3)
209
            result["T4"] = torch.mm(A, T4)
210
            result["T5"] = torch.mm(A, T5)
211
            result["T6"] = torch.mm(A, T6)
212
            result["T7"] = torch.mm(A, T7)
213
            result["T8"] = torch.mm(A, T8)
214
            return result
215

216
        FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph)
217
        self.run_pass("batch_mm", test_batch_mm.graph)
218
        FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not(
219
            "aten::mm"
220
        ).run(test_batch_mm.graph)
221

222
    def test_batch_mm_side_prohibited_mutation_uncommon_side(self):
223
        @torch.jit.script
224
        def test_batch_mm(n: int):
225
            A = torch.zeros((n, n))
226
            T1 = torch.zeros((n, n))
227
            T2 = torch.zeros((n, n))
228
            T3 = torch.zeros((n, n))
229
            T4 = torch.zeros((n, n))
230
            T5 = torch.zeros((n, n))
231
            T6 = torch.zeros((n, n))
232
            T7 = torch.zeros((n, n))
233
            T8 = torch.zeros((n, n))
234
            T9 = torch.zeros((n, n))
235
            T10 = torch.zeros((n, n))
236
            torch.relu_(T1)
237
            result = {}
238
            result["T1"] = torch.mm(A, T1)
239
            result["T2"] = torch.mm(A, T2)
240
            result["T3"] = torch.mm(A, T3)
241
            result["T4"] = torch.mm(A, T4)
242
            result["T5"] = torch.mm(A, T5)
243
            result["T6"] = torch.mm(A, T6)
244
            result["T7"] = torch.mm(A, T7)
245
            result["T8"] = torch.mm(A, T8)
246
            result["T9"] = torch.mm(A, T9)
247
            result["T10"] = torch.mm(A, T10)
248
            return result
249

250
        FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
251
        self.run_pass("batch_mm", test_batch_mm.graph)
252

253
        FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph)
254
        FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run(
255
            test_batch_mm.graph
256
        )
257

258
    def test_batch_mm_side_prohibited_mutation_common_side(self):
259
        @torch.jit.script
260
        def test_batch_mm(n: int):
261
            A = torch.zeros((n, n))
262
            T1 = torch.zeros((n, n))
263
            T2 = torch.zeros((n, n))
264
            T3 = torch.zeros((n, n))
265
            T4 = torch.zeros((n, n))
266
            T5 = torch.zeros((n, n))
267
            T6 = torch.zeros((n, n))
268
            T7 = torch.zeros((n, n))
269
            T8 = torch.zeros((n, n))
270
            T9 = torch.zeros((n, n))
271
            T10 = torch.zeros((n, n))
272
            torch.relu_(A)
273
            result = {}
274
            result["T1"] = torch.mm(A, T1)
275
            result["T2"] = torch.mm(A, T2)
276
            result["T3"] = torch.mm(A, T3)
277
            result["T4"] = torch.mm(A, T4)
278
            result["T5"] = torch.mm(A, T5)
279
            result["T6"] = torch.mm(A, T6)
280
            result["T7"] = torch.mm(A, T7)
281
            result["T8"] = torch.mm(A, T8)
282
            result["T9"] = torch.mm(A, T9)
283
            result["T10"] = torch.mm(A, T10)
284
            return result
285

286
        FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
287
        self.run_pass("batch_mm", test_batch_mm.graph)
288
        FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
289
            "prim::MMBatchSide"
290
        ).run(test_batch_mm.graph)
291

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.