pytorch

Форк
0
/
test_fx_reinplace_pass.py 
362 строки · 13.4 Кб
1
# Owner(s): ["module: functionalization"]
2
import torch
3
from torch.testing._internal.common_utils import TestCase, run_tests
4
from torch.fx.passes.reinplace import reinplace
5
from torch.fx.experimental.proxy_tensor import make_fx
6

7
try:
8
    from functorch.experimental import functionalize
9
    HAS_FUNCTIONALIZATION = True
10
except Exception as e:
11
    HAS_FUNCTIONALIZATION = False
12

13
class TestReinplacePass(TestCase):
14

15
    def test_reinplace_basic(self):
16
        # Basic test: the out-of-place add() call should be converted
17
        # into add_()
18
        def f(x):
19
            a = x.clone()
20
            b = a.add(1)
21
            return b
22

23
        inpt = torch.ones(2)
24
        f2 = reinplace(make_fx(f)(inpt), inpt)
25
        expected_out = f(inpt)
26
        actual_out = f2(inpt)
27
        self.assertEqual(actual_out, expected_out)
28
        self.assertExpectedInline(f2.code, """\
29

30

31

32
def forward(self, x_1):
33
    clone = torch.ops.aten.clone.default(x_1);  x_1 = None
34
    add = torch.ops.aten.add_.Tensor(clone, 1)
35
    return clone
36
    """)
37

38

39
    def test_reinplace_with_view(self):
40
        def f(x):
41
            a = x.clone()
42
            a_view = a.view(-1)
43
            # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
44
            b = a.add(1)
45
            # Second add() is fine to re-inplace
46
            c = a_view.add(1)
47
            return c
48

49
        inpt = torch.ones(2)
50
        f2 = reinplace(make_fx(f)(inpt), inpt)
51
        expected_out = f(inpt)
52
        actual_out = f2(inpt)
53
        self.assertEqual(actual_out, expected_out)
54
        self.assertExpectedInline(f2.code, """\
55

56

57

58
def forward(self, x_1):
59
    clone = torch.ops.aten.clone.default(x_1);  x_1 = None
60
    view = torch.ops.aten.view.default(clone, [-1])
61
    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
62
    add_1 = torch.ops.aten.add_.Tensor(view, 1)
63
    return view
64
    """)
65

66
    def test_reinplace_different_metadata(self):
67
        def f(a_):
68
            a = a_.clone()
69
            b = a + 1
70
            # Naively, we shouldn't try to inplace the .ge() call,
71
            # because that would require resizing "b" (from a float to a bool tensor).
72
            c = torch.ge(b, a)
73
            return c
74
        inpt = torch.ones(4)
75
        f2 = reinplace(make_fx(f)(inpt), inpt)
76
        expected_out = f(inpt)
77
        actual_out = f2(inpt)
78
        self.assertEqual(actual_out, expected_out)
79
        # The .ge() should not be reinplaced.
80
        self.assertExpectedInline(f2.code, """\
81

82

83

84
def forward(self, a__1):
85
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
86
    add = torch.ops.aten.add.Tensor(clone, 1)
87
    ge = torch.ops.aten.ge.Tensor(add, clone);  add = clone = None
88
    return ge
89
    """)
90

91
    def test_reinplace_overlapping_memory(self):
92
        def f(a_):
93
            a = a_.clone()
94
            b = a.expand(4, 4)
95
            # Can't reinplace because b has overlapping memory.
96
            c = b.add(1)
97
            return c
98
        inpt = torch.ones(1)
99
        f2 = reinplace(make_fx(f)(inpt), inpt)
100
        expected_out = f(inpt)
101
        actual_out = f2(inpt)
102
        self.assertEqual(actual_out, expected_out)
103
        self.assertExpectedInline(f2.code, """\
104

105

106

107
def forward(self, a__1):
108
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
109
    expand = torch.ops.aten.expand.default(clone, [4, 4]);  clone = None
110
    add = torch.ops.aten.add.Tensor(expand, 1);  expand = None
111
    return add
112
    """)
113

114
    # This test won't actually run in CI, because it requires functionalize() from functorch.
115
    # I'm planning on testing more comprehensively with torchbench models,
116
    # but we can make this testing better once functorch moves into pytorch/pytorch.
117
    def test_reinplace_scatter_op(self):
118
        def f(a_):
119
            # for now, don't test mutations to inputs
120
            a = a_.clone()
121
            e = a.view(-1)
122
            b = a.view(-1)
123
            c = b[0]
124
            d = c.view(-1)
125
            d.add_(1)
126
            return a + e
127

128
        if not HAS_FUNCTIONALIZATION:
129
            return
130
        inpt = torch.ones(4)
131
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
132
        expected_out = f(inpt)
133
        actual_out = f2(inpt)
134
        self.assertEqual(actual_out, expected_out)
135
        # NOTE: one slight pessimization here is the fact that
136
        # there are a bunch of redundant views in the graph.
137
        # Technically, half of these views are duplicates that we could de-dup.
138
        # This shouldn't really hurt performance though, since creating an extra view
139
        # is effectively just moving some metadata around (and allocating a new TensorImpl).
140
        # We can/should update the pass in the future to clean this up.
141
        self.assertExpectedInline(f2.code, """\
142

143

144

145
def forward(self, a__1):
146
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
147
    view = torch.ops.aten.view.default(clone, [-1])
148
    view_1 = torch.ops.aten.view.default(clone, [-1])
149
    select = torch.ops.aten.select.int(view_1, 0, 0);  view_1 = None
150
    view_2 = torch.ops.aten.view.default(select, [-1]);  select = None
151
    add = torch.ops.aten.add_.Tensor(view_2, 1)
152
    view_3 = torch.ops.aten.view.default(clone, [-1]);  clone = None
153
    select_1 = torch.ops.aten.select.int(view_3, 0, 0)
154
    view_4 = torch.ops.aten.view.default(view_2, []);  view_2 = None
155
    view_5 = torch.ops.aten.view.default(view_3, [4]);  view_3 = None
156
    view_6 = torch.ops.aten.view.default(view_5, [-1])
157
    select_2 = torch.ops.aten.select.int(view_6, 0, 0);  view_6 = None
158
    view_7 = torch.ops.aten.view.default(select_2, [-1]);  select_2 = None
159
    view_8 = torch.ops.aten.view.default(view_5, [-1])
160
    add_1 = torch.ops.aten.add_.Tensor(view_5, view_8);  view_8 = None
161
    return view_5
162
    """)
163

164
    def test_reinplace_scatter_twice(self):
165
        def f(a_):
166
            # for now, don't test mutations to inputs
167
            a = a_.clone()
168
            b = a[:, 1]
169
            c = b[1]
170
            c.add_(1)
171
            return a
172

173
        if not HAS_FUNCTIONALIZATION:
174
            return
175

176
        inpt = torch.ones(4, 4)
177
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
178
        expected_out = f(inpt)
179
        actual_out = f2(inpt)
180
        self.assertEqual(actual_out, expected_out)
181
        self.assertExpectedInline(f2.code, """\
182

183

184

185
def forward(self, a__1):
186
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
187
    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
188
    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
189
    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
190
    add = torch.ops.aten.add_.Tensor(select_1, 1);  select_1 = None
191
    slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
192
    select_2 = torch.ops.aten.select.int(slice_2, 1, 1);  slice_2 = None
193
    slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
194
    select_3 = torch.ops.aten.select.int(slice_3, 1, 1);  slice_3 = None
195
    select_4 = torch.ops.aten.select.int(select_3, 0, 1);  select_3 = None
196
    return clone
197
    """)
198

199
    def test_reinplace_scatter_twice_with_different_view_op_valid(self):
200
        def f(a_):
201
            a = a_.clone()
202
            b = a[:, 1]
203
            c = b[1]
204
            c_updated = c.add(1)
205
            good_mirror_of_b = a.as_strided((4,), (4,), 1)
206
            # good_mirror_of_b points to the same region of memory as b.
207
            # and this scatter op below tries to scatter c_updated into the same region
208
            # that c currently takes up.
209
            # reinplacing logic checks this by confirming that:
210
            #   c_updated
211
            #   good_mirror_of_b.select(0, 1)
212
            # have the same size/stride/storage_offset.
213
            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
214
            return b_updated
215

216
        inpt = torch.ones(4, 4)
217
        f2 = reinplace(make_fx(f)(inpt), inpt)
218
        expected_out = f(inpt)
219
        actual_out = f2(inpt)
220
        self.assertEqual(actual_out, expected_out)
221
        self.assertExpectedInline(f2.code, """\
222

223

224

225
def forward(self, a__1):
226
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
227
    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
228
    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
229
    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
230
    add = torch.ops.aten.add_.Tensor(select_1, 1);  select_1 = None
231
    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1);  clone = None
232
    return as_strided
233
    """)
234

235
    # Test example where we have a scatter op, where the base tensor
236
    # has the same size/stride/storage offset (even though it is a different view),
237
    # making it valid to re-inplace
238
    def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
239
        def f(a_):
240
            a = a_.clone()
241
            b = a[:, 1]
242
            c = b[1]
243
            c_updated = c.add(1)
244
            good_mirror_of_b = a.as_strided((4,), (4,), 1)
245
            # The first arg to select_scatter is an equivalent view to b.
246
            # However, the select_scatter call below tries to put c_updated
247
            # into a different slice of "b" than what "c" currently occupies.
248
            #
249
            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
250
            return b_updated
251

252
        inpt = torch.ones(4, 4)
253
        f2 = reinplace(make_fx(f)(inpt), inpt)
254
        expected_out = f(inpt)
255
        actual_out = f2(inpt)
256
        self.assertEqual(actual_out, expected_out)
257
        self.assertExpectedInline(f2.code, """\
258

259

260

261
def forward(self, a__1):
262
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
263
    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
264
    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
265
    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
266
    add = torch.ops.aten.add.Tensor(select_1, 1);  select_1 = None
267
    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1);  clone = None
268
    select_int = torch.ops.aten.select.int(as_strided, 0, 0)
269
    copy__default = torch.ops.aten.copy_.default(select_int, add);  select_int = add = None
270
    return as_strided
271
    """)  # noqa: B950
272

273
    def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
274
        def f(a_):
275
            a = a_.clone()
276
            b = a[:, 1]
277
            c = b[1]
278
            c_updated = c.add(1)
279
            bad_mirror_of_b = a.as_strided((4,), (4,), 0)
280
            # The first arg to select_scatter points to a different than c's base.
281
            # This makes it invalid to re-inplace.
282
            b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
283
            return b_updated
284

285
        inpt = torch.ones(4, 4)
286
        f2 = reinplace(make_fx(f)(inpt), inpt)
287
        expected_out = f(inpt)
288
        actual_out = f2(inpt)
289
        # self.assertEqual(actual_out, expected_out)
290
        self.assertExpectedInline(f2.code, """\
291

292

293

294
def forward(self, a__1):
295
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
296
    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
297
    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
298
    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
299
    add = torch.ops.aten.add.Tensor(select_1, 1);  select_1 = None
300
    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0);  clone = None
301
    select_int = torch.ops.aten.select.int(as_strided, 0, 1)
302
    copy__default = torch.ops.aten.copy_.default(select_int, add);  select_int = add = None
303
    return as_strided
304
    """)  # noqa: B950
305

306

307
    def test_out_node_updated(self):
308
        def f():
309
            x = torch.zeros(2, 2)
310
            y = x.diagonal()
311
            y_updated = y.add(1)
312
            z = torch.diagonal_scatter(x, y_updated)
313
            # reinplace needs to know to replace output [z] with [x]
314
            return [z]
315

316
        if not HAS_FUNCTIONALIZATION:
317
            return
318
        f2 = reinplace(make_fx(functionalize(f))())
319
        expected_out = f()
320
        actual_out = f2()
321
        self.assertEqual(actual_out, expected_out)
322
        self.assertExpectedInline(f2.code, """\
323

324

325

326
def forward(self):
327
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
328
    diagonal = torch.ops.aten.diagonal.default(zeros)
329
    add = torch.ops.aten.add_.Tensor(diagonal, 1);  diagonal = None
330
    return [zeros]
331
    """)
332

333
    def test_reinplace_index_mutation(self):
334
        def f():
335
            a = torch.zeros(4, 4, 4)
336
            a[:, 2:] = torch.ones(4, 2, 4)
337
            return a
338

339
        if not HAS_FUNCTIONALIZATION:
340
            return
341
        f2 = reinplace(make_fx(functionalize(f))())
342
        expected_out = f()
343
        actual_out = f2()
344
        self.assertEqual(actual_out, expected_out)
345
        self.assertExpectedInline(f2.code, """\
346

347

348

349
def forward(self):
350
    zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
351
    ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
352
    slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
353
    slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807);  slice_1 = None
354
    copy = torch.ops.aten.copy_.default(slice_2, ones);  slice_2 = ones = None
355
    slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
356
    slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
357
    slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807);  slice_4 = None
358
    return zeros
359
    """)
360

361
if __name__ == '__main__':
362
    run_tests()
363

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.