1
# Owner(s): ["module: functionalization"]
3
from torch.testing._internal.common_utils import TestCase, run_tests
4
from torch.fx.passes.reinplace import reinplace
5
from torch.fx.experimental.proxy_tensor import make_fx
8
from functorch.experimental import functionalize
9
HAS_FUNCTIONALIZATION = True
11
HAS_FUNCTIONALIZATION = False
13
class TestReinplacePass(TestCase):
15
def test_reinplace_basic(self):
16
# Basic test: the out-of-place add() call should be converted
24
f2 = reinplace(make_fx(f)(inpt), inpt)
25
expected_out = f(inpt)
27
self.assertEqual(actual_out, expected_out)
28
self.assertExpectedInline(f2.code, """\
32
def forward(self, x_1):
33
clone = torch.ops.aten.clone.default(x_1); x_1 = None
34
add = torch.ops.aten.add_.Tensor(clone, 1)
39
def test_reinplace_with_view(self):
43
# We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
45
# Second add() is fine to re-inplace
50
f2 = reinplace(make_fx(f)(inpt), inpt)
51
expected_out = f(inpt)
53
self.assertEqual(actual_out, expected_out)
54
self.assertExpectedInline(f2.code, """\
58
def forward(self, x_1):
59
clone = torch.ops.aten.clone.default(x_1); x_1 = None
60
view = torch.ops.aten.view.default(clone, [-1])
61
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
62
add_1 = torch.ops.aten.add_.Tensor(view, 1)
66
def test_reinplace_different_metadata(self):
70
# Naively, we shouldn't try to inplace the .ge() call,
71
# because that would require resizing "b" (from a float to a bool tensor).
75
f2 = reinplace(make_fx(f)(inpt), inpt)
76
expected_out = f(inpt)
78
self.assertEqual(actual_out, expected_out)
79
# The .ge() should not be reinplaced.
80
self.assertExpectedInline(f2.code, """\
84
def forward(self, a__1):
85
clone = torch.ops.aten.clone.default(a__1); a__1 = None
86
add = torch.ops.aten.add.Tensor(clone, 1)
87
ge = torch.ops.aten.ge.Tensor(add, clone); add = clone = None
91
def test_reinplace_overlapping_memory(self):
95
# Can't reinplace because b has overlapping memory.
99
f2 = reinplace(make_fx(f)(inpt), inpt)
100
expected_out = f(inpt)
101
actual_out = f2(inpt)
102
self.assertEqual(actual_out, expected_out)
103
self.assertExpectedInline(f2.code, """\
107
def forward(self, a__1):
108
clone = torch.ops.aten.clone.default(a__1); a__1 = None
109
expand = torch.ops.aten.expand.default(clone, [4, 4]); clone = None
110
add = torch.ops.aten.add.Tensor(expand, 1); expand = None
114
# This test won't actually run in CI, because it requires functionalize() from functorch.
115
# I'm planning on testing more comprehensively with torchbench models,
116
# but we can make this testing better once functorch moves into pytorch/pytorch.
117
def test_reinplace_scatter_op(self):
119
# for now, don't test mutations to inputs
128
if not HAS_FUNCTIONALIZATION:
131
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
132
expected_out = f(inpt)
133
actual_out = f2(inpt)
134
self.assertEqual(actual_out, expected_out)
135
# NOTE: one slight pessimization here is the fact that
136
# there are a bunch of redundant views in the graph.
137
# Technically, half of these views are duplicates that we could de-dup.
138
# This shouldn't really hurt performance though, since creating an extra view
139
# is effectively just moving some metadata around (and allocating a new TensorImpl).
140
# We can/should update the pass in the future to clean this up.
141
self.assertExpectedInline(f2.code, """\
145
def forward(self, a__1):
146
clone = torch.ops.aten.clone.default(a__1); a__1 = None
147
view = torch.ops.aten.view.default(clone, [-1])
148
view_1 = torch.ops.aten.view.default(clone, [-1])
149
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
150
view_2 = torch.ops.aten.view.default(select, [-1]); select = None
151
add = torch.ops.aten.add_.Tensor(view_2, 1)
152
view_3 = torch.ops.aten.view.default(clone, [-1]); clone = None
153
select_1 = torch.ops.aten.select.int(view_3, 0, 0)
154
view_4 = torch.ops.aten.view.default(view_2, []); view_2 = None
155
view_5 = torch.ops.aten.view.default(view_3, [4]); view_3 = None
156
view_6 = torch.ops.aten.view.default(view_5, [-1])
157
select_2 = torch.ops.aten.select.int(view_6, 0, 0); view_6 = None
158
view_7 = torch.ops.aten.view.default(select_2, [-1]); select_2 = None
159
view_8 = torch.ops.aten.view.default(view_5, [-1])
160
add_1 = torch.ops.aten.add_.Tensor(view_5, view_8); view_8 = None
164
def test_reinplace_scatter_twice(self):
166
# for now, don't test mutations to inputs
173
if not HAS_FUNCTIONALIZATION:
176
inpt = torch.ones(4, 4)
177
f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
178
expected_out = f(inpt)
179
actual_out = f2(inpt)
180
self.assertEqual(actual_out, expected_out)
181
self.assertExpectedInline(f2.code, """\
185
def forward(self, a__1):
186
clone = torch.ops.aten.clone.default(a__1); a__1 = None
187
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
188
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
189
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
190
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
191
slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
192
select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = None
193
slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
194
select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None
195
select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = None
199
def test_reinplace_scatter_twice_with_different_view_op_valid(self):
205
good_mirror_of_b = a.as_strided((4,), (4,), 1)
206
# good_mirror_of_b points to the same region of memory as b.
207
# and this scatter op below tries to scatter c_updated into the same region
208
# that c currently takes up.
209
# reinplacing logic checks this by confirming that:
211
# good_mirror_of_b.select(0, 1)
212
# have the same size/stride/storage_offset.
213
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
216
inpt = torch.ones(4, 4)
217
f2 = reinplace(make_fx(f)(inpt), inpt)
218
expected_out = f(inpt)
219
actual_out = f2(inpt)
220
self.assertEqual(actual_out, expected_out)
221
self.assertExpectedInline(f2.code, """\
225
def forward(self, a__1):
226
clone = torch.ops.aten.clone.default(a__1); a__1 = None
227
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
228
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
229
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
230
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
231
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
235
# Test example where we have a scatter op, where the base tensor
236
# has the same size/stride/storage offset (even though it is a different view),
237
# making it valid to re-inplace
238
def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
244
good_mirror_of_b = a.as_strided((4,), (4,), 1)
245
# The first arg to select_scatter is an equivalent view to b.
246
# However, the select_scatter call below tries to put c_updated
247
# into a different slice of "b" than what "c" currently occupies.
249
b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
252
inpt = torch.ones(4, 4)
253
f2 = reinplace(make_fx(f)(inpt), inpt)
254
expected_out = f(inpt)
255
actual_out = f2(inpt)
256
self.assertEqual(actual_out, expected_out)
257
self.assertExpectedInline(f2.code, """\
261
def forward(self, a__1):
262
clone = torch.ops.aten.clone.default(a__1); a__1 = None
263
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
264
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
265
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
266
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
267
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
268
select_int = torch.ops.aten.select.int(as_strided, 0, 0)
269
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None
273
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
279
bad_mirror_of_b = a.as_strided((4,), (4,), 0)
280
# The first arg to select_scatter points to a different than c's base.
281
# This makes it invalid to re-inplace.
282
b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
285
inpt = torch.ones(4, 4)
286
f2 = reinplace(make_fx(f)(inpt), inpt)
287
expected_out = f(inpt)
288
actual_out = f2(inpt)
289
# self.assertEqual(actual_out, expected_out)
290
self.assertExpectedInline(f2.code, """\
294
def forward(self, a__1):
295
clone = torch.ops.aten.clone.default(a__1); a__1 = None
296
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
297
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
298
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
299
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
300
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None
301
select_int = torch.ops.aten.select.int(as_strided, 0, 1)
302
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None
307
def test_out_node_updated(self):
309
x = torch.zeros(2, 2)
312
z = torch.diagonal_scatter(x, y_updated)
313
# reinplace needs to know to replace output [z] with [x]
316
if not HAS_FUNCTIONALIZATION:
318
f2 = reinplace(make_fx(functionalize(f))())
321
self.assertEqual(actual_out, expected_out)
322
self.assertExpectedInline(f2.code, """\
327
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
328
diagonal = torch.ops.aten.diagonal.default(zeros)
329
add = torch.ops.aten.add_.Tensor(diagonal, 1); diagonal = None
333
def test_reinplace_index_mutation(self):
335
a = torch.zeros(4, 4, 4)
336
a[:, 2:] = torch.ones(4, 2, 4)
339
if not HAS_FUNCTIONALIZATION:
341
f2 = reinplace(make_fx(functionalize(f))())
344
self.assertEqual(actual_out, expected_out)
345
self.assertExpectedInline(f2.code, """\
350
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
351
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
352
slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
353
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
354
copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = None
355
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
356
slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
357
slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = None
361
if __name__ == '__main__':