4
from torch.testing import FileCheck
5
from torch.testing._internal.jit_utils import JitTestCase
8
if __name__ == "__main__":
10
"This test file is not meant to be run directly, use:\n\n"
11
"\tpython test/test_jit.py TESTNAME\n\n"
16
class TestBatchMM(JitTestCase):
18
def _get_test_tensors(n: int):
20
torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
22
else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
26
def test_batch_mm_no_mutation(self):
44
test_batch_mm_scripted = torch.jit.script(test_batch_mm)
46
tensors = TestBatchMM._get_test_tensors(8)
47
expected = test_batch_mm(*tensors)
49
FileCheck().check_count("aten::mm", 4, exactly=True).run(
50
test_batch_mm_scripted.graph
52
self.run_pass("batch_mm", test_batch_mm_scripted.graph)
53
FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
54
test_batch_mm_scripted.graph
57
actual = test_batch_mm_scripted(*tensors)
58
self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
60
def test_batch_mm_permitted_mutation(self):
78
result["constant"] = torch.tensor([42.0])
81
test_batch_mm_scripted = torch.jit.script(test_batch_mm)
83
tensors = TestBatchMM._get_test_tensors(8)
84
expected = test_batch_mm(*tensors)
86
FileCheck().check_count("aten::mm", 4, exactly=True).run(
87
test_batch_mm_scripted.graph
89
self.run_pass("batch_mm", test_batch_mm_scripted.graph)
90
FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
91
test_batch_mm_scripted.graph
94
actual = test_batch_mm_scripted(*tensors)
95
self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
97
def test_batch_mm_prohibited_mutation(self):
99
def test_batch_mm(n: int):
100
T1 = torch.zeros((n, n))
101
T2 = torch.zeros((n, n))
102
T3 = torch.zeros((n, n))
103
T4 = torch.zeros((n, n))
104
T5 = torch.zeros((n, n))
105
T6 = torch.zeros((n, n))
106
T7 = torch.zeros((n, n))
107
T8 = torch.zeros((n, n))
117
FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph)
118
self.run_pass("batch_mm", test_batch_mm.graph)
119
FileCheck().check_count("aten::mm", 4, exactly=True).check_not(
121
).run(test_batch_mm.graph)
123
def test_batch_mm_prohibited_mutation_multiple_adds(self):
125
def test_batch_mm(n: int):
126
T1 = torch.zeros((n, n))
127
T2 = torch.zeros((n, n))
128
T3 = torch.zeros((n, n))
129
T4 = torch.zeros((n, n))
130
T5 = torch.zeros((n, n))
131
T6 = torch.zeros((n, n))
132
T7 = torch.zeros((n, n))
133
T8 = torch.zeros((n, n))
134
T9 = torch.zeros((n, n))
135
T10 = torch.zeros((n, n))
138
result["no_mutated_parameters"] = (
144
result["all_parameters"] = (
153
self.run_pass("batch_mm", test_batch_mm.graph)
154
FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count(
155
"aten::mm", 5, exactly=True
156
).run(test_batch_mm.graph)
158
def test_batch_mm_prohibited_mutation_if_node(self):
160
def test_batch_mm(n: int, use_t1: bool):
161
T1 = torch.zeros((n, n))
162
T2 = torch.zeros((n, n))
163
T3 = torch.zeros((n, n))
164
T4 = torch.zeros((n, n))
165
T5 = torch.zeros((n, n))
166
T6 = torch.zeros((n, n))
167
T7 = torch.zeros((n, n))
168
T8 = torch.zeros((n, n))
169
T9 = torch.zeros((n, n))
170
T10 = torch.zeros((n, n))
188
self.run_pass("batch_mm", test_batch_mm.graph)
189
FileCheck().check_count("aten::mm", 5, exactly=True).check_count(
190
"prim::MMTreeReduce", 1, exactly=True
191
).run(test_batch_mm.graph)
193
def test_batch_mm_side_permitted_mutation(self):
195
def test_batch_mm(n: int):
197
A = torch.zeros((n, n))
198
T1 = torch.zeros((n, n))
199
T2 = torch.zeros((n, n))
200
T3 = torch.zeros((n, n))
201
T4 = torch.zeros((n, n))
202
T5 = torch.zeros((n, n))
203
T6 = torch.zeros((n, n))
204
T7 = torch.zeros((n, n))
205
T8 = torch.zeros((n, n))
206
result["T1"] = torch.mm(A, T1)
207
result["T2"] = torch.mm(A, T2)
208
result["T3"] = torch.mm(A, T3)
209
result["T4"] = torch.mm(A, T4)
210
result["T5"] = torch.mm(A, T5)
211
result["T6"] = torch.mm(A, T6)
212
result["T7"] = torch.mm(A, T7)
213
result["T8"] = torch.mm(A, T8)
216
FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph)
217
self.run_pass("batch_mm", test_batch_mm.graph)
218
FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not(
220
).run(test_batch_mm.graph)
222
def test_batch_mm_side_prohibited_mutation_uncommon_side(self):
224
def test_batch_mm(n: int):
225
A = torch.zeros((n, n))
226
T1 = torch.zeros((n, n))
227
T2 = torch.zeros((n, n))
228
T3 = torch.zeros((n, n))
229
T4 = torch.zeros((n, n))
230
T5 = torch.zeros((n, n))
231
T6 = torch.zeros((n, n))
232
T7 = torch.zeros((n, n))
233
T8 = torch.zeros((n, n))
234
T9 = torch.zeros((n, n))
235
T10 = torch.zeros((n, n))
238
result["T1"] = torch.mm(A, T1)
239
result["T2"] = torch.mm(A, T2)
240
result["T3"] = torch.mm(A, T3)
241
result["T4"] = torch.mm(A, T4)
242
result["T5"] = torch.mm(A, T5)
243
result["T6"] = torch.mm(A, T6)
244
result["T7"] = torch.mm(A, T7)
245
result["T8"] = torch.mm(A, T8)
246
result["T9"] = torch.mm(A, T9)
247
result["T10"] = torch.mm(A, T10)
250
FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
251
self.run_pass("batch_mm", test_batch_mm.graph)
253
FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph)
254
FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run(
258
def test_batch_mm_side_prohibited_mutation_common_side(self):
260
def test_batch_mm(n: int):
261
A = torch.zeros((n, n))
262
T1 = torch.zeros((n, n))
263
T2 = torch.zeros((n, n))
264
T3 = torch.zeros((n, n))
265
T4 = torch.zeros((n, n))
266
T5 = torch.zeros((n, n))
267
T6 = torch.zeros((n, n))
268
T7 = torch.zeros((n, n))
269
T8 = torch.zeros((n, n))
270
T9 = torch.zeros((n, n))
271
T10 = torch.zeros((n, n))
274
result["T1"] = torch.mm(A, T1)
275
result["T2"] = torch.mm(A, T2)
276
result["T3"] = torch.mm(A, T3)
277
result["T4"] = torch.mm(A, T4)
278
result["T5"] = torch.mm(A, T5)
279
result["T6"] = torch.mm(A, T6)
280
result["T7"] = torch.mm(A, T7)
281
result["T8"] = torch.mm(A, T8)
282
result["T9"] = torch.mm(A, T9)
283
result["T10"] = torch.mm(A, T10)
286
FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
287
self.run_pass("batch_mm", test_batch_mm.graph)
288
FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
290
).run(test_batch_mm.graph)