1
# Owner(s): ["module: codegen"]
4
from contextlib import nullcontext
5
from torch.testing._internal.common_utils import (
6
TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, IS_WINDOWS,
9
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode, dispatch_functionalize
10
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
11
from torch.utils._pytree import tree_map_only
12
from torch.fx.experimental.proxy_tensor import make_fx
13
from torch.fx.passes.reinplace import reinplace
14
from torch._dispatch.python import enable_crossref_functionalize, enable_python_dispatcher
15
from torch.multiprocessing.reductions import StorageWeakRef
16
from torch.utils import _pytree as pytree
21
x_storage = StorageWeakRef(x.storage())
22
y_storage = StorageWeakRef(y.storage())
23
return x_storage == y_storage
25
# We can unify testing and use functionalize() here instead
26
# if/when functorch moves into core.
27
# This is basically a crappy version of `functionalize()`.
28
def _functionalize(f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False):
29
def to_fun(t: torch.Tensor):
30
func_t = torch._to_functional_tensor(t)
31
func_t.requires_grad = t.requires_grad
37
ctx = enable_crossref_functionalize()
39
inputs_functional = tree_map_only(torch.Tensor, to_fun, inputs)
40
torch._enable_functionalization(reapply_views=reapply_views)
42
out = f(*inputs_functional)
44
torch._disable_functionalization()
45
flat_inputs = pytree.tree_leaves(inputs)
46
flat_inputs_functional = pytree.tree_leaves(inputs_functional)
48
for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
49
torch._sync(input_functional)
50
inpt_new = torch._from_functional_tensor(input_functional)
51
if inpt_new is not inpt and not skip_input_mutations:
52
# Existing deficiency in functionalize():
53
# we don't correctly mutate input metadata (yet?)
54
if inpt_new.shape == inpt.shape:
56
tree_map_only(torch.Tensor, torch._sync, out)
57
out_unwrapped = tree_map_only(torch.Tensor, torch._from_functional_tensor, out)
62
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
63
class TestFunctionalization(TestCase):
67
def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False):
68
inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts)
69
traced_f = make_fx(_functionalize(func, reapply_views=reapply_views, crossref=self.crossref))(*inpts)
71
traced_f = reinplace(traced_f, *inpts_clone)
74
def assert_functionalization(self, func, *inpts, reapply_views=False, mutated_input_metadata=False):
75
clones1 = tree_map_only(torch.Tensor, torch.clone, inpts)
76
clones2 = tree_map_only(torch.Tensor, torch.clone, inpts)
77
clones3 = tree_map_only(torch.Tensor, torch.clone, inpts)
79
# Compare outputs (and mutated inputs), with and without functionalization.
80
out_ref = func(*inpts)
81
out_functional = _functionalize(func, reapply_views=reapply_views, crossref=self.crossref)(*clones1)
83
# The reinplacing pass is only valid to run with reapply_views=True.
84
functional_func = make_fx(_functionalize(func, reapply_views=True, crossref=self.crossref))(*clones2)
85
reinplace_func = reinplace(functional_func, *clones2)
87
# NOTE: for now, need to pass in fresh inputs here, because make_fx
88
# will directly mutate the inputs that you trace with.
89
# Once this is fixed we can clean this up.
90
out_reinplace = reinplace_func(*clones3)
92
# functionalize() deficiency: input metadata mutations aren't propagated properly,
93
# so we just need to skip checks here for the tests that exercise that.
94
if not mutated_input_metadata:
95
flat_inpts = pytree.tree_leaves(inpts)
96
flat_clones1 = pytree.tree_leaves(clones1)
97
flat_clones3 = pytree.tree_leaves(clones3)
98
for inpt, input_clone, input_clone3 in zip(flat_inpts, flat_clones1, flat_clones3):
99
self.assertEqual(inpt, input_clone) # input mutations should still occur
100
self.assertEqual(inpt, input_clone3)
102
# Handle tests with multi-tensor outputs
103
if isinstance(out_ref, tuple):
104
out_refs, out_functionals, out_reinplaces = list(out_ref), list(out_functional), list(out_reinplace)
106
out_refs, out_functionals, out_reinplaces = [out_ref], [out_functional], [out_reinplace]
108
for out_ref_, out_functional_, out_reinplace_ in zip(out_refs, out_functionals, out_reinplaces):
109
self.assertEqual(out_ref_, out_functional_)
110
self.assertEqual(out_ref_, out_reinplace_)
112
def test_save_for_backwards_segfault(self):
113
inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
116
def test_multiple_views_of_same_base(self):
121
# y should have been updated.
123
# z should have been updated too.
126
self.assert_functionalization(f, torch.ones(4))
128
def test_freeze(self):
132
torch._freeze_functional_tensor(y)
134
self.assertRaises(RuntimeError, lambda: y.add_(1))
135
self.assertRaises(RuntimeError, lambda: z.add_(1))
138
_functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(3, 3))
140
def test_copy_stride_mismatch(self):
142
y = torch.empty_strided((2, 2), (5, 1))
146
r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
147
self.assertEqual(r.stride(), (5, 1))
155
# We should probaby get the crossref test to work,
156
# but fixing it for Storage() objects is annoying.
157
r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
158
self.assertEqual(str(r.device), 'cpu')
160
def test_advanced_indexing(self):
162
x = torch.zeros(3, 3)
163
idx = torch.tensor([0])
164
val = torch.ones(3, 1)
168
self.assert_functionalization(f)
170
def test_view_clone_view_inplace(self):
172
shape = [1, 1024, 128, 128]
173
input_reshaped = input.view(shape)
174
out = input_reshaped.clone()
175
r = out.view(input.shape)
181
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
182
import torch.fx.traceback as fx_traceback
183
setup_stacktrace_preservation_hooks([loss.grad_fn])
184
with fx_traceback.preserve_node_meta():
188
with torch.autograd.detect_anomaly(check_nan=False):
189
logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True))
190
self.assertExpectedInline(logs, """\
194
def forward(self, arg0_1):
195
view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 1024, 128, 128]); arg0_1 = None
196
clone = torch.ops.aten.clone.default(view_copy); view_copy = None
197
view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128])
198
relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None
199
view_copy_2 = torch.ops.aten.view_copy.default(relu, [1, 1024, 128, 128]); relu = None
200
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [16, 64, 128, 128]); view_copy_2 = None
201
view_copy_4 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None
202
sum_1 = torch.ops.aten.sum.default(view_copy_3)
203
ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format); sum_1 = None
204
expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None
205
view_copy_5 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None
206
new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_5, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
207
copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_5); new_empty_strided = view_copy_5 = None
208
view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
209
view_copy_7 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128])
210
clone_1 = torch.ops.aten.clone.default(view_copy_7, memory_format = torch.contiguous_format)
211
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, view_copy_3, 0); clone_1 = view_copy_3 = None
212
copy_1 = torch.ops.aten.copy.default(view_copy_7, threshold_backward); view_copy_7 = threshold_backward = None
213
view_copy_8 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None
214
view_copy_9 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128])
215
view_copy_10 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None
216
detach_copy = torch.ops.aten.detach_copy.default(view_copy_10); view_copy_10 = None
217
view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_8 = None
218
detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11); view_copy_11 = None
222
def test_simple(self):
224
# simple test: 1 view op, 1 inplace op
225
tmp = torch.ones(4, 2)
230
self.assert_functionalization(f, torch.ones(4, 2))
231
logs = self.get_logs(f, torch.ones(4, 2))
232
self.assertExpectedInline(logs, """\
236
def forward(self, arg0_1):
237
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
238
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
239
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
240
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
241
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
242
mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1)
243
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
247
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
248
self.assertExpectedInline(reinplaced_logs, """\
252
def forward(self, arg0_1):
253
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
254
view = torch.ops.aten.view.default(arg0_1, [4, 2])
255
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
256
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
257
view_2 = torch.ops.aten.view.default(view_1, [4, 2])
258
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
259
copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None
263
def test_simple_out(self):
265
tmp = torch.ones(4, 2)
267
# the out= tensor will get resized, since it has size=0 to start.
269
torch.add(y, tmp, out=z)
272
self.assert_functionalization(f, torch.ones(4, 2))
273
logs = self.get_logs(f, torch.ones(4, 2))
274
self.assertExpectedInline(logs, """\
278
def forward(self, arg0_1):
279
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
280
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None
281
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
282
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
283
mul = torch.ops.aten.mul.Tensor(add, add); add = None
287
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
288
self.assertExpectedInline(reinplaced_logs, """\
292
def forward(self, arg0_1):
293
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
294
view = torch.ops.aten.view.default(arg0_1, [4, 2]); arg0_1 = None
295
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
296
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
297
mul = torch.ops.aten.mul.Tensor(add, add); add = None
301
def test_multi_out(self):
303
# aminmax.out returns a tuple of tensors.
304
# functionalization should properly handle the tuple.
305
out_min = torch.empty(4)
306
out_max = torch.empty(4)
307
torch.aminmax(x, dim=0, out=(out_max, out_min))
309
self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
310
logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
311
self.assertExpectedInline(logs, """\
315
def forward(self, arg0_1):
316
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
317
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
318
aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None
320
getitem_1 = aminmax[1]; aminmax = None
324
reinplaced_logs = self.get_logs(f, torch.arange(8, dtype=torch.float32), reapply_views=True, run_reinplace=True)
325
self.assertExpectedInline(reinplaced_logs, """\
329
def forward(self, arg0_1):
330
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
331
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
332
aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None
334
getitem_1 = aminmax[1]; aminmax = None
338
def test_tensor_ctr(self):
340
y = torch.tensor((1, 2, 3))
345
inpt = torch.arange(3, dtype=torch.float32)
346
self.assert_functionalization(f, inpt)
348
logs = self.get_logs(f, inpt)
349
self.assertExpectedInline(logs, """\
353
def forward(self, arg0_1):
354
_tensor_constant0 = self._tensor_constant0
355
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
356
view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
357
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
358
view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None
359
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1])
363
reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
364
self.assertExpectedInline(reinplaced_logs, """\
368
def forward(self, arg0_1):
369
_tensor_constant0 = self._tensor_constant0
370
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
371
view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
372
add = torch.ops.aten.add_.Tensor(view, 1)
373
view_1 = torch.ops.aten.view.default(view, [3]); view = None
374
view_2 = torch.ops.aten.view.default(view_1, [-1])
378
def test_advanced_indexing_correct_strides(self):
380
# This test requires that *_scatter ops are able to return
381
# non-contiguous tensors.
383
c = torch.ones_like(b, dtype=torch.bool)
384
d = b.masked_fill_(c, 0)
386
self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True)
388
def test_tensor_list_mixed_functional_nonfunctional(self):
389
nonfunctional_tensor = torch.ones(2, dtype=torch.long)
392
# simple test: 1 view op, 1 inplace op
393
functional_tensor = torch.ones(2, dtype=torch.long)
394
out = x[functional_tensor, nonfunctional_tensor]
396
out = f(torch.ones(2, 2))
397
out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
398
self.assertEqual(out, out_functional)
400
def test_inplace_on_non_view(self):
402
# test for the case where we functionalize an inplace op on the other tensor - not a view.
403
# This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased.
404
tmp = torch.ones(4, 2)
408
self.assert_functionalization(f, torch.ones(4, 2))
409
logs = self.get_logs(f, torch.ones(4, 2))
410
self.assertExpectedInline(logs, """\
414
def forward(self, arg0_1):
415
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
416
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
417
add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None
418
copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None
419
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
423
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
424
self.assertExpectedInline(reinplaced_logs, """\
428
def forward(self, arg0_1):
429
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
430
view = torch.ops.aten.view.default(arg0_1, [4, 2])
431
add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None
432
copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None
433
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
437
# Some ops that are mutable are neither inplace nor out= ops.
438
# They also need special handling.
439
def test_mutable_op_not_inplace_or_other(self):
441
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
443
logs = self.get_logs(f, torch.ones(1))
444
self.assertExpectedInline(logs, """\
448
def forward(self, arg0_1):
449
_fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0)
450
getitem = _fused_moving_avg_obs_fq_helper_functional[0]
451
getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
452
getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]
453
getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]
454
getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]
455
getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None
456
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = None
457
return (getitem, getitem_1)
460
def test_as_strided(self):
462
y = x.as_strided((2,), (2,), 1)
465
self.assert_functionalization(f, torch.ones(9))
466
logs = self.get_logs(f, torch.ones(9))
467
self.assertExpectedInline(logs, """\
471
def forward(self, arg0_1):
472
as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1)
473
add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None
474
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None
475
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1)
476
copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None
477
return as_strided_scatter
480
# NB: even with reapply_views=True, we expect to see scatter op
481
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False)
482
self.assertExpectedInline(reinplaced_logs, """\
486
def forward(self, arg0_1):
487
as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1)
488
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
489
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None
490
as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1)
491
copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None
492
return as_strided_scatter
495
def test_tensor_list_composite(self):
497
# Test an op with TensorList input
498
y = torch.block_diag(x, x)
500
self.assert_functionalization(f, torch.ones(2, 2))
501
logs = self.get_logs(f, torch.ones(2, 2))
502
self.assertExpectedInline(logs, """\
506
def forward(self, arg0_1):
507
block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]); arg0_1 = None
514
torch.cat((x,), out=out)
516
self.assert_functionalization(f, torch.ones(2, 2))
517
logs = self.get_logs(f, torch.ones(2, 2))
518
self.assertExpectedInline(logs, """\
522
def forward(self, arg0_1):
523
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
524
cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None
528
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
529
self.assertExpectedInline(reinplaced_logs, """\
533
def forward(self, arg0_1):
534
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
535
cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None
540
def test_diagonal(self):
542
# test: view ops that take a subset of the original tensor (select/diagonal)
544
y = x.clone().diagonal()
548
self.assert_functionalization(f, torch.ones(2, 2))
549
logs = self.get_logs(f, torch.ones(2, 2))
550
self.assertExpectedInline(logs, """\
554
def forward(self, arg0_1):
555
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
556
clone = torch.ops.aten.clone.default(arg0_1)
557
diagonal_copy = torch.ops.aten.diagonal_copy.default(clone)
558
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
559
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(clone, add); clone = add = None
560
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_scatter = None
561
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
565
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
566
self.assertExpectedInline(reinplaced_logs, """\
570
def forward(self, arg0_1):
571
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
572
clone = torch.ops.aten.clone.default(arg0_1)
573
diagonal = torch.ops.aten.diagonal.default(clone)
574
add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = None
575
diagonal_1 = torch.ops.aten.diagonal.default(clone); clone = None
576
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
580
def test_diagonal_mutated_input(self):
582
# simple test: there are pending updates afterwards, which the test syncs manually
588
self.assert_functionalization(f, x)
589
logs = self.get_logs(f, torch.ones(2, 2))
590
self.assertExpectedInline(logs, """\
594
def forward(self, arg0_1):
595
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
596
diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1)
597
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
598
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None
599
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
600
copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None
601
return diagonal_scatter
604
# NB: even with reapply_views=True, we expect to see scatter op
605
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False)
606
self.assertExpectedInline(reinplaced_logs, """\
610
def forward(self, arg0_1):
611
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
612
diagonal = torch.ops.aten.diagonal.default(arg0_1)
613
add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None
614
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None
615
diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter)
616
copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None
617
return diagonal_scatter
620
def test_channels_last_contiguous(self):
622
return x.contiguous(memory_format=torch.channels_last)
627
x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
628
self.assert_functionalization(f, x)
629
logs = self.get_logs(f, x).strip()
630
# There should be no clone in the graph
631
self.assertExpectedInline(logs, """\
632
def forward(self, arg0_1):
635
def test_split(self):
637
# test: view ops that return multiple tensors (split)
644
self.assert_functionalization(f, torch.ones(4, 2))
645
logs = self.get_logs(f, torch.ones(4, 2))
646
self.assertExpectedInline(logs, """\
650
def forward(self, arg0_1):
651
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
652
split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
653
getitem = split_copy[0]
654
getitem_1 = split_copy[1]; split_copy = None
655
diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
656
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
657
split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2)
658
getitem_2 = split_copy_1[0]
659
getitem_3 = split_copy_1[1]; split_copy_1 = None
660
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None
661
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None
662
split_copy_2 = torch.ops.aten.split_copy.Tensor(slice_scatter, 2)
663
getitem_4 = split_copy_2[0]
664
getitem_5 = split_copy_2[1]; split_copy_2 = None
665
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_5); getitem_5 = None
666
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
667
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
668
return diagonal_copy_1
671
# NB: even with reapply_views=True, we expect to see scatter op
672
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
673
self.assertExpectedInline(reinplaced_logs, """\
677
def forward(self, arg0_1):
678
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
679
split = torch.ops.aten.split.Tensor(arg0_1, 2)
681
getitem_1 = split[1]; split = None
682
diagonal = torch.ops.aten.diagonal.default(getitem_1); getitem_1 = None
683
add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None
684
split_1 = torch.ops.aten.split.Tensor(arg0_1, 2)
685
getitem_2 = split_1[0]
686
getitem_3 = split_1[1]; split_1 = None
687
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None
688
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None
689
split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2)
690
getitem_4 = split_2[0]
691
getitem_5 = split_2[1]; split_2 = None
692
diagonal_1 = torch.ops.aten.diagonal.default(getitem_5); getitem_5 = None
693
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
694
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
698
def test_split_with_sizes(self):
700
# test: view ops that return multiple tensors (split_with_sizes)
702
y1, y2 = x.split_with_sizes([2, 2])
707
self.assert_functionalization(f, torch.ones(4, 2))
708
logs = self.get_logs(f, torch.ones(4, 2))
709
self.assertExpectedInline(logs, """\
713
def forward(self, arg0_1):
714
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
715
split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
716
getitem = split_with_sizes_copy[0]
717
getitem_1 = split_with_sizes_copy[1]; split_with_sizes_copy = None
718
diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem); getitem = None
719
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
720
split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
721
getitem_2 = split_with_sizes_copy_1[0]
722
getitem_3 = split_with_sizes_copy_1[1]; split_with_sizes_copy_1 = None
723
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None
724
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None
725
split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2])
726
getitem_4 = split_with_sizes_copy_2[0]
727
getitem_5 = split_with_sizes_copy_2[1]; split_with_sizes_copy_2 = None
728
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4); getitem_4 = None
729
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
730
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
731
return diagonal_copy_1
734
# NB: even with reapply_views=True, we expect to see scatter op
735
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
736
self.assertExpectedInline(reinplaced_logs, """\
740
def forward(self, arg0_1):
741
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
742
split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
743
getitem = split_with_sizes[0]
744
getitem_1 = split_with_sizes[1]; split_with_sizes = None
745
diagonal = torch.ops.aten.diagonal.default(getitem); getitem = None
746
add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None
747
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
748
getitem_2 = split_with_sizes_1[0]
749
getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = None
750
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None
751
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None
752
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2])
753
getitem_4 = split_with_sizes_2[0]
754
getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = None
755
diagonal_1 = torch.ops.aten.diagonal.default(getitem_4); getitem_4 = None
756
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
757
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
761
def test_slice(self):
768
self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
769
logs = self.get_logs(f, torch.ones(4, 2))
770
self.assertExpectedInline(logs, """\
774
def forward(self, arg0_1):
775
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
776
transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
777
slice_copy = torch.ops.aten.slice_copy.Tensor(transpose_copy, 0, 0, 2); transpose_copy = None
778
add = torch.ops.aten.add.Tensor(slice_copy, ones); slice_copy = ones = None
779
transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None
780
slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2); transpose_copy_1 = add = None
781
transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0); slice_scatter = None
782
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
783
slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2); transpose_copy_3 = None
784
transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
785
return transpose_copy_4
788
# NB: even with reapply_views=True, we expect to see scatter op
789
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
790
self.assertExpectedInline(reinplaced_logs, """\
794
def forward(self, arg0_1):
795
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
796
transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
797
slice_1 = torch.ops.aten.slice.Tensor(transpose, 0, 0, 2); transpose = None
798
add = torch.ops.aten.add.Tensor(slice_1, ones); slice_1 = ones = None
799
transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None
800
slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2); transpose_1 = add = None
801
transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0); slice_scatter = None
802
transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
803
slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2); transpose_3 = None
804
transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None
808
def test_view_inplace(self):
810
# test: view + inplace op (transpose_)
816
self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
817
logs = self.get_logs(f, torch.ones(4, 2))
818
self.assertExpectedInline(logs, """\
822
def forward(self, arg0_1):
823
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
824
transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
825
select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0); transpose_copy = None
826
add = torch.ops.aten.add.Tensor(select_copy, ones); select_copy = ones = None
827
transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None
828
select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None
829
transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None
830
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
831
select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = None
832
transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
833
return transpose_copy_4
836
# NB: even with reapply_views=True, we expect to see scatter op
837
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
838
self.assertExpectedInline(reinplaced_logs, """\
842
def forward(self, arg0_1):
843
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
844
transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
845
select = torch.ops.aten.select.int(transpose, 0, 0); transpose = None
846
add = torch.ops.aten.add.Tensor(select, ones); select = ones = None
847
transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None
848
select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None
849
transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None
850
transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
851
select_1 = torch.ops.aten.select.int(transpose_3, 0, 0); transpose_3 = None
852
transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None
856
def test_unbind(self):
858
# test: view + inplace op (transpose_)
864
self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
865
logs = self.get_logs(f, torch.ones(4, 2))
866
self.assertExpectedInline(logs, """\
870
def forward(self, arg0_1):
871
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
872
transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
873
unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy); transpose_copy = None
874
getitem = unbind_copy[0]
875
getitem_1 = unbind_copy[1]; unbind_copy = None
876
add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
877
transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None
878
select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None
879
transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None
880
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
881
unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3); transpose_copy_3 = None
882
getitem_2 = unbind_copy_1[0]
883
getitem_3 = unbind_copy_1[1]; unbind_copy_1 = None
884
transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
885
return transpose_copy_4
888
# NB: even with reapply_views=True, we expect to see scatter op
889
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
890
self.assertExpectedInline(reinplaced_logs, """\
894
def forward(self, arg0_1):
895
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
896
transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
897
unbind = torch.ops.aten.unbind.int(transpose); transpose = None
899
getitem_1 = unbind[1]; unbind = None
900
add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
901
transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None
902
select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None
903
transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None
904
transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
905
unbind_1 = torch.ops.aten.unbind.int(transpose_3); transpose_3 = None
906
getitem_2 = unbind_1[0]
907
getitem_3 = unbind_1[1]; unbind_1 = None
908
transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None
912
def test_optional_tensor_list(self):
914
# test: an operator that takes in a List[Optional[Tensor]] argument
917
indices = torch.arange(4)
918
values = torch.arange(4, dtype=y.dtype)
919
y.index_put_((indices,), values, accumulate=False)
921
self.assert_functionalization(f, torch.ones(4, 2))
922
logs = self.get_logs(f, torch.ones(4, 2))
923
self.assertExpectedInline(logs, """\
927
def forward(self, arg0_1):
928
view_copy = torch.ops.aten.view_copy.default(arg0_1, [8])
929
arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
930
arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
931
index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None
932
view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]); index_put = None
933
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8])
934
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
938
def test_scalars(self):
940
# test: the pass can handle scalar inputs properly
941
tmp = torch.ones(4, 2)
947
self.assert_functionalization(f, torch.ones(4, 2))
948
logs = self.get_logs(f, torch.ones(4, 2))
949
self.assertExpectedInline(logs, """\
953
def forward(self, arg0_1):
954
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
955
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2])
956
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
957
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
958
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2])
959
mul = torch.ops.aten.mul.Tensor(view_copy_2, 2); view_copy_2 = None
960
div = torch.ops.aten.div.Tensor(mul, 1); mul = None
961
copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None
965
@skipIfTorchDynamo("Test does not work with TorchDynamo")
966
def test_metadata_change(self):
968
# ops like ge_() are allowed to change the dtype of the input.
969
# functionalization should pick up on that.
973
self.assert_functionalization(f, torch.ones(4, 2))
974
logs = self.get_logs(f, torch.ones(4, 2))
975
self.assertExpectedInline(logs, """\
979
def forward(self, arg0_1):
980
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
981
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
982
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
986
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
987
self.assertExpectedInline(reinplaced_logs, """\
991
def forward(self, arg0_1):
992
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
993
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
994
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
998
@skipIfTorchDynamo("Test does not work with TorchDynamo")
999
def test_metadata_change_out_op(self):
1001
out_1 = torch.ones(1)
1002
return torch.add(t, y, out=out_1)
1004
inpt1, inpt2 = torch.tensor([1]), torch.tensor([1])
1005
inpt1_func, inpt2_func = torch._to_functional_tensor(inpt1), torch._to_functional_tensor(inpt2)
1007
out_ref = f(inpt1, inpt2)
1008
torch._enable_functionalization(reapply_views=True)
1010
out_functional = f(inpt1_func, inpt2_func)
1012
torch._disable_functionalization()
1013
self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
1016
def test_only_one_view(self):
1018
# This tests that we don't have any unnecessary views in the trace.
1019
# If the input wasn't mutated, we don't need to regenerate it,
1020
# so there should be a total of 1 op in the output trace.
1022
logs = self.get_logs(f, torch.ones(4, 2))
1023
self.assertExpectedInline(logs, """\
1027
def forward(self, arg0_1):
1028
view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None
1032
def test_everything(self):
1035
tmp = torch.ones(2, 2)
1038
z0 = y.reshape(2, 4)
1039
z1 = z0.transpose(1, 0)
1042
z2, z3 = z1.split(2)
1044
z4 = z0[0] + z2.reshape(4)
1046
self.assert_functionalization(f, torch.ones(4, 2))
1047
logs = self.get_logs(f, torch.ones(4, 2))
1048
self.assertExpectedInline(logs, """\
1052
def forward(self, arg0_1):
1053
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
1054
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
1055
view_copy = torch.ops.aten.view_copy.default(add, [8])
1056
view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None
1057
transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0)
1058
unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None
1059
squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None
1060
split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None
1061
getitem = split_copy[0]
1062
getitem_1 = split_copy[1]; split_copy = None
1063
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
1064
view_copy_2 = torch.ops.aten.view_copy.default(add, [8]); add = None
1065
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [2, 4]); view_copy_2 = None
1066
transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_3, 1, 0); view_copy_3 = None
1067
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
1068
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
1069
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = add_1 = None
1070
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
1071
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
1072
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
1073
view_copy_4 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None
1074
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 2]); view_copy_4 = None
1075
view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [8])
1076
view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [2, 4]); view_copy_6 = None
1077
transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_7, 1, 0); view_copy_7 = None
1078
unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None
1079
squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None
1080
split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None
1081
getitem_2 = split_copy_1[0]
1082
getitem_3 = split_copy_1[1]; split_copy_1 = None
1083
select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None
1084
view_copy_8 = torch.ops.aten.view_copy.default(getitem_2, [4])
1085
view_copy_9 = torch.ops.aten.view_copy.default(view_copy_5, [8])
1086
view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None
1087
select_copy_1 = torch.ops.aten.select_copy.int(view_copy_10, 0, 0); view_copy_10 = None
1088
view_copy_11 = torch.ops.aten.view_copy.default(view_copy_5, [8]); view_copy_5 = None
1089
view_copy_12 = torch.ops.aten.view_copy.default(view_copy_11, [2, 4]); view_copy_11 = None
1090
transpose_copy_4 = torch.ops.aten.transpose_copy.int(view_copy_12, 1, 0); view_copy_12 = None
1091
unsqueeze_copy_4 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_4, 0); transpose_copy_4 = None
1092
squeeze_copy_4 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_4); unsqueeze_copy_4 = None
1093
split_copy_2 = torch.ops.aten.split_copy.Tensor(squeeze_copy_4, 2); squeeze_copy_4 = None
1094
getitem_4 = split_copy_2[0]
1095
getitem_5 = split_copy_2[1]; split_copy_2 = None
1096
view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]); getitem_4 = None
1097
add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = None
1101
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
1102
self.assertExpectedInline(reinplaced_logs, """\
1106
def forward(self, arg0_1):
1107
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
1108
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
1109
view = torch.ops.aten.view.default(add, [8])
1110
view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None
1111
transpose = torch.ops.aten.transpose.int(view_1, 1, 0)
1112
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None
1113
squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None
1114
split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None
1116
getitem_1 = split[1]; split = None
1117
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); getitem = ones = None
1118
view_2 = torch.ops.aten.view.default(add, [8]); add = None
1119
view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None
1120
transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None
1121
unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None
1122
squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None
1123
unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None
1124
squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None
1125
transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None
1126
view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None
1127
view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None
1128
view_6 = torch.ops.aten.view.default(view_5, [8])
1129
view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None
1130
transpose_3 = torch.ops.aten.transpose.int(view_7, 1, 0); view_7 = None
1131
unsqueeze_3 = torch.ops.aten.unsqueeze.default(transpose_3, 0); transpose_3 = None
1132
squeeze_3 = torch.ops.aten.squeeze.default(unsqueeze_3); unsqueeze_3 = None
1133
split_1 = torch.ops.aten.split.Tensor(squeeze_3, 2); squeeze_3 = None
1134
getitem_2 = split_1[0]
1135
getitem_3 = split_1[1]; split_1 = None
1136
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
1137
clone = torch.ops.aten.clone.default(getitem_2, memory_format = torch.contiguous_format)
1138
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
1139
view_8 = torch.ops.aten.view.default(view_5, [8]); view_5 = None
1140
view_9 = torch.ops.aten.view.default(view_8, [2, 4]); view_8 = None
1141
select_1 = torch.ops.aten.select.int(view_9, 0, 0); view_9 = None
1142
add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None
1146
def test_reapply_views_simple(self):
1148
tmp = torch.ones(4, 2)
1153
self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
1154
logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
1155
self.assertExpectedInline(logs, """\
1159
def forward(self, arg0_1):
1160
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
1161
view = torch.ops.aten.view.default(arg0_1, [4, 2])
1162
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
1163
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
1164
view_2 = torch.ops.aten.view.default(view_1, [4, 2])
1165
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
1166
copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None
1170
def test_aliases_maintained_after_pass_when_reapplying_views(self):
1172
tmp = torch.ones(4, 2)
1178
input_functional = torch._to_functional_tensor(torch.ones(4, 2))
1179
torch._enable_functionalization(reapply_views=True)
1181
y, z = f(input_functional)
1185
torch._disable_functionalization()
1187
# y and z are aliases inside of the function, and that aliasing relationship should be maintained.
1188
_y = torch._from_functional_tensor(y)
1189
_z = torch._from_functional_tensor(z)
1190
self.assertTrue(are_aliased(_y, _z))
1192
# copy_() gets its own test, because it used to be special cased in functionalization.
1193
# However, now it works pretty similar to other functional ops
1194
def test_copy_(self):
1196
tmp = torch.zeros(2, 2)
1197
tmp_slice = tmp.diagonal()
1198
y = tmp_slice.copy_(x)
1202
# Test 1: copy_() with same dtype and shape
1203
# to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
1204
# self.assert_functionalization(f, torch.ones(2))
1205
logs = self.get_logs(f, torch.ones(2))
1206
self.assertExpectedInline(logs, """\
1210
def forward(self, arg0_1):
1211
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1212
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1213
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
1214
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
1215
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1216
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
1217
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
1218
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
1219
return diagonal_copy_2
1222
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
1223
self.assertExpectedInline(reinplaced_logs, """\
1227
def forward(self, arg0_1):
1228
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1229
diagonal = torch.ops.aten.diagonal.default(zeros)
1230
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
1231
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1232
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
1233
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
1237
# Test 2: copy_() with same dtype, different shape
1238
self.assert_functionalization(f, torch.ones(1))
1239
logs = self.get_logs(f, torch.ones(1))
1240
self.assertExpectedInline(logs, """\
1244
def forward(self, arg0_1):
1245
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1246
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1247
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
1248
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
1249
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1250
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
1251
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
1252
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
1253
return diagonal_copy_2
1256
reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
1257
self.assertExpectedInline(reinplaced_logs, """\
1261
def forward(self, arg0_1):
1262
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1263
diagonal = torch.ops.aten.diagonal.default(zeros)
1264
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
1265
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1266
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
1267
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
1271
# Test 3: copy_() with different dtype, same shape
1272
self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
1273
logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
1274
self.assertExpectedInline(logs, """\
1278
def forward(self, arg0_1):
1279
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1280
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1281
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
1282
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
1283
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1284
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
1285
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
1286
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
1287
return diagonal_copy_2
1290
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
1291
self.assertExpectedInline(reinplaced_logs, """\
1295
def forward(self, arg0_1):
1296
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1297
diagonal = torch.ops.aten.diagonal.default(zeros)
1298
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
1299
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1300
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
1301
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
1305
# Test 4: copy_() with different dtype, different shape
1306
self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
1307
logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
1308
self.assertExpectedInline(logs, """\
1312
def forward(self, arg0_1):
1313
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1314
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros)
1315
copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None
1316
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(zeros, copy); zeros = copy = None
1317
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1318
add = torch.ops.aten.add.Tensor(diagonal_copy_1, arg0_1); diagonal_copy_1 = arg0_1 = None
1319
diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None
1320
diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None
1321
return diagonal_copy_2
1324
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
1325
self.assertExpectedInline(reinplaced_logs, """\
1329
def forward(self, arg0_1):
1330
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
1331
diagonal = torch.ops.aten.diagonal.default(zeros)
1332
copy = torch.ops.aten.copy_.default(diagonal, arg0_1); diagonal = None
1333
diagonal_1 = torch.ops.aten.diagonal.default(zeros)
1334
add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None
1335
diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None
1339
def test_expand_symint(self):
1340
# Once some existing SymInt bugs are ironed out, we should update
1341
# this test to plumb FakeSymbolicTensors through it
1343
return x.expand(x.size(0), x.size(1))
1345
self.assert_functionalization(f, torch.ones(2, 2))
1346
logs = self.get_logs(f, torch.ones(2, 2))
1347
self.assertExpectedInline(logs, """\
1351
def forward(self, arg0_1):
1352
expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]); arg0_1 = None
1356
def test_fill_(self):
1363
self.assert_functionalization(f, torch.ones(2, 2))
1364
logs = self.get_logs(f, torch.ones(2, 2))
1365
self.assertExpectedInline(logs, """\
1369
def forward(self, arg0_1):
1370
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
1371
diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
1372
fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None
1373
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None
1374
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
1375
return diagonal_scatter
1378
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
1379
self.assertExpectedInline(reinplaced_logs, """\
1383
def forward(self, arg0_1):
1384
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
1385
diagonal = torch.ops.aten.diagonal.default(add)
1386
fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None
1387
diagonal_1 = torch.ops.aten.diagonal.default(add)
1391
def test_resize_smaller(self):
1393
# Resizing to a smaller size doesn't affect storage
1402
self.assert_functionalization(f, torch.ones(8, 2))
1403
logs = self.get_logs(f, torch.ones(8, 2))
1404
self.assertExpectedInline(logs, """\
1408
def forward(self, arg0_1):
1409
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
1410
view_copy = torch.ops.aten.view_copy.default(add, [4, 4])
1411
resize = torch.ops.aten.resize.default(view_copy, [3, 3])
1412
as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None
1413
view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]); as_strided_copy = None
1414
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
1415
view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None
1416
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1])
1417
view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None
1418
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None
1419
view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None
1420
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4])
1421
as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None
1422
view_copy_6 = torch.ops.aten.view_copy.default(as_strided_copy_2, [-1]); as_strided_copy_2 = None
1423
view_copy_7 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None
1424
as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None
1425
add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None
1429
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
1430
self.assertExpectedInline(reinplaced_logs, """\
1434
def forward(self, arg0_1):
1435
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
1436
view = torch.ops.aten.view.default(add, [4, 4])
1437
resize = torch.ops.aten.resize.default(view, [3, 3])
1438
as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None
1439
view_1 = torch.ops.aten.view.default(as_strided, [-1]); as_strided = None
1440
add_1 = torch.ops.aten.add_.Tensor(view_1, 1)
1441
view_2 = torch.ops.aten.view.default(add, [4, 4]); add = None
1442
as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1])
1443
view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = None
1444
view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None
1445
view_5 = torch.ops.aten.view.default(view_4, [4, 4])
1446
as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None
1447
view_6 = torch.ops.aten.view.default(as_strided_2, [-1]); as_strided_2 = None
1448
view_7 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None
1449
as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]); view_7 = None
1450
add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1)
1454
def test_resize_same_size_diff_rank(self):
1460
self.assert_functionalization(f, torch.ones(5, 5, 5))
1462
def test_resize_larger_valid(self):
1465
# resizing a tensor to a larger size is only currently allowed
1466
# if the tensor-to-resize is not a view / has no outstanding views.
1467
# See Note [resize_() in functionalization pass]
1470
# Do a mutation to ensure that aliases of the output of resize_()
1471
# propagate mutations correctly.
1472
# I'm using fill_ specifically because I want to guarantee that
1473
# none of the output has uninitialized memory at the end
1474
# (since these tests compare the data output against a reference impl)
1479
self.assert_functionalization(f, torch.ones(8, 2))
1480
logs = self.get_logs(f, torch.ones(8, 2))
1481
self.assertExpectedInline(logs, """\
1485
def forward(self, arg0_1):
1486
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
1487
resize = torch.ops.aten.resize.default(add, [5, 5]); add = None
1488
view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None
1489
fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None
1490
view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None
1491
view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25])
1492
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
1493
return (view_copy_1, add_1)
1496
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
1497
self.assertExpectedInline(reinplaced_logs, """\
1501
def forward(self, arg0_1):
1502
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
1503
resize = torch.ops.aten.resize_.default(add, [5, 5])
1504
view = torch.ops.aten.view.default(add, [25]); add = None
1505
fill = torch.ops.aten.fill_.Scalar(view, 1)
1506
view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None
1507
view_2 = torch.ops.aten.view.default(view_1, [25])
1508
add_1 = torch.ops.aten.add.Tensor(view_1, 1)
1509
return (view_1, add_1)
1512
def test_resize_larger_invalid(self):
1516
# resizing a tensor to a larger size is only currently allowed
1517
# if the tensor-to-resize is not a view / has no outstanding views.
1518
# See Note [resize_() in functionalization pass]
1526
with self.assertRaisesRegex(
1528
r'Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass'):
1529
self.assert_functionalization(f, torch.ones(8, 2))
1531
def test_nested_functions_propagate_updates(self):
1533
# Create a view of x
1536
# The view, y, gets deallocated at the end of this function
1539
# Calling g(x) should mutate x
1541
# We expect x to be synced here, even though the alias created in g() has been deallocated!
1545
self.assert_functionalization(f, torch.ones(2, 2))
1547
def test_mixed_wrappers_valid(self):
1553
x1_not_functional = LoggingTensor(torch.ones(4))
1554
x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))
1556
with capture_logs() as logs:
1557
y = f(x1_not_functional, x2_functional)
1559
# Make sure that functionalization ran the "+" kernel
1560
# with a functional + non-functional tensor, and wrapped the output appropriately.
1561
self.assertExpectedInline('\n'.join(logs), """\
1562
$2: f32[4] = torch._ops.aten.add.Tensor($0, $1)
1563
$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""")
1565
def test_mixed_wrappers_invalid(self):
1566
x1_not_functional = torch.ones(4)
1567
x2_functional = torch._to_functional_tensor(torch.ones(4))
1569
# When dealing with mixed functional + non functional tensors,
1570
# normal_tensor.add_(functional_tensor) is not valid
1571
# because normal_tensor would need to be "promoted" to a functional tensor.
1572
with self.assertRaises(RuntimeError):
1573
x1_not_functional.add_(x2_functional)
1575
def test_index_mutation_on_non_input(self):
1577
tmp = torch.zeros(10)
1580
self.assert_functionalization(f, torch.ones(2))
1581
logs = self.get_logs(f, torch.ones(2))
1582
self.assertExpectedInline(logs, """\
1586
def forward(self, arg0_1):
1587
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
1588
select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
1589
fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None
1590
select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None
1591
select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5)
1592
return select_scatter
1595
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
1596
self.assertExpectedInline(reinplaced_logs, """\
1600
def forward(self, arg0_1):
1601
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
1602
select = torch.ops.aten.select.int(zeros, 0, 5)
1603
fill = torch.ops.aten.fill_.Scalar(select, 1); select = None
1604
select_1 = torch.ops.aten.select.int(zeros, 0, 5)
1609
def test_instance_norm(self):
1612
def f(x, running_mean, running_var):
1613
with enable_python_dispatcher():
1614
return torch.instance_norm(x, None, None, running_mean, running_var,
1615
use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False)
1616
self.assert_functionalization(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size))
1617
# On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used
1618
# whereas on other platforms, the alias_copy's are before the view_copy's.
1619
# e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment.
1621
logs = self.get_logs(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size))
1622
self.assertExpectedInline(logs, """\
1626
def forward(self, arg0_1, arg1_1, arg2_1):
1627
repeat = torch.ops.aten.repeat.default(arg1_1, [20])
1628
repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
1629
view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None
1630
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
1631
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None
1632
getitem = _native_batch_norm_legit_functional[0]
1633
getitem_1 = _native_batch_norm_legit_functional[1]
1634
getitem_2 = _native_batch_norm_legit_functional[2]
1635
getitem_3 = _native_batch_norm_legit_functional[3]
1636
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
1637
alias_copy = torch.ops.aten.alias_copy.default(arg1_1)
1638
view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100])
1639
view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None
1640
mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None
1641
copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None
1642
alias_copy_1 = torch.ops.aten.alias_copy.default(copy); copy = None
1643
alias_copy_2 = torch.ops.aten.alias_copy.default(alias_copy_1)
1644
alias_copy_3 = torch.ops.aten.alias_copy.default(arg2_1)
1645
view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100])
1646
view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None
1647
mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None
1648
copy_1 = torch.ops.aten.copy.default(alias_copy_3, mean_1); alias_copy_3 = mean_1 = None
1649
alias_copy_4 = torch.ops.aten.alias_copy.default(copy_1); copy_1 = None
1650
alias_copy_5 = torch.ops.aten.alias_copy.default(alias_copy_4)
1651
view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None
1652
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = None
1653
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = None
1657
reinplaced_logs = self.get_logs(
1658
f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size),
1659
reapply_views=True, run_reinplace=True
1661
self.assertExpectedInline(reinplaced_logs, """\
1665
def forward(self, arg0_1, arg1_1, arg2_1):
1666
repeat = torch.ops.aten.repeat.default(arg1_1, [20])
1667
repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
1668
view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None
1669
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
1670
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None
1671
getitem = _native_batch_norm_legit_functional[0]
1672
getitem_1 = _native_batch_norm_legit_functional[1]
1673
getitem_2 = _native_batch_norm_legit_functional[2]
1674
getitem_3 = _native_batch_norm_legit_functional[3]
1675
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
1676
alias = torch.ops.aten.alias.default(arg1_1)
1677
view_1 = torch.ops.aten.view.default(getitem_3, [20, 100])
1678
view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None
1679
mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None
1680
copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None
1681
alias_1 = torch.ops.aten.alias.default(copy); copy = None
1682
alias_2 = torch.ops.aten.alias.default(alias_1)
1683
alias_3 = torch.ops.aten.alias.default(arg2_1)
1684
view_3 = torch.ops.aten.view.default(getitem_4, [20, 100])
1685
view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None
1686
mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None
1687
copy_1 = torch.ops.aten.copy.default(alias_3, mean_1); alias_3 = mean_1 = None
1688
alias_4 = torch.ops.aten.alias.default(copy_1); copy_1 = None
1689
alias_5 = torch.ops.aten.alias.default(alias_4)
1690
view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None
1691
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = None
1692
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = None
1697
def test_mutation_overlapping_mem(self):
1700
t1 = torch.add(x, x)
1701
t2 = t1.unfold(1, 3, 2)
1705
with self.assertRaisesRegex(RuntimeError, r'encountered a tensor being mutated that has internal overlap'):
1706
x = torch.ones(1, 5)
1707
out = _functionalize(fn, reapply_views=True, crossref=False)(x)
1710
def test_batch_norm(self):
1711
def f(x, running_mean, running_var):
1712
with enable_python_dispatcher():
1713
return torch.batch_norm(x, None, None, running_mean, running_var, True, 0.1, 1e-5, False)
1715
self.assert_functionalization(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100))
1716
logs = self.get_logs(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100))
1717
self.assertExpectedInline(logs, """\
1721
def forward(self, arg0_1, arg1_1, arg2_1):
1722
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
1723
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None
1724
getitem = _native_batch_norm_legit_functional[0]
1725
getitem_1 = _native_batch_norm_legit_functional[1]
1726
getitem_2 = _native_batch_norm_legit_functional[2]
1727
getitem_3 = _native_batch_norm_legit_functional[3]
1728
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
1729
copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None
1730
copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None
1734
reinplaced_logs = self.get_logs(
1735
f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100), reapply_views=True, run_reinplace=True
1737
self.assertExpectedInline(reinplaced_logs, """\
1741
def forward(self, arg0_1, arg1_1, arg2_1):
1742
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
1743
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None
1744
getitem = _native_batch_norm_legit_functional[0]
1745
getitem_1 = _native_batch_norm_legit_functional[1]
1746
getitem_2 = _native_batch_norm_legit_functional[2]
1747
getitem_3 = _native_batch_norm_legit_functional[3]
1748
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
1749
copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None
1750
copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None
1754
# This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode
1755
def test_python_functionalization(self):
1761
def f_functionalized(x):
1762
# Note [Disabling Functionalize TLS Above Python Functionalization]
1763
# This UX is pretty annoying (although python functionalization's main customer is AOTAutograd,
1764
# and is not really advertised as a user API).
1765
# We need to explicitly disable functionalization when using python FunctionalTensor and FunctionalTensorMode.
1766
# Why? FunctionalTensor is a wrapper tensor that holds an inner FunctionalTensorWrapper.
1767
# Since the inner tensor has `DispatchKey.Functionalize` in its keyset, then by default,
1768
# our FunctionalTensor will inherit the same keyset.
1769
# We don't have an easy way of directly mutating a tensor's keyset from python,
1770
# so globally disabling functionalization here is easier.
1771
maybe_disable = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
1772
with maybe_disable, FunctionalTensorMode():
1773
x_wrapped = FunctionalTensor.to_functional(x)
1774
out_wrapped = f(x_wrapped)
1775
out_unwrapped = out_wrapped.elem
1776
torch._sync(out_unwrapped)
1777
return torch._from_functional_tensor(out_unwrapped)
1780
x = torch.randn(2, requires_grad=True) + 1
1781
fx_g = make_fx(f_functionalized)(x)
1782
# NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a
1783
# DCE pass that will remove nodes like this later on.
1784
self.assertExpectedInline(fx_g.code.strip(), """\
1785
def forward(self, x_1):
1786
view = torch.ops.aten.view.default(x_1, [-1])
1787
mul = torch.ops.aten.mul.Tensor(x_1, 2); x_1 = None
1788
view_1 = torch.ops.aten.view.default(mul, [-1])
1789
view_2 = torch.ops.aten.view.default(mul, [-1]); mul = None
1790
add = torch.ops.aten.add.Tensor(view_2, 1); view_2 = None
1793
def test_python_functionalization_zero_tensor(self):
1795
y = torch.ops.aten._efficientzerotensor([4])
1801
out_test = dispatch_functionalize(f)(x)
1802
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
1803
self.assertEqual(out_ref, out_test)
1804
self.assertEqual(out_ref, out_test_cpp)
1805
fx_g = make_fx(dispatch_functionalize(f))(x)
1806
fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
1807
self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
1809
def test_python_functionalization_is_conj(self):
1812
return out, out.is_conj()
1814
x = torch.randn(4, dtype=torch.complex64)
1816
out_test = dispatch_functionalize(f)(x)
1817
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
1818
self.assertEqual(out_ref[0], out_test[0])
1819
self.assertEqual(out_ref[1], out_test[1])
1820
self.assertEqual(out_ref[0], out_test_cpp[0])
1821
self.assertEqual(out_ref[1], out_test_cpp[1])
1823
def test_python_functionalization_is_neg(self):
1826
return out, out.is_neg()
1828
x = torch.randn(4, dtype=torch.complex64)
1830
out_test = dispatch_functionalize(f)(x)
1831
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False)(x)
1832
self.assertEqual(out_ref[0], out_test[0])
1833
self.assertEqual(out_ref[1], out_test[1])
1834
self.assertEqual(out_ref[0], out_test_cpp[0])
1835
self.assertEqual(out_ref[1], out_test_cpp[1])
1838
def test_python_functionalization_conj(self):
1840
y = x.clone().conj()
1842
return torch.view_as_real(y.resolve_conj())
1844
x = torch.randn(4, dtype=torch.complex64)
1846
out_test = dispatch_functionalize(f)(x)
1847
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
1848
self.assertEqual(out_ref, out_test)
1849
self.assertEqual(out_test, out_test_cpp)
1850
fx_g = make_fx(dispatch_functionalize(f))(x)
1851
fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
1852
self.assertExpectedInline(fx_g.code.strip(), """\
1853
def forward(self, arg0_1):
1854
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
1855
_conj = torch.ops.aten._conj.default(clone); clone = None
1856
clone_1 = torch.ops.aten.clone.default(_conj)
1857
mul = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None
1858
clone_2 = torch.ops.aten.clone.default(_conj); _conj = None
1859
copy = torch.ops.aten.copy.default(clone_2, mul); clone_2 = mul = None
1860
_conj_1 = torch.ops.aten._conj.default(copy); copy = None
1861
_conj_2 = torch.ops.aten._conj.default(_conj_1); _conj_1 = None
1862
clone_3 = torch.ops.aten.clone.default(_conj_2); _conj_2 = None
1863
view_as_real = torch.ops.aten.view_as_real.default(clone_3); clone_3 = None
1864
return view_as_real""")
1865
self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
1867
def test_python_functionalization_neg(self):
1875
out_test = dispatch_functionalize(f)(x)
1876
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
1877
self.assertEqual(out_ref, out_test)
1878
self.assertEqual(out_ref, out_test_cpp)
1879
fx_g = make_fx(dispatch_functionalize(f))(x)
1880
fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
1881
self.assertExpectedInline(fx_g.code.strip(), """\
1882
def forward(self, arg0_1):
1883
_neg_view = torch.ops.aten._neg_view.default(arg0_1); arg0_1 = None
1884
clone = torch.ops.aten.clone.default(_neg_view); _neg_view = None
1885
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
1887
self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
1889
def test_python_functionalization_lift_fresh_storage(self):
1890
unlifted = torch.tensor([0.0])
1892
maybe_disable = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
1893
with maybe_disable, FunctionalTensorMode():
1894
lifted = torch.ops.aten.lift_fresh.default(unlifted)
1896
self.assertNotEqual(unlifted.untyped_storage(), lifted.untyped_storage())
1898
def test_python_functionalization_lift_fresh(self):
1900
tmp = torch.tensor([0.0])
1905
out_test = dispatch_functionalize(f)(x)
1906
out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x)
1907
self.assertEqual(out_ref, out_test)
1908
self.assertEqual(out_ref, out_test_cpp)
1909
fx_g = make_fx(dispatch_functionalize(f))(x)
1910
fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x)
1911
self.assertExpectedInline(fx_g.code.strip(), """\
1912
def forward(self, arg0_1):
1913
_tensor_constant0 = self._tensor_constant0
1914
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
1915
add = torch.ops.aten.add.Tensor(lift_fresh_copy, arg0_1); lift_fresh_copy = arg0_1 = None
1917
self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip())
1919
@xfail_inherited_tests([
1923
"test_diagonal_mutated_input",
1928
"test_split_with_sizes",
1930
"test_view_clone_view_inplace",
1931
"test_view_inplace",
1933
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well")
1934
class TestCrossRefFunctionalization(TestFunctionalization):
1937
if __name__ == '__main__':