3
from torch.testing._internal.common_utils import TestCase, run_tests
4
from torch.fx.passes.reinplace import reinplace
5
from torch.fx.experimental.proxy_tensor import make_fx
6
from torch.fx.experimental.symbolic_shapes import ShapeEnv
7
from torch._dynamo.source import ConstantSource
8
from torch.fx.experimental.sym_node import SymNode
11
from functorch.experimental import functionalize
12
HAS_FUNCTIONALIZATION = True
14
HAS_FUNCTIONALIZATION = False
16
class TestReinplacePass(TestCase):
18
def test_reinplace_basic(self):
27
f2 = reinplace(make_fx(f)(inpt), inpt)
28
expected_out = f(inpt)
30
self.assertEqual(actual_out, expected_out)
31
self.assertExpectedInline(f2.code, """\
35
def forward(self, x_1):
36
clone = torch.ops.aten.clone.default(x_1); x_1 = None
37
add = torch.ops.aten.add_.Tensor(clone, 1); add = None
42
def test_reinplace_with_view(self):
53
f2 = reinplace(make_fx(f)(inpt), inpt)
54
expected_out = f(inpt)
56
self.assertEqual(actual_out, expected_out)
57
self.assertExpectedInline(f2.code, """\
61
def forward(self, x_1):
62
clone = torch.ops.aten.clone.default(x_1); x_1 = None
63
view = torch.ops.aten.view.default(clone, [-1])
64
add = torch.ops.aten.add.Tensor(clone, 1); clone = add = None
65
add_1 = torch.ops.aten.add_.Tensor(view, 1); add_1 = None
69
def test_reinplace_different_metadata(self):
78
f2 = reinplace(make_fx(f)(inpt), inpt)
79
expected_out = f(inpt)
81
self.assertEqual(actual_out, expected_out)
83
self.assertExpectedInline(f2.code, """\
87
def forward(self, a__1):
88
clone = torch.ops.aten.clone.default(a__1); a__1 = None
89
add = torch.ops.aten.add.Tensor(clone, 1)
90
ge = torch.ops.aten.ge.Tensor(add, clone); add = clone = None
94
def test_reinplace_overlapping_memory(self):
102
f2 = reinplace(make_fx(f)(inpt), inpt)
103
expected_out = f(inpt)
104
actual_out = f2(inpt)
105
self.assertEqual(actual_out, expected_out)
106
self.assertExpectedInline(f2.code, """\
110
def forward(self, a__1):
111
clone = torch.ops.aten.clone.default(a__1); a__1 = None
112
expand = torch.ops.aten.expand.default(clone, [4, 4]); clone = None
113
add = torch.ops.aten.add.Tensor(expand, 1); expand = None
120
def test_reinplace_scatter_op(self):
131
if not HAS_FUNCTIONALIZATION:
134
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
135
expected_out = f(inpt)
136
actual_out = f2(inpt)
137
self.assertEqual(actual_out, expected_out)
144
self.assertExpectedInline(f2.code, """\
148
def forward(self, a__1):
149
clone = torch.ops.aten.clone.default(a__1); a__1 = None
150
view = torch.ops.aten.view.default(clone, [-1]); view = None
151
view_1 = torch.ops.aten.view.default(clone, [-1])
152
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
153
view_2 = torch.ops.aten.view.default(select, [-1]); select = None
154
add = torch.ops.aten.add_.Tensor(view_2, 1); add = None
155
view_3 = torch.ops.aten.view.default(clone, [-1]); clone = None
156
select_1 = torch.ops.aten.select.int(view_3, 0, 0); select_1 = None
157
view_4 = torch.ops.aten.view.default(view_2, []); view_2 = view_4 = None
158
view_5 = torch.ops.aten.view.default(view_3, [4]); view_3 = None
159
view_6 = torch.ops.aten.view.default(view_5, [-1])
160
select_2 = torch.ops.aten.select.int(view_6, 0, 0); view_6 = None
161
view_7 = torch.ops.aten.view.default(select_2, [-1]); select_2 = view_7 = None
162
view_8 = torch.ops.aten.view.default(view_5, [-1])
163
add_1 = torch.ops.aten.add_.Tensor(view_5, view_8); view_8 = add_1 = None
167
def test_reinplace_scatter_twice(self):
176
if not HAS_FUNCTIONALIZATION:
179
inpt = torch.ones(4, 4)
180
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
181
expected_out = f(inpt)
182
actual_out = f2(inpt)
183
self.assertEqual(actual_out, expected_out)
184
self.assertExpectedInline(f2.code, """\
188
def forward(self, a__1):
189
clone = torch.ops.aten.clone.default(a__1); a__1 = None
190
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
191
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
192
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
193
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None
194
slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
195
select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = select_2 = None
196
slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
197
select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None
198
select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = select_4 = None
202
def test_reinplace_scatter_twice_with_different_view_op_valid(self):
208
good_mirror_of_b = a.as_strided((4,), (4,), 1)
216
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
219
inpt = torch.ones(4, 4)
220
f2 = reinplace(make_fx(f)(inpt), inpt)
221
expected_out = f(inpt)
222
actual_out = f2(inpt)
223
self.assertEqual(actual_out, expected_out)
224
self.assertExpectedInline(f2.code, """\
228
def forward(self, a__1):
229
clone = torch.ops.aten.clone.default(a__1); a__1 = None
230
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
231
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
232
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
233
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None
234
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
241
def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
247
good_mirror_of_b = a.as_strided((4,), (4,), 1)
252
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
255
inpt = torch.ones(4, 4)
256
f2 = reinplace(make_fx(f)(inpt), inpt)
257
expected_out = f(inpt)
258
actual_out = f2(inpt)
259
self.assertEqual(actual_out, expected_out)
260
self.assertExpectedInline(f2.code, """\
264
def forward(self, a__1):
265
clone = torch.ops.aten.clone.default(a__1); a__1 = None
266
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
267
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
268
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
269
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
270
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
271
select_int = torch.ops.aten.select.int(as_strided, 0, 0)
272
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = copy__default = None
276
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
282
bad_mirror_of_b = a.as_strided((4,), (4,), 0)
285
b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
288
inpt = torch.ones(4, 4)
289
f2 = reinplace(make_fx(f)(inpt), inpt)
290
expected_out = f(inpt)
291
actual_out = f2(inpt)
293
self.assertExpectedInline(f2.code, """\
297
def forward(self, a__1):
298
clone = torch.ops.aten.clone.default(a__1); a__1 = None
299
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
300
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
301
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
302
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
303
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None
304
select_int = torch.ops.aten.select.int(as_strided, 0, 1)
305
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = copy__default = None
310
def test_out_node_updated(self):
312
x = torch.zeros(2, 2)
315
z = torch.diagonal_scatter(x, y_updated)
319
if not HAS_FUNCTIONALIZATION:
321
f2 = reinplace(make_fx(functionalize(f))())
324
self.assertEqual(actual_out, expected_out)
325
self.assertExpectedInline(f2.code, """\
330
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
331
diagonal = torch.ops.aten.diagonal.default(zeros)
332
add = torch.ops.aten.add_.Tensor(diagonal, 1); diagonal = add = None
336
def test_reinplace_index_mutation(self):
338
a = torch.zeros(4, 4, 4)
339
a[:, 2:] = torch.ones(4, 2, 4)
342
if not HAS_FUNCTIONALIZATION:
344
f2 = reinplace(make_fx(functionalize(f))())
347
self.assertEqual(actual_out, expected_out)
348
self.assertExpectedInline(f2.code, """\
353
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
354
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
355
slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
356
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
357
copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = copy = None
358
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807); slice_3 = None
359
slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
360
slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = slice_5 = None
364
def test_reinplace_sym_input(self):
368
a = torch.select(x, 0, index)
373
x = torch.randn((4, 8, 16, 16), requires_grad=False)
375
shape_env = ShapeEnv()
376
symbol = shape_env.create_symbol(index, source=ConstantSource(
377
f"__testing_only{len(shape_env.var_to_val)}"))
378
sym_index = torch.SymInt(SymNode(symbol, shape_env, int, hint=index))
380
inpt = [x, sym_index]
381
f2 = reinplace(make_fx(f)(*inpt), *inpt)
383
real_inpt = [x, index]
384
expected_out = f(*real_inpt)
385
actual_out = f2(*real_inpt)
386
self.assertEqual(actual_out, expected_out)
388
self.assertExpectedInline(f2.code, """\
392
def forward(self, x_1, index_1):
393
select = torch.ops.aten.select.int(x_1, 0, index_1); x_1 = index_1 = None
394
clone = torch.ops.aten.clone.default(select); select = None
395
add = torch.ops.aten.add_.Tensor(clone, 1); add = None
400
if __name__ == '__main__':