pytorch

Форк
0
/
test_proxy_tensor.py 
2174 строки · 81.9 Кб
1
# Owner(s): ["module: ProxyTensor"]
2

3
from torch.testing._internal.common_utils import TestCase, run_tests
4
import torch
5
import torch._dynamo
6
import unittest
7
import warnings
8
import operator
9
from collections.abc import Iterable
10
from torch.nn.utils import stateless
11
from torch.testing._internal.common_device_type import instantiate_device_type_tests
12
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
13
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
14
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
15
from torch._decomp import decomposition_table
16
from torch.fx.experimental.symbolic_shapes import (
17
    eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
18
    guard_int, GuardOnDataDependentSymNode
19
)
20
from torch.testing._internal.custom_op_db import custom_op_db
21
from torch.testing._internal.hop_db import hop_db
22
from torch.testing._internal.common_device_type import ops
23
import torch.testing._internal.optests as optests
24
from torch._C import _disabled_torch_function_impl
25
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
26
from torch.utils._pytree import tree_map
27
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
28
from torch import nn
29
import torch._functorch.config
30
import re
31

32
import functools
33
import itertools
34

35
aten = torch.ops.aten
36

37
HAS_CUDA = torch.cuda.is_available()
38

39

40
def strip_end(s, suffix):
41
    if suffix and s.endswith(suffix):
42
        return s[:-len(suffix)]
43
    else:
44
        return s
45

46

47
def show_guards(gm):
48
    names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
49
    return "\n".join(
50
        gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
51
    )
52

53

54
def process_failures():
55
    """
56
    Takes file containing failures like
57

58
    FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition  # noqa: B950
59

60
    and processes them into a list of opinfo xfails
61
    """
62
    f = open('pytest_failures')
63
    failures = f.readlines()
64
    failures = [i.strip() for i in failures]
65

66
    def process_failure_string(s, matcher):
67
        out = re.search(matcher, s)
68
        return out.groups()
69

70
    SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
71
    failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
72

73
    def create_normalized_name(op):
74
        if op.variant_test_name == '':
75
            s = op.name
76
        else:
77
            s = f"{op.name}.{op.variant_test_name}"
78
        return s.replace('.', '_')
79

80
    remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
81

82
    print("symbolic_tensor_failures = {")
83
    for failure, reason in failures:
84
        print(f"    xfail{remap_opinfo[failure]},  # {reason}")
85
    print("}")
86

87

88
USE_TORCHVISION = False
89
try:
90
    import torchvision
91
    USE_TORCHVISION = True
92
except ImportError:
93
    warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
94
                  "to install it with commands from pytorch.org, post-fixed with "
95
                  "`--no-deps` to avoid overwriting the pytorch installation",
96
                  UserWarning)
97

98

99
def _create_new_input(x):
100
    if not isinstance(x, torch.Tensor):
101
        return x
102
    if x.dtype != torch.float:
103
        return x + 1
104
    if x.is_leaf:
105
        return torch.rand_like(x, requires_grad=x.requires_grad)
106
    else:
107
        return torch.rand_like(x)
108

109
"""
110
Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
111
"""
112
class UnwrapTensor(torch.Tensor):
113
    @staticmethod
114
    def __new__(cls, tensor: torch.Tensor):
115
        r = torch.Tensor._make_wrapper_subclass(
116
            cls,
117
            tensor.size(),
118
            dtype=tensor.dtype,
119
            device=tensor.device,
120
            layout=tensor.layout,
121
            requires_grad=tensor.requires_grad,
122
        )
123
        r._tensor = tensor
124
        return r
125

126
    def __repr__(self):
127
        # TODO: consider all_gather the local tensors for better debugging
128
        return f"UnwrapTensor({self._tensor})"
129

130
    __torch_function__ = _disabled_torch_function_impl
131

132
    @classmethod
133
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
134
        def unwrap(e):
135
            ret = e
136
            if isinstance(e, UnwrapTensor):
137
                ret = e._tensor.cos()
138

139
            return ret
140

141
        args = tree_map(unwrap, args)
142
        kwargs = tree_map(unwrap, kwargs)
143
        return func(*args, **kwargs)
144

145
class TestGenericProxyTensor(TestCase):
146
    # WARNING: if any of your inputs are index tensors, DO NOT use this
147
    # function
148
    def _test(self, f, inps):
149
        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
150
        new_inps = tree_map(_create_new_input, inps)
151
        r1 = fx_f(*new_inps)
152
        r2 = f(*new_inps)
153
        self.assertEqual(r1, r2)
154

155
    def test_pre_dispatch_mode_stack(self):
156
        def f(a):
157
            b = torch.ones(4, 4)
158
            return torch.matmul(a, b)
159
        # We expect to see matmul in the trace - it should NOT be decomposed into mm.
160
        # Also, torch.ones() doesn't show up in the trace.
161
        # This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
162
        # so our mode never sees it - it goes directly to the BackendSelect key.
163
        inp = torch.ones(4, 4)
164
        # Test that make_fx(pre_dispatch=True) clears caches properly.
165
        from torch._dispatch.python import enable_python_dispatcher
166
        with enable_python_dispatcher():
167
            out1 = f(inp)
168
        fx_g = make_fx(f, pre_dispatch=True)(inp)
169
        self.assertExpectedInline(fx_g.code.strip(), """\
170
def forward(self, a_1):
171
    ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
172
    matmul = torch.ops.aten.matmul.default(a_1, ones);  a_1 = ones = None
173
    return matmul""")
174

175
    def test_pre_dispatch_linear(self):
176
        def f(a, b, c):
177
            return torch.nn.functional.linear(a, b, c)
178
        a = torch.ones(4, 4)
179
        b = torch.ones(4, 4)
180
        c = torch.ones(4)
181
        fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
182
        out1 = f(a, b, c)
183
        out2 = fx_g(a, b, c)
184
        self.assertEqual(out1, out2)
185

186
    def test_pre_dispatch_no_grad(self):
187
        def f(a):
188
            b = a.sin()
189
            torch.set_grad_enabled(False)
190
            c = b.cos()
191
            torch.set_grad_enabled(True)
192
            return b + c.sin()
193
        a1 = torch.randn(4, requires_grad=True)
194
        a2 = a1.clone().detach().requires_grad_(True)
195
        a_tmp = a1.clone().detach().requires_grad_(True)
196
        fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
197
        out1 = f(a1)
198
        out2 = fx_g(a2)
199
        self.assertEqual(out1, out2)
200
        out1.sum().backward()
201
        out2.sum().backward()
202
        self.assertEqual(a1.grad, a2.grad)
203

204
    def test_make_fx_simple(self):
205
        def f(x):
206
            return torch.sin(x)
207
        self._test(f, (torch.randn(3),))
208

209
    def test_scalar_device(self, device='cpu'):
210
        def f(a, b):
211
            return a + b
212
        self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
213

214
    def test_isolated_graphmodule(self):
215
        def is_any_sum(gm):
216
            return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
217

218
        def is_any_digamma(gm):
219
            return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
220

221
        def is_any_sigmoid(gm):
222
            return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
223

224
        def inner(x):
225
            return torch.sum(x)
226

227
        def f(x):
228
            gm = get_isolated_graphmodule(inner, (x,), {})
229
            self.assertTrue(is_any_sum(gm))
230
            return x + torch.randn(x.shape)
231

232
        # get_isolated_graphmodule uses make_fx internally that shouldn't be traced
233
        # by the outer make_fx call
234
        traced = make_fx(f)(torch.randn(3))
235
        self.assertFalse(is_any_sum(traced))
236

237
        # When factory functions are used, they should not be traced
238
        # by the outer make_fx call
239
        def inner_with_factory():
240
            val = torch.tensor(float(1))
241
            val.add_(2)
242
            return torch.full((10, 10), val).sum()
243

244
        def f1(x):
245
            gm = get_isolated_graphmodule(inner_with_factory, (), {})
246
            self.assertTrue(is_any_sum(gm))
247
            return torch.sigmoid(x)
248

249
        def f2(x):
250
            gm = get_isolated_graphmodule(f1, (x,), {})
251
            self.assertFalse(is_any_sum(gm))
252
            self.assertTrue(is_any_sigmoid(gm))
253
            return torch.digamma(x)
254

255
        traced = make_fx(f2)(torch.randn(3))
256
        self.assertFalse(is_any_sum(traced))
257
        self.assertFalse(is_any_sigmoid(traced))
258
        self.assertTrue(is_any_digamma(traced))
259

260
        # Verify nested make_fx calls don't make factory functions to be leaked
261
        # into the outer graph. Verify that `make_fx`` itself does not leak its execution.
262
        def f2(x):
263
            gm = make_fx(f1)(x)
264
            self.assertFalse(is_any_sum(gm))
265
            self.assertTrue(is_any_sigmoid(gm))
266
            return torch.digamma(x)
267

268
        traced = make_fx(f2)(torch.randn(3))
269
        self.assertFalse(is_any_sum(traced))
270
        self.assertFalse(is_any_sigmoid(traced))
271
        self.assertTrue(is_any_digamma(traced))
272

273
        # Verify that the `forward`` function of a graph module produced as a
274
        # side effect of an interior `make_fx` is still traced
275
        def f3(x):
276
            gm = make_fx(f1)(x)
277
            self.assertFalse(is_any_sum(gm))
278
            self.assertTrue(is_any_sigmoid(gm))
279
            # `gm.forward`` is still traced
280
            return torch.digamma(gm(x))
281

282
        traced = make_fx(f3)(torch.randn(3))
283
        self.assertFalse(is_any_sum(traced))
284
        self.assertTrue(is_any_sigmoid(traced))
285
        self.assertTrue(is_any_digamma(traced))
286

287
        # Verify interaction with non-ProxyTensor modes
288
        from torch.testing._internal.logging_tensor import LoggingTensorMode
289

290
        def f1_logging(x):
291
            with LoggingTensorMode():
292
                gm = get_isolated_graphmodule(inner_with_factory, (), {})
293
            self.assertTrue(is_any_sum(gm))
294
            return torch.sigmoid(x)
295

296
        def f2_logging(x):
297
            with LoggingTensorMode(), LoggingTensorMode():
298
                gm = get_isolated_graphmodule(f1_logging, (x,), {})
299
            self.assertFalse(is_any_sum(gm))
300
            self.assertTrue(is_any_sigmoid(gm))
301
            return torch.digamma(x)
302

303
        traced = make_fx(f2_logging)(torch.randn(3))
304
        self.assertFalse(is_any_sum(traced))
305
        self.assertFalse(is_any_sigmoid(traced))
306
        self.assertTrue(is_any_digamma(traced))
307

308
        # Verify interaction with another tensor subclass
309
        # This case currently doesn't work and should raise an error
310
        # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
311
        from torch.testing._internal.logging_tensor import LoggingTensor
312

313
        def f1_logging_tensor(x):
314
            gm = get_isolated_graphmodule(inner_with_factory, (), {})
315
            self.assertTrue(is_any_sum(gm))
316
            return torch.sigmoid(x)
317

318
        def f2_logging_tensor(x):
319
            x = LoggingTensor(x)
320
            gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
321
            self.assertFalse(is_any_sum(gm))
322
            self.assertTrue(is_any_sigmoid(gm))
323
            return torch.digamma(x)
324

325
        traced = make_fx(f2_logging_tensor)(torch.randn(3))
326
        self.assertFalse(is_any_sum(traced))
327
        self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor
328
        self.assertTrue(is_any_digamma(traced))
329

330
    # See https://github.com/pytorch/pytorch/issues/97541
331
    def test_empty_like_doesnt_burn_in_defaults(self):
332
        def f(x):
333
            return torch.empty_like(x)
334
        out = make_fx(f)(torch.randn(3))
335
        self.assertExpectedInline(out.code.strip(), """\
336
def forward(self, x_1):
337
    empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False);  x_1 = None
338
    return empty_like""")
339

340
    def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
341
        def f(x):
342
            y = x.new_zeros(x.size())
343
            y.copy_(x)
344
            return y
345

346
        def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
347
            return torch.zeros(size, dtype=inp.dtype, device=inp.device)
348

349
        factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
350

351
        # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode
352
        # to still be (re-entrantly) enabled, so that the `torch.zero()` call
353
        # returns a ProxyTensor.
354
        out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
355
        self.assertExpectedInline(out.code, """\
356

357

358

359
def forward(self, x_1):
360
    zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
361
    copy_ = torch.ops.aten.copy_.default(zeros, x_1);  zeros = x_1 = None
362
    return copy_
363
    """)
364

365
    def test_make_fx_reentrant_dispatch(self):
366
        def f(x):
367
            return torch.ops.aten.norm.Scalar(x, 2.0)
368

369
        def norm_decomp(x, p=2.0):
370
            if p != 2.0:
371
                raise RuntimeError("can't handle with p != 2")
372
            return torch.sqrt(torch.sum(torch.square(x)))
373

374
        decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
375

376
        traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
377

378
        for n in traced.graph.nodes:
379
            self.assertTrue("square" not in str(n.target))
380
            self.assertTrue("norm" not in str(n.target))
381

382
    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
383
    def test_resnet18_backward_trace(self):
384
        mod = torchvision.models.resnet18()
385

386
        # An old version of this test called the module directly.  This works
387
        # for tracing_mode == "real", but for fake tensors, we also have to
388
        # ensure that the parameters and buffers get wrapped in fake tensors
389
        # because free fake tensors are not supported.  Fortunately functional_call
390
        # does precisely this for us.
391
        def f(x, params, buffers):
392
            for p in params.values():
393
                p.grad = None
394
            loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
395
            # I could have done this with the functional API, but there is
396
            # plenty of exercising this; I want to show mutating API still
397
            # works
398
            loss.backward()
399
            return [p.grad for p in params.values()]
400

401
        inp = torch.randn(3, 3, 250, 250)
402
        self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
403

404
    def test_varargs(self):
405
        def f(*args):
406
            return sum(args)
407

408
        self._test(f, [torch.randn(2), torch.randn(2)])
409

410
    def test_proxy_tensor(self):
411
        def f_grad(x):
412
            val = x.cos().cos().sum()
413
            return torch.autograd.grad(val, x)
414

415
        def f_backward(x):
416
            val = x.cos().cos().sum()
417
            val.backward()
418
            return x.grad
419

420
        for f in [f_grad, f_backward]:
421
            self._test(f, [torch.randn(3, requires_grad=True)])
422

423
    def test_pickle_issue89626(self):
424
        import pickle
425
        x = torch.randn(2)
426
        make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
427
        pickle.dumps(x)
428

429
    def test_inplace_metadata(self):
430
        def f(x):
431
            x = x.clone()
432
            x.unsqueeze_(-1)
433
            assert x.shape[-1] == 1
434
            return x
435

436
        self._test(f, [torch.randn(5)])
437

438
    def test_mode_tracing_factory_function(self):
439
        def f(x):
440
            return x + torch.randn(x.shape)
441

442
        # default behavior should trace factory functions
443
        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
444
        self.assertTrue(
445
            any(
446
                node.target == aten.randn.default
447
                for node in traced.graph.nodes
448
            )
449
        )
450

451
    def test_pre_dispatch_functionalization(self):
452
        def f(x):
453
            a = FunctionalTensorMode(pre_dispatch=True)
454
            with a:
455
                x_unwrapped = FunctionalTensor.to_functional(x)
456
                y = torch.matmul(x_unwrapped, x_unwrapped)
457
                y = y + x_unwrapped
458
                y.mul_(5)
459
                y_unwrapped = torch._from_functional_tensor(y.elem)
460
                return y_unwrapped
461

462
        from torch._dispatch.python import enable_python_dispatcher
463

464
        with enable_python_dispatcher():
465
            inp = torch.randn(4, 4)
466
            gm = make_fx(f, pre_dispatch=True)(inp)
467

468
        # TODO actually not decompose
469
        self.assertExpectedInline(gm.code.strip(), """\
470
def forward(self, x_1):
471
    matmul = torch.ops.aten.matmul.default(x_1, x_1)
472
    add = torch.ops.aten.add.Tensor(matmul, x_1);  matmul = x_1 = None
473
    mul = torch.ops.aten.mul.Tensor(add, 5);  add = None
474
    return mul""")
475

476
    def test_pre_dispatch_functionalization_view_op(self):
477
        def f(x):
478
            a = FunctionalTensorMode(pre_dispatch=True)
479
            with a:
480
                x_unwrapped = FunctionalTensor.to_functional(x)
481
                y = torch.matmul(x_unwrapped, x_unwrapped)
482
                x_unwrapped = x_unwrapped.transpose(1, 0)
483
                y = y + x_unwrapped
484
                y = y.view(2, 8)
485
                y_unwrapped = torch._from_functional_tensor(y.elem)
486
                return y_unwrapped
487

488
        from torch._dispatch.python import enable_python_dispatcher
489

490
        with enable_python_dispatcher():
491
            inp = torch.randn(4, 4)
492
            gm = make_fx(f, pre_dispatch=True)(inp)
493

494
        # TODO actually not decompose
495
        self.assertExpectedInline(gm.code.strip(), """\
496
def forward(self, x_1):
497
    matmul = torch.ops.aten.matmul.default(x_1, x_1)
498
    transpose = torch.ops.aten.transpose.int(x_1, 1, 0);  x_1 = None
499
    add = torch.ops.aten.add.Tensor(matmul, transpose);  matmul = transpose = None
500
    view = torch.ops.aten.view.default(add, [2, 8]);  add = None
501
    return view""")
502

503
    def test_val_metadata_mutation(self):
504
        def f(x):
505
            y = x.clone()
506
            y.unsqueeze_(0)
507
            return y
508

509
        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True))
510
        self.assertEqual([
511
            tuple(node.meta['val'].shape)
512
            for node in traced.graph.nodes
513
            if 'val' in node.meta
514
        ], [(3,), (3,), (1, 3)])
515

516
    def test_make_fx_overloads(self):
517
        def f(x):
518
            return x.cos() + torch.randn(x.shape)
519

520
        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
521

522
        self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
523
                            for node in traced.graph.nodes if node.op == 'call_function'))
524

525
    def test_tensor_constants(self):
526
        def f():
527
            val = torch.tensor(float('inf'))
528
            return torch.full((100, 100), val)
529

530
        self._test(f, [])
531

532
    def test_allclose(self):
533
        def f(a, b):
534
            return torch.allclose(a, b)
535

536
        def test_f():
537
            make_fx(f, tracing_mode=self.tracing_mode)(
538
                torch.zeros(3), torch.zeros(3)
539
            )
540

541
        if self.tracing_mode != "real":
542
            self.assertRaises(DataDependentOutputException, test_f)
543
        else:
544
            self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
545

546
    def test_constant_proxy_tensor_mut(self):
547
        def f():
548
            val = torch.tensor(float(1))
549
            val.add_(2)
550
            return torch.full((100, 100), val)
551

552
        g = make_fx(f, tracing_mode=self.tracing_mode)()
553
        self.assertEqual(g(), f())
554
        # In case we mutated shared state in the g graph!
555
        self.assertEqual(g(), f())
556

557
    def test_constant_unbind(self):
558
        def f():
559
            val = torch.tensor([2])
560
            r, = torch.unbind(val, 0)
561
            return r.item()
562

563
        g = make_fx(f, tracing_mode=self.tracing_mode)()
564
        self.assertEqual(g(), f())
565

566
    def test_constant_blowup(self):
567
        def f():
568
            val = torch.tensor([2])
569
            blowup = val.repeat(1000)
570
            return bool(blowup.sum().item() == 2)
571

572
        def test_f():
573
            make_fx(f, tracing_mode=self.tracing_mode)()
574

575
        self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
576

577
    def test_constant_random(self):
578
        def f():
579
            val = torch.tensor([2.0])
580
            val.normal_()
581
            return bool(val.item() == 2.1)
582

583
        def test_f():
584
            make_fx(f, tracing_mode=self.tracing_mode)()
585

586
        self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
587

588
    def test_decomposition_interpreter(self):
589
        def fn(x):
590
            return torch.nn.functional.silu(x)
591

592
        x = torch.rand((4, 4))
593
        fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
594

595
        found_silu = False
596
        for n in fx_module.graph.nodes:
597
            if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
598
                found_silu = True
599

600
        self.assertTrue(found_silu)
601

602
        new_graph = torch.fx.Graph()
603
        silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
604
        DecompositionInterpreter(
605
            fx_module,
606
            new_graph=new_graph,
607
            decomposition_table=silu_decomp_table,
608
        ).run(x)
609

610
        decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
611

612
        for n in decomposed_module.graph.nodes:
613
            self.assertTrue(n.target != torch.ops.aten.silu)
614
            self.assertTrue(n.target != torch.ops.aten.silu.default)
615

616
        self.assertEqual(fx_module(x), decomposed_module(x))
617

618
    def test_make_fx_model_fwd_bwd(self):
619
        class Foo(torch.nn.Module):
620
            def __init__(self) -> None:
621
                super().__init__()
622
                self.linear = torch.nn.Linear(5, 5)
623

624
            def forward(self, x):
625
                return self.linear(x).relu()
626

627
        model = Foo()
628

629
        def f(x, params):
630
            out = torch.func.functional_call(model, params, x).sum()
631
            out.backward()
632
            return list(params.values())
633
        input = torch.randn(3, 5, requires_grad=True)
634
        params = dict(model.named_parameters())
635
        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
636
        # fx may change the order of parameters in list, so using set() to compare
637
        self.assertTrue(
638
            torch.allclose(fx_f(input, params)[0], f(input, params)[0])
639
            or
640
            torch.allclose(fx_f(input, params)[0], f(input, params)[1])
641
        )
642
        self.assertTrue(
643
            torch.allclose(fx_f(input, params)[1], f(input, params)[0])
644
            or
645
            torch.allclose(fx_f(input, params)[1], f(input, params)[1])
646
        )
647

648
    def test_make_fx_model_double_param(self):
649
        class Emformer(torch.nn.Module):
650
            def __init__(
651
                self,
652
                input_dim: int = 256,
653
            ) -> None:
654
                super().__init__()
655

656
                self.layer_norm = torch.nn.LayerNorm(input_dim)
657

658
            def forward(mod_self, x):  # noqa: B902
659
                self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
660
                y = mod_self.layer_norm(x)
661
                self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
662
                z = mod_self.layer_norm(y)
663
                return z
664

665

666
        gm = make_fx(Emformer())(torch.randn(16, 1, 256))
667
        ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
668
        self.assertEqual(len(ops), 2)
669

670

671
    def test_make_fx_model_fwd_bwd_wgtupdate(self):
672
        class Foo(torch.nn.Module):
673
            def __init__(self) -> None:
674
                super().__init__()
675
                self.linear = torch.nn.Linear(5, 5)
676

677
            def forward(self, x):
678
                return self.linear(x).relu()
679

680
        model = Foo()
681

682
        def f(args, params, buffers):
683
            for p in params.values():
684
                p.grad = None
685
            if not isinstance(args, Iterable):
686
                args = [args]
687
            params_and_buffers = {**params, **buffers}
688
            out = torch.func.functional_call(model, params_and_buffers, args)
689
            out.sum().backward()
690
            return [p - 1e-4 * p.grad for p in params.values()]
691

692
        input = torch.randn(3, 5, requires_grad=True)
693
        params = dict(model.named_parameters())
694
        buffers = dict(model.named_buffers())
695
        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
696
        # fx may change the order of parameters in list, so using set() to compare
697
        # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
698
        self.assertTrue(
699
            torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
700
            or
701
            torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
702
        )
703
        self.assertTrue(
704
            torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
705
            or
706
            torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
707
        )
708

709
    def test_trace_subclasses(self):
710
        def f1(x):
711
            x = UnwrapTensor(x)
712
            y = x * 2
713
            return y
714

715
        def f2(x):
716
            wrapped = UnwrapTensor(x)
717
            y = x * wrapped
718
            return y
719

720
        inp = [torch.randn(5)]
721
        self._test(f1, inp)
722
        self._test(f2, inp)
723

724
    def test_partial_decomp(self):
725
        def f(a, b, c):
726
            x = torch.addmm(a, b, c)
727
            y = torch.addmm(a, b, c, beta=2, alpha=1)
728
            return x + y
729
        inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
730
        fx_g = make_fx(f)(*inps)
731

732
        def addmm(a, b, c, beta=1, alpha=1):
733
            if beta == 1 and alpha == 1:
734
                return NotImplemented
735
            return beta * a + alpha * (b @ c)
736

737
        decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps)
738

739
        self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
740
        self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
741
        self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
742

743
    def test_decomp_of_capture(self):
744
        val = torch.randn(5)
745

746
        def f(x):
747
            return x.t() + val.t()
748

749
        def nop(x):
750
            return x.cos()
751

752
        traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
753
        self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
754

755

756
    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
757
    def test_amp_cache(self):
758
        layer = torch.nn.Conv2d(3, 3, 3).cuda()
759

760
        def f(x, w):
761
            return torch.nn.functional.conv2d(x, w, stride=layer.stride)
762

763
        inp = torch.randn(4, 3, 10, 10, device='cuda')
764
        with torch.autocast('cuda'):
765
            out_graph = make_fx(f)(inp, layer.weight).graph
766
            out_graph2 = make_fx(f)(inp, layer.weight).graph
767

768
        self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
769
        for a, b in zip(out_graph.nodes, out_graph2.nodes):
770
            self.assertEqual(a.op, b.op)
771

772
    def test_strides(self):
773
        def f(x):
774
            self.assertTrue(x.is_contiguous())
775
            self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
776
            x = x.permute(0, 3, 1, 2)
777
            self.assertFalse(x.is_contiguous())
778
            self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
779
            return x
780
        make_fx(f)(torch.randn(2, 3, 4, 5))
781

782
        def f(x):
783
            self.assertTrue(x.is_contiguous())
784
            y = x[:, 1]
785
            self.assertFalse(y.is_contiguous())
786
            y = x[:, ::2]
787
            self.assertFalse(y.is_contiguous())
788
            return x.cos()
789

790
        make_fx(f)(torch.randn(2, 3, 4, 5))
791

792
    def test_pr_86917(self):
793
        # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344
794
        def f(a, b):
795
            return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10)
796

797
        self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
798

799
class TestGenericProxyTensorReal(TestGenericProxyTensor):
800
    tracing_mode = "real"
801

802

803
class TestGenericProxyTensorFake(TestGenericProxyTensor):
804
    tracing_mode = "fake"
805

806

807
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
808
    tracing_mode = "symbolic"
809

810

811
del TestGenericProxyTensor
812

813

814
class TestRealProxyTensor(TestCase):
815
    def test_error_on_data_dependent_ops(self):
816
        def f():
817
            x = torch.randn([])
818
            y = torch.randn([])
819
            assert torch.allclose(x * y, y * x)
820
            z = float(x)
821
            z2 = float(y)
822

823
        # Smoke tests
824
        make_fx(f, _error_on_data_dependent_ops=False)()
825
        make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)()
826

827
class TestFakeProxyTensor(TestCase):
828
    def test_issue82547(self):
829
        x = nn.Parameter(torch.randn(3, 3))
830

831
        def f():
832
            return torch.ops.aten.t.default(x)
833
        self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")())
834

835
        class A(torch.Tensor):
836
            pass
837

838
        x = A(torch.randn(3, 3))
839
        self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")())
840

841
    def test_use_fake_and_tensor(self):
842
        def f(x, y):
843
            z = torch.tensor([2.0, 3.0])
844
            return x + y + z
845

846
        g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
847
        x, y = torch.randn(2), torch.randn(2)
848
        self.assertEqual(g(x, y), f(x, y))
849

850
    def test_free_fake(self):
851
        def f(x):
852
            return torch.add(x, y)
853

854
        with FakeTensorMode() as fake_mode:
855
            y = torch.randn(2)
856
            make_fx(f, tracing_mode="real")(torch.randn(2))
857

858
    def test_fused_adam(self):
859
        # See https://github.com/pytorch/pytorch/issues/99356
860
        params = [torch.randn(10, 10) for _ in range(10)]
861
        grads = [torch.randn(10, 10) for _ in range(10)]
862
        exp_avgs = [torch.randn(10, 10) for _ in range(10)]
863
        exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
864
        max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
865
        state_steps = [torch.tensor(0) for _ in range(10)]
866

867
        def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
868
            (new_params, _, _, _, _) = aten._fused_adam.default(
869
                params,
870
                grads,
871
                exp_avgs,
872
                exp_avg_sqs,
873
                max_exp_avg_sqs,
874
                state_steps,
875
                lr=0.1,
876
                beta1=0.9,
877
                beta2=0.999,
878
                weight_decay=0.01,
879
                eps=1e-8,
880
                amsgrad=False,
881
                maximize=False,
882
            )
883

884
            for p, new_p in zip(params, new_params):
885
                p.copy_(new_p)
886

887
            return params
888

889
        gm = make_fx(fused_adam, tracing_mode='fake')(
890
            params,
891
            grads,
892
            exp_avgs,
893
            exp_avg_sqs,
894
            max_exp_avg_sqs,
895
            state_steps,
896
        )
897
        ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
898
        for n in gm.graph.nodes:
899
            if n.op == "call_function" and n.target in ensure_ops_have_val:
900
                self.assertIn('val', n.meta)
901

902
    def test_alias(self):
903
        def f(x):
904
            return torch.ops.aten.alias(x)
905

906
        r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
907
        # NB: this should not have a detach call
908
        self.assertExpectedInline(r, """\
909
def forward(self, x_1):
910
    alias = torch.ops.aten.alias.default(x_1);  x_1 = None
911
    return alias""")
912

913
    def test_meta(self):
914
        def f(x):
915
            a = x.cos()
916
            b = torch.var_mean(a, dim=0)
917
            c = b * 2
918
            return c
919

920
        out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5))
921
        for n in out.graph.nodes:
922
            if n.op == 'output':
923
                continue
924
            self.assertTrue('val' in n.meta)
925

926
def _get_node(fx_g, cond):
927
    for n in fx_g.graph.nodes:
928
        if cond(n):
929
            return n
930
    raise AssertionError
931

932
def _get_free_symbols(shape_env):
933
    vars = tuple(shape_env.var_to_val.keys())
934
    return len([var for var in vars if var not in shape_env.replacements])
935

936
def _trace(f, *args):
937
    inps = [torch.randn(arg) for arg in args]
938
    return make_fx(f, tracing_mode="symbolic")(*inps)
939

940
# TODO: Need to test the guards themselves specifically as well
941
class TestSymbolicTracing(TestCase):
942
    def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
943
        """
944
        Tests fn traced with trace_inputs against test_inputs
945
        Also returns shape env
946
        """
947
        trace_inputs = [torch.randn(shape) for shape in trace_inputs]
948
        traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
949
        for input in test_inputs:
950
            input = [torch.randn(shape) for shape in input]
951
            rx, ry = traced_f(*input), fn(*input)
952
            if assert_eq:
953
                self.assertEqual(rx, ry)
954
        return traced_f
955

956

957
    def test_debug_interpreter(self):
958
        import torch.library
959
        from torch.library import Library
960

961
        foo = Library("foo", "DEF")  # noqa: TOR901
962
        foo.define("foo(Tensor self) -> Tensor")
963

964
        # Operator where meta and cpu disagree on strides
965
        @torch.library.impl(foo, "foo", "CPU")
966
        def foo_cpu(x):
967
            return x.clone().T
968

969
        @torch.library.impl(foo, "foo", "Meta")
970
        def foo_meta(x):
971
            return x.clone()
972

973
        def f(x):
974
            return torch.ops.foo.foo.default(x)
975

976
        gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2))
977
        from torch._functorch.compilers import DebugInterpreter
978

979
        interp = DebugInterpreter(gm)
980

981
        # input mismatch is caught (indicates guard problem)
982
        self.assertRaisesRegex(
983
            AssertionError, r"3 != 1",
984
            lambda: interp.run(torch.randn(3, 3).T),
985
        )
986

987
        # Catch the incorrect meta
988
        self.assertRaisesRegex(
989
            AssertionError, r"\(3, 1\) != \(1, 3\)",
990
            lambda: interp.run(torch.randn(3, 3))
991
        )
992

993
    def test_int_input(self):
994
        def f(x, y):
995
            return x.view(y)
996

997
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip()
998
        self.assertExpectedInline(r, """\
999
def forward(self, x_1, y_1):
1000
    view = torch.ops.aten.view.default(x_1, [y_1]);  x_1 = y_1 = None
1001
    return view""")
1002

1003
    def test_resize_from_zero(self):
1004
        def f(x, y):
1005
            x.resize_(y.size(0))
1006

1007
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
1008
        self.assertExpectedInline(r, """\
1009
def forward(self, x_1, y_1):
1010
    sym_size_int = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1011
    resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]);  x_1 = sym_size_int = resize_ = None
1012
    return None""")
1013

1014
    def test_broadcast_shapes(self):
1015
        def f(x, y):
1016
            return torch.functional.broadcast_shapes(x.size(), y.size()[0])
1017

1018
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 1), torch.empty(5)).code).strip()
1019
        self.assertExpectedInline(r, """\
1020
def forward(self, x_1, y_1):
1021
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0);  x_1 = None
1022
    sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1023
    return (sym_size_int, sym_size_int_1)""")
1024

1025
    def test_deduped_shape(self):
1026
        def f(s0, s1, x, y):
1027
            return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1028

1029
        x = torch.empty(3, 1)
1030
        y = torch.empty(5)
1031
        from torch.fx.experimental.symbolic_shapes import ShapeEnv
1032
        shape_env = ShapeEnv()
1033

1034
        with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1035
            x = fake_mode.from_tensor(x)
1036
            y = fake_mode.from_tensor(y)
1037
            r = str(make_fx(f, tracing_mode="real")(x.shape[0], y.shape[0], x, y).code).strip()
1038
            self.assertExpectedInline(r, """\
1039
def forward(self, s0_1, s1_1, x_1, y_1):
1040
    empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False)
1041
    return ((s0_1, s1_1), empty)""")
1042

1043
    def test_non_deduped_shape(self):
1044
        def f(x, y):
1045
            return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1046

1047
        x = torch.empty(3, 1)
1048
        y = torch.empty(5)
1049
        from torch.fx.experimental.symbolic_shapes import ShapeEnv
1050
        shape_env = ShapeEnv()
1051

1052
        with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1053
            x = fake_mode.from_tensor(x)
1054
            y = fake_mode.from_tensor(y)
1055
            r = str(make_fx(f, tracing_mode="real")(x, y).code).strip()
1056
            self.assertExpectedInline(r, """\
1057
def forward(self, x_1, y_1):
1058
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0);  x_1 = None
1059
    sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1060
    empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
1061
    return ((sym_size_int, sym_size_int_1), empty)""")
1062

1063
    def test_unary(self):
1064
        def f(x):
1065
            assert x.shape[0] < 20
1066
            return x.cos()
1067
        test_inputs = []
1068
        test_inputs.append([(2, 5)])
1069
        test_inputs.append([(6, 8)])
1070
        gm = self._test_dynamic(f, [(3, 4)], test_inputs)
1071
        self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
1072
        self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
1073
        self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
1074
        self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
1075

1076
    def test_repeat_interleave(self):
1077
        def f(src_tokens, beam_size_src):
1078
            return src_tokens.repeat_interleave(beam_size_src.size(0), 0)
1079

1080
        prompt_size = 64
1081
        vocab_size = 64
1082
        batch_size = 4
1083
        src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1084
        gm = make_fx(f, tracing_mode="symbolic")(src_tokens, torch.randn(5))
1085
        self.assertEqual(len(gm.shape_env.guards), 0)
1086

1087
    def test_non_symint_size_spec(self):
1088
        # this isn't really a proxy tensor test, but it's the most convenient
1089
        # way to get a fake tensor with symbolic sizes
1090
        def f(x):
1091
            torch._C._non_sym_sizes(x)
1092
            return x + 1
1093

1094
        x = torch.randn(2, 3)
1095
        make_fx(f, tracing_mode="symbolic")(x)
1096

1097
    # https://github.com/pytorch/pytorch/issues/108195
1098
    def test_symbolic_repeat_interleave(self):
1099
        def f(y, x):
1100
            return y.repeat_interleave(x, dim=1)
1101

1102
        y = torch.tensor([[1, 2], [3, 4]])
1103
        x = torch.tensor([2, 3])
1104
        r = str(make_fx(f, tracing_mode="symbolic")(y, x).code).strip()
1105
        self.assertExpectedInline(r, """\
1106
def forward(self, y_1, x_1):
1107
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1);  x_1 = None
1108
    index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave);  y_1 = repeat_interleave = None
1109
    return index_select""")
1110

1111
    def test_mod_gcd_unbacked(self):
1112
        def f(_a, _b, _stride):
1113
            a = _a.item()
1114
            b = _b.item()
1115
            stride = _stride.item()
1116
            torch._check_is_size(a)
1117
            torch._check_is_size(b)
1118
            torch._check_is_size(stride)
1119
            ta = torch.randn(a * stride)
1120
            tb = torch.randn(b * stride)
1121
            r = torch.cat([ta, tb])
1122
            return r.view(a + b, stride)
1123

1124
        _a = torch.tensor(30)
1125
        _b = torch.tensor(20)
1126
        _stride = torch.tensor(10)
1127
        r = str(make_fx(f, tracing_mode="symbolic")(_a, _b, _stride).code).strip()
1128
        self.assertExpectedInline(r, """\
1129
def forward(self, _a_1, _b_1, _stride_1):
1130
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(_a_1);  _a_1 = None
1131
    _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(_b_1);  _b_1 = None
1132
    _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(_stride_1);  _stride_1 = None
1133
    mul = _local_scalar_dense * _local_scalar_dense_2
1134
    randn = torch.ops.aten.randn.default([mul], device = device(type='cpu'), pin_memory = False);  mul = None
1135
    mul_1 = _local_scalar_dense_1 * _local_scalar_dense_2
1136
    randn_1 = torch.ops.aten.randn.default([mul_1], device = device(type='cpu'), pin_memory = False);  mul_1 = None
1137
    cat = torch.ops.aten.cat.default([randn, randn_1]);  randn = randn_1 = None
1138
    add = _local_scalar_dense + _local_scalar_dense_1;  _local_scalar_dense = _local_scalar_dense_1 = None
1139
    view = torch.ops.aten.view.default(cat, [add, _local_scalar_dense_2]);  cat = add = _local_scalar_dense_2 = None
1140
    return view""")
1141

1142
    def test_cumsum_unbacked(self):
1143
        def f(x):
1144
            y = x.item()
1145
            z = torch.randn((3, y, 3))
1146
            return z.cumsum(0)
1147

1148
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([5])).code).strip()
1149
        self.assertExpectedInline(
1150
            r, """\
1151
def forward(self, x_1):
1152
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1153
    randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1154
    cumsum = torch.ops.aten.cumsum.default(randn, 0);  randn = None
1155
    return cumsum"""  # noqa: B950
1156
        )
1157

1158

1159
    def test_repeat_interleave_unbacked_output_size(self):
1160
        def f(x, y):
1161
            s = x.sum().item()
1162
            return y.repeat_interleave(x, dim=0, output_size=s)
1163

1164
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip()
1165
        self.assertExpectedInline(
1166
            r, """\
1167
def forward(self, x_1, y_1):
1168
    sum_1 = torch.ops.aten.sum.default(x_1)
1169
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1);  sum_1 = None
1170
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense);  x_1 = _local_scalar_dense = None
1171
    index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave);  y_1 = repeat_interleave = None
1172
    return index_select"""  # noqa: B950
1173
        )
1174

1175
    def test_arange_unbacked_output_size(self):
1176
        def f(x):
1177
            return torch.arange(0, x)
1178

1179
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10)).code).strip()
1180
        self.assertExpectedInline(
1181
            r, """\
1182
def forward(self, x_1):
1183
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1184
    arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1185
    return arange"""  # noqa: B950
1186
        )
1187

1188
    def test_adv_index_batch(self):
1189
        def f(src_tokens):
1190
            bsz, src_len = src_tokens.size()[:2]
1191
            start_step = src_tokens.shape[1]
1192
            beam_size = 1
1193
            generate_size = 64
1194
            max_len = src_len + generate_size
1195
            tokens = torch.zeros(bsz * beam_size, max_len).to(src_tokens).long().fill_(0)
1196
            tokens[:, :start_step] = src_tokens.repeat_interleave(beam_size, 0)
1197
            return tokens
1198

1199
        prompt_size = 64
1200
        vocab_size = 64
1201
        batch_size = 4
1202
        src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1203
        gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
1204
        # Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
1205
        # 1 ok)
1206
        self.assertEqual(len(gm.shape_env.guards), 1)
1207

1208
    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1209
    def test_cpu_scalar_cuda(self):
1210
        # Extracted from wave2vec2
1211
        def f(a, b):
1212
            return (a * b) @ b
1213

1214
        r = str(
1215
            make_fx(f, tracing_mode="symbolic")(
1216
                torch.tensor(1.0), torch.randn(2, 2, device='cuda')
1217
            ).code
1218
        ).strip()
1219
        self.assertExpectedInline(r, """\
1220
def forward(self, a_1, b_1):
1221
    mul = torch.ops.aten.mul.Tensor(a_1, b_1);  a_1 = None
1222
    mm = torch.ops.aten.mm.default(mul, b_1);  mul = b_1 = None
1223
    return mm""")
1224

1225
    def test_binary_broadcast(self):
1226
        def f(a, b):
1227
            c = a * b
1228
            return c
1229

1230
        test_inputs = []
1231
        test_inputs.append([(1, 5), (3, 1)])
1232
        test_inputs.append([(1, 4), (4, 1)])
1233
        shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env
1234
        assert len(shape_env.guards) == 0
1235

1236
    def test_multiply_shape(self):
1237
        def f(a):
1238
            return torch.empty(a.shape[0] * 2)
1239

1240
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1241
        self.assertExpectedInline(r, """\
1242
def forward(self, a_1):
1243
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0);  a_1 = None
1244
    mul = sym_size_int * 2;  sym_size_int = None
1245
    empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False);  mul = None
1246
    return empty""")
1247

1248
    def test_item(self):
1249
        def f(a):
1250
            r = a.item()
1251
            return r * a
1252

1253
        r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip()
1254
        self.assertExpectedInline(r, """\
1255
def forward(self, a_1):
1256
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1)
1257
    mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense);  a_1 = _local_scalar_dense = None
1258
    return mul""")
1259

1260
    def test_tensor_symfloat(self):
1261
        def f(a):
1262
            r = torch.tensor(a.size(0) ** 2.0)
1263
            assert r.dtype is torch.float
1264
            return r
1265

1266
        gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2))
1267
        r = str(gm.code).strip()
1268
        # NB: this specializes, which is fine, the point is to make sure the
1269
        # dtype inference is correct
1270
        self.assertExpectedInline(r, """\
1271
def forward(self, a_1):
1272
    _tensor_constant0 = self._tensor_constant0
1273
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
1274
    return lift_fresh_copy""")
1275
        self.assertEqual(gm._tensor_constant0, torch.tensor(4.0))
1276

1277
    def test_item_to_constructor(self):
1278
        def f(a):
1279
            r = a.item()
1280
            return torch.empty(r)
1281

1282
        r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
1283
        self.assertExpectedInline(
1284
            r, """\
1285
def forward(self, a_1):
1286
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1);  a_1 = None
1287
    empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1288
    return empty"""  # noqa: B950
1289
        )
1290

1291

1292
    def test_setitem_symint(self):
1293
        # from moco
1294
        # https://github.com/pytorch/pytorch/issues/101939
1295
        def f(x):
1296
            x[0] = x.size(0)
1297
            return x
1298

1299
        r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(10)).code).strip()
1300
        self.assertExpectedInline(
1301
            r, """\
1302
def forward(self, x_1):
1303
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1304
    scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  sym_size_int = None
1305
    select = torch.ops.aten.select.int(x_1, 0, 0)
1306
    copy_ = torch.ops.aten.copy_.default(select, scalar_tensor);  select = scalar_tensor = copy_ = None
1307
    return x_1"""  # noqa: B950
1308
        )
1309

1310
    def test_dynamic_pointwise_scalar(self):
1311
        def f(gravity, mask):
1312
            gravity[mask, 0] = gravity[mask, 0] * -1
1313

1314
        r = str(make_fx(f, tracing_mode="symbolic")(
1315
            torch.randn((12, 4)),
1316
            torch.randint(0, 2, (12,), dtype=torch.bool)
1317
        ).code).strip()
1318
        self.assertExpectedInline(r, """\
1319
def forward(self, gravity_1, mask_1):
1320
    select = torch.ops.aten.select.int(gravity_1, 1, 0)
1321
    index = torch.ops.aten.index.Tensor(select, [mask_1]);  select = None
1322
    mul = torch.ops.aten.mul.Tensor(index, -1);  index = None
1323
    select_1 = torch.ops.aten.select.int(gravity_1, 1, 0);  gravity_1 = None
1324
    index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul);  select_1 = mask_1 = mul = index_put_ = None
1325
    return None""")
1326

1327
    def test_reflect_r_over_x(self):
1328
        def reflect_R_over_x(R):
1329
            reflect = torch.eye(3, device=R.device)
1330
            reflect[0, 0] = -1
1331
            return reflect @ R @ reflect
1332

1333
        def f(crop_camera, mask):
1334
            crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
1335

1336
        r = str(make_fx(f, tracing_mode="symbolic")(
1337
            torch.randn((12, 3, 3)),
1338
            torch.randint(0, 2, (12,), dtype=torch.bool)
1339
        ).code).strip()
1340
        self.assertExpectedInline(r, """\
1341
def forward(self, crop_camera_1, mask_1):
1342
    index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
1343
    eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
1344
    _tensor_constant0 = self._tensor_constant0
1345
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
1346
    select = torch.ops.aten.select.int(eye, 0, 0)
1347
    select_1 = torch.ops.aten.select.int(select, 0, 0);  select = None
1348
    copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy);  select_1 = lift_fresh_copy = copy_ = None
1349
    sym_size_int = torch.ops.aten.sym_size.int(index, 0)
1350
    expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
1351
    view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]);  expand = None
1352
    sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
1353
    sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
1354
    expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]);  index = None
1355
    view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  expand_1 = sym_size_int_1 = sym_size_int_2 = None
1356
    bmm = torch.ops.aten.bmm.default(view, view_1);  view = view_1 = None
1357
    view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]);  bmm = None
1358
    mul_4 = sym_size_int * 3
1359
    view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]);  view_2 = mul_4 = None
1360
    mm = torch.ops.aten.mm.default(view_3, eye);  view_3 = eye = None
1361
    view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]);  mm = sym_size_int = None
1362
    index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4);  crop_camera_1 = mask_1 = view_4 = index_put_ = None
1363
    return None""")  # noqa: B950
1364

1365
    def test_unbacked_slice(self):
1366
        def f(x, m):
1367
            x = x[m]
1368
            return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
1369

1370
        make_fx(f, tracing_mode="symbolic")(
1371
            torch.randn((12, 3, 3)),
1372
            torch.randint(0, 2, (12,), dtype=torch.bool)
1373
        )
1374

1375
    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
1376
    def test_unbacked_batch_resnet(self):
1377
        mod = torchvision.models.resnet18()
1378

1379
        def f(x, mask, params, buffers):
1380
            for p in itertools.chain([x, mask], params.values(), buffers.values()):
1381
                for s in p.shape:
1382
                    guard_int(s)
1383
            x = x[mask]
1384
            torch._check(x.shape[0] >= 1)
1385
            for p in params.values():
1386
                p.grad = None
1387
            return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
1388

1389
        make_fx(f, tracing_mode="symbolic")(
1390
            torch.randn(3, 3, 250, 250),
1391
            torch.randint(0, 2, (3,), dtype=torch.bool),
1392
            dict(mod.named_parameters()),
1393
            dict(mod.named_buffers()),
1394
        )
1395

1396
    def test_boolean_index(self):
1397
        def f(images, handedness, valid):
1398
            images = images[valid]
1399
            handedness = handedness[valid]
1400
            right_hand_mask = handedness == 1
1401
            images[right_hand_mask] = images[right_hand_mask].flip(-1)
1402

1403
        r = str(make_fx(f, tracing_mode="symbolic")(
1404
            torch.randint(0, 256, (512, 1, 96, 96)),
1405
            torch.randint(0, 1, (512,)),
1406
            torch.randint(0, 2, (512,), dtype=torch.bool)
1407
        ).code).strip()
1408
        self.assertExpectedInline(r, """\
1409
def forward(self, images_1, handedness_1, valid_1):
1410
    index = torch.ops.aten.index.Tensor(images_1, [valid_1]);  images_1 = None
1411
    index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]);  handedness_1 = valid_1 = None
1412
    eq = torch.ops.aten.eq.Scalar(index_1, 1);  index_1 = None
1413
    index_2 = torch.ops.aten.index.Tensor(index, [eq])
1414
    flip = torch.ops.aten.flip.default(index_2, [-1]);  index_2 = None
1415
    index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip);  index = eq = flip = index_put_ = None
1416
    return None""")
1417

1418
    def test_neg_shape(self):
1419
        def f(a):
1420
            return torch.empty(-a.shape[0] + 10)
1421

1422
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
1423
        self.assertExpectedInline(r, """\
1424
def forward(self, a_1):
1425
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0);  a_1 = None
1426
    neg = -sym_size_int;  sym_size_int = None
1427
    add = neg + 10;  neg = None
1428
    empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False);  add = None
1429
    return empty""")
1430

1431
    def test_unbacked_unification(self):
1432
        def f(x, y):
1433
            z = torch.zeros(x.item())
1434
            return z + y
1435

1436
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1437
        self.assertExpectedInline(r, """\
1438
def forward(self, x_1, y_1):
1439
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1440
    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1441
    add = torch.ops.aten.add.Tensor(zeros, y_1);  zeros = y_1 = None
1442
    return add""")  # noqa: B950
1443

1444
    def test_reshape_divisibility_unbacked(self):
1445
        def f(x):
1446
            i0 = x.item()
1447
            r = torch.zeros(i0, 4, 20)
1448
            r = r.transpose(2, 1)
1449
            return r.reshape(-1, 80)
1450
        make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
1451

1452
    def test_view_divisibility_unbacked(self):
1453
        def f(x):
1454
            i0 = x.item()
1455
            r = torch.zeros(i0, 192)
1456
            return r.view(12, -1, 192)
1457
        make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
1458

1459
    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1460
    def test_view_divisibility_unbacked_relatively_prime(self):
1461
        # See https://github.com/pytorch/pytorch/issues/123651
1462
        def f(x):
1463
            i0 = x.item()
1464
            torch._check_is_size(i0)
1465
            # To trigger the original issue, the max bound has to
1466
            # be chosen such that 448 / 447 < 2 (which it is.)
1467
            torch._check(i0 <= 448)
1468
            return torch.zeros(256 * i0).view(-1, 447)
1469
        make_fx(f, tracing_mode="symbolic")(torch.tensor(256 * 447, device="cuda"))
1470

1471
    def test_unbacked_unify_guard(self):
1472
        def f(x, y):
1473
            z = torch.zeros(x.item())
1474
            torch._check(z.size(0) == y.size(0))  # refines i0 = s0
1475
            if z.size(0) == 4:
1476
                return y * 2
1477
            else:
1478
                return y + 2
1479

1480
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1481
        self.assertExpectedInline(r, """\
1482
def forward(self, x_1, y_1):
1483
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1484
    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = zeros = None
1485
    add = torch.ops.aten.add.Tensor(y_1, 2);  y_1 = None
1486
    return add""")  # noqa: B950
1487

1488
    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1489
    @unittest.expectedFailure
1490
    def test_unbacked_unify_guard_transitivity(self):
1491
        def f(x1, x2, y):
1492
            z1 = torch.zeros(x1.item())
1493
            z2 = torch.zeros(x2.item())
1494
            torch._check(z1.size(0) == z2.size(0))  # refines i0 = i1
1495
            torch._check(z2.size(0) == y.size(0))  # refines i0 = s0
1496
            if z1.size(0) == 4:
1497
                return y * 2
1498
            else:
1499
                return y + 2
1500

1501
        gm = make_fx(f, tracing_mode="symbolic")(
1502
            torch.tensor(10, device="cuda"),
1503
            torch.tensor(10, device="cuda"),
1504
            torch.randn(10, device="cuda")
1505
        )
1506
        insert_deferred_runtime_asserts(gm, gm.shape_env, "test")
1507
        gm.recompile()
1508
        r = str(gm.code).strip()
1509
        # self.assertExpectedInline(
1510
        #     r, """"""  # noqa: B950
1511
        # )
1512

1513
    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1514
    def test_unbacked_unify_dependency_violation(self):
1515
        def f(x1, x2, x3, y):
1516
            z1 = x1.item()
1517
            torch._check(z1 // 9 == 1)
1518
            z2 = x2.item()
1519
            z3 = x3.item()
1520
            torch._check(z1 == z2 + z3)
1521
            return y * 2
1522
            if z2 + z3 == z1:
1523
                return y * 2
1524
            else:
1525
                return y + 3
1526

1527
        # NB: inputs are done as CUDA to ensure they aren't queried to be
1528
        # backed
1529

1530
        gm = make_fx(f, tracing_mode="symbolic")(
1531
            torch.tensor(10, device="cuda"), torch.tensor(5, device="cuda"),
1532
            torch.tensor(5, device="cuda"), torch.randn(1, device="cuda")
1533
        )
1534
        insert_deferred_runtime_asserts(gm, gm.shape_env, "test")
1535
        gm.recompile()
1536
        self.assertEqual(gm(
1537
            torch.tensor(12, device="cuda"), torch.tensor(6, device="cuda"),
1538
            torch.tensor(6, device="cuda"), torch.tensor([1.0], device="cuda")),
1539
            torch.tensor([2.0], device="cuda")
1540
        )
1541
        with self.assertRaises(RuntimeError):
1542
            gm(
1543
                torch.tensor(20, device="cuda"), torch.tensor(10, device="cuda"),
1544
                torch.tensor(10, device="cuda"), torch.tensor([1.0], device="cuda")
1545
            )
1546

1547

1548
    def test_split_unbacked_sizes(self):
1549
        def f(lengths, values):
1550
            # tolist not directly supported atm
1551
            sizes = [lengths[i].item() for i in range(lengths.size(0))]
1552
            for s in sizes:
1553
                # TODO(avik): no assertion generated with torch._check_is_size?
1554
                torch._constrain_as_size(s)
1555
            return torch.split(values, sizes)
1556

1557
        r = str(make_fx(f, tracing_mode="symbolic")(
1558
            torch.tensor([2, 3, 4]),
1559
            torch.randn(9)
1560
        ).code).strip()
1561
        self.assertExpectedInline(r, """\
1562
def forward(self, lengths_1, values_1):
1563
    select = torch.ops.aten.select.int(lengths_1, 0, 0)
1564
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select);  select = None
1565
    select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
1566
    _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1);  select_1 = None
1567
    select_2 = torch.ops.aten.select.int(lengths_1, 0, 2);  lengths_1 = None
1568
    _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2);  select_2 = None
1569
    sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense);  sym_constrain_range_for_size = None
1570
    sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1);  sym_constrain_range_for_size_1 = None
1571
    sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2);  sym_constrain_range_for_size_2 = None
1572
    split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]);  values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
1573
    getitem = split_with_sizes[0]
1574
    getitem_1 = split_with_sizes[1]
1575
    getitem_2 = split_with_sizes[2];  split_with_sizes = None
1576
    return (getitem, getitem_1, getitem_2)""")  # noqa: B950
1577

1578
    def test_invalidate_nonzero(self):
1579
        ok = False
1580

1581
        def f(a):
1582
            nonlocal ok
1583
            b = a.clone()
1584
            x = b.nonzero()
1585
            x1 = b.nonzero()
1586
            x2 = b.nonzero()
1587
            assert x1.shape[0] == x2.shape[0]
1588
            ok = True
1589
            b.normal_()
1590
            y = b.nonzero()
1591
            try:
1592
                bool(x1.shape[0] == y.shape[0])
1593
                self.fail("didn't raise exception")
1594
            except GuardOnDataDependentSymNode:
1595
                pass
1596

1597
        make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1598

1599
    @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
1600
    def test_invalidate_nonzero_propagate_real_tensors(self):
1601
        def f(a):
1602
            b = a.clone()
1603
            x = b.nonzero()
1604
            x1 = b.nonzero()
1605
            x2 = b.nonzero()
1606
            assert x1.shape[0] == x2.shape[0]
1607
            b.normal_()
1608
            y = b.nonzero()
1609
            # Because you're not actually going to generate exactly zero with
1610
            # normal_ lol
1611
            assert x1.shape[0] == y.shape[0]
1612

1613
        make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1614

1615
    def test_sqrt_size(self):
1616
        def f(a):
1617
            return a / a.size(-1) ** 0.5
1618

1619
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1620
        self.assertExpectedInline(r, """\
1621
def forward(self, a_1):
1622
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1623
    sym_float = torch.sym_float(sym_size_int);  sym_size_int = None
1624
    pow_1 = sym_float ** 0.5;  sym_float = None
1625
    div = torch.ops.aten.div.Tensor(a_1, pow_1);  a_1 = pow_1 = None
1626
    return div""")
1627

1628
    def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self):
1629

1630
        class Bar(torch.nn.Module):
1631
            def __init__(self) -> None:
1632
                super().__init__()
1633

1634
            def forward(self, x):
1635
                return x + 1
1636

1637
        class Foo(torch.nn.Module):
1638
            def __init__(self) -> None:
1639
                super().__init__()
1640
                self.bar = Bar()
1641

1642
            def forward(self, x):
1643
                return x + self.bar(x)
1644

1645
        gm = make_fx(Foo())(torch.randn(4, 4))
1646
        for node in gm.graph.nodes:
1647
            self.assertTrue("nn_module_stack" not in node.meta)
1648

1649
        foo = Foo()
1650

1651
        def functional_call(*args, **kwargs):
1652
            with stateless._reparametrize_module(foo, {}):
1653
                return foo(*args, **kwargs)
1654

1655
        functional_call._orig_mod = foo
1656

1657
        gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4))
1658
        found = False
1659
        for node in gm_with_stack.graph.nodes:
1660
            if "nn_module_stack" in node.meta:
1661
                if len(node.meta["nn_module_stack"]) == 1:
1662
                    self.assertTrue("custom_tracer_preserving_nn_module_stack.<locals>.Foo" in str(node.meta["nn_module_stack"]))
1663
                    found = True
1664
                elif len(node.meta["nn_module_stack"]) == 2:
1665
                    self.assertTrue("preserving_nn_module_stack.<locals>.Bar" in str(node.meta["nn_module_stack"]))
1666
                    found = True
1667
                else:
1668
                    # there can be at most 2 level
1669
                    self.assertTrue(False)
1670

1671
        self.assertTrue(found)
1672

1673
        gm_without_stack = make_fx(functional_call)(torch.randn(4, 4))
1674
        for node in gm_without_stack.graph.nodes:
1675
            self.assertTrue("nn_module_stack" not in node.meta)
1676

1677
    def test_symint_to_tensor(self):
1678
        def f(a):
1679
            return a / a.shape[0]
1680

1681
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1682
        self.assertExpectedInline(r, """\
1683
def forward(self, a_1):
1684
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1685
    div = torch.ops.aten.div.Tensor(a_1, sym_size_int);  a_1 = sym_size_int = None
1686
    return div""")
1687

1688
        r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
1689
        self.assertExpectedInline(r, """\
1690
def forward(self, a_1):
1691
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1692
    sym_float = torch.sym_float(sym_size_int);  sym_size_int = None
1693
    div = torch.ops.prims.div.default(a_1, sym_float);  a_1 = sym_float = None
1694
    return div""")
1695

1696
    def test_cat(self):
1697
        def f(a, b):
1698
            val = torch.mul(a, b)
1699
            out = torch.cat([val, val])
1700
            if out.shape[0] * out.shape[1] > 20:
1701
                out = out.cos()
1702
            return out
1703

1704
        test_inputs = []
1705
        test_inputs.append([(1, 5), (6, 1)])
1706
        test_inputs.append([(1, 4), (3, 1)])
1707
        gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
1708
        self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
1709
        self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
1710
        self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
1711

1712
    def test_new_empty(self):
1713
        def f(a, b):
1714
            return a.new_empty(b.shape[0], b.shape[1] * 2)
1715

1716
        self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env
1717

1718
    def test_size_with_tensor(self):
1719
        # I think I messed up writing this test case originally, I think
1720
        # I'm supposed to hit an error case, but the code here works in both
1721
        # eager and tracing
1722
        def f(tensor):
1723
            max_size = torch.tensor([800, 1216], dtype=torch.int64)
1724
            batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
1725
            return tensor.new_empty(batch_shape)
1726

1727
        a = torch.randn(3, 800, 1199)
1728
        f(a)
1729
        make_fx(f, tracing_mode="symbolic")(a)
1730

1731
    def test_fake_tensor_as_size(self):
1732
        def f(x):
1733
            r = torch.zeros([x])
1734
            return r
1735

1736
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.tensor(4))
1737
        self.assertExpectedInline(fx_g.code.strip(), """\
1738
def forward(self, x_1):
1739
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1740
    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1741
    return zeros""")  # noqa: B950
1742

1743
    def test_expand(self):
1744
        def f(a):
1745
            b = torch.mul(a, a)
1746
            c = b.expand(a.shape)
1747
            return c
1748

1749
        self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
1750
        self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
1751

1752
    def test_metadata(self):
1753
        def f(a, b):
1754
            d = a.new_empty(a.shape[0] + b.shape[0])
1755
            return d
1756
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
1757
        meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
1758
        meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
1759
        self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr)
1760

1761
    def test_metadata_fresh(self):
1762
        def f(x):
1763
            assert x.shape[0] == 3
1764
            return x.cos()
1765

1766
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
1767
        meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
1768
        meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
1769
        self.assertTrue(meta_cos.meta['val'].shape[0] == 3)
1770
        # Checks if the input expr has been updated even though the constraint
1771
        # happened afterwards
1772
        self.assertTrue(meta_inp.meta['val'].shape[0] == 3)
1773

1774
    def test_elementwise_meta_with_sym_numbers(self):
1775
        def f(x, offset, as_sym_float=False):
1776
            x0 = x.size()[0]
1777
            if as_sym_float:
1778
                x0 = torch.sym_float(x0)
1779
            return torch.add(x0, offset)
1780

1781
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False)
1782
        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1783
        self.assertEqual(meta_add.meta['val'].shape, ())
1784
        self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1785

1786
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False)
1787
        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1788
        self.assertEqual(meta_add.meta['val'].shape, ())
1789
        self.assertEqual(meta_add.meta['val'].dtype, torch.int64)
1790

1791
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True)
1792
        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1793
        self.assertEqual(meta_add.meta['val'].shape, ())
1794
        self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1795

1796
    def test_return_symint(self):
1797
        def f(x):
1798
            return x.shape[0], x.cos(), x.shape[0] / 5
1799
        self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1800

1801
        def f(x):
1802
            return x.shape
1803
        self._test_dynamic(f, [(5, 3)], [[(4, 6)]])
1804

1805
    def test_rmethod(self):
1806
        def f(x):
1807
            return x.size(0) + x
1808
        self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1809

1810
    def test_mega_guard(self):
1811
        def f(a, b):
1812
            assert a.shape[0] == b.shape[0] * 2
1813
            return a.cos()
1814
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
1815
        from torch._dynamo.source import LocalSource
1816
        self.assertExpectedInline(
1817
            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)),  # noqa: B950
1818
            """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]"""  # noqa: B950
1819
        )
1820
        self.assertExpectedInline(
1821
            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)),  # noqa: B950
1822
            """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]"""  # noqa: B950
1823
        )
1824

1825
    def test_guard_upperbound_range_refinement(self):
1826
        def f(a):
1827
            assert a.shape[0] > 5 and a.shape[0] > 12
1828
            return a.cos()
1829
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1830
        self.assertExpectedInline(show_guards(tensor), """13 <= L['a'].size()[0]""")
1831

1832
    def test_guard_lowerbound_range_refinement(self):
1833
        def f(a):
1834
            assert a.shape[0] < 20 and a.shape[0] < 30
1835
            return a.cos()
1836
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1837
        self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] <= 19""")
1838

1839
    def test_guard_upperbound_range_refinement_multivariate(self):
1840
        def f(a):
1841
            assert a.shape[0] > 5 and a.shape[0] > 12
1842
            assert a.shape[1] > 5 and a.shape[1] > a.shape[0]
1843
            return a.cos()
1844
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20)))
1845
        self.assertExpectedInline(show_guards(tensor), """\
1846
L['a'].size()[1] > L['a'].size()[0]
1847
13 <= L['a'].size()[0]
1848
14 <= L['a'].size()[1]""")
1849

1850
    def test_guard_lowerbound_range_refinement_multivariate(self):
1851
        def f(a):
1852
            assert a.shape[0] < 20 and a.shape[0] < 30
1853
            assert a.shape[1] < 30 and a.shape[1] < a.shape[0]
1854
            return a.cos()
1855
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5)))
1856
        self.assertExpectedInline(
1857
            show_guards(tensor),
1858
            """\
1859
L['a'].size()[1] < L['a'].size()[0]
1860
L['a'].size()[0] <= 19
1861
L['a'].size()[1] <= 18""")
1862

1863
    def test_sym_storage_offset(self):
1864
        def f(x, y):
1865
            return x + y
1866

1867
        inp = (torch.randn(8)[3:], torch.randn(5))
1868
        fx_g = make_fx(f, tracing_mode="symbolic")(*inp)
1869
        inp = (torch.randn(8)[3:], torch.randn(5))
1870
        self.assertEqual(fx_g(*inp), f(*inp))
1871

1872
    def _assert_no_guards(self, fx_g, free_symbols):
1873
        assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
1874
        assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
1875

1876
    def test_guards_equal(self):
1877
        def f(a, b):
1878
            return a * b
1879

1880
        # NB: Numbers are carefully chosen to avoid duck shaping from applying
1881

1882
        fx_g = _trace(f, (5, 6), (5, 6))
1883
        self._assert_no_guards(fx_g, 2)
1884

1885
        fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
1886
        self._assert_no_guards(fx_g, 3)
1887

1888
        fx_g = _trace(f, (5, 1), (1, 6))
1889
        self._assert_no_guards(fx_g, 2)
1890

1891
        def f(a, b, c, d):
1892
            a = a + b
1893
            cat = torch.cat([c, d])
1894
            return a + cat
1895

1896
        fx_g = _trace(f, 7, 7, 4, 3)
1897
        self._assert_no_guards(fx_g, 2)
1898

1899
        def f(a, b, c, d, e):
1900
            vals = [a, b, c, d, e]
1901
            x = a
1902
            for idx in range(len(vals) - 1):
1903
                x = torch.cat([x, vals[idx]]) + vals[idx + 1]
1904
            return x
1905

1906
        fx_g = _trace(f, 2, 4, 8, 16, 32)
1907
        self._assert_no_guards(fx_g, 1)
1908

1909
        def f(a, b):
1910
            a = a.view(b.shape[0])
1911
            return a + b.sum()
1912

1913
        fx_g = _trace(f, (4, 2), 8)
1914
        self._assert_no_guards(fx_g, 2)
1915

1916
        fx_g = _trace(f, (4, 2), (8, 5))
1917
        self._assert_no_guards(fx_g, 3)
1918

1919
        fx_g = _trace(f, (2, 3, 4), 24)
1920
        self._assert_no_guards(fx_g, 3)
1921

1922
    def test_nonidentity_transitive_guards(self):
1923
        def f(a, b, c, d, e):
1924
            vals = [a, b, c, d, e]
1925
            cat_vals = []
1926
            for idx in range(len(vals) - 1):
1927
                cat_vals.append(torch.cat([vals[idx], vals[idx]]))
1928
            final_vals = []
1929
            for a, b in reversed(list(zip(cat_vals, vals[1:]))):
1930
                final_vals.append(a + b)
1931
            return final_vals
1932

1933
        fx_g = _trace(f, 2, 4, 8, 16, 32)
1934
        self.assertExpectedInline(show_guards(fx_g), """""")
1935

1936
    @torch.fx.experimental._config.patch(translation_validation=True)
1937
    def test_constant_specialization(self):
1938
        def f(t):
1939
            assert t.shape[0] == 10
1940
            return t
1941

1942
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10))
1943
        self.assertExpectedInline(show_guards(tensor), """""")
1944

1945

1946
make_fx_failures = {
1947
    # unknown
1948
    xfail('allclose'),
1949
    xfail('equal'),
1950
    # empty
1951
    skip('new_empty'),
1952
    skip('empty_like'),
1953
    skip('empty'),
1954
    skip('empty_permuted'),
1955
    # flaky
1956
    skip('linalg.lstsq', 'grad_oriented'),
1957
    skip('nn.functional.max_unpool1d', '', device_type='cpu'),
1958
    skip('nn.functional.max_unpool2d', '', device_type='cpu'),
1959
    skip('nn.functional.max_unpool3d', '', device_type='cpu'),
1960
    skip('linalg.lstsq'),  # flaky, probably just a precision issue
1961

1962
    # data-dependent control flow
1963
    skip('item'),
1964
    xfail('cov'),
1965
    xfail('nn.functional.gaussian_nll_loss'),
1966
    xfail('tensor_split'),
1967
    xfail('corrcoef'),
1968
    xfail('quantile'),
1969
    xfail('nanquantile'),
1970

1971
    # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
1972
    xfail('sparse.sampled_addmm'),
1973
    xfail('sparse.mm', 'reduce'),
1974

1975
    # proxy tensor doesn't support sparse correctly right now
1976
    skip('to_sparse'),
1977
    # segfaults
1978
    skip('block_diag'),
1979

1980
    # AssertionError: Tensor-likes are not close!
1981
    skip('empty_strided', '', device_type='cpu'),
1982
}
1983

1984
only_real_tensor_failures = {
1985
    xfail('narrow'),
1986
}
1987

1988
only_fake_tensor_failures = {
1989
    xfail('narrow'),
1990
}
1991

1992
fake_tensor_failures = {
1993
    # ASAN failures due to divide by 0
1994
    skip('nn.functional.nll_loss'),
1995
}
1996

1997
symbolic_tensor_failures = {
1998
    xfail('combinations', ''),
1999
    xfail('geqrf', ''),  # aten.geqrf.default - couldn't find symbolic meta function/decomposition
2000
    xfail('histogram', ''),  # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
2001
    xfail('histogramdd', ''),  # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
2002
    xfail('nanquantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
2003
    xfail('nn.functional.binary_cross_entropy', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decom...
2004
    xfail('nn.functional.cross_entropy', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
2005
    xfail('nn.functional.ctc_loss'),  # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
2006
    xfail('quantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
2007
    xfail('unique_consecutive', ''),  # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
2008

2009
    xfail('max_pool2d_with_indices_backward', ''),  # Expected a value of type 'List[int]' for argument 'kernel_size' but...
2010

2011
    # many complex operators incorrect striding, metadata
2012
    xfail('fft.fft', ''),
2013
    xfail('fft.hfft2', ''),
2014
    xfail('fft.hfft', ''),
2015
    xfail('fft.hfftn', ''),
2016
    xfail('fft.ifft', ''),
2017
    xfail('fft.ihfft2', ''),
2018
    xfail('fft.ihfft', ''),
2019
    xfail('fft.ihfftn', ''),
2020
    xfail('fft.ihfft2', ''),
2021
    xfail('fft.irfft2', ''),
2022
    xfail('fft.irfft', ''),
2023
    xfail('fft.irfftn', ''),
2024
    xfail('fft.rfft2', ''),
2025
    xfail('fft.rfft', ''),
2026
    xfail('fft.rfftn', ''),
2027
    xfail('stft', '')
2028
}
2029
symbolic_tensor_segfaults = {
2030
    skip('nn.functional.batch_norm')  # Segfault??
2031
}
2032

2033
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
2034

2035
inplace_symbolic_tensor_failures = {
2036
    # bugs
2037
    xfail('float_power', ''),  # base given to float_power_ has dtype Float but the operation's result requires dtype Double
2038
}
2039

2040
out_symbolic_tensor_failures = {
2041
    # Cast error details: Unable to cast (...) to Tensor
2042
    #
2043
    # This happens because the test is set up to call the out variant using the `out` kwarg:
2044
    #   torch._some_op(arg1, arg2, out=(out1, out2, out3))
2045
    #
2046
    # However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`,
2047
    # this fails because the op has no python bindings, so it doesn't support the `out` kwarg
2048
    # way of calling its out variant.
2049
    xfail('_batch_norm_with_update', ''),
2050
    xfail('_native_batch_norm_legit', ''),
2051
    xfail('angle', ''),
2052
    xfail('argmax', ''),
2053
    xfail('argmin', ''),
2054
    xfail('fft.fft2', ''),
2055
    xfail('fft.fftn', ''),
2056
    xfail('fft.ifft2', ''),
2057
    xfail('fft.ifftn', ''),
2058
    xfail('gather', ''),
2059
    xfail('linalg.pinv', ''),
2060
    xfail('linalg.pinv', 'hermitian'),
2061
    xfail('lu', ''),
2062
    xfail('scatter_add', ''),
2063
    xfail('scatter', ''),
2064
    xfail('take_along_dim', ''),
2065
    xfail('triangular_solve', ''),
2066

2067
    # SymIntArrayRef expected to contain only concrete
2068
    xfail('ones', ''),
2069
    xfail('randn', ''),
2070
    xfail('zeros', ''),
2071

2072
    # RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides
2073
    xfail('index_reduce', 'prod'),
2074
    xfail('index_reduce', 'mean'),
2075
    xfail('index_reduce', 'amax'),
2076
    xfail('index_reduce', 'amin'),
2077
}
2078

2079
out_symbolic_tensor_segfaults = {
2080
    skip('nanmean', ''),
2081
}
2082

2083
out_symbolic_tensor_failures.update(out_symbolic_tensor_segfaults)
2084

2085
# Copies inputs to inplace operations to avoid inplace modifications
2086
#   to leaves requiring gradient
2087
def _get_safe_inplace(inplace_variant):
2088
    @functools.wraps(inplace_variant)
2089
    def _fn(t, *args, **kwargs):
2090
        return inplace_variant(t.clone(), *args, **kwargs)
2091

2092
    return _fn
2093

2094
def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, out=False):
2095
    fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op
2096
    sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
2097

2098
    # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
2099
    count = 100
2100
    if out:
2101
        count = 5
2102
    for sample_input in itertools.islice(sample_inputs_itr, count):
2103
        if inplace and sample_input.broadcasts_input:
2104
            continue
2105
        args = [sample_input.input] + list(sample_input.args)
2106
        kwargs = sample_input.kwargs
2107
        if out:
2108
            expected = fn(*args, **kwargs)
2109
            kwargs['out'] = expected
2110

2111
        try:
2112
            optests.make_fx_check(fn, args, kwargs, tracing_mode, self.assertEqual,
2113
                                  randomize_data=True)
2114
        except DynamicOutputShapeException:
2115
            self.skipTest("Dynamic output shape operation in trace")
2116

2117

2118
def skipIfNameMatches(pattern):
2119
    """
2120
    Decorator to skip a test if its name matches the given pattern.
2121
    """
2122
    def decorator(test_func):
2123
        def wrapper(*args, **kwargs):
2124
            if re.match(pattern, test_func.__name__):
2125
                raise unittest.SkipTest(f"Test '{test_func.__name__}' skipped because its name matches the pattern '{pattern}'")
2126
            return test_func(*args, **kwargs)
2127
        return wrapper
2128
    return decorator
2129

2130
# Auto functionalize shouldn't work with make_fx directly
2131
filtered_hop_db = [op for op in hop_db if op.name != "auto_functionalize"]
2132

2133
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond requires dynamo")
2134
class TestProxyTensorOpInfo(TestCase):
2135
    @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2136
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures.union(only_real_tensor_failures))
2137
    def test_make_fx_exhaustive(self, device, dtype, op):
2138
        _test_make_fx_helper(self, device, dtype, op, "real")
2139

2140
    @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2141
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive',
2142
             make_fx_failures.union(fake_tensor_failures, only_fake_tensor_failures))
2143
    def test_make_fx_fake_exhaustive(self, device, dtype, op):
2144
        _test_make_fx_helper(self, device, dtype, op, "fake")
2145

2146
    @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2147
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
2148
             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
2149
    def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
2150
        _test_make_fx_helper(self, device, dtype, op, "symbolic")
2151

2152
    @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2153
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
2154
             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)
2155
    def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op):
2156
        if not op.get_inplace():
2157
            self.skipTest("No inplace variable for this op")
2158
        _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True)
2159

2160
    @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2161
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out',
2162
             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | out_symbolic_tensor_failures)
2163
    def test_make_fx_symbolic_exhaustive_out(self, device, dtype, op):
2164
        if not op.supports_out:
2165
            self.skipTest("Op doesn't support out")
2166
        _test_make_fx_helper(self, device, dtype, op, "symbolic", out=True)
2167

2168

2169
only_for = ("cpu")
2170
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
2171

2172

2173
if __name__ == '__main__':
2174
    run_tests()
2175

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.