3
from torch.testing._internal.common_utils import TestCase, run_tests
9
from collections.abc import Iterable
10
from torch.nn.utils import stateless
11
from torch.testing._internal.common_device_type import instantiate_device_type_tests
12
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
13
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
14
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
15
from torch._decomp import decomposition_table
16
from torch.fx.experimental.symbolic_shapes import (
17
eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
18
guard_int, GuardOnDataDependentSymNode
20
from torch.testing._internal.custom_op_db import custom_op_db
21
from torch.testing._internal.hop_db import hop_db
22
from torch.testing._internal.common_device_type import ops
23
import torch.testing._internal.optests as optests
24
from torch._C import _disabled_torch_function_impl
25
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
26
from torch.utils._pytree import tree_map
27
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
29
import torch._functorch.config
37
HAS_CUDA = torch.cuda.is_available()
40
def strip_end(s, suffix):
41
if suffix and s.endswith(suffix):
42
return s[:-len(suffix)]
48
names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
50
gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
54
def process_failures():
56
Takes file containing failures like
58
FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950
60
and processes them into a list of opinfo xfails
62
f = open('pytest_failures')
63
failures = f.readlines()
64
failures = [i.strip() for i in failures]
66
def process_failure_string(s, matcher):
67
out = re.search(matcher, s)
70
SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
71
failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
73
def create_normalized_name(op):
74
if op.variant_test_name == '':
77
s = f"{op.name}.{op.variant_test_name}"
78
return s.replace('.', '_')
80
remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
82
print("symbolic_tensor_failures = {")
83
for failure, reason in failures:
84
print(f" xfail{remap_opinfo[failure]}, # {reason}")
88
USE_TORCHVISION = False
91
USE_TORCHVISION = True
93
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
94
"to install it with commands from pytorch.org, post-fixed with "
95
"`--no-deps` to avoid overwriting the pytorch installation",
99
def _create_new_input(x):
100
if not isinstance(x, torch.Tensor):
102
if x.dtype != torch.float:
105
return torch.rand_like(x, requires_grad=x.requires_grad)
107
return torch.rand_like(x)
110
Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
112
class UnwrapTensor(torch.Tensor):
114
def __new__(cls, tensor: torch.Tensor):
115
r = torch.Tensor._make_wrapper_subclass(
119
device=tensor.device,
120
layout=tensor.layout,
121
requires_grad=tensor.requires_grad,
128
return f"UnwrapTensor({self._tensor})"
130
__torch_function__ = _disabled_torch_function_impl
133
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
136
if isinstance(e, UnwrapTensor):
137
ret = e._tensor.cos()
141
args = tree_map(unwrap, args)
142
kwargs = tree_map(unwrap, kwargs)
143
return func(*args, **kwargs)
145
class TestGenericProxyTensor(TestCase):
148
def _test(self, f, inps):
149
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
150
new_inps = tree_map(_create_new_input, inps)
153
self.assertEqual(r1, r2)
155
def test_pre_dispatch_mode_stack(self):
158
return torch.matmul(a, b)
163
inp = torch.ones(4, 4)
165
from torch._dispatch.python import enable_python_dispatcher
166
with enable_python_dispatcher():
168
fx_g = make_fx(f, pre_dispatch=True)(inp)
169
self.assertExpectedInline(fx_g.code.strip(), """\
170
def forward(self, a_1):
171
ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
172
matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
175
def test_pre_dispatch_linear(self):
177
return torch.nn.functional.linear(a, b, c)
181
fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
184
self.assertEqual(out1, out2)
186
def test_pre_dispatch_no_grad(self):
189
torch.set_grad_enabled(False)
191
torch.set_grad_enabled(True)
193
a1 = torch.randn(4, requires_grad=True)
194
a2 = a1.clone().detach().requires_grad_(True)
195
a_tmp = a1.clone().detach().requires_grad_(True)
196
fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
199
self.assertEqual(out1, out2)
200
out1.sum().backward()
201
out2.sum().backward()
202
self.assertEqual(a1.grad, a2.grad)
204
def test_make_fx_simple(self):
207
self._test(f, (torch.randn(3),))
209
def test_scalar_device(self, device='cpu'):
212
self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
214
def test_isolated_graphmodule(self):
216
return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
218
def is_any_digamma(gm):
219
return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
221
def is_any_sigmoid(gm):
222
return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
228
gm = get_isolated_graphmodule(inner, (x,), {})
229
self.assertTrue(is_any_sum(gm))
230
return x + torch.randn(x.shape)
234
traced = make_fx(f)(torch.randn(3))
235
self.assertFalse(is_any_sum(traced))
239
def inner_with_factory():
240
val = torch.tensor(float(1))
242
return torch.full((10, 10), val).sum()
245
gm = get_isolated_graphmodule(inner_with_factory, (), {})
246
self.assertTrue(is_any_sum(gm))
247
return torch.sigmoid(x)
250
gm = get_isolated_graphmodule(f1, (x,), {})
251
self.assertFalse(is_any_sum(gm))
252
self.assertTrue(is_any_sigmoid(gm))
253
return torch.digamma(x)
255
traced = make_fx(f2)(torch.randn(3))
256
self.assertFalse(is_any_sum(traced))
257
self.assertFalse(is_any_sigmoid(traced))
258
self.assertTrue(is_any_digamma(traced))
264
self.assertFalse(is_any_sum(gm))
265
self.assertTrue(is_any_sigmoid(gm))
266
return torch.digamma(x)
268
traced = make_fx(f2)(torch.randn(3))
269
self.assertFalse(is_any_sum(traced))
270
self.assertFalse(is_any_sigmoid(traced))
271
self.assertTrue(is_any_digamma(traced))
277
self.assertFalse(is_any_sum(gm))
278
self.assertTrue(is_any_sigmoid(gm))
280
return torch.digamma(gm(x))
282
traced = make_fx(f3)(torch.randn(3))
283
self.assertFalse(is_any_sum(traced))
284
self.assertTrue(is_any_sigmoid(traced))
285
self.assertTrue(is_any_digamma(traced))
288
from torch.testing._internal.logging_tensor import LoggingTensorMode
291
with LoggingTensorMode():
292
gm = get_isolated_graphmodule(inner_with_factory, (), {})
293
self.assertTrue(is_any_sum(gm))
294
return torch.sigmoid(x)
297
with LoggingTensorMode(), LoggingTensorMode():
298
gm = get_isolated_graphmodule(f1_logging, (x,), {})
299
self.assertFalse(is_any_sum(gm))
300
self.assertTrue(is_any_sigmoid(gm))
301
return torch.digamma(x)
303
traced = make_fx(f2_logging)(torch.randn(3))
304
self.assertFalse(is_any_sum(traced))
305
self.assertFalse(is_any_sigmoid(traced))
306
self.assertTrue(is_any_digamma(traced))
311
from torch.testing._internal.logging_tensor import LoggingTensor
313
def f1_logging_tensor(x):
314
gm = get_isolated_graphmodule(inner_with_factory, (), {})
315
self.assertTrue(is_any_sum(gm))
316
return torch.sigmoid(x)
318
def f2_logging_tensor(x):
320
gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
321
self.assertFalse(is_any_sum(gm))
322
self.assertTrue(is_any_sigmoid(gm))
323
return torch.digamma(x)
325
traced = make_fx(f2_logging_tensor)(torch.randn(3))
326
self.assertFalse(is_any_sum(traced))
327
self.assertFalse(is_any_sigmoid(traced))
328
self.assertTrue(is_any_digamma(traced))
331
def test_empty_like_doesnt_burn_in_defaults(self):
333
return torch.empty_like(x)
334
out = make_fx(f)(torch.randn(3))
335
self.assertExpectedInline(out.code.strip(), """\
336
def forward(self, x_1):
337
empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None
338
return empty_like""")
340
def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
342
y = x.new_zeros(x.size())
346
def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
347
return torch.zeros(size, dtype=inp.dtype, device=inp.device)
349
factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
354
out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
355
self.assertExpectedInline(out.code, """\
359
def forward(self, x_1):
360
zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
361
copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
365
def test_make_fx_reentrant_dispatch(self):
367
return torch.ops.aten.norm.Scalar(x, 2.0)
369
def norm_decomp(x, p=2.0):
371
raise RuntimeError("can't handle with p != 2")
372
return torch.sqrt(torch.sum(torch.square(x)))
374
decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
376
traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
378
for n in traced.graph.nodes:
379
self.assertTrue("square" not in str(n.target))
380
self.assertTrue("norm" not in str(n.target))
382
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
383
def test_resnet18_backward_trace(self):
384
mod = torchvision.models.resnet18()
391
def f(x, params, buffers):
392
for p in params.values():
394
loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
399
return [p.grad for p in params.values()]
401
inp = torch.randn(3, 3, 250, 250)
402
self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
404
def test_varargs(self):
408
self._test(f, [torch.randn(2), torch.randn(2)])
410
def test_proxy_tensor(self):
412
val = x.cos().cos().sum()
413
return torch.autograd.grad(val, x)
416
val = x.cos().cos().sum()
420
for f in [f_grad, f_backward]:
421
self._test(f, [torch.randn(3, requires_grad=True)])
423
def test_pickle_issue89626(self):
426
make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
429
def test_inplace_metadata(self):
433
assert x.shape[-1] == 1
436
self._test(f, [torch.randn(5)])
438
def test_mode_tracing_factory_function(self):
440
return x + torch.randn(x.shape)
443
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
446
node.target == aten.randn.default
447
for node in traced.graph.nodes
451
def test_pre_dispatch_functionalization(self):
453
a = FunctionalTensorMode(pre_dispatch=True)
455
x_unwrapped = FunctionalTensor.to_functional(x)
456
y = torch.matmul(x_unwrapped, x_unwrapped)
459
y_unwrapped = torch._from_functional_tensor(y.elem)
462
from torch._dispatch.python import enable_python_dispatcher
464
with enable_python_dispatcher():
465
inp = torch.randn(4, 4)
466
gm = make_fx(f, pre_dispatch=True)(inp)
469
self.assertExpectedInline(gm.code.strip(), """\
470
def forward(self, x_1):
471
matmul = torch.ops.aten.matmul.default(x_1, x_1)
472
add = torch.ops.aten.add.Tensor(matmul, x_1); matmul = x_1 = None
473
mul = torch.ops.aten.mul.Tensor(add, 5); add = None
476
def test_pre_dispatch_functionalization_view_op(self):
478
a = FunctionalTensorMode(pre_dispatch=True)
480
x_unwrapped = FunctionalTensor.to_functional(x)
481
y = torch.matmul(x_unwrapped, x_unwrapped)
482
x_unwrapped = x_unwrapped.transpose(1, 0)
485
y_unwrapped = torch._from_functional_tensor(y.elem)
488
from torch._dispatch.python import enable_python_dispatcher
490
with enable_python_dispatcher():
491
inp = torch.randn(4, 4)
492
gm = make_fx(f, pre_dispatch=True)(inp)
495
self.assertExpectedInline(gm.code.strip(), """\
496
def forward(self, x_1):
497
matmul = torch.ops.aten.matmul.default(x_1, x_1)
498
transpose = torch.ops.aten.transpose.int(x_1, 1, 0); x_1 = None
499
add = torch.ops.aten.add.Tensor(matmul, transpose); matmul = transpose = None
500
view = torch.ops.aten.view.default(add, [2, 8]); add = None
503
def test_val_metadata_mutation(self):
509
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True))
511
tuple(node.meta['val'].shape)
512
for node in traced.graph.nodes
513
if 'val' in node.meta
514
], [(3,), (3,), (1, 3)])
516
def test_make_fx_overloads(self):
518
return x.cos() + torch.randn(x.shape)
520
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
522
self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
523
for node in traced.graph.nodes if node.op == 'call_function'))
525
def test_tensor_constants(self):
527
val = torch.tensor(float('inf'))
528
return torch.full((100, 100), val)
532
def test_allclose(self):
534
return torch.allclose(a, b)
537
make_fx(f, tracing_mode=self.tracing_mode)(
538
torch.zeros(3), torch.zeros(3)
541
if self.tracing_mode != "real":
542
self.assertRaises(DataDependentOutputException, test_f)
544
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
546
def test_constant_proxy_tensor_mut(self):
548
val = torch.tensor(float(1))
550
return torch.full((100, 100), val)
552
g = make_fx(f, tracing_mode=self.tracing_mode)()
553
self.assertEqual(g(), f())
555
self.assertEqual(g(), f())
557
def test_constant_unbind(self):
559
val = torch.tensor([2])
560
r, = torch.unbind(val, 0)
563
g = make_fx(f, tracing_mode=self.tracing_mode)()
564
self.assertEqual(g(), f())
566
def test_constant_blowup(self):
568
val = torch.tensor([2])
569
blowup = val.repeat(1000)
570
return bool(blowup.sum().item() == 2)
573
make_fx(f, tracing_mode=self.tracing_mode)()
575
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
577
def test_constant_random(self):
579
val = torch.tensor([2.0])
581
return bool(val.item() == 2.1)
584
make_fx(f, tracing_mode=self.tracing_mode)()
586
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
588
def test_decomposition_interpreter(self):
590
return torch.nn.functional.silu(x)
592
x = torch.rand((4, 4))
593
fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
596
for n in fx_module.graph.nodes:
597
if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
600
self.assertTrue(found_silu)
602
new_graph = torch.fx.Graph()
603
silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
604
DecompositionInterpreter(
607
decomposition_table=silu_decomp_table,
610
decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
612
for n in decomposed_module.graph.nodes:
613
self.assertTrue(n.target != torch.ops.aten.silu)
614
self.assertTrue(n.target != torch.ops.aten.silu.default)
616
self.assertEqual(fx_module(x), decomposed_module(x))
618
def test_make_fx_model_fwd_bwd(self):
619
class Foo(torch.nn.Module):
620
def __init__(self) -> None:
622
self.linear = torch.nn.Linear(5, 5)
624
def forward(self, x):
625
return self.linear(x).relu()
630
out = torch.func.functional_call(model, params, x).sum()
632
return list(params.values())
633
input = torch.randn(3, 5, requires_grad=True)
634
params = dict(model.named_parameters())
635
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
638
torch.allclose(fx_f(input, params)[0], f(input, params)[0])
640
torch.allclose(fx_f(input, params)[0], f(input, params)[1])
643
torch.allclose(fx_f(input, params)[1], f(input, params)[0])
645
torch.allclose(fx_f(input, params)[1], f(input, params)[1])
648
def test_make_fx_model_double_param(self):
649
class Emformer(torch.nn.Module):
652
input_dim: int = 256,
656
self.layer_norm = torch.nn.LayerNorm(input_dim)
658
def forward(mod_self, x):
659
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
660
y = mod_self.layer_norm(x)
661
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
662
z = mod_self.layer_norm(y)
666
gm = make_fx(Emformer())(torch.randn(16, 1, 256))
667
ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
668
self.assertEqual(len(ops), 2)
671
def test_make_fx_model_fwd_bwd_wgtupdate(self):
672
class Foo(torch.nn.Module):
673
def __init__(self) -> None:
675
self.linear = torch.nn.Linear(5, 5)
677
def forward(self, x):
678
return self.linear(x).relu()
682
def f(args, params, buffers):
683
for p in params.values():
685
if not isinstance(args, Iterable):
687
params_and_buffers = {**params, **buffers}
688
out = torch.func.functional_call(model, params_and_buffers, args)
690
return [p - 1e-4 * p.grad for p in params.values()]
692
input = torch.randn(3, 5, requires_grad=True)
693
params = dict(model.named_parameters())
694
buffers = dict(model.named_buffers())
695
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
699
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
701
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
704
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
706
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
709
def test_trace_subclasses(self):
716
wrapped = UnwrapTensor(x)
720
inp = [torch.randn(5)]
724
def test_partial_decomp(self):
726
x = torch.addmm(a, b, c)
727
y = torch.addmm(a, b, c, beta=2, alpha=1)
729
inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
730
fx_g = make_fx(f)(*inps)
732
def addmm(a, b, c, beta=1, alpha=1):
733
if beta == 1 and alpha == 1:
734
return NotImplemented
735
return beta * a + alpha * (b @ c)
737
decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps)
739
self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
740
self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
741
self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
743
def test_decomp_of_capture(self):
747
return x.t() + val.t()
752
traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
753
self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
756
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
757
def test_amp_cache(self):
758
layer = torch.nn.Conv2d(3, 3, 3).cuda()
761
return torch.nn.functional.conv2d(x, w, stride=layer.stride)
763
inp = torch.randn(4, 3, 10, 10, device='cuda')
764
with torch.autocast('cuda'):
765
out_graph = make_fx(f)(inp, layer.weight).graph
766
out_graph2 = make_fx(f)(inp, layer.weight).graph
768
self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
769
for a, b in zip(out_graph.nodes, out_graph2.nodes):
770
self.assertEqual(a.op, b.op)
772
def test_strides(self):
774
self.assertTrue(x.is_contiguous())
775
self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
776
x = x.permute(0, 3, 1, 2)
777
self.assertFalse(x.is_contiguous())
778
self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
780
make_fx(f)(torch.randn(2, 3, 4, 5))
783
self.assertTrue(x.is_contiguous())
785
self.assertFalse(y.is_contiguous())
787
self.assertFalse(y.is_contiguous())
790
make_fx(f)(torch.randn(2, 3, 4, 5))
792
def test_pr_86917(self):
795
return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10)
797
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
799
class TestGenericProxyTensorReal(TestGenericProxyTensor):
800
tracing_mode = "real"
803
class TestGenericProxyTensorFake(TestGenericProxyTensor):
804
tracing_mode = "fake"
807
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
808
tracing_mode = "symbolic"
811
del TestGenericProxyTensor
814
class TestRealProxyTensor(TestCase):
815
def test_error_on_data_dependent_ops(self):
819
assert torch.allclose(x * y, y * x)
824
make_fx(f, _error_on_data_dependent_ops=False)()
825
make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)()
827
class TestFakeProxyTensor(TestCase):
828
def test_issue82547(self):
829
x = nn.Parameter(torch.randn(3, 3))
832
return torch.ops.aten.t.default(x)
833
self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")())
835
class A(torch.Tensor):
838
x = A(torch.randn(3, 3))
839
self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")())
841
def test_use_fake_and_tensor(self):
843
z = torch.tensor([2.0, 3.0])
846
g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
847
x, y = torch.randn(2), torch.randn(2)
848
self.assertEqual(g(x, y), f(x, y))
850
def test_free_fake(self):
852
return torch.add(x, y)
854
with FakeTensorMode() as fake_mode:
856
make_fx(f, tracing_mode="real")(torch.randn(2))
858
def test_fused_adam(self):
860
params = [torch.randn(10, 10) for _ in range(10)]
861
grads = [torch.randn(10, 10) for _ in range(10)]
862
exp_avgs = [torch.randn(10, 10) for _ in range(10)]
863
exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
864
max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
865
state_steps = [torch.tensor(0) for _ in range(10)]
867
def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
868
(new_params, _, _, _, _) = aten._fused_adam.default(
884
for p, new_p in zip(params, new_params):
889
gm = make_fx(fused_adam, tracing_mode='fake')(
897
ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
898
for n in gm.graph.nodes:
899
if n.op == "call_function" and n.target in ensure_ops_have_val:
900
self.assertIn('val', n.meta)
902
def test_alias(self):
904
return torch.ops.aten.alias(x)
906
r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
908
self.assertExpectedInline(r, """\
909
def forward(self, x_1):
910
alias = torch.ops.aten.alias.default(x_1); x_1 = None
916
b = torch.var_mean(a, dim=0)
920
out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5))
921
for n in out.graph.nodes:
924
self.assertTrue('val' in n.meta)
926
def _get_node(fx_g, cond):
927
for n in fx_g.graph.nodes:
932
def _get_free_symbols(shape_env):
933
vars = tuple(shape_env.var_to_val.keys())
934
return len([var for var in vars if var not in shape_env.replacements])
937
inps = [torch.randn(arg) for arg in args]
938
return make_fx(f, tracing_mode="symbolic")(*inps)
941
class TestSymbolicTracing(TestCase):
942
def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
944
Tests fn traced with trace_inputs against test_inputs
945
Also returns shape env
947
trace_inputs = [torch.randn(shape) for shape in trace_inputs]
948
traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
949
for input in test_inputs:
950
input = [torch.randn(shape) for shape in input]
951
rx, ry = traced_f(*input), fn(*input)
953
self.assertEqual(rx, ry)
957
def test_debug_interpreter(self):
959
from torch.library import Library
961
foo = Library("foo", "DEF")
962
foo.define("foo(Tensor self) -> Tensor")
965
@torch.library.impl(foo, "foo", "CPU")
969
@torch.library.impl(foo, "foo", "Meta")
974
return torch.ops.foo.foo.default(x)
976
gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2))
977
from torch._functorch.compilers import DebugInterpreter
979
interp = DebugInterpreter(gm)
982
self.assertRaisesRegex(
983
AssertionError, r"3 != 1",
984
lambda: interp.run(torch.randn(3, 3).T),
988
self.assertRaisesRegex(
989
AssertionError, r"\(3, 1\) != \(1, 3\)",
990
lambda: interp.run(torch.randn(3, 3))
993
def test_int_input(self):
997
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip()
998
self.assertExpectedInline(r, """\
999
def forward(self, x_1, y_1):
1000
view = torch.ops.aten.view.default(x_1, [y_1]); x_1 = y_1 = None
1003
def test_resize_from_zero(self):
1005
x.resize_(y.size(0))
1007
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
1008
self.assertExpectedInline(r, """\
1009
def forward(self, x_1, y_1):
1010
sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
1011
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = resize_ = None
1014
def test_broadcast_shapes(self):
1016
return torch.functional.broadcast_shapes(x.size(), y.size()[0])
1018
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 1), torch.empty(5)).code).strip()
1019
self.assertExpectedInline(r, """\
1020
def forward(self, x_1, y_1):
1021
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
1022
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
1023
return (sym_size_int, sym_size_int_1)""")
1025
def test_deduped_shape(self):
1026
def f(s0, s1, x, y):
1027
return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1029
x = torch.empty(3, 1)
1031
from torch.fx.experimental.symbolic_shapes import ShapeEnv
1032
shape_env = ShapeEnv()
1034
with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1035
x = fake_mode.from_tensor(x)
1036
y = fake_mode.from_tensor(y)
1037
r = str(make_fx(f, tracing_mode="real")(x.shape[0], y.shape[0], x, y).code).strip()
1038
self.assertExpectedInline(r, """\
1039
def forward(self, s0_1, s1_1, x_1, y_1):
1040
empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False)
1041
return ((s0_1, s1_1), empty)""")
1043
def test_non_deduped_shape(self):
1045
return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1047
x = torch.empty(3, 1)
1049
from torch.fx.experimental.symbolic_shapes import ShapeEnv
1050
shape_env = ShapeEnv()
1052
with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1053
x = fake_mode.from_tensor(x)
1054
y = fake_mode.from_tensor(y)
1055
r = str(make_fx(f, tracing_mode="real")(x, y).code).strip()
1056
self.assertExpectedInline(r, """\
1057
def forward(self, x_1, y_1):
1058
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
1059
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
1060
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
1061
return ((sym_size_int, sym_size_int_1), empty)""")
1063
def test_unary(self):
1065
assert x.shape[0] < 20
1068
test_inputs.append([(2, 5)])
1069
test_inputs.append([(6, 8)])
1070
gm = self._test_dynamic(f, [(3, 4)], test_inputs)
1071
self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
1072
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
1073
self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
1074
self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
1076
def test_repeat_interleave(self):
1077
def f(src_tokens, beam_size_src):
1078
return src_tokens.repeat_interleave(beam_size_src.size(0), 0)
1083
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1084
gm = make_fx(f, tracing_mode="symbolic")(src_tokens, torch.randn(5))
1085
self.assertEqual(len(gm.shape_env.guards), 0)
1087
def test_non_symint_size_spec(self):
1091
torch._C._non_sym_sizes(x)
1094
x = torch.randn(2, 3)
1095
make_fx(f, tracing_mode="symbolic")(x)
1098
def test_symbolic_repeat_interleave(self):
1100
return y.repeat_interleave(x, dim=1)
1102
y = torch.tensor([[1, 2], [3, 4]])
1103
x = torch.tensor([2, 3])
1104
r = str(make_fx(f, tracing_mode="symbolic")(y, x).code).strip()
1105
self.assertExpectedInline(r, """\
1106
def forward(self, y_1, x_1):
1107
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1); x_1 = None
1108
index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave); y_1 = repeat_interleave = None
1109
return index_select""")
1111
def test_mod_gcd_unbacked(self):
1112
def f(_a, _b, _stride):
1115
stride = _stride.item()
1116
torch._check_is_size(a)
1117
torch._check_is_size(b)
1118
torch._check_is_size(stride)
1119
ta = torch.randn(a * stride)
1120
tb = torch.randn(b * stride)
1121
r = torch.cat([ta, tb])
1122
return r.view(a + b, stride)
1124
_a = torch.tensor(30)
1125
_b = torch.tensor(20)
1126
_stride = torch.tensor(10)
1127
r = str(make_fx(f, tracing_mode="symbolic")(_a, _b, _stride).code).strip()
1128
self.assertExpectedInline(r, """\
1129
def forward(self, _a_1, _b_1, _stride_1):
1130
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(_a_1); _a_1 = None
1131
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(_b_1); _b_1 = None
1132
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(_stride_1); _stride_1 = None
1133
mul = _local_scalar_dense * _local_scalar_dense_2
1134
randn = torch.ops.aten.randn.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
1135
mul_1 = _local_scalar_dense_1 * _local_scalar_dense_2
1136
randn_1 = torch.ops.aten.randn.default([mul_1], device = device(type='cpu'), pin_memory = False); mul_1 = None
1137
cat = torch.ops.aten.cat.default([randn, randn_1]); randn = randn_1 = None
1138
add = _local_scalar_dense + _local_scalar_dense_1; _local_scalar_dense = _local_scalar_dense_1 = None
1139
view = torch.ops.aten.view.default(cat, [add, _local_scalar_dense_2]); cat = add = _local_scalar_dense_2 = None
1142
def test_cumsum_unbacked(self):
1145
z = torch.randn((3, y, 3))
1148
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([5])).code).strip()
1149
self.assertExpectedInline(
1151
def forward(self, x_1):
1152
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1153
randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1154
cumsum = torch.ops.aten.cumsum.default(randn, 0); randn = None
1159
def test_repeat_interleave_unbacked_output_size(self):
1162
return y.repeat_interleave(x, dim=0, output_size=s)
1164
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip()
1165
self.assertExpectedInline(
1167
def forward(self, x_1, y_1):
1168
sum_1 = torch.ops.aten.sum.default(x_1)
1169
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1); sum_1 = None
1170
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense); x_1 = _local_scalar_dense = None
1171
index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave); y_1 = repeat_interleave = None
1172
return index_select"""
1175
def test_arange_unbacked_output_size(self):
1177
return torch.arange(0, x)
1179
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10)).code).strip()
1180
self.assertExpectedInline(
1182
def forward(self, x_1):
1183
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1184
arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1188
def test_adv_index_batch(self):
1190
bsz, src_len = src_tokens.size()[:2]
1191
start_step = src_tokens.shape[1]
1194
max_len = src_len + generate_size
1195
tokens = torch.zeros(bsz * beam_size, max_len).to(src_tokens).long().fill_(0)
1196
tokens[:, :start_step] = src_tokens.repeat_interleave(beam_size, 0)
1202
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1203
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
1206
self.assertEqual(len(gm.shape_env.guards), 1)
1208
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1209
def test_cpu_scalar_cuda(self):
1215
make_fx(f, tracing_mode="symbolic")(
1216
torch.tensor(1.0), torch.randn(2, 2, device='cuda')
1219
self.assertExpectedInline(r, """\
1220
def forward(self, a_1, b_1):
1221
mul = torch.ops.aten.mul.Tensor(a_1, b_1); a_1 = None
1222
mm = torch.ops.aten.mm.default(mul, b_1); mul = b_1 = None
1225
def test_binary_broadcast(self):
1231
test_inputs.append([(1, 5), (3, 1)])
1232
test_inputs.append([(1, 4), (4, 1)])
1233
shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env
1234
assert len(shape_env.guards) == 0
1236
def test_multiply_shape(self):
1238
return torch.empty(a.shape[0] * 2)
1240
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1241
self.assertExpectedInline(r, """\
1242
def forward(self, a_1):
1243
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
1244
mul = sym_size_int * 2; sym_size_int = None
1245
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
1248
def test_item(self):
1253
r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip()
1254
self.assertExpectedInline(r, """\
1255
def forward(self, a_1):
1256
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1)
1257
mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None
1260
def test_tensor_symfloat(self):
1262
r = torch.tensor(a.size(0) ** 2.0)
1263
assert r.dtype is torch.float
1266
gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2))
1267
r = str(gm.code).strip()
1270
self.assertExpectedInline(r, """\
1271
def forward(self, a_1):
1272
_tensor_constant0 = self._tensor_constant0
1273
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
1274
return lift_fresh_copy""")
1275
self.assertEqual(gm._tensor_constant0, torch.tensor(4.0))
1277
def test_item_to_constructor(self):
1280
return torch.empty(r)
1282
r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
1283
self.assertExpectedInline(
1285
def forward(self, a_1):
1286
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None
1287
empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1292
def test_setitem_symint(self):
1299
r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(10)).code).strip()
1300
self.assertExpectedInline(
1302
def forward(self, x_1):
1303
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1304
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None
1305
select = torch.ops.aten.select.int(x_1, 0, 0)
1306
copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = copy_ = None
1310
def test_dynamic_pointwise_scalar(self):
1311
def f(gravity, mask):
1312
gravity[mask, 0] = gravity[mask, 0] * -1
1314
r = str(make_fx(f, tracing_mode="symbolic")(
1315
torch.randn((12, 4)),
1316
torch.randint(0, 2, (12,), dtype=torch.bool)
1318
self.assertExpectedInline(r, """\
1319
def forward(self, gravity_1, mask_1):
1320
select = torch.ops.aten.select.int(gravity_1, 1, 0)
1321
index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None
1322
mul = torch.ops.aten.mul.Tensor(index, -1); index = None
1323
select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None
1324
index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = index_put_ = None
1327
def test_reflect_r_over_x(self):
1328
def reflect_R_over_x(R):
1329
reflect = torch.eye(3, device=R.device)
1331
return reflect @ R @ reflect
1333
def f(crop_camera, mask):
1334
crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
1336
r = str(make_fx(f, tracing_mode="symbolic")(
1337
torch.randn((12, 3, 3)),
1338
torch.randint(0, 2, (12,), dtype=torch.bool)
1340
self.assertExpectedInline(r, """\
1341
def forward(self, crop_camera_1, mask_1):
1342
index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
1343
eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
1344
_tensor_constant0 = self._tensor_constant0
1345
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
1346
select = torch.ops.aten.select.int(eye, 0, 0)
1347
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
1348
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = copy_ = None
1349
sym_size_int = torch.ops.aten.sym_size.int(index, 0)
1350
expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
1351
view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None
1352
sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
1353
sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
1354
expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]); index = None
1355
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
1356
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
1357
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
1358
mul_4 = sym_size_int * 3
1359
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None
1360
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
1361
view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
1362
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None
1365
def test_unbacked_slice(self):
1368
return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
1370
make_fx(f, tracing_mode="symbolic")(
1371
torch.randn((12, 3, 3)),
1372
torch.randint(0, 2, (12,), dtype=torch.bool)
1375
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
1376
def test_unbacked_batch_resnet(self):
1377
mod = torchvision.models.resnet18()
1379
def f(x, mask, params, buffers):
1380
for p in itertools.chain([x, mask], params.values(), buffers.values()):
1384
torch._check(x.shape[0] >= 1)
1385
for p in params.values():
1387
return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
1389
make_fx(f, tracing_mode="symbolic")(
1390
torch.randn(3, 3, 250, 250),
1391
torch.randint(0, 2, (3,), dtype=torch.bool),
1392
dict(mod.named_parameters()),
1393
dict(mod.named_buffers()),
1396
def test_boolean_index(self):
1397
def f(images, handedness, valid):
1398
images = images[valid]
1399
handedness = handedness[valid]
1400
right_hand_mask = handedness == 1
1401
images[right_hand_mask] = images[right_hand_mask].flip(-1)
1403
r = str(make_fx(f, tracing_mode="symbolic")(
1404
torch.randint(0, 256, (512, 1, 96, 96)),
1405
torch.randint(0, 1, (512,)),
1406
torch.randint(0, 2, (512,), dtype=torch.bool)
1408
self.assertExpectedInline(r, """\
1409
def forward(self, images_1, handedness_1, valid_1):
1410
index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None
1411
index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None
1412
eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None
1413
index_2 = torch.ops.aten.index.Tensor(index, [eq])
1414
flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None
1415
index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = index_put_ = None
1418
def test_neg_shape(self):
1420
return torch.empty(-a.shape[0] + 10)
1422
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
1423
self.assertExpectedInline(r, """\
1424
def forward(self, a_1):
1425
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
1426
neg = -sym_size_int; sym_size_int = None
1427
add = neg + 10; neg = None
1428
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
1431
def test_unbacked_unification(self):
1433
z = torch.zeros(x.item())
1436
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1437
self.assertExpectedInline(r, """\
1438
def forward(self, x_1, y_1):
1439
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1440
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1441
add = torch.ops.aten.add.Tensor(zeros, y_1); zeros = y_1 = None
1444
def test_reshape_divisibility_unbacked(self):
1447
r = torch.zeros(i0, 4, 20)
1448
r = r.transpose(2, 1)
1449
return r.reshape(-1, 80)
1450
make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
1452
def test_view_divisibility_unbacked(self):
1455
r = torch.zeros(i0, 192)
1456
return r.view(12, -1, 192)
1457
make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
1459
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1460
def test_view_divisibility_unbacked_relatively_prime(self):
1464
torch._check_is_size(i0)
1467
torch._check(i0 <= 448)
1468
return torch.zeros(256 * i0).view(-1, 447)
1469
make_fx(f, tracing_mode="symbolic")(torch.tensor(256 * 447, device="cuda"))
1471
def test_unbacked_unify_guard(self):
1473
z = torch.zeros(x.item())
1474
torch._check(z.size(0) == y.size(0))
1480
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1481
self.assertExpectedInline(r, """\
1482
def forward(self, x_1, y_1):
1483
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1484
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = zeros = None
1485
add = torch.ops.aten.add.Tensor(y_1, 2); y_1 = None
1488
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1489
@unittest.expectedFailure
1490
def test_unbacked_unify_guard_transitivity(self):
1492
z1 = torch.zeros(x1.item())
1493
z2 = torch.zeros(x2.item())
1494
torch._check(z1.size(0) == z2.size(0))
1495
torch._check(z2.size(0) == y.size(0))
1501
gm = make_fx(f, tracing_mode="symbolic")(
1502
torch.tensor(10, device="cuda"),
1503
torch.tensor(10, device="cuda"),
1504
torch.randn(10, device="cuda")
1506
insert_deferred_runtime_asserts(gm, gm.shape_env, "test")
1508
r = str(gm.code).strip()
1513
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1514
def test_unbacked_unify_dependency_violation(self):
1515
def f(x1, x2, x3, y):
1517
torch._check(z1 // 9 == 1)
1520
torch._check(z1 == z2 + z3)
1530
gm = make_fx(f, tracing_mode="symbolic")(
1531
torch.tensor(10, device="cuda"), torch.tensor(5, device="cuda"),
1532
torch.tensor(5, device="cuda"), torch.randn(1, device="cuda")
1534
insert_deferred_runtime_asserts(gm, gm.shape_env, "test")
1536
self.assertEqual(gm(
1537
torch.tensor(12, device="cuda"), torch.tensor(6, device="cuda"),
1538
torch.tensor(6, device="cuda"), torch.tensor([1.0], device="cuda")),
1539
torch.tensor([2.0], device="cuda")
1541
with self.assertRaises(RuntimeError):
1543
torch.tensor(20, device="cuda"), torch.tensor(10, device="cuda"),
1544
torch.tensor(10, device="cuda"), torch.tensor([1.0], device="cuda")
1548
def test_split_unbacked_sizes(self):
1549
def f(lengths, values):
1551
sizes = [lengths[i].item() for i in range(lengths.size(0))]
1554
torch._constrain_as_size(s)
1555
return torch.split(values, sizes)
1557
r = str(make_fx(f, tracing_mode="symbolic")(
1558
torch.tensor([2, 3, 4]),
1561
self.assertExpectedInline(r, """\
1562
def forward(self, lengths_1, values_1):
1563
select = torch.ops.aten.select.int(lengths_1, 0, 0)
1564
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None
1565
select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
1566
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
1567
select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None
1568
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None
1569
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense); sym_constrain_range_for_size = None
1570
sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1); sym_constrain_range_for_size_1 = None
1571
sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2); sym_constrain_range_for_size_2 = None
1572
split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
1573
getitem = split_with_sizes[0]
1574
getitem_1 = split_with_sizes[1]
1575
getitem_2 = split_with_sizes[2]; split_with_sizes = None
1576
return (getitem, getitem_1, getitem_2)""")
1578
def test_invalidate_nonzero(self):
1587
assert x1.shape[0] == x2.shape[0]
1592
bool(x1.shape[0] == y.shape[0])
1593
self.fail("didn't raise exception")
1594
except GuardOnDataDependentSymNode:
1597
make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1599
@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
1600
def test_invalidate_nonzero_propagate_real_tensors(self):
1606
assert x1.shape[0] == x2.shape[0]
1611
assert x1.shape[0] == y.shape[0]
1613
make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1615
def test_sqrt_size(self):
1617
return a / a.size(-1) ** 0.5
1619
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1620
self.assertExpectedInline(r, """\
1621
def forward(self, a_1):
1622
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1623
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
1624
pow_1 = sym_float ** 0.5; sym_float = None
1625
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
1628
def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self):
1630
class Bar(torch.nn.Module):
1631
def __init__(self) -> None:
1634
def forward(self, x):
1637
class Foo(torch.nn.Module):
1638
def __init__(self) -> None:
1642
def forward(self, x):
1643
return x + self.bar(x)
1645
gm = make_fx(Foo())(torch.randn(4, 4))
1646
for node in gm.graph.nodes:
1647
self.assertTrue("nn_module_stack" not in node.meta)
1651
def functional_call(*args, **kwargs):
1652
with stateless._reparametrize_module(foo, {}):
1653
return foo(*args, **kwargs)
1655
functional_call._orig_mod = foo
1657
gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4))
1659
for node in gm_with_stack.graph.nodes:
1660
if "nn_module_stack" in node.meta:
1661
if len(node.meta["nn_module_stack"]) == 1:
1662
self.assertTrue("custom_tracer_preserving_nn_module_stack.<locals>.Foo" in str(node.meta["nn_module_stack"]))
1664
elif len(node.meta["nn_module_stack"]) == 2:
1665
self.assertTrue("preserving_nn_module_stack.<locals>.Bar" in str(node.meta["nn_module_stack"]))
1669
self.assertTrue(False)
1671
self.assertTrue(found)
1673
gm_without_stack = make_fx(functional_call)(torch.randn(4, 4))
1674
for node in gm_without_stack.graph.nodes:
1675
self.assertTrue("nn_module_stack" not in node.meta)
1677
def test_symint_to_tensor(self):
1679
return a / a.shape[0]
1681
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1682
self.assertExpectedInline(r, """\
1683
def forward(self, a_1):
1684
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1685
div = torch.ops.aten.div.Tensor(a_1, sym_size_int); a_1 = sym_size_int = None
1688
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
1689
self.assertExpectedInline(r, """\
1690
def forward(self, a_1):
1691
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1692
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
1693
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
1698
val = torch.mul(a, b)
1699
out = torch.cat([val, val])
1700
if out.shape[0] * out.shape[1] > 20:
1705
test_inputs.append([(1, 5), (6, 1)])
1706
test_inputs.append([(1, 4), (3, 1)])
1707
gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
1708
self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
1709
self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
1710
self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
1712
def test_new_empty(self):
1714
return a.new_empty(b.shape[0], b.shape[1] * 2)
1716
self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env
1718
def test_size_with_tensor(self):
1723
max_size = torch.tensor([800, 1216], dtype=torch.int64)
1724
batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
1725
return tensor.new_empty(batch_shape)
1727
a = torch.randn(3, 800, 1199)
1729
make_fx(f, tracing_mode="symbolic")(a)
1731
def test_fake_tensor_as_size(self):
1733
r = torch.zeros([x])
1736
fx_g = make_fx(f, tracing_mode="symbolic")(torch.tensor(4))
1737
self.assertExpectedInline(fx_g.code.strip(), """\
1738
def forward(self, x_1):
1739
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
1740
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
1743
def test_expand(self):
1746
c = b.expand(a.shape)
1749
self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
1750
self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
1752
def test_metadata(self):
1754
d = a.new_empty(a.shape[0] + b.shape[0])
1756
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
1757
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
1758
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
1759
self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr)
1761
def test_metadata_fresh(self):
1763
assert x.shape[0] == 3
1766
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
1767
meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
1768
meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
1769
self.assertTrue(meta_cos.meta['val'].shape[0] == 3)
1772
self.assertTrue(meta_inp.meta['val'].shape[0] == 3)
1774
def test_elementwise_meta_with_sym_numbers(self):
1775
def f(x, offset, as_sym_float=False):
1778
x0 = torch.sym_float(x0)
1779
return torch.add(x0, offset)
1781
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False)
1782
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1783
self.assertEqual(meta_add.meta['val'].shape, ())
1784
self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1786
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False)
1787
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1788
self.assertEqual(meta_add.meta['val'].shape, ())
1789
self.assertEqual(meta_add.meta['val'].dtype, torch.int64)
1791
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True)
1792
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1793
self.assertEqual(meta_add.meta['val'].shape, ())
1794
self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1796
def test_return_symint(self):
1798
return x.shape[0], x.cos(), x.shape[0] / 5
1799
self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1803
self._test_dynamic(f, [(5, 3)], [[(4, 6)]])
1805
def test_rmethod(self):
1807
return x.size(0) + x
1808
self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1810
def test_mega_guard(self):
1812
assert a.shape[0] == b.shape[0] * 2
1814
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
1815
from torch._dynamo.source import LocalSource
1816
self.assertExpectedInline(
1817
str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)),
1818
"""["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]"""
1820
self.assertExpectedInline(
1821
str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)),
1822
"""["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]"""
1825
def test_guard_upperbound_range_refinement(self):
1827
assert a.shape[0] > 5 and a.shape[0] > 12
1829
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1830
self.assertExpectedInline(show_guards(tensor), """13 <= L['a'].size()[0]""")
1832
def test_guard_lowerbound_range_refinement(self):
1834
assert a.shape[0] < 20 and a.shape[0] < 30
1836
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1837
self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] <= 19""")
1839
def test_guard_upperbound_range_refinement_multivariate(self):
1841
assert a.shape[0] > 5 and a.shape[0] > 12
1842
assert a.shape[1] > 5 and a.shape[1] > a.shape[0]
1844
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20)))
1845
self.assertExpectedInline(show_guards(tensor), """\
1846
L['a'].size()[1] > L['a'].size()[0]
1847
13 <= L['a'].size()[0]
1848
14 <= L['a'].size()[1]""")
1850
def test_guard_lowerbound_range_refinement_multivariate(self):
1852
assert a.shape[0] < 20 and a.shape[0] < 30
1853
assert a.shape[1] < 30 and a.shape[1] < a.shape[0]
1855
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5)))
1856
self.assertExpectedInline(
1857
show_guards(tensor),
1859
L['a'].size()[1] < L['a'].size()[0]
1860
L['a'].size()[0] <= 19
1861
L['a'].size()[1] <= 18""")
1863
def test_sym_storage_offset(self):
1867
inp = (torch.randn(8)[3:], torch.randn(5))
1868
fx_g = make_fx(f, tracing_mode="symbolic")(*inp)
1869
inp = (torch.randn(8)[3:], torch.randn(5))
1870
self.assertEqual(fx_g(*inp), f(*inp))
1872
def _assert_no_guards(self, fx_g, free_symbols):
1873
assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
1874
assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
1876
def test_guards_equal(self):
1882
fx_g = _trace(f, (5, 6), (5, 6))
1883
self._assert_no_guards(fx_g, 2)
1885
fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
1886
self._assert_no_guards(fx_g, 3)
1888
fx_g = _trace(f, (5, 1), (1, 6))
1889
self._assert_no_guards(fx_g, 2)
1893
cat = torch.cat([c, d])
1896
fx_g = _trace(f, 7, 7, 4, 3)
1897
self._assert_no_guards(fx_g, 2)
1899
def f(a, b, c, d, e):
1900
vals = [a, b, c, d, e]
1902
for idx in range(len(vals) - 1):
1903
x = torch.cat([x, vals[idx]]) + vals[idx + 1]
1906
fx_g = _trace(f, 2, 4, 8, 16, 32)
1907
self._assert_no_guards(fx_g, 1)
1910
a = a.view(b.shape[0])
1913
fx_g = _trace(f, (4, 2), 8)
1914
self._assert_no_guards(fx_g, 2)
1916
fx_g = _trace(f, (4, 2), (8, 5))
1917
self._assert_no_guards(fx_g, 3)
1919
fx_g = _trace(f, (2, 3, 4), 24)
1920
self._assert_no_guards(fx_g, 3)
1922
def test_nonidentity_transitive_guards(self):
1923
def f(a, b, c, d, e):
1924
vals = [a, b, c, d, e]
1926
for idx in range(len(vals) - 1):
1927
cat_vals.append(torch.cat([vals[idx], vals[idx]]))
1929
for a, b in reversed(list(zip(cat_vals, vals[1:]))):
1930
final_vals.append(a + b)
1933
fx_g = _trace(f, 2, 4, 8, 16, 32)
1934
self.assertExpectedInline(show_guards(fx_g), """""")
1936
@torch.fx.experimental._config.patch(translation_validation=True)
1937
def test_constant_specialization(self):
1939
assert t.shape[0] == 10
1942
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10))
1943
self.assertExpectedInline(show_guards(tensor), """""")
1954
skip('empty_permuted'),
1956
skip('linalg.lstsq', 'grad_oriented'),
1957
skip('nn.functional.max_unpool1d', '', device_type='cpu'),
1958
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
1959
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
1960
skip('linalg.lstsq'),
1965
xfail('nn.functional.gaussian_nll_loss'),
1966
xfail('tensor_split'),
1969
xfail('nanquantile'),
1972
xfail('sparse.sampled_addmm'),
1973
xfail('sparse.mm', 'reduce'),
1981
skip('empty_strided', '', device_type='cpu'),
1984
only_real_tensor_failures = {
1988
only_fake_tensor_failures = {
1992
fake_tensor_failures = {
1994
skip('nn.functional.nll_loss'),
1997
symbolic_tensor_failures = {
1998
xfail('combinations', ''),
2000
xfail('histogram', ''),
2001
xfail('histogramdd', ''),
2002
xfail('nanquantile', ''),
2003
xfail('nn.functional.binary_cross_entropy', ''),
2004
xfail('nn.functional.cross_entropy', ''),
2005
xfail('nn.functional.ctc_loss'),
2006
xfail('quantile', ''),
2007
xfail('unique_consecutive', ''),
2009
xfail('max_pool2d_with_indices_backward', ''),
2012
xfail('fft.fft', ''),
2013
xfail('fft.hfft2', ''),
2014
xfail('fft.hfft', ''),
2015
xfail('fft.hfftn', ''),
2016
xfail('fft.ifft', ''),
2017
xfail('fft.ihfft2', ''),
2018
xfail('fft.ihfft', ''),
2019
xfail('fft.ihfftn', ''),
2020
xfail('fft.ihfft2', ''),
2021
xfail('fft.irfft2', ''),
2022
xfail('fft.irfft', ''),
2023
xfail('fft.irfftn', ''),
2024
xfail('fft.rfft2', ''),
2025
xfail('fft.rfft', ''),
2026
xfail('fft.rfftn', ''),
2029
symbolic_tensor_segfaults = {
2030
skip('nn.functional.batch_norm')
2033
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
2035
inplace_symbolic_tensor_failures = {
2037
xfail('float_power', ''),
2040
out_symbolic_tensor_failures = {
2049
xfail('_batch_norm_with_update', ''),
2050
xfail('_native_batch_norm_legit', ''),
2052
xfail('argmax', ''),
2053
xfail('argmin', ''),
2054
xfail('fft.fft2', ''),
2055
xfail('fft.fftn', ''),
2056
xfail('fft.ifft2', ''),
2057
xfail('fft.ifftn', ''),
2058
xfail('gather', ''),
2059
xfail('linalg.pinv', ''),
2060
xfail('linalg.pinv', 'hermitian'),
2062
xfail('scatter_add', ''),
2063
xfail('scatter', ''),
2064
xfail('take_along_dim', ''),
2065
xfail('triangular_solve', ''),
2073
xfail('index_reduce', 'prod'),
2074
xfail('index_reduce', 'mean'),
2075
xfail('index_reduce', 'amax'),
2076
xfail('index_reduce', 'amin'),
2079
out_symbolic_tensor_segfaults = {
2080
skip('nanmean', ''),
2083
out_symbolic_tensor_failures.update(out_symbolic_tensor_segfaults)
2087
def _get_safe_inplace(inplace_variant):
2088
@functools.wraps(inplace_variant)
2089
def _fn(t, *args, **kwargs):
2090
return inplace_variant(t.clone(), *args, **kwargs)
2094
def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, out=False):
2095
fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op
2096
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
2102
for sample_input in itertools.islice(sample_inputs_itr, count):
2103
if inplace and sample_input.broadcasts_input:
2105
args = [sample_input.input] + list(sample_input.args)
2106
kwargs = sample_input.kwargs
2108
expected = fn(*args, **kwargs)
2109
kwargs['out'] = expected
2112
optests.make_fx_check(fn, args, kwargs, tracing_mode, self.assertEqual,
2113
randomize_data=True)
2114
except DynamicOutputShapeException:
2115
self.skipTest("Dynamic output shape operation in trace")
2118
def skipIfNameMatches(pattern):
2120
Decorator to skip a test if its name matches the given pattern.
2122
def decorator(test_func):
2123
def wrapper(*args, **kwargs):
2124
if re.match(pattern, test_func.__name__):
2125
raise unittest.SkipTest(f"Test '{test_func.__name__}' skipped because its name matches the pattern '{pattern}'")
2126
return test_func(*args, **kwargs)
2131
filtered_hop_db = [op for op in hop_db if op.name != "auto_functionalize"]
2133
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond requires dynamo")
2134
class TestProxyTensorOpInfo(TestCase):
2135
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2136
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures.union(only_real_tensor_failures))
2137
def test_make_fx_exhaustive(self, device, dtype, op):
2138
_test_make_fx_helper(self, device, dtype, op, "real")
2140
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2141
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive',
2142
make_fx_failures.union(fake_tensor_failures, only_fake_tensor_failures))
2143
def test_make_fx_fake_exhaustive(self, device, dtype, op):
2144
_test_make_fx_helper(self, device, dtype, op, "fake")
2146
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2147
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
2148
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
2149
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
2150
_test_make_fx_helper(self, device, dtype, op, "symbolic")
2152
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2153
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
2154
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)
2155
def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op):
2156
if not op.get_inplace():
2157
self.skipTest("No inplace variable for this op")
2158
_test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True)
2160
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2161
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out',
2162
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | out_symbolic_tensor_failures)
2163
def test_make_fx_symbolic_exhaustive_out(self, device, dtype, op):
2164
if not op.supports_out:
2165
self.skipTest("Op doesn't support out")
2166
_test_make_fx_helper(self, device, dtype, op, "symbolic", out=True)
2170
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
2173
if __name__ == '__main__':