pytorch

Форк
0
/
test_functionalization.py 
1938 строк · 88.5 Кб
1
# Owner(s): ["module: codegen"]
2

3
import torch
4
from contextlib import nullcontext
5
from torch.testing._internal.common_utils import (
6
    TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, IS_WINDOWS,
7
    xfail_inherited_tests
8
)
9
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode, dispatch_functionalize
10
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
11
from torch.utils._pytree import tree_map_only
12
from torch.fx.experimental.proxy_tensor import make_fx
13
from torch.fx.passes.reinplace import reinplace
14
from torch._dispatch.python import enable_crossref_functionalize, enable_python_dispatcher
15
from torch.multiprocessing.reductions import StorageWeakRef
16
from torch.utils import _pytree as pytree
17

18
import unittest
19

20
def are_aliased(x, y):
21
    x_storage = StorageWeakRef(x.storage())
22
    y_storage = StorageWeakRef(y.storage())
23
    return x_storage == y_storage
24

25
# We can unify testing and use functionalize() here instead
26
# if/when functorch moves into core.
27
# This is basically a crappy version of `functionalize()`.
28
def _functionalize(f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False):
29
    def to_fun(t: torch.Tensor):
30
        func_t = torch._to_functional_tensor(t)
31
        func_t.requires_grad = t.requires_grad
32
        return func_t
33

34
    def wrapped(*inputs):
35
        ctx = nullcontext()
36
        if crossref:
37
            ctx = enable_crossref_functionalize()
38
        with ctx:
39
            inputs_functional = tree_map_only(torch.Tensor, to_fun, inputs)
40
            torch._enable_functionalization(reapply_views=reapply_views)
41
            try:
42
                out = f(*inputs_functional)
43
            finally:
44
                torch._disable_functionalization()
45
            flat_inputs = pytree.tree_leaves(inputs)
46
            flat_inputs_functional = pytree.tree_leaves(inputs_functional)
47

48
            for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
49
                torch._sync(input_functional)
50
                inpt_new = torch._from_functional_tensor(input_functional)
51
                if inpt_new is not inpt and not skip_input_mutations:
52
                    # Existing deficiency in functionalize():
53
                    # we don't correctly mutate input metadata (yet?)
54
                    if inpt_new.shape == inpt.shape:
55
                        inpt.copy_(inpt_new)
56
            tree_map_only(torch.Tensor, torch._sync, out)
57
            out_unwrapped = tree_map_only(torch.Tensor, torch._from_functional_tensor, out)
58
            return out_unwrapped
59

60
    return wrapped
61

62
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
63
class TestFunctionalization(TestCase):
64

65
    crossref = False
66

67
    def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False):
68
        inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts)
69
        traced_f = make_fx(_functionalize(func, reapply_views=reapply_views, crossref=self.crossref))(*inpts)
70
        if run_reinplace:
71
            traced_f = reinplace(traced_f, *inpts_clone)
72
        return traced_f.code
73

74
    def assert_functionalization(self, func, *inpts, reapply_views=False, mutated_input_metadata=False):
75
        clones1 = tree_map_only(torch.Tensor, torch.clone, inpts)
76
        clones2 = tree_map_only(torch.Tensor, torch.clone, inpts)
77
        clones3 = tree_map_only(torch.Tensor, torch.clone, inpts)
78

79
        # Compare outputs (and mutated inputs), with and without functionalization.
80
        out_ref = func(*inpts)
81
        out_functional = _functionalize(func, reapply_views=reapply_views, crossref=self.crossref)(*clones1)
82

83
        # The reinplacing pass is only valid to run with reapply_views=True.
84
        functional_func = make_fx(_functionalize(func, reapply_views=True, crossref=self.crossref))(*clones2)
85
        reinplace_func = reinplace(functional_func, *clones2)
86

87
        # NOTE: for now, need to pass in fresh inputs here, because make_fx
88
        # will directly mutate the inputs that you trace with.
89
        # Once this is fixed we can clean this up.
90
        out_reinplace = reinplace_func(*clones3)
91

92
        # functionalize() deficiency: input metadata mutations aren't propagated properly,
93
        # so we just need to skip checks here for the tests that exercise that.
94
        if not mutated_input_metadata:
95
            flat_inpts = pytree.tree_leaves(inpts)
96
            flat_clones1 = pytree.tree_leaves(clones1)
97
            flat_clones3 = pytree.tree_leaves(clones3)
98
            for inpt, input_clone, input_clone3 in zip(flat_inpts, flat_clones1, flat_clones3):
99
                self.assertEqual(inpt, input_clone)  # input mutations should still occur
100
                self.assertEqual(inpt, input_clone3)
101

102
        # Handle tests with multi-tensor outputs
103
        if isinstance(out_ref, tuple):
104
            out_refs, out_functionals, out_reinplaces = list(out_ref), list(out_functional), list(out_reinplace)
105
        else:
106
            out_refs, out_functionals, out_reinplaces = [out_ref], [out_functional], [out_reinplace]
107

108
        for out_ref_, out_functional_, out_reinplace_ in zip(out_refs, out_functionals, out_reinplaces):
109
            self.assertEqual(out_ref_, out_functional_)
110
            self.assertEqual(out_ref_, out_reinplace_)
111

112
    def test_save_for_backwards_segfault(self):
113
        inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
114
        inp.exp()
115

116
    def test_multiple_views_of_same_base(self):
117
        def f(x):
118
            y = x.view(-1)
119
            z = x.view(-1)
120
            x.add_(1)
121
            # y should have been updated.
122
            y2 = y + 1
123
            # z should have been updated too.
124
            z2 = z + 1
125
            return z2
126
        self.assert_functionalization(f, torch.ones(4))
127

128
    def test_freeze(self):
129
        def f(x):
130
            y = x.clone()
131
            z = y[0]
132
            torch._freeze_functional_tensor(y)
133
            x.add_(1)
134
            self.assertRaises(RuntimeError, lambda: y.add_(1))
135
            self.assertRaises(RuntimeError, lambda: z.add_(1))
136
            return z
137

138
        _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(3, 3))
139

140
    def test_copy_stride_mismatch(self):
141
        def f(x):
142
            y = torch.empty_strided((2, 2), (5, 1))
143
            y.copy_(x)
144
            return y
145

146
        r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
147
        self.assertEqual(r.stride(), (5, 1))
148

149
    def test_set_(self):
150
        def f(x):
151
            y = torch.ones(2)
152
            y.set_(x.storage())
153
            return y
154

155
        # We should probaby get the crossref test to work,
156
        # but fixing it for Storage() objects is annoying.
157
        r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
158
        self.assertEqual(str(r.device), 'cpu')
159

160
    def test_advanced_indexing(self):
161
        def f():
162
            x = torch.zeros(3, 3)
163
            idx = torch.tensor([0])
164
            val = torch.ones(3, 1)
165
            x[:, idx] = val
166
            return x
167

168
        self.assert_functionalization(f)
169

170
    def test_view_clone_view_inplace(self):
171
        def f(input):
172
            shape = [1, 1024, 128, 128]
173
            input_reshaped = input.view(shape)
174
            out = input_reshaped.clone()
175
            r = out.view(input.shape)
176
            r.relu_()
177
            return r
178

179
        def g(x):
180
            loss = f(x).sum()
181
            from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
182
            import torch.fx.traceback as fx_traceback
183
            setup_stacktrace_preservation_hooks([loss.grad_fn])
184
            with fx_traceback.preserve_node_meta():
185
                loss.backward()
186
            return x.grad
187

188
        with torch.autograd.detect_anomaly(check_nan=False):
189
            logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True))
190
        self.assertExpectedInline(logs, """\
191

192

193

194
def forward(self, arg0_1):
195
    view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 1024, 128, 128]);  arg0_1 = None
196
    clone = torch.ops.aten.clone.default(view_copy);  view_copy = None
197
    view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128])
198
    relu = torch.ops.aten.relu.default(view_copy_1);  view_copy_1 = None
199
    view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]);  relu = None
200
    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]);  view_copy_2 = None
201
    view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]);  clone = None
202
    sum_1 = torch.ops.aten.sum.default(view_copy_3)
203
    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format);  sum_1 = None
204
    expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]);  ones_like = None
205
    view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]);  expand_copy = None
206
    new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
207
    copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5);  new_empty_strided = view_copy_5 = None
208
    view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
209
    view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
210
    clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format)
211
    threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0);  clone_1 = view_copy_3 = None
212
    copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward);  view_copy_7 = threshold_backward = None
213
    view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]);  copy_1 = None
214
    view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128])
215
    view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]);  copy = None
216
    detach_copy = torch.ops.aten.detach_copy.default(view_copy_10);  view_copy_10 = None
217
    view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]);  view_copy_8 = None
218
    detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11);  view_copy_11 = None
219
    return detach_copy_1
220
    """)  # noqa: B950
221

222
    def test_simple(self):
223
        def f(x):
224
            # simple test: 1 view op, 1 inplace op
225
            tmp = torch.ones(4, 2)
226
            y = x.view(4, 2)
227
            y.add_(tmp)
228
            z = x * x
229
            return y
230
        self.assert_functionalization(f, torch.ones(4, 2))
231
        logs = self.get_logs(f, torch.ones(4, 2))
232
        self.assertExpectedInline(logs, """\
233

234

235

236
def forward(self, arg0_1):
237
    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
238
    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
239
    add = torch.ops.aten.add.Tensor(view_copy, ones);  view_copy = ones = None
240
    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
241
    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
242
    mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1)
243
    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = None
244
    return view_copy_2
245
    """)
246

247
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
248
        self.assertExpectedInline(reinplaced_logs, """\
249

250

251

252
def forward(self, arg0_1):
253
    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
254
    view = torch.ops.aten.view.default(arg0_1, [4, 2])
255
    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
256
    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
257
    view_2 = torch.ops.aten.view.default(view_1, [4, 2])
258
    mul = torch.ops.aten.mul.Tensor(view_1, view_1)
259
    copy_ = torch.ops.aten.copy_.default(arg0_1, view_1);  arg0_1 = view_1 = None
260
    return view_2
261
    """)
262

263
    def test_simple_out(self):
264
        def f(x):
265
            tmp = torch.ones(4, 2)
266
            y = x.view(4, 2)
267
            # the out= tensor will get resized, since it has size=0 to start.
268
            z = torch.empty(())
269
            torch.add(y, tmp, out=z)
270
            w = z * z
271
            return w
272
        self.assert_functionalization(f, torch.ones(4, 2))
273
        logs = self.get_logs(f, torch.ones(4, 2))
274
        self.assertExpectedInline(logs, """\
275

276

277

278
def forward(self, arg0_1):
279
    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
280
    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]);  arg0_1 = None
281
    empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
282
    add = torch.ops.aten.add.Tensor(view_copy, ones);  view_copy = ones = None
283
    mul = torch.ops.aten.mul.Tensor(add, add);  add = None
284
    return mul
285
    """)
286

287
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
288
        self.assertExpectedInline(reinplaced_logs, """\
289

290

291

292
def forward(self, arg0_1):
293
    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
294
    view = torch.ops.aten.view.default(arg0_1, [4, 2]);  arg0_1 = None
295
    empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
296
    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
297
    mul = torch.ops.aten.mul.Tensor(add, add);  add = None
298
    return mul
299
    """)
300

301
    def test_multi_out(self):
302
        def f(x):
303
            # aminmax.out returns a tuple of tensors.
304
            # functionalization should properly handle the tuple.
305
            out_min = torch.empty(4)
306
            out_max = torch.empty(4)
307
            torch.aminmax(x, dim=0, out=(out_max, out_min))
308
            return out_max
309
        self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
310
        logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
311
        self.assertExpectedInline(logs, """\
312

313

314

315
def forward(self, arg0_1):
316
    empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
317
    empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
318
    aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0);  arg0_1 = None
319
    getitem = aminmax[0]
320
    getitem_1 = aminmax[1];  aminmax = None
321
    return getitem
322
    """)
323

324
        reinplaced_logs = self.get_logs(f, torch.arange(8, dtype=torch.float32), reapply_views=True, run_reinplace=True)
325
        self.assertExpectedInline(reinplaced_logs, """\
326

327

328

329
def forward(self, arg0_1):
330
    empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
331
    empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
332
    aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0);  arg0_1 = None
333
    getitem = aminmax[0]
334
    getitem_1 = aminmax[1];  aminmax = None
335
    return getitem
336
    """)
337

338
    def test_tensor_ctr(self):
339
        def f(x):
340
            y = torch.tensor((1, 2, 3))
341
            z = y.view(-1)
342
            z.add_(1)
343
            return y
344

345
        inpt = torch.arange(3, dtype=torch.float32)
346
        self.assert_functionalization(f, inpt)
347

348
        logs = self.get_logs(f, inpt)
349
        self.assertExpectedInline(logs, """\
350

351

352

353
def forward(self, arg0_1):
354
    _tensor_constant0 = self._tensor_constant0
355
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
356
    view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]);  lift_fresh_copy = None
357
    add = torch.ops.aten.add.Tensor(view_copy, 1);  view_copy = None
358
    view_copy_1 = torch.ops.aten.view_copy.default(add, [3]);  add = None
359
    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1])
360
    return view_copy_1
361
    """)
362

363
        reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
364
        self.assertExpectedInline(reinplaced_logs, """\
365

366

367

368
def forward(self, arg0_1):
369
    _tensor_constant0 = self._tensor_constant0
370
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
371
    view = torch.ops.aten.view.default(lift_fresh_copy, [-1]);  lift_fresh_copy = None
372
    add = torch.ops.aten.add_.Tensor(view, 1)
373
    view_1 = torch.ops.aten.view.default(view, [3]);  view = None
374
    view_2 = torch.ops.aten.view.default(view_1, [-1])
375
    return view_1
376
    """)
377

378
    def test_advanced_indexing_correct_strides(self):
379
        def f(a):
380
            # This test requires that *_scatter ops are able to return
381
            # non-contiguous tensors.
382
            b = a.clone()[:, 1]
383
            c = torch.ones_like(b, dtype=torch.bool)
384
            d = b.masked_fill_(c, 0)
385
            return d
386
        self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True)
387

388
    def test_tensor_list_mixed_functional_nonfunctional(self):
389
        nonfunctional_tensor = torch.ones(2, dtype=torch.long)
390

391
        def f(x):
392
            # simple test: 1 view op, 1 inplace op
393
            functional_tensor = torch.ones(2, dtype=torch.long)
394
            out = x[functional_tensor, nonfunctional_tensor]
395
            return out
396
        out = f(torch.ones(2, 2))
397
        out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
398
        self.assertEqual(out, out_functional)
399

400
    def test_inplace_on_non_view(self):
401
        def f(x):
402
            # test for the case where we functionalize an inplace op on the other tensor - not a view.
403
            # This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased.
404
            tmp = torch.ones(4, 2)
405
            y = x.view(4, 2)
406
            x.add_(tmp)
407
            return y
408
        self.assert_functionalization(f, torch.ones(4, 2))
409
        logs = self.get_logs(f, torch.ones(4, 2))
410
        self.assertExpectedInline(logs, """\
411

412

413

414
def forward(self, arg0_1):
415
    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
416
    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
417
    add = torch.ops.aten.add.Tensor(arg0_1, ones);  ones = None
418
    copy_ = torch.ops.aten.copy_.default(arg0_1, add);  arg0_1 = None
419
    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
420
    return view_copy_1
421
    """)
422

423
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
424
        self.assertExpectedInline(reinplaced_logs, """\
425

426

427

428
def forward(self, arg0_1):
429
    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
430
    view = torch.ops.aten.view.default(arg0_1, [4, 2])
431
    add = torch.ops.aten.add.Tensor(arg0_1, ones);  ones = None
432
    copy_ = torch.ops.aten.copy_.default(arg0_1, add);  arg0_1 = None
433
    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
434
    return view_1
435
    """)
436

437
    # Some ops that are mutable are neither inplace nor out= ops.
438
    # They also need special handling.
439
    def test_mutable_op_not_inplace_or_other(self):
440
        def f(x):
441
            return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
442

443
        logs = self.get_logs(f, torch.ones(1))
444
        self.assertExpectedInline(logs, """\
445

446

447

448
def forward(self, arg0_1):
449
    _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0)
450
    getitem = _fused_moving_avg_obs_fq_helper_functional[0]
451
    getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
452
    getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]
453
    getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]
454
    getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]
455
    getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5];  _fused_moving_avg_obs_fq_helper_functional = None
456
    copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5);  arg0_1 = getitem_5 = None
457
    return (getitem, getitem_1)
458
    """)  # noqa: B950
459

460
    def test_as_strided(self):
461
        def f(x):
462
            y = x.as_strided((2,), (2,), 1)
463
            y.add_(1)
464
            return x
465
        self.assert_functionalization(f, torch.ones(9))
466
        logs = self.get_logs(f, torch.ones(9))
467
        self.assertExpectedInline(logs, """\
468

469

470

471
def forward(self, arg0_1):
472
    as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1)
473
    add = torch.ops.aten.add.Tensor(as_strided_copy, 1);  as_strided_copy = None
474
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1);  add = None
475
    as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1)
476
    copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter);  arg0_1 = None
477
    return as_strided_scatter
478
    """)
479

480
        # NB: even with reapply_views=True, we expect to see scatter op
481
        reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False)
482
        self.assertExpectedInline(reinplaced_logs, """\
483

484

485

486
def forward(self, arg0_1):
487
    as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1)
488
    add = torch.ops.aten.add.Tensor(as_strided, 1);  as_strided = None
489
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1);  add = None
490
    as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1)
491
    copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter);  arg0_1 = None
492
    return as_strided_scatter
493
    """)
494

495
    def test_tensor_list_composite(self):
496
        def f(x):
497
            # Test an op with TensorList input
498
            y = torch.block_diag(x, x)
499
            return y
500
        self.assert_functionalization(f, torch.ones(2, 2))
501
        logs = self.get_logs(f, torch.ones(2, 2))
502
        self.assertExpectedInline(logs, """\
503

504

505

506
def forward(self, arg0_1):
507
    block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]);  arg0_1 = None
508
    return block_diag
509
    """)
510

511
    def test_cat(self):
512
        def f(x):
513
            out = torch.empty(0)
514
            torch.cat((x,), out=out)
515
            return out
516
        self.assert_functionalization(f, torch.ones(2, 2))
517
        logs = self.get_logs(f, torch.ones(2, 2))
518
        self.assertExpectedInline(logs, """\
519

520

521

522
def forward(self, arg0_1):
523
    empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
524
    cat = torch.ops.aten.cat.default([arg0_1]);  arg0_1 = None
525
    return cat
526
    """)
527

528
        reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
529
        self.assertExpectedInline(reinplaced_logs, """\
530

531

532

533
def forward(self, arg0_1):
534
    empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
535
    cat = torch.ops.aten.cat.default([arg0_1]);  arg0_1 = None
536
    return cat
537
    """)
538

539

540
    def test_diagonal(self):
541
        def f(x):
542
            # test: view ops that take a subset of the original tensor (select/diagonal)
543
            tmp = torch.ones(2)
544
            y = x.clone().diagonal()
545
            y.add_(tmp)
546
            z = x * x
547
            return z
548
        self.assert_functionalization(f, torch.ones(2, 2))
549
        logs = self.get_logs(f, torch.ones(2, 2))
550
        self.assertExpectedInline(logs, """\
551

552

553

554
def forward(self, arg0_1):
555
    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
556
    clone = torch.ops.aten.clone.default(arg0_1)
557
    diagonal_copy = torch.ops.aten.diagonal_copy.default(clone)
558
    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
559
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add);  clone = add = None
560
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter);  diagonal_scatter = None
561
    mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
562
    return mul
563
    """)
564

565
        reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
566
        self.assertExpectedInline(reinplaced_logs, """\
567

568

569

570
def forward(self, arg0_1):
571
    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
572
    clone = torch.ops.aten.clone.default(arg0_1)
573
    diagonal = torch.ops.aten.diagonal.default(clone)
574
    add = torch.ops.aten.add_.Tensor(diagonal, ones);  diagonal = ones = None
575
    diagonal_1 = torch.ops.aten.diagonal.default(clone);  clone = None
576
    mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
577
    return mul
578
    """)
579

580
    def test_diagonal_mutated_input(self):
581
        def f(x):
582
            # simple test: there are pending updates afterwards, which the test syncs manually
583
            tmp = torch.ones(2)
584
            y = x.diagonal()
585
            y.add_(tmp)
586
            return x
587
        x = torch.ones(2, 2)
588
        self.assert_functionalization(f, x)
589
        logs = self.get_logs(f, torch.ones(2, 2))
590
        self.assertExpectedInline(logs, """\
591

592

593

594
def forward(self, arg0_1):
595
    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
596
    diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1)
597
    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
598
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add);  add = None
599
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
600
    copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter);  arg0_1 = None
601
    return diagonal_scatter
602
    """)
603

604
        # NB: even with reapply_views=True, we expect to see scatter op
605
        reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False)
606
        self.assertExpectedInline(reinplaced_logs, """\
607

608

609

610
def forward(self, arg0_1):
611
    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
612
    diagonal = torch.ops.aten.diagonal.default(arg0_1)
613
    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
614
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add);  add = None
615
    diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter)
616
    copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter);  arg0_1 = None
617
    return diagonal_scatter
618
    """)
619

620
    def test_channels_last_contiguous(self):
621
        def f(x):
622
            return x.contiguous(memory_format=torch.channels_last)
623
            tmp = torch.ones(2)
624
            y = x.diagonal()
625
            y.add_(tmp)
626
            return x
627
        x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
628
        self.assert_functionalization(f, x)
629
        logs = self.get_logs(f, x).strip()
630
        # There should be no clone in the graph
631
        self.assertExpectedInline(logs, """\
632
def forward(self, arg0_1):
633
    return arg0_1""")
634

635
    def test_split(self):
636
        def f(x):
637
            # test: view ops that return multiple tensors (split)
638
            tmp = torch.ones(2)
639
            y1, y2 = x.split(2)
640
            y3 = y2.diagonal()
641
            y3.add_(tmp)
642
            z = x * x
643
            return y3
644
        self.assert_functionalization(f, torch.ones(4, 2))
645
        logs = self.get_logs(f, torch.ones(4, 2))
646
        self.assertExpectedInline(logs, """\
647

648

649

650
def forward(self, arg0_1):
651
    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
652
    split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
653
    getitem = split_copy[0]
654
    getitem_1 = split_copy[1];  split_copy = None
655
    diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1);  getitem_1 = None
656
    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
657
    split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
658
    getitem_2 = split_copy_1[0]
659
    getitem_3 = split_copy_1[1];  split_copy_1 = None
660
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add);  getitem_3 = add = None
661
    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4);  diagonal_scatter = None
662
    split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2)
663
    getitem_4 = split_copy_2[0]
664
    getitem_5 = split_copy_2[1];  split_copy_2 = None
665
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5);  getitem_5 = None
666
    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
667
    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = None
668
    return diagonal_copy_1
669
    """)  # noqa: B950
670

671
        # NB: even with reapply_views=True, we expect to see scatter op
672
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
673
        self.assertExpectedInline(reinplaced_logs, """\
674

675

676

677
def forward(self, arg0_1):
678
    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
679
    split = torch.ops.aten.split.Tensor(arg0_1, 2)
680
    getitem = split[0]
681
    getitem_1 = split[1];  split = None
682
    diagonal = torch.ops.aten.diagonal.default(getitem_1);  getitem_1 = None
683
    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
684
    split_1 = torch.ops.aten.split.Tensor(arg0_1, 2)
685
    getitem_2 = split_1[0]
686
    getitem_3 = split_1[1];  split_1 = None
687
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add);  getitem_3 = add = None
688
    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4);  diagonal_scatter = None
689
    split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2)
690
    getitem_4 = split_2[0]
691
    getitem_5 = split_2[1];  split_2 = None
692
    diagonal_1 = torch.ops.aten.diagonal.default(getitem_5);  getitem_5 = None
693
    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
694
    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = None
695
    return diagonal_1
696
    """)  # noqa: B950
697

698
    def test_split_with_sizes(self):
699
        def f(x):
700
            # test: view ops that return multiple tensors (split_with_sizes)
701
            tmp = torch.ones(2)
702
            y1, y2 = x.split_with_sizes([2, 2])
703
            y3 = y1.diagonal()
704
            y3.add_(tmp)
705
            z = x * x
706
            return y3
707
        self.assert_functionalization(f, torch.ones(4, 2))
708
        logs = self.get_logs(f, torch.ones(4, 2))
709
        self.assertExpectedInline(logs, """\
710

711

712

713
def forward(self, arg0_1):
714
    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
715
    split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
716
    getitem = split_with_sizes_copy[0]
717
    getitem_1 = split_with_sizes_copy[1];  split_with_sizes_copy = None
718
    diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem);  getitem = None
719
    add = torch.ops.aten.add.Tensor(diagonal_copy, ones);  diagonal_copy = ones = None
720
    split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
721
    getitem_2 = split_with_sizes_copy_1[0]
722
    getitem_3 = split_with_sizes_copy_1[1];  split_with_sizes_copy_1 = None
723
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add);  getitem_2 = add = None
724
    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2);  diagonal_scatter = None
725
    split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2])
726
    getitem_4 = split_with_sizes_copy_2[0]
727
    getitem_5 = split_with_sizes_copy_2[1];  split_with_sizes_copy_2 = None
728
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4);  getitem_4 = None
729
    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
730
    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = None
731
    return diagonal_copy_1
732
    """)  # noqa: B950
733

734
        # NB: even with reapply_views=True, we expect to see scatter op
735
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
736
        self.assertExpectedInline(reinplaced_logs, """\
737

738

739

740
def forward(self, arg0_1):
741
    ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
742
    split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
743
    getitem = split_with_sizes[0]
744
    getitem_1 = split_with_sizes[1];  split_with_sizes = None
745
    diagonal = torch.ops.aten.diagonal.default(getitem);  getitem = None
746
    add = torch.ops.aten.add.Tensor(diagonal, ones);  diagonal = ones = None
747
    split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
748
    getitem_2 = split_with_sizes_1[0]
749
    getitem_3 = split_with_sizes_1[1];  split_with_sizes_1 = None
750
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add);  getitem_2 = add = None
751
    slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2);  diagonal_scatter = None
752
    split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2])
753
    getitem_4 = split_with_sizes_2[0]
754
    getitem_5 = split_with_sizes_2[1];  split_with_sizes_2 = None
755
    diagonal_1 = torch.ops.aten.diagonal.default(getitem_4);  getitem_4 = None
756
    mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
757
    copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter);  arg0_1 = slice_scatter = None
758
    return diagonal_1
759
    """)  # noqa: B950
760

761
    def test_slice(self):
762
        def f(x):
763
            tmp = torch.ones(4)
764
            x.transpose_(1, 0)
765
            y = x[0:2]
766
            y.add_(tmp)
767
            return x
768
        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
769
        logs = self.get_logs(f, torch.ones(4, 2))
770
        self.assertExpectedInline(logs, """\
771

772

773

774
def forward(self, arg0_1):
775
    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
776
    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
777
    slice_copy = torch.ops.aten.slice_copy.Tensor(transpose_copy, 0, 0, 2);  transpose_copy = None
778
    add = torch.ops.aten.add.Tensor(slice_copy, ones);  slice_copy = ones = None
779
    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
780
    slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2);  transpose_copy_1 = add = None
781
    transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0);  slice_scatter = None
782
    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
783
    slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2);  transpose_copy_3 = None
784
    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
785
    return transpose_copy_4
786
    """)  # noqa: B950
787

788
        # NB: even with reapply_views=True, we expect to see scatter op
789
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
790
        self.assertExpectedInline(reinplaced_logs, """\
791

792

793

794
def forward(self, arg0_1):
795
    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
796
    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
797
    slice_1 = torch.ops.aten.slice.Tensor(transpose, 0, 0, 2);  transpose = None
798
    add = torch.ops.aten.add.Tensor(slice_1, ones);  slice_1 = ones = None
799
    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
800
    slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2);  transpose_1 = add = None
801
    transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0);  slice_scatter = None
802
    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
803
    slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2);  transpose_3 = None
804
    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
805
    return transpose_4
806
    """)  # noqa: B950
807

808
    def test_view_inplace(self):
809
        def f(x):
810
            # test: view + inplace op (transpose_)
811
            tmp = torch.ones(4)
812
            x.transpose_(1, 0)
813
            y = x[0]
814
            y.add_(tmp)
815
            return x
816
        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
817
        logs = self.get_logs(f, torch.ones(4, 2))
818
        self.assertExpectedInline(logs, """\
819

820

821

822
def forward(self, arg0_1):
823
    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
824
    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
825
    select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0);  transpose_copy = None
826
    add = torch.ops.aten.add.Tensor(select_copy, ones);  select_copy = ones = None
827
    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
828
    select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0);  transpose_copy_1 = add = None
829
    transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0);  select_scatter = None
830
    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
831
    select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0);  transpose_copy_3 = None
832
    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
833
    return transpose_copy_4
834
    """)  # noqa: B950
835

836
        # NB: even with reapply_views=True, we expect to see scatter op
837
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
838
        self.assertExpectedInline(reinplaced_logs, """\
839

840

841

842
def forward(self, arg0_1):
843
    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
844
    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
845
    select = torch.ops.aten.select.int(transpose, 0, 0);  transpose = None
846
    add = torch.ops.aten.add.Tensor(select, ones);  select = ones = None
847
    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
848
    select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0);  transpose_1 = add = None
849
    transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0);  select_scatter = None
850
    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
851
    select_1 = torch.ops.aten.select.int(transpose_3, 0, 0);  transpose_3 = None
852
    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
853
    return transpose_4
854
    """)  # noqa: B950
855

856
    def test_unbind(self):
857
        def f(x):
858
            # test: view + inplace op (transpose_)
859
            tmp = torch.ones(4)
860
            x.transpose_(1, 0)
861
            y, _ = x.unbind(0)
862
            y.add_(tmp)
863
            return x
864
        self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
865
        logs = self.get_logs(f, torch.ones(4, 2))
866
        self.assertExpectedInline(logs, """\
867

868

869

870
def forward(self, arg0_1):
871
    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
872
    transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
873
    unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy);  transpose_copy = None
874
    getitem = unbind_copy[0]
875
    getitem_1 = unbind_copy[1];  unbind_copy = None
876
    add = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
877
    transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0);  arg0_1 = None
878
    select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0);  transpose_copy_1 = add = None
879
    transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0);  select_scatter = None
880
    transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
881
    unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3);  transpose_copy_3 = None
882
    getitem_2 = unbind_copy_1[0]
883
    getitem_3 = unbind_copy_1[1];  unbind_copy_1 = None
884
    transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0);  transpose_copy_2 = None
885
    return transpose_copy_4
886
    """)  # noqa: B950
887

888
        # NB: even with reapply_views=True, we expect to see scatter op
889
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
890
        self.assertExpectedInline(reinplaced_logs, """\
891

892

893

894
def forward(self, arg0_1):
895
    ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
896
    transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
897
    unbind = torch.ops.aten.unbind.int(transpose);  transpose = None
898
    getitem = unbind[0]
899
    getitem_1 = unbind[1];  unbind = None
900
    add = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
901
    transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0);  arg0_1 = None
902
    select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0);  transpose_1 = add = None
903
    transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0);  select_scatter = None
904
    transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
905
    unbind_1 = torch.ops.aten.unbind.int(transpose_3);  transpose_3 = None
906
    getitem_2 = unbind_1[0]
907
    getitem_3 = unbind_1[1];  unbind_1 = None
908
    transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0);  transpose_2 = None
909
    return transpose_4
910
    """)  # noqa: B950
911

912
    def test_optional_tensor_list(self):
913
        def f(x):
914
            # test: an operator that takes in a List[Optional[Tensor]] argument
915
            # (index_put)
916
            y = x.view(8)
917
            indices = torch.arange(4)
918
            values = torch.arange(4, dtype=y.dtype)
919
            y.index_put_((indices,), values, accumulate=False)
920
            return y
921
        self.assert_functionalization(f, torch.ones(4, 2))
922
        logs = self.get_logs(f, torch.ones(4, 2))
923
        self.assertExpectedInline(logs, """\
924

925

926

927
def forward(self, arg0_1):
928
    view_copy = torch.ops.aten.view_copy.default(arg0_1, [8])
929
    arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
930
    arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
931
    index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1);  view_copy = arange = arange_1 = None
932
    view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]);  index_put = None
933
    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8])
934
    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = None
935
    return view_copy_2
936
    """)  # noqa: B950
937

938
    def test_scalars(self):
939
        def f(x):
940
            # test: the pass can handle scalar inputs properly
941
            tmp = torch.ones(4, 2)
942
            y = x.view(4, 2)
943
            y.add_(1)
944
            z = 2 * y
945
            z.div_(1)
946
            return z
947
        self.assert_functionalization(f, torch.ones(4, 2))
948
        logs = self.get_logs(f, torch.ones(4, 2))
949
        self.assertExpectedInline(logs, """\
950

951

952

953
def forward(self, arg0_1):
954
    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
955
    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
956
    add = torch.ops.aten.add.Tensor(view_copy, 1);  view_copy = None
957
    view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]);  add = None
958
    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
959
    mul = torch.ops.aten.mul.Tensor(view_copy_2, 2);  view_copy_2 = None
960
    div = torch.ops.aten.div.Tensor(mul, 1);  mul = None
961
    copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1);  arg0_1 = view_copy_1 = None
962
    return div
963
    """)
964

965
    @skipIfTorchDynamo("Test does not work with TorchDynamo")
966
    def test_metadata_change(self):
967
        def f(x):
968
            # ops like ge_() are allowed to change the dtype of the input.
969
            # functionalization should pick up on that.
970
            y = x.clone()
971
            out = y.ge_(0)
972
            return out
973
        self.assert_functionalization(f, torch.ones(4, 2))
974
        logs = self.get_logs(f, torch.ones(4, 2))
975
        self.assertExpectedInline(logs, """\
976

977

978

979
def forward(self, arg0_1):
980
    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
981
    ge = torch.ops.aten.ge.Scalar(clone, 0);  clone = None
982
    _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided);  ge = None
983
    return _to_copy
984
    """)
985

986
        reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
987
        self.assertExpectedInline(reinplaced_logs, """\
988

989

990

991
def forward(self, arg0_1):
992
    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
993
    ge = torch.ops.aten.ge.Scalar(clone, 0);  clone = None
994
    _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided);  ge = None
995
    return _to_copy
996
    """)  # noqa: B950
997

998
    @skipIfTorchDynamo("Test does not work with TorchDynamo")
999
    def test_metadata_change_out_op(self):
1000
        def f(t, y):
1001
            out_1 = torch.ones(1)
1002
            return torch.add(t, y, out=out_1)
1003

1004
        inpt1, inpt2 = torch.tensor([1]), torch.tensor([1])
1005
        inpt1_func, inpt2_func = torch._to_functional_tensor(inpt1), torch._to_functional_tensor(inpt2)
1006

1007
        out_ref = f(inpt1, inpt2)
1008
        torch._enable_functionalization(reapply_views=True)
1009
        try:
1010
            out_functional = f(inpt1_func, inpt2_func)
1011
        finally:
1012
            torch._disable_functionalization()
1013
        self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
1014

1015

1016
    def test_only_one_view(self):
1017
        def f(x):
1018
            # This tests that we don't have any unnecessary views in the trace.
1019
            # If the input wasn't mutated, we don't need to regenerate it,
1020
            # so there should be a total of 1 op in the output trace.
1021
            return x.view(4, 2)
1022
        logs = self.get_logs(f, torch.ones(4, 2))
1023
        self.assertExpectedInline(logs, """\
1024

1025

1026

1027
def forward(self, arg0_1):
1028
    view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]);  arg0_1 = None
1029
    return view_copy
1030
    """)
1031

1032
    def test_everything(self):
1033
        def f(x):
1034
            # test: everything
1035
            tmp = torch.ones(2, 2)
1036
            x2 = x + x
1037
            y = x2.view(8)
1038
            z0 = y.reshape(2, 4)
1039
            z1 = z0.transpose(1, 0)
1040
            z1.unsqueeze_(0)
1041
            z1.squeeze_()
1042
            z2, z3 = z1.split(2)
1043
            z2.add_(tmp)
1044
            z4 = z0[0] + z2.reshape(4)
1045
            return z2
1046
        self.assert_functionalization(f, torch.ones(4, 2))
1047
        logs = self.get_logs(f, torch.ones(4, 2))
1048
        self.assertExpectedInline(logs, """\
1049

1050

1051

1052
def forward(self, arg0_1):
1053
    ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
1054
    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1055
    view_copy = torch.ops.aten.view_copy.default(add, [8])
1056
    view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]);  view_copy = None
1057
    transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0)
1058
    unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0);  transpose_copy = None
1059
    squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy);  unsqueeze_copy = None
1060
    split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2);  squeeze_copy = None
1061
    getitem = split_copy[0]
1062
    getitem_1 = split_copy[1];  split_copy = None
1063
    add_1 = torch.ops.aten.add.Tensor(getitem, ones);  getitem = ones = None
1064
    view_copy_2 = torch.ops.aten.view_copy.default(add, [8]);  add = None
1065
    view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]);  view_copy_2 = None
1066
    transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_3, 1, 0);  view_copy_3 = None
1067
    unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0);  transpose_copy_1 = None
1068
    squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1);  unsqueeze_copy_1 = None
1069
    slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2);  squeeze_copy_1 = add_1 = None
1070
    unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0);  slice_scatter = None
1071
    squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0);  unsqueeze_copy_2 = None
1072
    transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0);  squeeze_copy_2 = None
1073
    view_copy_4 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]);  transpose_copy_2 = None
1074
    view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 2]);  view_copy_4 = None
1075
    view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [8])
1076
    view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [2, 4]);  view_copy_6 = None
1077
    transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_7, 1, 0);  view_copy_7 = None
1078
    unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0);  transpose_copy_3 = None
1079
    squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3);  unsqueeze_copy_3 = None
1080
    split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2);  squeeze_copy_3 = None
1081
    getitem_2 = split_copy_1[0]
1082
    getitem_3 = split_copy_1[1];  split_copy_1 = None
1083
    select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0);  view_copy_1 = None
1084
    view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4])
1085
    view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8])
1086
    view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]);  view_copy_9 = None
1087
    select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0);  view_copy_10 = None
1088
    view_copy_11 = torch.ops.aten.view_copy.default(view_copy_5, [8]);  view_copy_5 = None
1089
    view_copy_12 = torch.ops.aten.view_copy.default(view_copy_11, [2, 4]);  view_copy_11 = None
1090
    transpose_copy_4 = torch.ops.aten.transpose_copy.int(view_copy_12, 1, 0);  view_copy_12 = None
1091
    unsqueeze_copy_4 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_4, 0);  transpose_copy_4 = None
1092
    squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4);  unsqueeze_copy_4 = None
1093
    split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2);  squeeze_copy_4 = None
1094
    getitem_4 = split_copy_2[0]
1095
    getitem_5 = split_copy_2[1];  split_copy_2 = None
1096
    view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]);  getitem_4 = None
1097
    add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13);  select_copy_1 = view_copy_13 = None
1098
    return getitem_2
1099
    """)  # noqa: B950
1100

1101
        reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
1102
        self.assertExpectedInline(reinplaced_logs, """\
1103

1104

1105

1106
def forward(self, arg0_1):
1107
    ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
1108
    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1109
    view = torch.ops.aten.view.default(add, [8])
1110
    view_1 = torch.ops.aten.view.default(view, [2, 4]);  view = None
1111
    transpose = torch.ops.aten.transpose.int(view_1, 1, 0)
1112
    unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0);  transpose = None
1113
    squeeze = torch.ops.aten.squeeze.default(unsqueeze);  unsqueeze = None
1114
    split = torch.ops.aten.split.Tensor(squeeze, 2);  squeeze = None
1115
    getitem = split[0]
1116
    getitem_1 = split[1];  split = None
1117
    add_1 = torch.ops.aten.add_.Tensor(getitem, ones);  getitem = ones = None
1118
    view_2 = torch.ops.aten.view.default(add, [8]);  add = None
1119
    view_3 = torch.ops.aten.view.default(view_2, [2, 4]);  view_2 = None
1120
    transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0);  view_3 = None
1121
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0);  transpose_1 = None
1122
    squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1);  unsqueeze_1 = None
1123
    unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0);  squeeze_1 = None
1124
    squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0);  unsqueeze_2 = None
1125
    transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0);  squeeze_2 = None
1126
    view_4 = torch.ops.aten.view.default(transpose_2, [8]);  transpose_2 = None
1127
    view_5 = torch.ops.aten.view.default(view_4, [4, 2]);  view_4 = None
1128
    view_6 = torch.ops.aten.view.default(view_5, [8])
1129
    view_7 = torch.ops.aten.view.default(view_6, [2, 4]);  view_6 = None
1130
    transpose_3 = torch.ops.aten.transpose.int(view_7, 1, 0);  view_7 = None
1131
    unsqueeze_3 = torch.ops.aten.unsqueeze.default(transpose_3, 0);  transpose_3 = None
1132
    squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3);  unsqueeze_3 = None
1133
    split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2);  squeeze_3 = None
1134
    getitem_2 = split_1[0]
1135
    getitem_3 = split_1[1];  split_1 = None
1136
    select = torch.ops.aten.select.int(view_1, 0, 0);  view_1 = None
1137
    clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format)
1138
    _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]);  clone = None
1139
    view_8 = torch.ops.aten.view.default(view_5, [8]);  view_5 = None
1140
    view_9 = torch.ops.aten.view.default(view_8, [2, 4]);  view_8 = None
1141
    select_1 = torch.ops.aten.select.int(view_9, 0, 0);  view_9 = None
1142
    add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view);  select_1 = _unsafe_view = None
1143
    return getitem_2
1144
    """)
1145

1146
    def test_reapply_views_simple(self):
1147
        def f(x):
1148
            tmp = torch.ones(4, 2)
1149
            y = x.view(4, 2)
1150
            y.add_(tmp)
1151
            z = x * x
1152
            return y
1153
        self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
1154
        logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
1155
        self.assertExpectedInline(logs, """\
1156

1157

1158

1159
def forward(self, arg0_1):
1160
    ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
1161
    view = torch.ops.aten.view.default(arg0_1, [4, 2])
1162
    add = torch.ops.aten.add.Tensor(view, ones);  view = ones = None
1163
    view_1 = torch.ops.aten.view.default(add, [4, 2]);  add = None
1164
    view_2 = torch.ops.aten.view.default(view_1, [4, 2])
1165
    mul = torch.ops.aten.mul.Tensor(view_1, view_1)
1166
    copy_ = torch.ops.aten.copy_.default(arg0_1, view_1);  arg0_1 = view_1 = None
1167
    return view_2
1168
    """)
1169

1170
    def test_aliases_maintained_after_pass_when_reapplying_views(self):
1171
        def f(x):
1172
            tmp = torch.ones(4, 2)
1173
            y = x.view(4, 2)
1174
            z = x.view(4, 2)
1175
            y.add_(tmp)
1176
            return y, z
1177

1178
        input_functional = torch._to_functional_tensor(torch.ones(4, 2))
1179
        torch._enable_functionalization(reapply_views=True)
1180
        try:
1181
            y, z = f(input_functional)
1182
            torch._sync(y)
1183
            torch._sync(z)
1184
        finally:
1185
            torch._disable_functionalization()
1186

1187
        # y and z are aliases inside of the function, and that aliasing relationship should be maintained.
1188
        _y = torch._from_functional_tensor(y)
1189
        _z = torch._from_functional_tensor(z)
1190
        self.assertTrue(are_aliased(_y, _z))
1191

1192
    # copy_() gets its own test, because it used to be special cased in functionalization.
1193
    # However, now it works pretty similar to other functional ops
1194
    def test_copy_(self):
1195
        def f(x):
1196
            tmp = torch.zeros(2, 2)
1197
            tmp_slice = tmp.diagonal()
1198
            y = tmp_slice.copy_(x)
1199
            z = y.add_(x)
1200
            return z
1201

1202
        # Test 1: copy_() with same dtype and shape
1203
        # to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
1204
        # self.assert_functionalization(f, torch.ones(2))
1205
        logs = self.get_logs(f, torch.ones(2))
1206
        self.assertExpectedInline(logs, """\
1207

1208

1209

1210
def forward(self, arg0_1):
1211
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1212
    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1213
    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1214
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1215
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1216
    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1217
    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1218
    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1219
    return diagonal_copy_2
1220
    """)
1221

1222
        reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
1223
        self.assertExpectedInline(reinplaced_logs, """\
1224

1225

1226

1227
def forward(self, arg0_1):
1228
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1229
    diagonal = torch.ops.aten.diagonal.default(zeros)
1230
    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = None
1231
    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1232
    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = None
1233
    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1234
    return diagonal_2
1235
    """)
1236

1237
        # Test 2: copy_() with same dtype, different shape
1238
        self.assert_functionalization(f, torch.ones(1))
1239
        logs = self.get_logs(f, torch.ones(1))
1240
        self.assertExpectedInline(logs, """\
1241

1242

1243

1244
def forward(self, arg0_1):
1245
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1246
    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1247
    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1248
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1249
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1250
    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1251
    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1252
    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1253
    return diagonal_copy_2
1254
    """)
1255

1256
        reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
1257
        self.assertExpectedInline(reinplaced_logs, """\
1258

1259

1260

1261
def forward(self, arg0_1):
1262
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1263
    diagonal = torch.ops.aten.diagonal.default(zeros)
1264
    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = None
1265
    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1266
    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = None
1267
    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1268
    return diagonal_2
1269
    """)
1270

1271
        # Test 3: copy_() with different dtype, same shape
1272
        self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
1273
        logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
1274
        self.assertExpectedInline(logs, """\
1275

1276

1277

1278
def forward(self, arg0_1):
1279
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1280
    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1281
    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1282
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1283
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1284
    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1285
    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1286
    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1287
    return diagonal_copy_2
1288
    """)  # noqa: B950
1289

1290
        reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
1291
        self.assertExpectedInline(reinplaced_logs, """\
1292

1293

1294

1295
def forward(self, arg0_1):
1296
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1297
    diagonal = torch.ops.aten.diagonal.default(zeros)
1298
    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = None
1299
    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1300
    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = None
1301
    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1302
    return diagonal_2
1303
    """)  # noqa: B950
1304

1305
        # Test 4: copy_() with different dtype, different shape
1306
        self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
1307
        logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
1308
        self.assertExpectedInline(logs, """\
1309

1310

1311

1312
def forward(self, arg0_1):
1313
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1314
    diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1315
    copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1);  diagonal_copy = None
1316
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy);  zeros = copy = None
1317
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1318
    add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1);  diagonal_copy_1 = arg0_1 = None
1319
    diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add);  diagonal_scatter = add = None
1320
    diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1);  diagonal_scatter_1 = None
1321
    return diagonal_copy_2
1322
    """)  # noqa: B950
1323

1324
        reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
1325
        self.assertExpectedInline(reinplaced_logs, """\
1326

1327

1328

1329
def forward(self, arg0_1):
1330
    zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1331
    diagonal = torch.ops.aten.diagonal.default(zeros)
1332
    copy = torch.ops.aten.copy_.default(diagonal, arg0_1);  diagonal = None
1333
    diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1334
    add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1);  diagonal_1 = arg0_1 = None
1335
    diagonal_2 = torch.ops.aten.diagonal.default(zeros);  zeros = None
1336
    return diagonal_2
1337
    """)  # noqa: B950
1338

1339
    def test_expand_symint(self):
1340
        # Once some existing SymInt bugs are ironed out, we should update
1341
        # this test to plumb FakeSymbolicTensors through it
1342
        def f(x):
1343
            return x.expand(x.size(0), x.size(1))
1344

1345
        self.assert_functionalization(f, torch.ones(2, 2))
1346
        logs = self.get_logs(f, torch.ones(2, 2))
1347
        self.assertExpectedInline(logs, """\
1348

1349

1350

1351
def forward(self, arg0_1):
1352
    expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]);  arg0_1 = None
1353
    return expand_copy
1354
    """)
1355

1356
    def test_fill_(self):
1357
        def f(x):
1358
            y = x + x
1359
            z = y.diagonal()
1360
            z.fill_(0)
1361
            return y
1362

1363
        self.assert_functionalization(f, torch.ones(2, 2))
1364
        logs = self.get_logs(f, torch.ones(2, 2))
1365
        self.assertExpectedInline(logs, """\
1366

1367

1368

1369
def forward(self, arg0_1):
1370
    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1371
    diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
1372
    fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0);  diagonal_copy = None
1373
    diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill);  add = fill = None
1374
    diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1375
    return diagonal_scatter
1376
    """)
1377

1378
        reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
1379
        self.assertExpectedInline(reinplaced_logs, """\
1380

1381

1382

1383
def forward(self, arg0_1):
1384
    add = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
1385
    diagonal = torch.ops.aten.diagonal.default(add)
1386
    fill = torch.ops.aten.fill_.Scalar(diagonal, 0);  diagonal = None
1387
    diagonal_1 = torch.ops.aten.diagonal.default(add)
1388
    return add
1389
    """)
1390

1391
    def test_resize_smaller(self):
1392
        def f(w):
1393
            # Resizing to a smaller size doesn't affect storage
1394
            x = w + 1
1395
            y = x.view(4, 4)
1396
            y.resize_(3, 3)
1397
            y2 = y.view(-1)
1398
            y2.add_(1)
1399
            z = y + 1
1400
            return z
1401

1402
        self.assert_functionalization(f, torch.ones(8, 2))
1403
        logs = self.get_logs(f, torch.ones(8, 2))
1404
        self.assertExpectedInline(logs, """\
1405

1406

1407

1408
def forward(self, arg0_1):
1409
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1410
    view_copy = torch.ops.aten.view_copy.default(add, [4, 4])
1411
    resize = torch.ops.aten.resize.default(view_copy, [3, 3])
1412
    as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]);  view_copy = None
1413
    view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]);  as_strided_copy = None
1414
    add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1);  view_copy_1 = None
1415
    view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]);  add = None
1416
    as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1])
1417
    view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]);  add_1 = None
1418
    as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]);  view_copy_2 = view_copy_3 = None
1419
    view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]);  as_strided_scatter = None
1420
    view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4])
1421
    as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]);  view_copy_5 = None
1422
    view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]);  as_strided_copy_2 = None
1423
    view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]);  view_copy_4 = None
1424
    as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]);  view_copy_7 = None
1425
    add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1);  as_strided_copy_3 = None
1426
    return add_2
1427
    """)  # noqa: B950
1428

1429
        reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
1430
        self.assertExpectedInline(reinplaced_logs, """\
1431

1432

1433

1434
def forward(self, arg0_1):
1435
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1436
    view = torch.ops.aten.view.default(add, [4, 4])
1437
    resize = torch.ops.aten.resize.default(view, [3, 3])
1438
    as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]);  view = None
1439
    view_1 = torch.ops.aten.view.default(as_strided, [-1]);  as_strided = None
1440
    add_1 = torch.ops.aten.add_.Tensor(view_1, 1)
1441
    view_2 = torch.ops.aten.view.default(add, [4, 4]);  add = None
1442
    as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1])
1443
    view_3 = torch.ops.aten.view.default(view_1, [3, 3]);  view_1 = None
1444
    view_4 = torch.ops.aten.view.default(view_2, [8, 2]);  view_2 = None
1445
    view_5 = torch.ops.aten.view.default(view_4, [4, 4])
1446
    as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]);  view_5 = None
1447
    view_6 = torch.ops.aten.view.default(as_strided_2, [-1]);  as_strided_2 = None
1448
    view_7 = torch.ops.aten.view.default(view_4, [4, 4]);  view_4 = None
1449
    as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]);  view_7 = None
1450
    add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1)
1451
    return as_strided_3
1452
    """)
1453

1454
    def test_resize_same_size_diff_rank(self):
1455
        def f(x):
1456
            y = x.clone()
1457
            y.resize_(25, 5)
1458
            return y
1459

1460
        self.assert_functionalization(f, torch.ones(5, 5, 5))
1461

1462
    def test_resize_larger_valid(self):
1463
        def f(x):
1464
            y = x + 1
1465
            # resizing a tensor to a larger size is only currently allowed
1466
            # if the tensor-to-resize is not a view / has no outstanding views.
1467
            # See Note [resize_() in functionalization pass]
1468
            y.resize_(5, 5)
1469
            y2 = y.view(25)
1470
            # Do a mutation to ensure that aliases of the output of resize_()
1471
            # propagate mutations correctly.
1472
            # I'm using fill_ specifically because I want to guarantee that
1473
            # none of the output has uninitialized memory at the end
1474
            # (since these tests compare the data output against a reference impl)
1475
            y2.fill_(1)
1476
            out = y + 1
1477
            return y, out
1478

1479
        self.assert_functionalization(f, torch.ones(8, 2))
1480
        logs = self.get_logs(f, torch.ones(8, 2))
1481
        self.assertExpectedInline(logs, """\
1482

1483

1484

1485
def forward(self, arg0_1):
1486
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1487
    resize = torch.ops.aten.resize.default(add, [5, 5]);  add = None
1488
    view_copy = torch.ops.aten.view_copy.default(resize, [25]);  resize = None
1489
    fill = torch.ops.aten.fill.Scalar(view_copy, 1);  view_copy = None
1490
    view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]);  fill = None
1491
    view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25])
1492
    add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
1493
    return (view_copy_1, add_1)
1494
    """)
1495

1496
        reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
1497
        self.assertExpectedInline(reinplaced_logs, """\
1498

1499

1500

1501
def forward(self, arg0_1):
1502
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
1503
    resize = torch.ops.aten.resize_.default(add, [5, 5])
1504
    view = torch.ops.aten.view.default(add, [25]);  add = None
1505
    fill = torch.ops.aten.fill_.Scalar(view, 1)
1506
    view_1 = torch.ops.aten.view.default(view, [5, 5]);  view = None
1507
    view_2 = torch.ops.aten.view.default(view_1, [25])
1508
    add_1 = torch.ops.aten.add.Tensor(view_1, 1)
1509
    return (view_1, add_1)
1510
    """)
1511

1512
    def test_resize_larger_invalid(self):
1513
        def f(x):
1514
            y = x + 1
1515
            z = y.view(4, 4)
1516
            # resizing a tensor to a larger size is only currently allowed
1517
            # if the tensor-to-resize is not a view / has no outstanding views.
1518
            # See Note [resize_() in functionalization pass]
1519
            # This should fail
1520
            z.resize_(5, 5)
1521
            z2 = z.view(25)
1522
            z2.fill_(1)
1523
            out = z + 1
1524
            return y, out
1525

1526
        with self.assertRaisesRegex(
1527
                RuntimeError,
1528
                r'Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass'):
1529
            self.assert_functionalization(f, torch.ones(8, 2))
1530

1531
    def test_nested_functions_propagate_updates(self):
1532
        def g(x):
1533
            # Create a view of x
1534
            y = x[0]
1535
            y.add_(1)
1536
            # The view, y, gets deallocated at the end of this function
1537

1538
        def f(x):
1539
            # Calling g(x) should mutate x
1540
            g(x)
1541
            # We expect x to be synced here, even though the alias created in g() has been deallocated!
1542
            y = x + x
1543
            return y
1544

1545
        self.assert_functionalization(f, torch.ones(2, 2))
1546

1547
    def test_mixed_wrappers_valid(self):
1548
        def f(x, y):
1549
            z = x + y
1550
            z.add_(1)
1551
            return z
1552

1553
        x1_not_functional = LoggingTensor(torch.ones(4))
1554
        x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))
1555

1556
        with capture_logs() as logs:
1557
            y = f(x1_not_functional, x2_functional)
1558

1559
        # Make sure that functionalization ran the "+" kernel
1560
        # with a functional + non-functional tensor, and wrapped the output appropriately.
1561
        self.assertExpectedInline('\n'.join(logs), """\
1562
$2: f32[4] = torch._ops.aten.add.Tensor($0, $1)
1563
$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""")
1564

1565
    def test_mixed_wrappers_invalid(self):
1566
        x1_not_functional = torch.ones(4)
1567
        x2_functional = torch._to_functional_tensor(torch.ones(4))
1568

1569
        # When dealing with mixed functional + non functional tensors,
1570
        # normal_tensor.add_(functional_tensor) is not valid
1571
        # because normal_tensor would need to be "promoted" to a functional tensor.
1572
        with self.assertRaises(RuntimeError):
1573
            x1_not_functional.add_(x2_functional)
1574

1575
    def test_index_mutation_on_non_input(self):
1576
        def f(x):
1577
            tmp = torch.zeros(10)
1578
            tmp[5].fill_(1)
1579
            return tmp
1580
        self.assert_functionalization(f, torch.ones(2))
1581
        logs = self.get_logs(f, torch.ones(2))
1582
        self.assertExpectedInline(logs, """\
1583

1584

1585

1586
def forward(self, arg0_1):
1587
    zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
1588
    select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
1589
    fill = torch.ops.aten.fill.Scalar(select_copy, 1);  select_copy = None
1590
    select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5);  zeros = fill = None
1591
    select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5)
1592
    return select_scatter
1593
    """)  # noqa: B950
1594

1595
        reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
1596
        self.assertExpectedInline(reinplaced_logs, """\
1597

1598

1599

1600
def forward(self, arg0_1):
1601
    zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
1602
    select = torch.ops.aten.select.int(zeros, 0, 5)
1603
    fill = torch.ops.aten.fill_.Scalar(select, 1);  select = None
1604
    select_1 = torch.ops.aten.select.int(zeros, 0, 5)
1605
    return zeros
1606
    """)
1607

1608

1609
    def test_instance_norm(self):
1610
        size = 100
1611

1612
        def f(x, running_mean, running_var):
1613
            with enable_python_dispatcher():
1614
                return torch.instance_norm(x, None, None, running_mean, running_var,
1615
                                           use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False)
1616
        self.assert_functionalization(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size))
1617
        # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used
1618
        # whereas on other platforms, the alias_copy's are before the view_copy's.
1619
        # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment.
1620
        if not IS_WINDOWS:
1621
            logs = self.get_logs(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size))
1622
            self.assertExpectedInline(logs, """\
1623

1624

1625

1626
def forward(self, arg0_1, arg1_1, arg2_1):
1627
    repeat = torch.ops.aten.repeat.default(arg1_1, [20])
1628
    repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
1629
    view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]);  arg0_1 = None
1630
    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
1631
    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05);  view_copy = repeat = repeat_1 = None
1632
    getitem = _native_batch_norm_legit_functional[0]
1633
    getitem_1 = _native_batch_norm_legit_functional[1]
1634
    getitem_2 = _native_batch_norm_legit_functional[2]
1635
    getitem_3 = _native_batch_norm_legit_functional[3]
1636
    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
1637
    alias_copy = torch.ops.aten.alias_copy.default(arg1_1)
1638
    view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100])
1639
    view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]);  getitem_3 = None
1640
    mean = torch.ops.aten.mean.dim(view_copy_2, [0]);  view_copy_2 = None
1641
    copy = torch.ops.aten.copy.default(alias_copy, mean);  alias_copy = mean = None
1642
    alias_copy_1 = torch.ops.aten.alias_copy.default(copy);  copy = None
1643
    alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1)
1644
    alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1)
1645
    view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100])
1646
    view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]);  getitem_4 = None
1647
    mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]);  view_copy_4 = None
1648
    copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1);  alias_copy_3 = mean_1 = None
1649
    alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1);  copy_1 = None
1650
    alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4)
1651
    view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]);  getitem = None
1652
    copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1);  arg1_1 = alias_copy_1 = None
1653
    copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4);  arg2_1 = alias_copy_4 = None
1654
    return view_copy_5
1655
    """)  # noqa: B950
1656

1657
            reinplaced_logs = self.get_logs(
1658
                f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size),
1659
                reapply_views=True, run_reinplace=True
1660
            )
1661
            self.assertExpectedInline(reinplaced_logs, """\
1662

1663

1664

1665
def forward(self, arg0_1, arg1_1, arg2_1):
1666
    repeat = torch.ops.aten.repeat.default(arg1_1, [20])
1667
    repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
1668
    view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]);  arg0_1 = None
1669
    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
1670
    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05);  view = repeat = repeat_1 = None
1671
    getitem = _native_batch_norm_legit_functional[0]
1672
    getitem_1 = _native_batch_norm_legit_functional[1]
1673
    getitem_2 = _native_batch_norm_legit_functional[2]
1674
    getitem_3 = _native_batch_norm_legit_functional[3]
1675
    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
1676
    alias = torch.ops.aten.alias.default(arg1_1)
1677
    view_1 = torch.ops.aten.view.default(getitem_3, [20, 100])
1678
    view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]);  getitem_3 = None
1679
    mean = torch.ops.aten.mean.dim(view_2, [0]);  view_2 = None
1680
    copy = torch.ops.aten.copy.default(alias, mean);  alias = mean = None
1681
    alias_1 = torch.ops.aten.alias.default(copy);  copy = None
1682
    alias_2 = torch.ops.aten.alias.default(alias_1)
1683
    alias_3 = torch.ops.aten.alias.default(arg2_1)
1684
    view_3 = torch.ops.aten.view.default(getitem_4, [20, 100])
1685
    view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]);  getitem_4 = None
1686
    mean_1 = torch.ops.aten.mean.dim(view_4, [0]);  view_4 = None
1687
    copy_1 = torch.ops.aten.copy.default(alias_3, mean_1);  alias_3 = mean_1 = None
1688
    alias_4 = torch.ops.aten.alias.default(copy_1);  copy_1 = None
1689
    alias_5 = torch.ops.aten.alias.default(alias_4)
1690
    view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]);  getitem = None
1691
    copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1);  arg1_1 = alias_1 = None
1692
    copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4);  arg2_1 = alias_4 = None
1693
    return view_5
1694
    """)  # noqa: B950
1695

1696

1697
    def test_mutation_overlapping_mem(self):
1698
        def fn(x):
1699
            # x: (1, 5)
1700
            t1 = torch.add(x, x)
1701
            t2 = t1.unfold(1, 3, 2)
1702
            t3 = t2.abs_()
1703
            return t3
1704

1705
        with self.assertRaisesRegex(RuntimeError, r'encountered a tensor being mutated that has internal overlap'):
1706
            x = torch.ones(1, 5)
1707
            out = _functionalize(fn, reapply_views=True, crossref=False)(x)
1708

1709

1710
    def test_batch_norm(self):
1711
        def f(x, running_mean, running_var):
1712
            with enable_python_dispatcher():
1713
                return torch.batch_norm(x, None, None, running_mean, running_var, True, 0.1, 1e-5, False)
1714

1715
        self.assert_functionalization(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100))
1716
        logs = self.get_logs(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100))
1717
        self.assertExpectedInline(logs, """\
1718

1719

1720

1721
def forward(self, arg0_1, arg1_1, arg2_1):
1722
    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
1723
    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05);  arg0_1 = None
1724
    getitem = _native_batch_norm_legit_functional[0]
1725
    getitem_1 = _native_batch_norm_legit_functional[1]
1726
    getitem_2 = _native_batch_norm_legit_functional[2]
1727
    getitem_3 = _native_batch_norm_legit_functional[3]
1728
    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
1729
    copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3);  arg1_1 = getitem_3 = None
1730
    copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4);  arg2_1 = getitem_4 = None
1731
    return getitem
1732
    """)  # noqa: B950
1733

1734
        reinplaced_logs = self.get_logs(
1735
            f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100), reapply_views=True, run_reinplace=True
1736
        )
1737
        self.assertExpectedInline(reinplaced_logs, """\
1738

1739

1740

1741
def forward(self, arg0_1, arg1_1, arg2_1):
1742
    empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
1743
    _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05);  arg0_1 = None
1744
    getitem = _native_batch_norm_legit_functional[0]
1745
    getitem_1 = _native_batch_norm_legit_functional[1]
1746
    getitem_2 = _native_batch_norm_legit_functional[2]
1747
    getitem_3 = _native_batch_norm_legit_functional[3]
1748
    getitem_4 = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
1749
    copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3);  arg1_1 = getitem_3 = None
1750
    copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4);  arg2_1 = getitem_4 = None
1751
    return getitem
1752
    """)  # noqa: B950
1753

1754
    # This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode
1755
    def test_python_functionalization(self):
1756
        def f(x):
1757
            x_view = x.view(-1)
1758
            x.mul_(2)
1759
            return x_view + 1
1760

1761
        def f_functionalized(x):
1762
            # Note [Disabling Functionalize TLS Above Python Functionalization]
1763
            # This UX is pretty annoying (although python functionalization's main customer is AOTAutograd,
1764
            # and is not really advertised as a user API).
1765
            # We need to explicitly disable functionalization when using python FunctionalTensor and FunctionalTensorMode.
1766
            # Why? FunctionalTensor is a wrapper tensor that holds an inner FunctionalTensorWrapper.
1767
            # Since the inner tensor has `DispatchKey.Functionalize` in its keyset, then by default,
1768
            # our FunctionalTensor will inherit the same keyset.
1769
            # We don't have an easy way of directly mutating a tensor's keyset from python,
1770
            # so globally disabling functionalization here is easier.
1771
            maybe_disable = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
1772
            with maybe_disable, FunctionalTensorMode():
1773
                x_wrapped = FunctionalTensor.to_functional(x)
1774
                out_wrapped = f(x_wrapped)
1775
            out_unwrapped = out_wrapped.elem
1776
            torch._sync(out_unwrapped)
1777
            return torch._from_functional_tensor(out_unwrapped)
1778

1779
        # Make a non-leaf
1780
        x = torch.randn(2, requires_grad=True) + 1
1781
        fx_g = make_fx(f_functionalized)(x)
1782
        # NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a
1783
        # DCE pass that will remove nodes like this later on.
1784
        self.assertExpectedInline(fx_g.code.strip(), """\
1785
def forward(self, x_1):
1786
    view = torch.ops.aten.view.default(x_1, [-1])
1787
    mul = torch.ops.aten.mul.Tensor(x_1, 2);  x_1 = None
1788
    view_1 = torch.ops.aten.view.default(mul, [-1])
1789
    view_2 = torch.ops.aten.view.default(mul, [-1]);  mul = None
1790
    add = torch.ops.aten.add.Tensor(view_2, 1);  view_2 = None
1791
    return add""")
1792

1793
    def test_python_functionalization_zero_tensor(self):
1794
        def f(x):
1795
            y = torch.ops.aten._efficientzerotensor([4])
1796
            out = x + y
1797
            out.mul_(2)
1798
            return out
1799
        x = torch.randn(4)
1800
        out_ref = f(x)
1801
        out_test = dispatch_functionalize(f)(x)
1802
        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
1803
        self.assertEqual(out_ref, out_test)
1804
        self.assertEqual(out_ref, out_test_cpp)
1805
        fx_g = make_fx(dispatch_functionalize(f))(x)
1806
        fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
1807
        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
1808

1809
    def test_python_functionalization_is_conj(self):
1810
        def f(x):
1811
            out = x.conj()
1812
            return out, out.is_conj()
1813

1814
        x = torch.randn(4, dtype=torch.complex64)
1815
        out_ref = f(x)
1816
        out_test = dispatch_functionalize(f)(x)
1817
        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
1818
        self.assertEqual(out_ref[0], out_test[0])
1819
        self.assertEqual(out_ref[1], out_test[1])
1820
        self.assertEqual(out_ref[0], out_test_cpp[0])
1821
        self.assertEqual(out_ref[1], out_test_cpp[1])
1822

1823
    def test_python_functionalization_is_neg(self):
1824
        def f(x):
1825
            out = x.neg()
1826
            return out, out.is_neg()
1827

1828
        x = torch.randn(4, dtype=torch.complex64)
1829
        out_ref = f(x)
1830
        out_test = dispatch_functionalize(f)(x)
1831
        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
1832
        self.assertEqual(out_ref[0], out_test[0])
1833
        self.assertEqual(out_ref[1], out_test[1])
1834
        self.assertEqual(out_ref[0], out_test_cpp[0])
1835
        self.assertEqual(out_ref[1], out_test_cpp[1])
1836

1837

1838
    def test_python_functionalization_conj(self):
1839
        def f(x):
1840
            y = x.clone().conj()
1841
            y.mul_(2)
1842
            return torch.view_as_real(y.resolve_conj())
1843

1844
        x = torch.randn(4, dtype=torch.complex64)
1845
        out_ref = f(x)
1846
        out_test = dispatch_functionalize(f)(x)
1847
        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
1848
        self.assertEqual(out_ref, out_test)
1849
        self.assertEqual(out_test, out_test_cpp)
1850
        fx_g = make_fx(dispatch_functionalize(f))(x)
1851
        fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
1852
        self.assertExpectedInline(fx_g.code.strip(), """\
1853
def forward(self, arg0_1):
1854
    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1855
    _conj = torch.ops.aten._conj.default(clone);  clone = None
1856
    clone_1 = torch.ops.aten.clone.default(_conj)
1857
    mul = torch.ops.aten.mul.Tensor(clone_1, 2);  clone_1 = None
1858
    clone_2 = torch.ops.aten.clone.default(_conj);  _conj = None
1859
    copy = torch.ops.aten.copy.default(clone_2, mul);  clone_2 = mul = None
1860
    _conj_1 = torch.ops.aten._conj.default(copy);  copy = None
1861
    _conj_2 = torch.ops.aten._conj.default(_conj_1);  _conj_1 = None
1862
    clone_3 = torch.ops.aten.clone.default(_conj_2);  _conj_2 = None
1863
    view_as_real = torch.ops.aten.view_as_real.default(clone_3);  clone_3 = None
1864
    return view_as_real""")
1865
        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
1866

1867
    def test_python_functionalization_neg(self):
1868
        def f(x):
1869
            y = x._neg_view()
1870
            z = y.resolve_neg()
1871
            return z + 1
1872

1873
        x = torch.randn(4)
1874
        out_ref = f(x)
1875
        out_test = dispatch_functionalize(f)(x)
1876
        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
1877
        self.assertEqual(out_ref, out_test)
1878
        self.assertEqual(out_ref, out_test_cpp)
1879
        fx_g = make_fx(dispatch_functionalize(f))(x)
1880
        fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
1881
        self.assertExpectedInline(fx_g.code.strip(), """\
1882
def forward(self, arg0_1):
1883
    _neg_view = torch.ops.aten._neg_view.default(arg0_1);  arg0_1 = None
1884
    clone = torch.ops.aten.clone.default(_neg_view);  _neg_view = None
1885
    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
1886
    return add""")
1887
        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
1888

1889
    def test_python_functionalization_lift_fresh_storage(self):
1890
        unlifted = torch.tensor([0.0])
1891

1892
        maybe_disable = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
1893
        with maybe_disable, FunctionalTensorMode():
1894
            lifted = torch.ops.aten.lift_fresh.default(unlifted)
1895

1896
        self.assertNotEqual(unlifted.untyped_storage(), lifted.untyped_storage())
1897

1898
    def test_python_functionalization_lift_fresh(self):
1899
        def f(x):
1900
            tmp = torch.tensor([0.0])
1901
            return tmp + x
1902

1903
        x = torch.randn(4)
1904
        out_ref = f(x)
1905
        out_test = dispatch_functionalize(f)(x)
1906
        out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
1907
        self.assertEqual(out_ref, out_test)
1908
        self.assertEqual(out_ref, out_test_cpp)
1909
        fx_g = make_fx(dispatch_functionalize(f))(x)
1910
        fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
1911
        self.assertExpectedInline(fx_g.code.strip(), """\
1912
def forward(self, arg0_1):
1913
    _tensor_constant0 = self._tensor_constant0
1914
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
1915
    add = torch.ops.aten.add.Tensor(lift_fresh_copy, arg0_1);  lift_fresh_copy = arg0_1 = None
1916
    return add""")
1917
        self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
1918

1919
@xfail_inherited_tests([
1920
    "test_as_strided",
1921
    "test_copy_",
1922
    "test_diagonal",
1923
    "test_diagonal_mutated_input",
1924
    "test_everything",
1925
    "test_fill_",
1926
    "test_slice",
1927
    "test_split",
1928
    "test_split_with_sizes",
1929
    "test_unbind",
1930
    "test_view_clone_view_inplace",
1931
    "test_view_inplace",
1932
])
1933
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well")
1934
class TestCrossRefFunctionalization(TestFunctionalization):
1935
    crossref = True
1936

1937
if __name__ == '__main__':
1938
    run_tests()
1939

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.