3
from torch.testing._internal.common_utils import TestCase, run_tests
8
from collections.abc import Iterable
9
from torch.nn.utils import stateless
10
from torch.testing._internal.common_device_type import instantiate_device_type_tests
11
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
12
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
13
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
14
from torch._decomp import decomposition_table
15
from torch.fx.experimental.symbolic_shapes import (
16
eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
17
guard_int, GuardOnDataDependentSymNode
19
from torch.testing._internal.custom_op_db import custom_op_db
20
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
21
from torch.testing._internal.common_device_type import ops
22
import torch.testing._internal.optests as optests
23
from torch._C import _disabled_torch_function_impl
24
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
25
from torch.utils._pytree import tree_map
34
HAS_CUDA = torch.cuda.is_available()
37
def strip_end(s, suffix):
38
if suffix and s.endswith(suffix):
39
return s[:-len(suffix)]
45
names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
47
gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
51
def process_failures():
53
Takes file containing failures like
55
FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950
57
and processes them into a list of opinfo xfails
59
f = open('pytest_failures')
60
failures = f.readlines()
61
failures = [i.strip() for i in failures]
63
def process_failure_string(s, matcher):
64
out = re.search(matcher, s)
67
SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
68
failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
70
def create_normalized_name(op):
71
if op.variant_test_name == '':
74
s = f"{op.name}.{op.variant_test_name}"
75
return s.replace('.', '_')
77
remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
79
print("symbolic_tensor_failures = {")
80
for failure, reason in failures:
81
print(f" xfail{remap_opinfo[failure]}, # {reason}")
85
USE_TORCHVISION = False
88
USE_TORCHVISION = True
90
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
91
"to install it with commands from pytorch.org, post-fixed with "
92
"`--no-deps` to avoid overwriting the pytorch installation",
96
def _create_new_input(x):
97
if not isinstance(x, torch.Tensor):
99
if x.dtype != torch.float:
102
return torch.rand_like(x, requires_grad=x.requires_grad)
104
return torch.rand_like(x)
107
Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
109
class UnwrapTensor(torch.Tensor):
111
def __new__(cls, tensor: torch.Tensor):
112
r = torch.Tensor._make_wrapper_subclass(
116
device=tensor.device,
117
layout=tensor.layout,
118
requires_grad=tensor.requires_grad,
125
return f"UnwrapTensor({self._tensor})"
127
__torch_function__ = _disabled_torch_function_impl
130
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
133
if isinstance(e, UnwrapTensor):
134
ret = e._tensor.cos()
138
args = tree_map(unwrap, args)
139
kwargs = tree_map(unwrap, kwargs)
140
return func(*args, **kwargs)
142
class TestGenericProxyTensor(TestCase):
145
def _test(self, f, inps):
146
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
147
new_inps = tree_map(_create_new_input, inps)
150
self.assertEqual(r1, r2)
152
def test_pre_dispatch_mode_stack(self):
155
return torch.matmul(a, b)
160
inp = torch.ones(4, 4)
162
from torch._dispatch.python import enable_python_dispatcher
163
with enable_python_dispatcher():
165
fx_g = make_fx(f, pre_dispatch=True)(inp)
166
self.assertExpectedInline(fx_g.code.strip(), """\
167
def forward(self, a_1):
168
ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
169
matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
172
def test_pre_dispatch_linear(self):
174
return torch.nn.functional.linear(a, b, c)
178
fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
181
self.assertEqual(out1, out2)
183
def test_pre_dispatch_no_grad(self):
186
torch.set_grad_enabled(False)
188
torch.set_grad_enabled(True)
190
a1 = torch.randn(4, requires_grad=True)
191
a2 = a1.clone().detach().requires_grad_(True)
192
a_tmp = a1.clone().detach().requires_grad_(True)
193
fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
196
self.assertEqual(out1, out2)
197
out1.sum().backward()
198
out2.sum().backward()
199
self.assertEqual(a1.grad, a2.grad)
201
def test_make_fx_simple(self):
204
self._test(f, (torch.randn(3),))
206
def test_scalar_device(self, device='cpu'):
209
self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
211
def test_isolated_graphmodule(self):
213
return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
215
def is_any_digamma(gm):
216
return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
218
def is_any_sigmoid(gm):
219
return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
225
gm = get_isolated_graphmodule(inner, (x,), {})
226
self.assertTrue(is_any_sum(gm))
227
return x + torch.randn(x.shape)
231
traced = make_fx(f)(torch.randn(3))
232
self.assertFalse(is_any_sum(traced))
236
def inner_with_factory():
237
val = torch.tensor(float(1))
239
return torch.full((10, 10), val).sum()
242
gm = get_isolated_graphmodule(inner_with_factory, (), {})
243
self.assertTrue(is_any_sum(gm))
244
return torch.sigmoid(x)
247
gm = get_isolated_graphmodule(f1, (x,), {})
248
self.assertFalse(is_any_sum(gm))
249
self.assertTrue(is_any_sigmoid(gm))
250
return torch.digamma(x)
252
traced = make_fx(f2)(torch.randn(3))
253
self.assertFalse(is_any_sum(traced))
254
self.assertFalse(is_any_sigmoid(traced))
255
self.assertTrue(is_any_digamma(traced))
261
self.assertFalse(is_any_sum(gm))
262
self.assertTrue(is_any_sigmoid(gm))
263
return torch.digamma(x)
265
traced = make_fx(f2)(torch.randn(3))
266
self.assertFalse(is_any_sum(traced))
267
self.assertFalse(is_any_sigmoid(traced))
268
self.assertTrue(is_any_digamma(traced))
274
self.assertFalse(is_any_sum(gm))
275
self.assertTrue(is_any_sigmoid(gm))
277
return torch.digamma(gm(x))
279
traced = make_fx(f3)(torch.randn(3))
280
self.assertFalse(is_any_sum(traced))
281
self.assertTrue(is_any_sigmoid(traced))
282
self.assertTrue(is_any_digamma(traced))
285
from torch.testing._internal.logging_tensor import LoggingTensorMode
288
with LoggingTensorMode():
289
gm = get_isolated_graphmodule(inner_with_factory, (), {})
290
self.assertTrue(is_any_sum(gm))
291
return torch.sigmoid(x)
294
with LoggingTensorMode(), LoggingTensorMode():
295
gm = get_isolated_graphmodule(f1_logging, (x,), {})
296
self.assertFalse(is_any_sum(gm))
297
self.assertTrue(is_any_sigmoid(gm))
298
return torch.digamma(x)
300
traced = make_fx(f2_logging)(torch.randn(3))
301
self.assertFalse(is_any_sum(traced))
302
self.assertFalse(is_any_sigmoid(traced))
303
self.assertTrue(is_any_digamma(traced))
308
from torch.testing._internal.logging_tensor import LoggingTensor
310
def f1_logging_tensor(x):
311
gm = get_isolated_graphmodule(inner_with_factory, (), {})
312
self.assertTrue(is_any_sum(gm))
313
return torch.sigmoid(x)
315
def f2_logging_tensor(x):
317
gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
318
self.assertFalse(is_any_sum(gm))
319
self.assertTrue(is_any_sigmoid(gm))
320
return torch.digamma(x)
322
traced = make_fx(f2_logging_tensor)(torch.randn(3))
323
self.assertFalse(is_any_sum(traced))
324
self.assertFalse(is_any_sigmoid(traced))
325
self.assertTrue(is_any_digamma(traced))
328
def test_empty_like_doesnt_burn_in_defaults(self):
330
return torch.empty_like(x)
331
out = make_fx(f)(torch.randn(3))
332
self.assertExpectedInline(out.code.strip(), """\
333
def forward(self, x_1):
334
empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None
335
return empty_like""")
337
def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
339
y = x.new_zeros(x.size())
343
def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
344
return torch.zeros(size, dtype=inp.dtype, device=inp.device)
346
factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
351
out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
352
self.assertExpectedInline(out.code, """\
356
def forward(self, x_1):
357
zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
358
copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
362
def test_make_fx_reentrant_dispatch(self):
364
return torch.ops.aten.norm.Scalar(x, 2.0)
366
def norm_decomp(x, p=2.0):
368
raise RuntimeError("can't handle with p != 2")
369
return torch.sqrt(torch.sum(torch.square(x)))
371
decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
373
traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
375
for n in traced.graph.nodes:
376
self.assertTrue("square" not in str(n.target))
377
self.assertTrue("norm" not in str(n.target))
379
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
380
def test_resnet18_backward_trace(self):
381
mod = torchvision.models.resnet18()
388
def f(x, params, buffers):
389
for p in params.values():
391
loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
396
return [p.grad for p in params.values()]
398
inp = torch.randn(3, 3, 250, 250)
399
self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
401
def test_varargs(self):
405
self._test(f, [torch.randn(2), torch.randn(2)])
407
def test_proxy_tensor(self):
409
val = x.cos().cos().sum()
410
return torch.autograd.grad(val, x)
413
val = x.cos().cos().sum()
417
for f in [f_grad, f_backward]:
418
self._test(f, [torch.randn(3, requires_grad=True)])
420
def test_pickle_issue89626(self):
423
make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
426
def test_inplace_metadata(self):
430
assert x.shape[-1] == 1
433
self._test(f, [torch.randn(5)])
435
def test_mode_tracing_factory_function(self):
437
return x + torch.randn(x.shape)
440
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
443
node.target == aten.randn.default
444
for node in traced.graph.nodes
448
def test_pre_dispatch_functionalization(self):
450
a = FunctionalTensorMode(pre_dispatch=True)
452
x_unwrapped = FunctionalTensor.to_functional(x)
453
y = torch.matmul(x_unwrapped, x_unwrapped)
456
y_unwrapped = torch._from_functional_tensor(y.elem)
459
from torch._dispatch.python import enable_python_dispatcher
461
with enable_python_dispatcher():
462
inp = torch.randn(4, 4)
463
gm = make_fx(f, pre_dispatch=True)(inp)
466
self.assertExpectedInline(gm.code.strip(), """\
467
def forward(self, x_1):
468
matmul = torch.ops.aten.matmul.default(x_1, x_1)
469
add = torch.ops.aten.add.Tensor(matmul, x_1); matmul = x_1 = None
470
mul = torch.ops.aten.mul.Tensor(add, 5); add = None
473
def test_pre_dispatch_functionalization_view_op(self):
475
a = FunctionalTensorMode(pre_dispatch=True)
477
x_unwrapped = FunctionalTensor.to_functional(x)
478
y = torch.matmul(x_unwrapped, x_unwrapped)
479
x_unwrapped = x_unwrapped.transpose(1, 0)
482
y_unwrapped = torch._from_functional_tensor(y.elem)
485
from torch._dispatch.python import enable_python_dispatcher
487
with enable_python_dispatcher():
488
inp = torch.randn(4, 4)
489
gm = make_fx(f, pre_dispatch=True)(inp)
492
self.assertExpectedInline(gm.code.strip(), """\
493
def forward(self, x_1):
494
matmul = torch.ops.aten.matmul.default(x_1, x_1)
495
transpose = torch.ops.aten.transpose.int(x_1, 1, 0); x_1 = None
496
add = torch.ops.aten.add.Tensor(matmul, transpose); matmul = transpose = None
497
view = torch.ops.aten.view.default(add, [2, 8]); add = None
500
def test_val_metadata_mutation(self):
506
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True))
508
tuple(node.meta['val'].shape)
509
for node in traced.graph.nodes
510
if 'val' in node.meta
511
], [(3,), (3,), (1, 3)])
513
def test_make_fx_overloads(self):
515
return x.cos() + torch.randn(x.shape)
517
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
519
self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
520
for node in traced.graph.nodes if node.op == 'call_function'))
522
def test_tensor_constants(self):
524
val = torch.tensor(float('inf'))
525
return torch.full((100, 100), val)
529
def test_allclose(self):
531
return torch.allclose(a, b)
534
make_fx(f, tracing_mode=self.tracing_mode)(
535
torch.zeros(3), torch.zeros(3)
538
if self.tracing_mode != "real":
539
self.assertRaises(DataDependentOutputException, test_f)
541
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
543
def test_constant_proxy_tensor_mut(self):
545
val = torch.tensor(float(1))
547
return torch.full((100, 100), val)
549
g = make_fx(f, tracing_mode=self.tracing_mode)()
550
self.assertEqual(g(), f())
552
self.assertEqual(g(), f())
554
def test_constant_unbind(self):
556
val = torch.tensor([2])
557
r, = torch.unbind(val, 0)
560
g = make_fx(f, tracing_mode=self.tracing_mode)()
561
self.assertEqual(g(), f())
563
def test_constant_blowup(self):
565
val = torch.tensor([2])
566
blowup = val.repeat(1000)
567
return bool(blowup.sum().item() == 2)
570
make_fx(f, tracing_mode=self.tracing_mode)()
572
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
574
def test_constant_random(self):
576
val = torch.tensor([2.0])
578
return bool(val.item() == 2.1)
581
make_fx(f, tracing_mode=self.tracing_mode)()
583
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
585
def test_decomposition_interpreter(self):
587
return torch.nn.functional.silu(x)
589
x = torch.rand((4, 4))
590
fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
593
for n in fx_module.graph.nodes:
594
if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
597
self.assertTrue(found_silu)
599
new_graph = torch.fx.Graph()
600
silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
601
DecompositionInterpreter(
604
decomposition_table=silu_decomp_table,
607
decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
609
for n in decomposed_module.graph.nodes:
610
self.assertTrue(n.target != torch.ops.aten.silu)
611
self.assertTrue(n.target != torch.ops.aten.silu.default)
613
self.assertEqual(fx_module(x), decomposed_module(x))
615
def test_make_fx_model_fwd_bwd(self):
616
class Foo(torch.nn.Module):
619
self.linear = torch.nn.Linear(5, 5)
621
def forward(self, x):
622
return self.linear(x).relu()
627
out = torch.func.functional_call(model, params, x).sum()
629
return list(params.values())
630
input = torch.randn(3, 5, requires_grad=True)
631
params = dict(model.named_parameters())
632
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
635
torch.allclose(fx_f(input, params)[0], f(input, params)[0])
637
torch.allclose(fx_f(input, params)[0], f(input, params)[1])
640
torch.allclose(fx_f(input, params)[1], f(input, params)[0])
642
torch.allclose(fx_f(input, params)[1], f(input, params)[1])
645
def test_make_fx_model_double_param(self):
646
class Emformer(torch.nn.Module):
649
input_dim: int = 256,
653
self.layer_norm = torch.nn.LayerNorm(input_dim)
655
def forward(mod_self, x):
656
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
657
y = mod_self.layer_norm(x)
658
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
659
z = mod_self.layer_norm(y)
663
gm = make_fx(Emformer())(torch.randn(16, 1, 256))
664
ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
665
self.assertEqual(len(ops), 2)
668
def test_make_fx_model_fwd_bwd_wgtupdate(self):
669
class Foo(torch.nn.Module):
672
self.linear = torch.nn.Linear(5, 5)
674
def forward(self, x):
675
return self.linear(x).relu()
679
def f(args, params, buffers):
680
for p in params.values():
682
if not isinstance(args, Iterable):
684
params_and_buffers = {**params, **buffers}
685
out = torch.func.functional_call(model, params_and_buffers, args)
687
return [p - 1e-4 * p.grad for p in params.values()]
689
input = torch.randn(3, 5, requires_grad=True)
690
params = dict(model.named_parameters())
691
buffers = dict(model.named_buffers())
692
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
696
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
698
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
701
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
703
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
706
def test_trace_subclasses(self):
713
wrapped = UnwrapTensor(x)
717
inp = [torch.randn(5)]
721
def test_partial_decomp(self):
723
x = torch.addmm(a, b, c)
724
y = torch.addmm(a, b, c, beta=2, alpha=1)
726
inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
727
fx_g = make_fx(f)(*inps)
729
def addmm(a, b, c, beta=1, alpha=1):
730
if beta == 1 and alpha == 1:
731
return NotImplemented
732
return beta * a + alpha * (b @ c)
734
decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps)
736
self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
737
self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
738
self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
740
def test_decomp_of_capture(self):
744
return x.t() + val.t()
749
traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
750
self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
753
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
754
def test_amp_cache(self):
755
layer = torch.nn.Conv2d(3, 3, 3).cuda()
758
return torch.nn.functional.conv2d(x, w, stride=layer.stride)
760
inp = torch.randn(4, 3, 10, 10, device='cuda')
761
with torch.autocast('cuda'):
762
out_graph = make_fx(f)(inp, layer.weight).graph
763
out_graph2 = make_fx(f)(inp, layer.weight).graph
765
self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
766
for a, b in zip(out_graph.nodes, out_graph2.nodes):
767
self.assertEqual(a.op, b.op)
769
def test_strides(self):
771
self.assertTrue(x.is_contiguous())
772
self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
773
x = x.permute(0, 3, 1, 2)
774
self.assertFalse(x.is_contiguous())
775
self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
777
make_fx(f)(torch.randn(2, 3, 4, 5))
780
self.assertTrue(x.is_contiguous())
782
self.assertFalse(y.is_contiguous())
784
self.assertFalse(y.is_contiguous())
787
make_fx(f)(torch.randn(2, 3, 4, 5))
789
def test_pr_86917(self):
792
return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10)
794
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
796
class TestGenericProxyTensorReal(TestGenericProxyTensor):
797
tracing_mode = "real"
800
class TestGenericProxyTensorFake(TestGenericProxyTensor):
801
tracing_mode = "fake"
804
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
805
tracing_mode = "symbolic"
808
del TestGenericProxyTensor
811
class TestRealProxyTensor(TestCase):
812
def test_error_on_data_dependent_ops(self):
816
assert torch.allclose(x * y, y * x)
821
make_fx(f, _error_on_data_dependent_ops=False)()
822
make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)()
824
class TestFakeProxyTensor(TestCase):
825
def test_issue82547(self):
826
x = nn.Parameter(torch.randn(3, 3))
829
return torch.ops.aten.t.default(x)
830
self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")())
832
class A(torch.Tensor):
835
x = A(torch.randn(3, 3))
836
self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")())
838
def test_use_fake_and_tensor(self):
840
z = torch.tensor([2.0, 3.0])
843
g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
844
x, y = torch.randn(2), torch.randn(2)
845
self.assertEqual(g(x, y), f(x, y))
847
def test_free_fake(self):
849
return torch.add(x, y)
851
with FakeTensorMode() as fake_mode:
853
make_fx(f, tracing_mode="real")(torch.randn(2))
855
def test_fused_adam(self):
857
params = [torch.randn(10, 10) for _ in range(10)]
858
grads = [torch.randn(10, 10) for _ in range(10)]
859
exp_avgs = [torch.randn(10, 10) for _ in range(10)]
860
exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
861
max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
862
state_steps = [torch.tensor(0) for _ in range(10)]
864
def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
865
(new_params, _, _, _, _) = aten._fused_adam.default(
881
for p, new_p in zip(params, new_params):
886
gm = make_fx(fused_adam, tracing_mode='fake')(
894
ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
895
for n in gm.graph.nodes:
896
if n.op == "call_function" and n.target in ensure_ops_have_val:
897
self.assertIn('val', n.meta)
899
def test_alias(self):
901
return torch.ops.aten.alias(x)
903
r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
905
self.assertExpectedInline(r, """\
906
def forward(self, x_1):
907
alias = torch.ops.aten.alias.default(x_1); x_1 = None
913
b = torch.var_mean(a, dim=0)
917
out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5))
918
for n in out.graph.nodes:
921
self.assertTrue('val' in n.meta)
923
def _get_node(fx_g, cond):
924
for n in fx_g.graph.nodes:
929
def _get_free_symbols(shape_env):
930
vars = tuple(shape_env.var_to_val.keys())
931
return len([var for var in vars if var not in shape_env.replacements])
934
inps = [torch.randn(arg) for arg in args]
935
return make_fx(f, tracing_mode="symbolic")(*inps)
938
class TestSymbolicTracing(TestCase):
939
def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
941
Tests fn traced with trace_inputs against test_inputs
942
Also returns shape env
944
trace_inputs = [torch.randn(shape) for shape in trace_inputs]
945
traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
946
for input in test_inputs:
947
input = [torch.randn(shape) for shape in input]
948
rx, ry = traced_f(*input), fn(*input)
950
self.assertEqual(rx, ry)
954
def test_debug_interpreter(self):
956
from torch.library import Library
958
foo = Library("foo", "DEF")
959
foo.define("foo(Tensor self) -> Tensor")
962
@torch.library.impl(foo, "foo", "CPU")
966
@torch.library.impl(foo, "foo", "Meta")
971
return torch.ops.foo.foo.default(x)
973
gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2))
974
from torch._functorch.compilers import DebugInterpreter
976
interp = DebugInterpreter(gm)
979
self.assertRaisesRegex(
980
AssertionError, r"3 != 1",
981
lambda: interp.run(torch.randn(3, 3).T),
985
self.assertRaisesRegex(
986
AssertionError, r"\(3, 1\) != \(1, 3\)",
987
lambda: interp.run(torch.randn(3, 3))
990
def test_int_input(self):
994
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip()
995
self.assertExpectedInline(r, """\
996
def forward(self, x_1, y_1):
997
view = torch.ops.aten.view.default(x_1, [y_1]); x_1 = y_1 = None
1000
def test_resize_from_zero(self):
1002
x.resize_(y.size(0))
1004
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
1005
self.assertExpectedInline(r, """\
1006
def forward(self, x_1, y_1):
1007
sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
1008
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = None
1011
def test_broadcast_shapes(self):
1013
return torch.functional.broadcast_shapes(x.size(), y.size()[0])
1015
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 1), torch.empty(5)).code).strip()
1016
self.assertExpectedInline(r, """\
1017
def forward(self, x_1, y_1):
1018
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
1019
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
1020
return (sym_size_int, sym_size_int_1)""")
1022
def test_deduped_shape(self):
1023
def f(s0, s1, x, y):
1024
return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1026
x = torch.empty(3, 1)
1028
from torch.fx.experimental.symbolic_shapes import ShapeEnv
1029
shape_env = ShapeEnv()
1031
with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1032
x = fake_mode.from_tensor(x)
1033
y = fake_mode.from_tensor(y)
1034
r = str(make_fx(f, tracing_mode="real")(x.shape[0], y.shape[0], x, y).code).strip()
1035
self.assertExpectedInline(r, """\
1036
def forward(self, s0_1, s1_1, x_1, y_1):
1037
empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False)
1038
return ((s0_1, s1_1), empty)""")
1040
def test_non_deduped_shape(self):
1042
return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1044
x = torch.empty(3, 1)
1046
from torch.fx.experimental.symbolic_shapes import ShapeEnv
1047
shape_env = ShapeEnv()
1049
with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1050
x = fake_mode.from_tensor(x)
1051
y = fake_mode.from_tensor(y)
1052
r = str(make_fx(f, tracing_mode="real")(x, y).code).strip()
1053
self.assertExpectedInline(r, """\
1054
def forward(self, x_1, y_1):
1055
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
1056
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
1057
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
1058
return ((sym_size_int, sym_size_int_1), empty)""")
1060
def test_unary(self):
1062
assert x.shape[0] < 20
1065
test_inputs.append([(2, 5)])
1066
test_inputs.append([(6, 8)])
1067
gm = self._test_dynamic(f, [(3, 4)], test_inputs)
1068
self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
1069
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
1070
self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
1071
self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
1073
def test_repeat_interleave(self):
1074
def f(src_tokens, beam_size_src):
1075
return src_tokens.repeat_interleave(beam_size_src.size(0), 0)
1080
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1081
gm = make_fx(f, tracing_mode="symbolic")(src_tokens, torch.randn(5))
1082
self.assertEqual(len(gm.shape_env.guards), 0)
1084
def test_non_symint_size_spec(self):
1088
torch._C._non_sym_sizes(x)
1091
x = torch.randn(2, 3)
1092
make_fx(f, tracing_mode="symbolic")(x)
1095
def test_symbolic_repeat_interleave(self):
1097
return y.repeat_interleave(x, dim=1)
1099
y = torch.tensor([[1, 2], [3, 4]])
1100
x = torch.tensor([2, 3])
1101
r = str(make_fx(f, tracing_mode="symbolic")(y, x).code).strip()
1102
self.assertExpectedInline(r, """\
1103
def forward(self, y_1, x_1):
1104
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1); x_1 = None
1105
index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave); y_1 = repeat_interleave = None
1106
return index_select""")
1108
def test_cumsum_unbacked(self):
1111
z = torch.randn((3, y, 3))
1114
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([5])).code).strip()
1115
self.assertExpectedInline(
1117
def forward(self, x_1):
1118
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1119
randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1120
cumsum = torch.ops.aten.cumsum.default(randn, 0); randn = None
1125
def test_repeat_interleave_unbacked_output_size(self):
1128
return y.repeat_interleave(x, dim=0, output_size=s)
1130
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip()
1131
self.assertExpectedInline(
1133
def forward(self, x_1, y_1):
1134
sum_1 = torch.ops.aten.sum.default(x_1)
1135
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1); sum_1 = None
1136
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense); x_1 = _local_scalar_dense = None
1137
index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave); y_1 = repeat_interleave = None
1138
return index_select"""
1141
def test_arange_unbacked_output_size(self):
1143
return torch.arange(0, x)
1145
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10)).code).strip()
1146
self.assertExpectedInline(
1148
def forward(self, x_1):
1149
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1150
arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1154
def test_adv_index_batch(self):
1156
bsz, src_len = src_tokens.size()[:2]
1157
start_step = src_tokens.shape[1]
1160
max_len = src_len + generate_size
1161
tokens = torch.zeros(bsz * beam_size, max_len).to(src_tokens).long().fill_(0)
1162
tokens[:, :start_step] = src_tokens.repeat_interleave(beam_size, 0)
1168
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1169
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
1170
self.assertEqual(len(gm.shape_env.guards), 0)
1172
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1173
def test_cpu_scalar_cuda(self):
1179
make_fx(f, tracing_mode="symbolic")(
1180
torch.tensor(1.0), torch.randn(2, 2, device='cuda')
1183
self.assertExpectedInline(r, """\
1184
def forward(self, a_1, b_1):
1185
mul = torch.ops.aten.mul.Tensor(a_1, b_1); a_1 = None
1186
mm = torch.ops.aten.mm.default(mul, b_1); mul = b_1 = None
1189
def test_binary_broadcast(self):
1195
test_inputs.append([(1, 5), (3, 1)])
1196
test_inputs.append([(1, 4), (4, 1)])
1197
shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env
1198
assert len(shape_env.guards) == 0
1200
def test_multiply_shape(self):
1202
return torch.empty(a.shape[0] * 2)
1204
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1205
self.assertExpectedInline(r, """\
1206
def forward(self, a_1):
1207
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
1208
mul = sym_size_int * 2; sym_size_int = None
1209
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
1212
def test_item(self):
1217
r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip()
1218
self.assertExpectedInline(r, """\
1219
def forward(self, a_1):
1220
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1)
1221
mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None
1224
def test_tensor_symfloat(self):
1226
r = torch.tensor(a.size(0) ** 2.0)
1227
assert r.dtype is torch.float
1230
gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2))
1231
r = str(gm.code).strip()
1234
self.assertExpectedInline(r, """\
1235
def forward(self, a_1):
1236
_tensor_constant0 = self._tensor_constant0
1237
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
1238
return lift_fresh_copy""")
1239
self.assertEqual(gm._tensor_constant0, torch.tensor(4.0))
1241
def test_item_to_constructor(self):
1244
return torch.empty(r)
1246
r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
1247
self.assertExpectedInline(
1249
def forward(self, a_1):
1250
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None
1251
empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1256
def test_setitem_symint(self):
1263
r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(10)).code).strip()
1264
self.assertExpectedInline(
1266
def forward(self, x_1):
1267
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1268
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None
1269
select = torch.ops.aten.select.int(x_1, 0, 0)
1270
copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = None
1274
def test_dynamic_pointwise_scalar(self):
1275
def f(gravity, mask):
1276
gravity[mask, 0] = gravity[mask, 0] * -1
1278
r = str(make_fx(f, tracing_mode="symbolic")(
1279
torch.randn((12, 4)),
1280
torch.randint(0, 2, (12,), dtype=torch.bool)
1282
self.assertExpectedInline(r, """\
1283
def forward(self, gravity_1, mask_1):
1284
select = torch.ops.aten.select.int(gravity_1, 1, 0)
1285
index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None
1286
mul = torch.ops.aten.mul.Tensor(index, -1); index = None
1287
select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None
1288
index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = None
1291
def test_reflect_r_over_x(self):
1292
def reflect_R_over_x(R):
1293
reflect = torch.eye(3, device=R.device)
1295
return reflect @ R @ reflect
1297
def f(crop_camera, mask):
1298
crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
1300
r = str(make_fx(f, tracing_mode="symbolic")(
1301
torch.randn((12, 3, 3)),
1302
torch.randint(0, 2, (12,), dtype=torch.bool)
1304
self.assertExpectedInline(r, """\
1305
def forward(self, crop_camera_1, mask_1):
1306
index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
1307
eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
1308
_tensor_constant0 = self._tensor_constant0
1309
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
1310
select = torch.ops.aten.select.int(eye, 0, 0)
1311
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
1312
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None
1313
sym_size_int = torch.ops.aten.sym_size.int(index, 0)
1314
expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
1315
view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None
1316
sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
1317
sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
1318
expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]); index = None
1319
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
1320
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
1321
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
1322
mul = sym_size_int * 3
1323
view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None
1324
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
1325
view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
1326
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = None
1329
def test_unbacked_slice(self):
1332
return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
1334
make_fx(f, tracing_mode="symbolic")(
1335
torch.randn((12, 3, 3)),
1336
torch.randint(0, 2, (12,), dtype=torch.bool)
1339
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
1340
def test_unbacked_batch_resnet(self):
1341
mod = torchvision.models.resnet18()
1343
def f(x, mask, params, buffers):
1344
for p in itertools.chain([x, mask], params.values(), buffers.values()):
1348
torch._constrain_as_value(x.shape[0], min=1)
1349
for p in params.values():
1351
return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
1353
make_fx(f, tracing_mode="symbolic")(
1354
torch.randn(3, 3, 250, 250),
1355
torch.randint(0, 2, (3,), dtype=torch.bool),
1356
dict(mod.named_parameters()),
1357
dict(mod.named_buffers()),
1360
def test_boolean_index(self):
1361
def f(images, handedness, valid):
1362
images = images[valid]
1363
handedness = handedness[valid]
1364
right_hand_mask = handedness == 1
1365
images[right_hand_mask] = images[right_hand_mask].flip(-1)
1367
r = str(make_fx(f, tracing_mode="symbolic")(
1368
torch.randint(0, 256, (512, 1, 96, 96)),
1369
torch.randint(0, 1, (512,)),
1370
torch.randint(0, 2, (512,), dtype=torch.bool)
1372
self.assertExpectedInline(r, """\
1373
def forward(self, images_1, handedness_1, valid_1):
1374
index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None
1375
index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None
1376
eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None
1377
index_2 = torch.ops.aten.index.Tensor(index, [eq])
1378
flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None
1379
index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = None
1382
def test_neg_shape(self):
1384
return torch.empty(-a.shape[0] + 10)
1386
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
1387
self.assertExpectedInline(r, """\
1388
def forward(self, a_1):
1389
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
1390
neg = -sym_size_int; sym_size_int = None
1391
add = neg + 10; neg = None
1392
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
1395
def test_unbacked_unification(self):
1397
z = torch.zeros(x.item())
1400
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1401
self.assertExpectedInline(r, """\
1402
def forward(self, x_1, y_1):
1403
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1404
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1405
add = torch.ops.aten.add.Tensor(zeros, y_1); zeros = y_1 = None
1408
def test_view_divisibility_unbacked(self):
1411
r = torch.zeros(i0, 192)
1412
return r.view(12, -1, 192)
1413
make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
1415
def test_unbacked_unify_guard(self):
1417
z = torch.zeros(x.item())
1418
torch._check(z.size(0) == y.size(0))
1424
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1425
self.assertExpectedInline(r, """\
1426
def forward(self, x_1, y_1):
1427
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1428
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1429
add = torch.ops.aten.add.Tensor(y_1, 2); y_1 = None
1432
def test_unbacked_unify_guard_transitivity(self):
1434
z1 = torch.zeros(x1.item())
1435
z2 = torch.zeros(x2.item())
1436
torch._check(z1.size(0) == z2.size(0))
1437
torch._check(z2.size(0) == y.size(0))
1443
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.tensor(10), torch.randn(10)).code).strip()
1444
self.assertExpectedInline(r, """\
1445
def forward(self, x1_1, x2_1, y_1):
1446
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x1_1); x1_1 = None
1447
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1448
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(x2_1); x2_1 = None
1449
zeros_1 = torch.ops.aten.zeros.default([_local_scalar_dense_1], device = device(type='cpu'), pin_memory = False); _local_scalar_dense_1 = None
1450
add = torch.ops.aten.add.Tensor(y_1, 2); y_1 = None
1453
def test_split_unbacked_sizes(self):
1454
def f(lengths, values):
1456
sizes = [lengths[i].item() for i in range(lengths.size(0))]
1458
torch._constrain_as_size(s)
1459
return torch.split(values, sizes)
1461
r = str(make_fx(f, tracing_mode="symbolic")(
1462
torch.tensor([2, 3, 4]),
1465
self.assertExpectedInline(r, """\
1466
def forward(self, lengths_1, values_1):
1467
select = torch.ops.aten.select.int(lengths_1, 0, 0)
1468
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None
1469
select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
1470
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
1471
select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None
1472
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None
1473
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense)
1474
sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1)
1475
sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2)
1476
split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
1477
getitem = split_with_sizes[0]
1478
getitem_1 = split_with_sizes[1]
1479
getitem_2 = split_with_sizes[2]; split_with_sizes = None
1480
return (getitem, getitem_1, getitem_2)""")
1482
def test_invalidate_nonzero(self):
1491
assert x1.shape[0] == x2.shape[0]
1496
bool(x1.shape[0] == y.shape[0])
1497
self.fail("didn't raise exception")
1498
except GuardOnDataDependentSymNode:
1501
make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1503
def test_sqrt_size(self):
1505
return a / a.size(-1) ** 0.5
1507
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1508
self.assertExpectedInline(r, """\
1509
def forward(self, a_1):
1510
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1511
pow_1 = sym_size_int ** 0.5; sym_size_int = None
1512
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
1515
def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self):
1517
class Bar(torch.nn.Module):
1521
def forward(self, x):
1524
class Foo(torch.nn.Module):
1529
def forward(self, x):
1530
return x + self.bar(x)
1532
gm = make_fx(Foo())(torch.randn(4, 4))
1533
for node in gm.graph.nodes:
1534
self.assertTrue("nn_module_stack" not in node.meta)
1538
def functional_call(*args, **kwargs):
1539
with stateless._reparametrize_module(foo, {}):
1540
return foo(*args, **kwargs)
1542
functional_call._orig_mod = foo
1544
gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4))
1546
for node in gm_with_stack.graph.nodes:
1547
if "nn_module_stack" in node.meta:
1548
if len(node.meta["nn_module_stack"]) == 1:
1549
self.assertTrue("custom_tracer_preserving_nn_module_stack.<locals>.Foo" in str(node.meta["nn_module_stack"]))
1551
elif len(node.meta["nn_module_stack"]) == 2:
1552
self.assertTrue("preserving_nn_module_stack.<locals>.Bar" in str(node.meta["nn_module_stack"]))
1556
self.assertTrue(False)
1558
self.assertTrue(found)
1560
gm_without_stack = make_fx(functional_call)(torch.randn(4, 4))
1561
for node in gm_without_stack.graph.nodes:
1562
self.assertTrue("nn_module_stack" not in node.meta)
1564
def test_symint_to_tensor(self):
1566
return a / a.shape[0]
1568
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1569
self.assertExpectedInline(r, """\
1570
def forward(self, a_1):
1571
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1572
div = torch.ops.aten.div.Tensor(a_1, sym_size_int); a_1 = sym_size_int = None
1575
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
1576
self.assertExpectedInline(r, """\
1577
def forward(self, a_1):
1578
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1579
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
1580
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
1585
val = torch.mul(a, b)
1586
out = torch.cat([val, val])
1587
if out.shape[0] * out.shape[1] > 20:
1592
test_inputs.append([(1, 5), (6, 1)])
1593
test_inputs.append([(1, 4), (3, 1)])
1594
gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
1595
self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
1596
self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
1597
self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
1599
def test_new_empty(self):
1601
return a.new_empty(b.shape[0], b.shape[1] * 2)
1603
self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env
1605
def test_size_with_tensor(self):
1610
max_size = torch.tensor([800, 1216], dtype=torch.int64)
1611
batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
1612
return tensor.new_empty(batch_shape)
1614
a = torch.randn(3, 800, 1199)
1616
make_fx(f, tracing_mode="symbolic")(a)
1618
def test_fake_tensor_as_size(self):
1620
r = torch.zeros([x])
1623
fx_g = make_fx(f, tracing_mode="symbolic")(torch.tensor(4))
1624
self.assertExpectedInline(fx_g.code.strip(), """\
1625
def forward(self, x_1):
1626
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1627
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1630
def test_expand(self):
1633
c = b.expand(a.shape)
1636
self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
1637
self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
1639
def test_metadata(self):
1641
d = a.new_empty(a.shape[0] + b.shape[0])
1643
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
1644
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
1645
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
1646
self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr)
1648
def test_metadata_fresh(self):
1650
assert x.shape[0] == 3
1653
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
1654
meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
1655
meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
1656
self.assertTrue(meta_cos.meta['val'].shape[0] == 3)
1659
self.assertTrue(meta_inp.meta['val'].shape[0] == 3)
1661
def test_elementwise_meta_with_sym_numbers(self):
1662
def f(x, offset, as_sym_float=False):
1665
x0 = torch.sym_float(x0)
1666
return torch.add(x0, offset)
1668
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False)
1669
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1670
self.assertEqual(meta_add.meta['val'].shape, ())
1671
self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1673
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False)
1674
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1675
self.assertEqual(meta_add.meta['val'].shape, ())
1676
self.assertEqual(meta_add.meta['val'].dtype, torch.int64)
1678
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True)
1679
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1680
self.assertEqual(meta_add.meta['val'].shape, ())
1681
self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1683
def test_return_symint(self):
1685
return x.shape[0], x.cos(), x.shape[0] / 5
1686
self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1690
self._test_dynamic(f, [(5, 3)], [[(4, 6)]])
1692
def test_rmethod(self):
1694
return x.size(0) + x
1695
self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1697
def test_mega_guard(self):
1699
assert a.shape[0] == b.shape[0] * 2
1701
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
1702
from torch._dynamo.source import LocalSource
1703
self.assertExpectedInline(
1704
str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)),
1705
"""["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]"""
1707
self.assertExpectedInline(
1708
str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)),
1709
"""["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]"""
1712
def test_guard_upperbound_range_refinement(self):
1714
assert a.shape[0] > 5 and a.shape[0] > 12
1716
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1717
self.assertExpectedInline(show_guards(tensor), """13 <= L['a'].size()[0]""")
1719
def test_guard_lowerbound_range_refinement(self):
1721
assert a.shape[0] < 20 and a.shape[0] < 30
1723
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1724
self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] <= 19""")
1726
def test_guard_upperbound_range_refinement_multivariate(self):
1728
assert a.shape[0] > 5 and a.shape[0] > 12
1729
assert a.shape[1] > 5 and a.shape[1] > a.shape[0]
1731
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20)))
1732
self.assertExpectedInline(show_guards(tensor), """\
1733
L['a'].size()[1] > L['a'].size()[0]
1734
13 <= L['a'].size()[0]
1735
14 <= L['a'].size()[1]""")
1737
def test_guard_lowerbound_range_refinement_multivariate(self):
1739
assert a.shape[0] < 20 and a.shape[0] < 30
1740
assert a.shape[1] < 30 and a.shape[1] < a.shape[0]
1742
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5)))
1743
self.assertExpectedInline(
1744
show_guards(tensor),
1746
L['a'].size()[1] < L['a'].size()[0]
1747
L['a'].size()[0] <= 19
1748
L['a'].size()[1] <= 18""")
1750
def test_sym_storage_offset(self):
1754
inp = (torch.randn(8)[3:], torch.randn(5))
1755
fx_g = make_fx(f, tracing_mode="symbolic")(*inp)
1756
inp = (torch.randn(8)[3:], torch.randn(5))
1757
self.assertEqual(fx_g(*inp), f(*inp))
1759
def _assert_no_guards(self, fx_g, free_symbols):
1760
assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
1761
assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
1763
def test_guards_equal(self):
1769
fx_g = _trace(f, (5, 6), (5, 6))
1770
self._assert_no_guards(fx_g, 2)
1772
fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
1773
self._assert_no_guards(fx_g, 3)
1775
fx_g = _trace(f, (5, 1), (1, 6))
1776
self._assert_no_guards(fx_g, 2)
1780
cat = torch.cat([c, d])
1783
fx_g = _trace(f, 7, 7, 4, 3)
1784
self._assert_no_guards(fx_g, 2)
1786
def f(a, b, c, d, e):
1787
vals = [a, b, c, d, e]
1789
for idx in range(len(vals) - 1):
1790
x = torch.cat([x, vals[idx]]) + vals[idx + 1]
1793
fx_g = _trace(f, 2, 4, 8, 16, 32)
1794
self._assert_no_guards(fx_g, 1)
1797
a = a.view(b.shape[0])
1800
fx_g = _trace(f, (4, 2), 8)
1801
self._assert_no_guards(fx_g, 2)
1803
fx_g = _trace(f, (4, 2), (8, 5))
1804
self._assert_no_guards(fx_g, 3)
1806
fx_g = _trace(f, (2, 3, 4), 24)
1807
self._assert_no_guards(fx_g, 3)
1809
def test_nonidentity_transitive_guards(self):
1810
def f(a, b, c, d, e):
1811
vals = [a, b, c, d, e]
1813
for idx in range(len(vals) - 1):
1814
cat_vals.append(torch.cat([vals[idx], vals[idx]]))
1816
for a, b in reversed(list(zip(cat_vals, vals[1:]))):
1817
final_vals.append(a + b)
1820
fx_g = _trace(f, 2, 4, 8, 16, 32)
1821
self.assertExpectedInline(show_guards(fx_g), """""")
1823
@torch.fx.experimental._config.patch(translation_validation=True)
1824
def test_constant_specialization(self):
1826
assert t.shape[0] == 10
1829
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10))
1830
self.assertExpectedInline(show_guards(tensor), """""")
1841
skip('empty_permuted'),
1843
skip('linalg.lstsq', 'grad_oriented'),
1844
skip('nn.functional.max_unpool1d', '', device_type='cpu'),
1845
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
1846
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
1847
skip('linalg.lstsq'),
1852
xfail('nn.functional.gaussian_nll_loss'),
1853
xfail('tensor_split'),
1856
xfail('nanquantile'),
1860
xfail('sparse.sampled_addmm'),
1861
xfail('sparse.mm', 'reduce'),
1869
skip('empty_strided', '', device_type='cpu'),
1872
fake_tensor_failures = {
1874
skip('nn.functional.nll_loss'),
1877
symbolic_tensor_failures = {
1878
xfail('linalg.eig'),
1879
xfail('linalg.eigvals'),
1880
xfail('combinations', ''),
1883
xfail('histogram', ''),
1884
xfail('histogramdd', ''),
1885
xfail('kthvalue', ''),
1886
xfail('nanquantile', ''),
1887
xfail('narrow', ''),
1888
xfail('nn.functional.binary_cross_entropy', ''),
1889
xfail('nn.functional.cross_entropy', ''),
1890
xfail('nn.functional.ctc_loss'),
1891
xfail('nn.functional.fractional_max_pool2d', ''),
1892
xfail('nn.functional.fractional_max_pool3d', ''),
1893
xfail('quantile', ''),
1894
xfail('resize_as_', ''),
1895
xfail('unique_consecutive', ''),
1896
xfail('unique', ''),
1898
xfail('max_pool2d_with_indices_backward', ''),
1901
xfail('fft.fft', ''),
1902
xfail('fft.hfft2', ''),
1903
xfail('fft.hfft', ''),
1904
xfail('fft.hfftn', ''),
1905
xfail('fft.ifft', ''),
1906
xfail('fft.ihfft2', ''),
1907
xfail('fft.ihfft', ''),
1908
xfail('fft.ihfftn', ''),
1909
xfail('fft.ihfft2', ''),
1910
xfail('fft.irfft2', ''),
1911
xfail('fft.irfft', ''),
1912
xfail('fft.irfftn', ''),
1913
xfail('fft.rfft2', ''),
1914
xfail('fft.rfft', ''),
1915
xfail('fft.rfftn', ''),
1918
symbolic_tensor_segfaults = {
1919
skip('nn.functional.batch_norm')
1922
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
1924
inplace_symbolic_tensor_failures = {
1926
xfail('float_power', ''),
1928
xfail('unique', ''),
1931
out_symbolic_tensor_failures = {
1932
xfail('_native_batch_norm_legit', ''),
1934
xfail('argmax', ''),
1935
xfail('argmin', ''),
1937
xfail('fft.fft2', ''),
1938
xfail('fft.fftn', ''),
1939
xfail('fft.ifft2', ''),
1940
xfail('fft.ifftn', ''),
1941
xfail('gather', ''),
1942
xfail('linalg.cholesky', ''),
1943
xfail('linalg.cholesky_ex', ''),
1944
xfail('linalg.det', ''),
1945
xfail('linalg.det', 'singular'),
1946
xfail('linalg.inv', ''),
1947
xfail('linalg.inv_ex', ''),
1948
xfail('linalg.pinv', ''),
1949
xfail('linalg.pinv', 'hermitian'),
1950
xfail('linalg.svdvals', ''),
1952
xfail('max', 'reduction_with_dim'),
1953
xfail('min', 'reduction_with_dim'),
1954
xfail('nn.functional.avg_pool2d', ''),
1955
xfail('nn.functional.linear', ''),
1956
xfail('scatter_add', ''),
1957
xfail('scatter', ''),
1958
xfail('take_along_dim', ''),
1960
xfail('triangular_solve', ''),
1961
xfail('view_copy', ''),
1969
out_symbolic_tensor_segfaults = {
1970
skip('nanmean', ''),
1973
out_symbolic_tensor_failures.update(out_symbolic_tensor_segfaults)
1977
def _get_safe_inplace(inplace_variant):
1978
@functools.wraps(inplace_variant)
1979
def _fn(t, *args, **kwargs):
1980
return inplace_variant(t.clone(), *args, **kwargs)
1984
def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, out=False):
1985
fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op
1986
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1992
for sample_input in itertools.islice(sample_inputs_itr, count):
1993
if inplace and sample_input.broadcasts_input:
1995
args = [sample_input.input] + list(sample_input.args)
1996
kwargs = sample_input.kwargs
1998
expected = fn(*args, **kwargs)
1999
kwargs['out'] = expected
2002
optests.make_fx_check(fn, args, kwargs, tracing_mode, self.assertEqual,
2003
randomize_data=True)
2004
except DynamicOutputShapeException:
2005
self.skipTest("Dynamic output shape operation in trace")
2008
class TestProxyTensorOpInfo(TestCase):
2009
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
2010
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
2011
def test_make_fx_exhaustive(self, device, dtype, op):
2012
_test_make_fx_helper(self, device, dtype, op, "real")
2014
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
2015
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
2016
def test_make_fx_fake_exhaustive(self, device, dtype, op):
2017
_test_make_fx_helper(self, device, dtype, op, "fake")
2019
@ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
2020
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
2021
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
2022
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
2023
_test_make_fx_helper(self, device, dtype, op, "symbolic")
2025
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2026
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
2027
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)
2028
def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op):
2029
if not op.get_inplace():
2030
self.skipTest("No inplace variable for this op")
2031
_test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True)
2033
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2034
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out',
2035
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | out_symbolic_tensor_failures)
2036
def test_make_fx_symbolic_exhaustive_out(self, device, dtype, op):
2037
if not op.supports_out:
2038
self.skipTest("Op doesn't support out")
2039
_test_make_fx_helper(self, device, dtype, op, "symbolic", out=True)
2043
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
2046
if __name__ == '__main__':