pytorch

Форк
0
/
test_fx_reinplace_pass.py 
401 строка · 14.9 Кб
1
# Owner(s): ["module: functionalization"]
2
import torch
3
from torch.testing._internal.common_utils import TestCase, run_tests
4
from torch.fx.passes.reinplace import reinplace
5
from torch.fx.experimental.proxy_tensor import make_fx
6
from torch.fx.experimental.symbolic_shapes import ShapeEnv
7
from torch._dynamo.source import ConstantSource
8
from torch.fx.experimental.sym_node import SymNode
9

10
try:
11
    from functorch.experimental import functionalize
12
    HAS_FUNCTIONALIZATION = True
13
except Exception as e:
14
    HAS_FUNCTIONALIZATION = False
15

16
class TestReinplacePass(TestCase):
17

18
    def test_reinplace_basic(self):
19
        # Basic test: the out-of-place add() call should be converted
20
        # into add_()
21
        def f(x):
22
            a = x.clone()
23
            b = a.add(1)
24
            return b
25

26
        inpt = torch.ones(2)
27
        f2 = reinplace(make_fx(f)(inpt), inpt)
28
        expected_out = f(inpt)
29
        actual_out = f2(inpt)
30
        self.assertEqual(actual_out, expected_out)
31
        self.assertExpectedInline(f2.code, """\
32

33

34

35
def forward(self, x_1):
36
    clone = torch.ops.aten.clone.default(x_1);  x_1 = None
37
    add = torch.ops.aten.add_.Tensor(clone, 1);  add = None
38
    return clone
39
    """)
40

41

42
    def test_reinplace_with_view(self):
43
        def f(x):
44
            a = x.clone()
45
            a_view = a.view(-1)
46
            # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program
47
            b = a.add(1)
48
            # Second add() is fine to re-inplace
49
            c = a_view.add(1)
50
            return c
51

52
        inpt = torch.ones(2)
53
        f2 = reinplace(make_fx(f)(inpt), inpt)
54
        expected_out = f(inpt)
55
        actual_out = f2(inpt)
56
        self.assertEqual(actual_out, expected_out)
57
        self.assertExpectedInline(f2.code, """\
58

59

60

61
def forward(self, x_1):
62
    clone = torch.ops.aten.clone.default(x_1);  x_1 = None
63
    view = torch.ops.aten.view.default(clone, [-1])
64
    add = torch.ops.aten.add.Tensor(clone, 1);  clone = add = None
65
    add_1 = torch.ops.aten.add_.Tensor(view, 1);  add_1 = None
66
    return view
67
    """)
68

69
    def test_reinplace_different_metadata(self):
70
        def f(a_):
71
            a = a_.clone()
72
            b = a + 1
73
            # Naively, we shouldn't try to inplace the .ge() call,
74
            # because that would require resizing "b" (from a float to a bool tensor).
75
            c = torch.ge(b, a)
76
            return c
77
        inpt = torch.ones(4)
78
        f2 = reinplace(make_fx(f)(inpt), inpt)
79
        expected_out = f(inpt)
80
        actual_out = f2(inpt)
81
        self.assertEqual(actual_out, expected_out)
82
        # The .ge() should not be reinplaced.
83
        self.assertExpectedInline(f2.code, """\
84

85

86

87
def forward(self, a__1):
88
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
89
    add = torch.ops.aten.add.Tensor(clone, 1)
90
    ge = torch.ops.aten.ge.Tensor(add, clone);  add = clone = None
91
    return ge
92
    """)
93

94
    def test_reinplace_overlapping_memory(self):
95
        def f(a_):
96
            a = a_.clone()
97
            b = a.expand(4, 4)
98
            # Can't reinplace because b has overlapping memory.
99
            c = b.add(1)
100
            return c
101
        inpt = torch.ones(1)
102
        f2 = reinplace(make_fx(f)(inpt), inpt)
103
        expected_out = f(inpt)
104
        actual_out = f2(inpt)
105
        self.assertEqual(actual_out, expected_out)
106
        self.assertExpectedInline(f2.code, """\
107

108

109

110
def forward(self, a__1):
111
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
112
    expand = torch.ops.aten.expand.default(clone, [4, 4]);  clone = None
113
    add = torch.ops.aten.add.Tensor(expand, 1);  expand = None
114
    return add
115
    """)
116

117
    # This test won't actually run in CI, because it requires functionalize() from functorch.
118
    # I'm planning on testing more comprehensively with torchbench models,
119
    # but we can make this testing better once functorch moves into pytorch/pytorch.
120
    def test_reinplace_scatter_op(self):
121
        def f(a_):
122
            # for now, don't test mutations to inputs
123
            a = a_.clone()
124
            e = a.view(-1)
125
            b = a.view(-1)
126
            c = b[0]
127
            d = c.view(-1)
128
            d.add_(1)
129
            return a + e
130

131
        if not HAS_FUNCTIONALIZATION:
132
            return
133
        inpt = torch.ones(4)
134
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
135
        expected_out = f(inpt)
136
        actual_out = f2(inpt)
137
        self.assertEqual(actual_out, expected_out)
138
        # NOTE: one slight pessimization here is the fact that
139
        # there are a bunch of redundant views in the graph.
140
        # Technically, half of these views are duplicates that we could de-dup.
141
        # This shouldn't really hurt performance though, since creating an extra view
142
        # is effectively just moving some metadata around (and allocating a new TensorImpl).
143
        # We can/should update the pass in the future to clean this up.
144
        self.assertExpectedInline(f2.code, """\
145

146

147

148
def forward(self, a__1):
149
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
150
    view = torch.ops.aten.view.default(clone, [-1]);  view = None
151
    view_1 = torch.ops.aten.view.default(clone, [-1])
152
    select = torch.ops.aten.select.int(view_1, 0, 0);  view_1 = None
153
    view_2 = torch.ops.aten.view.default(select, [-1]);  select = None
154
    add = torch.ops.aten.add_.Tensor(view_2, 1);  add = None
155
    view_3 = torch.ops.aten.view.default(clone, [-1]);  clone = None
156
    select_1 = torch.ops.aten.select.int(view_3, 0, 0);  select_1 = None
157
    view_4 = torch.ops.aten.view.default(view_2, []);  view_2 = view_4 = None
158
    view_5 = torch.ops.aten.view.default(view_3, [4]);  view_3 = None
159
    view_6 = torch.ops.aten.view.default(view_5, [-1])
160
    select_2 = torch.ops.aten.select.int(view_6, 0, 0);  view_6 = None
161
    view_7 = torch.ops.aten.view.default(select_2, [-1]);  select_2 = view_7 = None
162
    view_8 = torch.ops.aten.view.default(view_5, [-1])
163
    add_1 = torch.ops.aten.add_.Tensor(view_5, view_8);  view_8 = add_1 = None
164
    return view_5
165
    """)
166

167
    def test_reinplace_scatter_twice(self):
168
        def f(a_):
169
            # for now, don't test mutations to inputs
170
            a = a_.clone()
171
            b = a[:, 1]
172
            c = b[1]
173
            c.add_(1)
174
            return a
175

176
        if not HAS_FUNCTIONALIZATION:
177
            return
178

179
        inpt = torch.ones(4, 4)
180
        f2 = reinplace(make_fx(functionalize(f))(inpt), inpt)
181
        expected_out = f(inpt)
182
        actual_out = f2(inpt)
183
        self.assertEqual(actual_out, expected_out)
184
        self.assertExpectedInline(f2.code, """\
185

186

187

188
def forward(self, a__1):
189
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
190
    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
191
    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
192
    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
193
    add = torch.ops.aten.add_.Tensor(select_1, 1);  select_1 = add = None
194
    slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
195
    select_2 = torch.ops.aten.select.int(slice_2, 1, 1);  slice_2 = select_2 = None
196
    slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
197
    select_3 = torch.ops.aten.select.int(slice_3, 1, 1);  slice_3 = None
198
    select_4 = torch.ops.aten.select.int(select_3, 0, 1);  select_3 = select_4 = None
199
    return clone
200
    """)
201

202
    def test_reinplace_scatter_twice_with_different_view_op_valid(self):
203
        def f(a_):
204
            a = a_.clone()
205
            b = a[:, 1]
206
            c = b[1]
207
            c_updated = c.add(1)
208
            good_mirror_of_b = a.as_strided((4,), (4,), 1)
209
            # good_mirror_of_b points to the same region of memory as b.
210
            # and this scatter op below tries to scatter c_updated into the same region
211
            # that c currently takes up.
212
            # reinplacing logic checks this by confirming that:
213
            #   c_updated
214
            #   good_mirror_of_b.select(0, 1)
215
            # have the same size/stride/storage_offset.
216
            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 1)
217
            return b_updated
218

219
        inpt = torch.ones(4, 4)
220
        f2 = reinplace(make_fx(f)(inpt), inpt)
221
        expected_out = f(inpt)
222
        actual_out = f2(inpt)
223
        self.assertEqual(actual_out, expected_out)
224
        self.assertExpectedInline(f2.code, """\
225

226

227

228
def forward(self, a__1):
229
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
230
    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
231
    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
232
    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
233
    add = torch.ops.aten.add_.Tensor(select_1, 1);  select_1 = add = None
234
    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1);  clone = None
235
    return as_strided
236
    """)
237

238
    # Test example where we have a scatter op, where the base tensor
239
    # has the same size/stride/storage offset (even though it is a different view),
240
    # making it valid to re-inplace
241
    def test_reinplace_scatter_twice_with_different_view_op_invalid(self):
242
        def f(a_):
243
            a = a_.clone()
244
            b = a[:, 1]
245
            c = b[1]
246
            c_updated = c.add(1)
247
            good_mirror_of_b = a.as_strided((4,), (4,), 1)
248
            # The first arg to select_scatter is an equivalent view to b.
249
            # However, the select_scatter call below tries to put c_updated
250
            # into a different slice of "b" than what "c" currently occupies.
251
            #
252
            b_updated = torch.select_scatter(good_mirror_of_b, c_updated, 0, 0)
253
            return b_updated
254

255
        inpt = torch.ones(4, 4)
256
        f2 = reinplace(make_fx(f)(inpt), inpt)
257
        expected_out = f(inpt)
258
        actual_out = f2(inpt)
259
        self.assertEqual(actual_out, expected_out)
260
        self.assertExpectedInline(f2.code, """\
261

262

263

264
def forward(self, a__1):
265
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
266
    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
267
    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
268
    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
269
    add = torch.ops.aten.add.Tensor(select_1, 1);  select_1 = None
270
    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1);  clone = None
271
    select_int = torch.ops.aten.select.int(as_strided, 0, 0)
272
    copy__default = torch.ops.aten.copy_.default(select_int, add);  select_int = add = copy__default = None
273
    return as_strided
274
    """)  # noqa: B950
275

276
    def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
277
        def f(a_):
278
            a = a_.clone()
279
            b = a[:, 1]
280
            c = b[1]
281
            c_updated = c.add(1)
282
            bad_mirror_of_b = a.as_strided((4,), (4,), 0)
283
            # The first arg to select_scatter points to a different than c's base.
284
            # This makes it invalid to re-inplace.
285
            b_updated = torch.select_scatter(bad_mirror_of_b, c_updated, 0, 1)
286
            return b_updated
287

288
        inpt = torch.ones(4, 4)
289
        f2 = reinplace(make_fx(f)(inpt), inpt)
290
        expected_out = f(inpt)
291
        actual_out = f2(inpt)
292
        # self.assertEqual(actual_out, expected_out)
293
        self.assertExpectedInline(f2.code, """\
294

295

296

297
def forward(self, a__1):
298
    clone = torch.ops.aten.clone.default(a__1);  a__1 = None
299
    slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
300
    select = torch.ops.aten.select.int(slice_1, 1, 1);  slice_1 = None
301
    select_1 = torch.ops.aten.select.int(select, 0, 1);  select = None
302
    add = torch.ops.aten.add.Tensor(select_1, 1);  select_1 = None
303
    as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0);  clone = None
304
    select_int = torch.ops.aten.select.int(as_strided, 0, 1)
305
    copy__default = torch.ops.aten.copy_.default(select_int, add);  select_int = add = copy__default = None
306
    return as_strided
307
    """)  # noqa: B950
308

309

310
    def test_out_node_updated(self):
311
        def f():
312
            x = torch.zeros(2, 2)
313
            y = x.diagonal()
314
            y_updated = y.add(1)
315
            z = torch.diagonal_scatter(x, y_updated)
316
            # reinplace needs to know to replace output [z] with [x]
317
            return [z]
318

319
        if not HAS_FUNCTIONALIZATION:
320
            return
321
        f2 = reinplace(make_fx(functionalize(f))())
322
        expected_out = f()
323
        actual_out = f2()
324
        self.assertEqual(actual_out, expected_out)
325
        self.assertExpectedInline(f2.code, """\
326

327

328

329
def forward(self):
330
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
331
    diagonal = torch.ops.aten.diagonal.default(zeros)
332
    add = torch.ops.aten.add_.Tensor(diagonal, 1);  diagonal = add = None
333
    return [zeros]
334
    """)
335

336
    def test_reinplace_index_mutation(self):
337
        def f():
338
            a = torch.zeros(4, 4, 4)
339
            a[:, 2:] = torch.ones(4, 2, 4)
340
            return a
341

342
        if not HAS_FUNCTIONALIZATION:
343
            return
344
        f2 = reinplace(make_fx(functionalize(f))())
345
        expected_out = f()
346
        actual_out = f2()
347
        self.assertEqual(actual_out, expected_out)
348
        self.assertExpectedInline(f2.code, """\
349

350

351

352
def forward(self):
353
    zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
354
    ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
355
    slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
356
    slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807);  slice_1 = None
357
    copy = torch.ops.aten.copy_.default(slice_2, ones);  slice_2 = ones = copy = None
358
    slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807);  slice_3 = None
359
    slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
360
    slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807);  slice_4 = slice_5 = None
361
    return zeros
362
    """)
363

364
    def test_reinplace_sym_input(self):
365
        # Symbolic input test: the out-of-place add() call should be converted
366
        # into add_(), and symbolic input won't cause any error.
367
        def f(x, index):
368
            a = torch.select(x, 0, index)
369
            clone = a.clone()
370
            b = clone.add(1)
371
            return b
372

373
        x = torch.randn((4, 8, 16, 16), requires_grad=False)
374
        index = 2
375
        shape_env = ShapeEnv()
376
        symbol = shape_env.create_symbol(index, source=ConstantSource(
377
            f"__testing_only{len(shape_env.var_to_val)}"))
378
        sym_index = torch.SymInt(SymNode(symbol, shape_env, int, hint=index))
379

380
        inpt = [x, sym_index]
381
        f2 = reinplace(make_fx(f)(*inpt), *inpt)
382

383
        real_inpt = [x, index]
384
        expected_out = f(*real_inpt)
385
        actual_out = f2(*real_inpt)
386
        self.assertEqual(actual_out, expected_out)
387
        print(f2.code)
388
        self.assertExpectedInline(f2.code, """\
389

390

391

392
def forward(self, x_1, index_1):
393
    select = torch.ops.aten.select.int(x_1, 0, index_1);  x_1 = index_1 = None
394
    clone = torch.ops.aten.clone.default(select);  select = None
395
    add = torch.ops.aten.add_.Tensor(clone, 1);  add = None
396
    return clone
397
    """)
398

399

400
if __name__ == '__main__':
401
    run_tests()
402

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.