pytorch

Форк
0
/
test_proxy_tensor.py 
2047 строк · 77.2 Кб
1
# Owner(s): ["module: ProxyTensor"]
2

3
from torch.testing._internal.common_utils import TestCase, run_tests
4
import torch
5
import unittest
6
import warnings
7
import operator
8
from collections.abc import Iterable
9
from torch.nn.utils import stateless
10
from torch.testing._internal.common_device_type import instantiate_device_type_tests
11
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
12
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
13
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
14
from torch._decomp import decomposition_table
15
from torch.fx.experimental.symbolic_shapes import (
16
    eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
17
    guard_int, GuardOnDataDependentSymNode
18
)
19
from torch.testing._internal.custom_op_db import custom_op_db
20
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
21
from torch.testing._internal.common_device_type import ops
22
import torch.testing._internal.optests as optests
23
from torch._C import _disabled_torch_function_impl
24
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
25
from torch.utils._pytree import tree_map
26
from torch import nn
27
import re
28

29
import functools
30
import itertools
31

32
aten = torch.ops.aten
33

34
HAS_CUDA = torch.cuda.is_available()
35

36

37
def strip_end(s, suffix):
38
    if suffix and s.endswith(suffix):
39
        return s[:-len(suffix)]
40
    else:
41
        return s
42

43

44
def show_guards(gm):
45
    names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
46
    return "\n".join(
47
        gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
48
    )
49

50

51
def process_failures():
52
    """
53
    Takes file containing failures like
54

55
    FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition  # noqa: B950
56

57
    and processes them into a list of opinfo xfails
58
    """
59
    f = open('pytest_failures')
60
    failures = f.readlines()
61
    failures = [i.strip() for i in failures]
62

63
    def process_failure_string(s, matcher):
64
        out = re.search(matcher, s)
65
        return out.groups()
66

67
    SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
68
    failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
69

70
    def create_normalized_name(op):
71
        if op.variant_test_name == '':
72
            s = op.name
73
        else:
74
            s = f"{op.name}.{op.variant_test_name}"
75
        return s.replace('.', '_')
76

77
    remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
78

79
    print("symbolic_tensor_failures = {")
80
    for failure, reason in failures:
81
        print(f"    xfail{remap_opinfo[failure]},  # {reason}")
82
    print("}")
83

84

85
USE_TORCHVISION = False
86
try:
87
    import torchvision
88
    USE_TORCHVISION = True
89
except ImportError:
90
    warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
91
                  "to install it with commands from pytorch.org, post-fixed with "
92
                  "`--no-deps` to avoid overwriting the pytorch installation",
93
                  UserWarning)
94

95

96
def _create_new_input(x):
97
    if not isinstance(x, torch.Tensor):
98
        return x
99
    if x.dtype != torch.float:
100
        return x + 1
101
    if x.is_leaf:
102
        return torch.rand_like(x, requires_grad=x.requires_grad)
103
    else:
104
        return torch.rand_like(x)
105

106
"""
107
Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
108
"""
109
class UnwrapTensor(torch.Tensor):
110
    @staticmethod
111
    def __new__(cls, tensor: torch.Tensor):
112
        r = torch.Tensor._make_wrapper_subclass(
113
            cls,
114
            tensor.size(),
115
            dtype=tensor.dtype,
116
            device=tensor.device,
117
            layout=tensor.layout,
118
            requires_grad=tensor.requires_grad,
119
        )
120
        r._tensor = tensor
121
        return r
122

123
    def __repr__(self):
124
        # TODO: consider all_gather the local tensors for better debugging
125
        return f"UnwrapTensor({self._tensor})"
126

127
    __torch_function__ = _disabled_torch_function_impl
128

129
    @classmethod
130
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
131
        def unwrap(e):
132
            ret = e
133
            if isinstance(e, UnwrapTensor):
134
                ret = e._tensor.cos()
135

136
            return ret
137

138
        args = tree_map(unwrap, args)
139
        kwargs = tree_map(unwrap, kwargs)
140
        return func(*args, **kwargs)
141

142
class TestGenericProxyTensor(TestCase):
143
    # WARNING: if any of your inputs are index tensors, DO NOT use this
144
    # function
145
    def _test(self, f, inps):
146
        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
147
        new_inps = tree_map(_create_new_input, inps)
148
        r1 = fx_f(*new_inps)
149
        r2 = f(*new_inps)
150
        self.assertEqual(r1, r2)
151

152
    def test_pre_dispatch_mode_stack(self):
153
        def f(a):
154
            b = torch.ones(4, 4)
155
            return torch.matmul(a, b)
156
        # We expect to see matmul in the trace - it should NOT be decomposed into mm.
157
        # Also, torch.ones() doesn't show up in the trace.
158
        # This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
159
        # so our mode never sees it - it goes directly to the BackendSelect key.
160
        inp = torch.ones(4, 4)
161
        # Test that make_fx(pre_dispatch=True) clears caches properly.
162
        from torch._dispatch.python import enable_python_dispatcher
163
        with enable_python_dispatcher():
164
            out1 = f(inp)
165
        fx_g = make_fx(f, pre_dispatch=True)(inp)
166
        self.assertExpectedInline(fx_g.code.strip(), """\
167
def forward(self, a_1):
168
    ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
169
    matmul = torch.ops.aten.matmul.default(a_1, ones);  a_1 = ones = None
170
    return matmul""")
171

172
    def test_pre_dispatch_linear(self):
173
        def f(a, b, c):
174
            return torch.nn.functional.linear(a, b, c)
175
        a = torch.ones(4, 4)
176
        b = torch.ones(4, 4)
177
        c = torch.ones(4)
178
        fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
179
        out1 = f(a, b, c)
180
        out2 = fx_g(a, b, c)
181
        self.assertEqual(out1, out2)
182

183
    def test_pre_dispatch_no_grad(self):
184
        def f(a):
185
            b = a.sin()
186
            torch.set_grad_enabled(False)
187
            c = b.cos()
188
            torch.set_grad_enabled(True)
189
            return b + c.sin()
190
        a1 = torch.randn(4, requires_grad=True)
191
        a2 = a1.clone().detach().requires_grad_(True)
192
        a_tmp = a1.clone().detach().requires_grad_(True)
193
        fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
194
        out1 = f(a1)
195
        out2 = fx_g(a2)
196
        self.assertEqual(out1, out2)
197
        out1.sum().backward()
198
        out2.sum().backward()
199
        self.assertEqual(a1.grad, a2.grad)
200

201
    def test_make_fx_simple(self):
202
        def f(x):
203
            return torch.sin(x)
204
        self._test(f, (torch.randn(3),))
205

206
    def test_scalar_device(self, device='cpu'):
207
        def f(a, b):
208
            return a + b
209
        self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
210

211
    def test_isolated_graphmodule(self):
212
        def is_any_sum(gm):
213
            return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
214

215
        def is_any_digamma(gm):
216
            return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
217

218
        def is_any_sigmoid(gm):
219
            return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
220

221
        def inner(x):
222
            return torch.sum(x)
223

224
        def f(x):
225
            gm = get_isolated_graphmodule(inner, (x,), {})
226
            self.assertTrue(is_any_sum(gm))
227
            return x + torch.randn(x.shape)
228

229
        # get_isolated_graphmodule uses make_fx internally that shouldn't be traced
230
        # by the outer make_fx call
231
        traced = make_fx(f)(torch.randn(3))
232
        self.assertFalse(is_any_sum(traced))
233

234
        # When factory functions are used, they should not be traced
235
        # by the outer make_fx call
236
        def inner_with_factory():
237
            val = torch.tensor(float(1))
238
            val.add_(2)
239
            return torch.full((10, 10), val).sum()
240

241
        def f1(x):
242
            gm = get_isolated_graphmodule(inner_with_factory, (), {})
243
            self.assertTrue(is_any_sum(gm))
244
            return torch.sigmoid(x)
245

246
        def f2(x):
247
            gm = get_isolated_graphmodule(f1, (x,), {})
248
            self.assertFalse(is_any_sum(gm))
249
            self.assertTrue(is_any_sigmoid(gm))
250
            return torch.digamma(x)
251

252
        traced = make_fx(f2)(torch.randn(3))
253
        self.assertFalse(is_any_sum(traced))
254
        self.assertFalse(is_any_sigmoid(traced))
255
        self.assertTrue(is_any_digamma(traced))
256

257
        # Verify nested make_fx calls don't make factory functions to be leaked
258
        # into the outer graph. Verify that `make_fx`` itself does not leak its execution.
259
        def f2(x):
260
            gm = make_fx(f1)(x)
261
            self.assertFalse(is_any_sum(gm))
262
            self.assertTrue(is_any_sigmoid(gm))
263
            return torch.digamma(x)
264

265
        traced = make_fx(f2)(torch.randn(3))
266
        self.assertFalse(is_any_sum(traced))
267
        self.assertFalse(is_any_sigmoid(traced))
268
        self.assertTrue(is_any_digamma(traced))
269

270
        # Verify that the `forward`` function of a graph module produced as a
271
        # side effect of an interior `make_fx` is still traced
272
        def f3(x):
273
            gm = make_fx(f1)(x)
274
            self.assertFalse(is_any_sum(gm))
275
            self.assertTrue(is_any_sigmoid(gm))
276
            # `gm.forward`` is still traced
277
            return torch.digamma(gm(x))
278

279
        traced = make_fx(f3)(torch.randn(3))
280
        self.assertFalse(is_any_sum(traced))
281
        self.assertTrue(is_any_sigmoid(traced))
282
        self.assertTrue(is_any_digamma(traced))
283

284
        # Verify interaction with non-ProxyTensor modes
285
        from torch.testing._internal.logging_tensor import LoggingTensorMode
286

287
        def f1_logging(x):
288
            with LoggingTensorMode():
289
                gm = get_isolated_graphmodule(inner_with_factory, (), {})
290
            self.assertTrue(is_any_sum(gm))
291
            return torch.sigmoid(x)
292

293
        def f2_logging(x):
294
            with LoggingTensorMode(), LoggingTensorMode():
295
                gm = get_isolated_graphmodule(f1_logging, (x,), {})
296
            self.assertFalse(is_any_sum(gm))
297
            self.assertTrue(is_any_sigmoid(gm))
298
            return torch.digamma(x)
299

300
        traced = make_fx(f2_logging)(torch.randn(3))
301
        self.assertFalse(is_any_sum(traced))
302
        self.assertFalse(is_any_sigmoid(traced))
303
        self.assertTrue(is_any_digamma(traced))
304

305
        # Verify interaction with another tensor subclass
306
        # This case currently doesn't work and should raise an error
307
        # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
308
        from torch.testing._internal.logging_tensor import LoggingTensor
309

310
        def f1_logging_tensor(x):
311
            gm = get_isolated_graphmodule(inner_with_factory, (), {})
312
            self.assertTrue(is_any_sum(gm))
313
            return torch.sigmoid(x)
314

315
        def f2_logging_tensor(x):
316
            x = LoggingTensor(x)
317
            gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
318
            self.assertFalse(is_any_sum(gm))
319
            self.assertTrue(is_any_sigmoid(gm))
320
            return torch.digamma(x)
321

322
        traced = make_fx(f2_logging_tensor)(torch.randn(3))
323
        self.assertFalse(is_any_sum(traced))
324
        self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor
325
        self.assertTrue(is_any_digamma(traced))
326

327
    # See https://github.com/pytorch/pytorch/issues/97541
328
    def test_empty_like_doesnt_burn_in_defaults(self):
329
        def f(x):
330
            return torch.empty_like(x)
331
        out = make_fx(f)(torch.randn(3))
332
        self.assertExpectedInline(out.code.strip(), """\
333
def forward(self, x_1):
334
    empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False);  x_1 = None
335
    return empty_like""")
336

337
    def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
338
        def f(x):
339
            y = x.new_zeros(x.size())
340
            y.copy_(x)
341
            return y
342

343
        def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
344
            return torch.zeros(size, dtype=inp.dtype, device=inp.device)
345

346
        factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
347

348
        # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode
349
        # to still be (re-entrantly) enabled, so that the `torch.zero()` call
350
        # returns a ProxyTensor.
351
        out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
352
        self.assertExpectedInline(out.code, """\
353

354

355

356
def forward(self, x_1):
357
    zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
358
    copy_ = torch.ops.aten.copy_.default(zeros, x_1);  zeros = x_1 = None
359
    return copy_
360
    """)
361

362
    def test_make_fx_reentrant_dispatch(self):
363
        def f(x):
364
            return torch.ops.aten.norm.Scalar(x, 2.0)
365

366
        def norm_decomp(x, p=2.0):
367
            if p != 2.0:
368
                raise RuntimeError("can't handle with p != 2")
369
            return torch.sqrt(torch.sum(torch.square(x)))
370

371
        decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
372

373
        traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
374

375
        for n in traced.graph.nodes:
376
            self.assertTrue("square" not in str(n.target))
377
            self.assertTrue("norm" not in str(n.target))
378

379
    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
380
    def test_resnet18_backward_trace(self):
381
        mod = torchvision.models.resnet18()
382

383
        # An old version of this test called the module directly.  This works
384
        # for tracing_mode == "real", but for fake tensors, we also have to
385
        # ensure that the parameters and buffers get wrapped in fake tensors
386
        # because free fake tensors are not supported.  Fortunately functional_call
387
        # does precisely this for us.
388
        def f(x, params, buffers):
389
            for p in params.values():
390
                p.grad = None
391
            loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
392
            # I could have done this with the functional API, but there is
393
            # plenty of exercising this; I want to show mutating API still
394
            # works
395
            loss.backward()
396
            return [p.grad for p in params.values()]
397

398
        inp = torch.randn(3, 3, 250, 250)
399
        self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
400

401
    def test_varargs(self):
402
        def f(*args):
403
            return sum(args)
404

405
        self._test(f, [torch.randn(2), torch.randn(2)])
406

407
    def test_proxy_tensor(self):
408
        def f_grad(x):
409
            val = x.cos().cos().sum()
410
            return torch.autograd.grad(val, x)
411

412
        def f_backward(x):
413
            val = x.cos().cos().sum()
414
            val.backward()
415
            return x.grad
416

417
        for f in [f_grad, f_backward]:
418
            self._test(f, [torch.randn(3, requires_grad=True)])
419

420
    def test_pickle_issue89626(self):
421
        import pickle
422
        x = torch.randn(2)
423
        make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
424
        pickle.dumps(x)
425

426
    def test_inplace_metadata(self):
427
        def f(x):
428
            x = x.clone()
429
            x.unsqueeze_(-1)
430
            assert x.shape[-1] == 1
431
            return x
432

433
        self._test(f, [torch.randn(5)])
434

435
    def test_mode_tracing_factory_function(self):
436
        def f(x):
437
            return x + torch.randn(x.shape)
438

439
        # default behavior should trace factory functions
440
        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
441
        self.assertTrue(
442
            any(
443
                node.target == aten.randn.default
444
                for node in traced.graph.nodes
445
            )
446
        )
447

448
    def test_pre_dispatch_functionalization(self):
449
        def f(x):
450
            a = FunctionalTensorMode(pre_dispatch=True)
451
            with a:
452
                x_unwrapped = FunctionalTensor.to_functional(x)
453
                y = torch.matmul(x_unwrapped, x_unwrapped)
454
                y = y + x_unwrapped
455
                y.mul_(5)
456
                y_unwrapped = torch._from_functional_tensor(y.elem)
457
                return y_unwrapped
458

459
        from torch._dispatch.python import enable_python_dispatcher
460

461
        with enable_python_dispatcher():
462
            inp = torch.randn(4, 4)
463
            gm = make_fx(f, pre_dispatch=True)(inp)
464

465
        # TODO actually not decompose
466
        self.assertExpectedInline(gm.code.strip(), """\
467
def forward(self, x_1):
468
    matmul = torch.ops.aten.matmul.default(x_1, x_1)
469
    add = torch.ops.aten.add.Tensor(matmul, x_1);  matmul = x_1 = None
470
    mul = torch.ops.aten.mul.Tensor(add, 5);  add = None
471
    return mul""")
472

473
    def test_pre_dispatch_functionalization_view_op(self):
474
        def f(x):
475
            a = FunctionalTensorMode(pre_dispatch=True)
476
            with a:
477
                x_unwrapped = FunctionalTensor.to_functional(x)
478
                y = torch.matmul(x_unwrapped, x_unwrapped)
479
                x_unwrapped = x_unwrapped.transpose(1, 0)
480
                y = y + x_unwrapped
481
                y = y.view(2, 8)
482
                y_unwrapped = torch._from_functional_tensor(y.elem)
483
                return y_unwrapped
484

485
        from torch._dispatch.python import enable_python_dispatcher
486

487
        with enable_python_dispatcher():
488
            inp = torch.randn(4, 4)
489
            gm = make_fx(f, pre_dispatch=True)(inp)
490

491
        # TODO actually not decompose
492
        self.assertExpectedInline(gm.code.strip(), """\
493
def forward(self, x_1):
494
    matmul = torch.ops.aten.matmul.default(x_1, x_1)
495
    transpose = torch.ops.aten.transpose.int(x_1, 1, 0);  x_1 = None
496
    add = torch.ops.aten.add.Tensor(matmul, transpose);  matmul = transpose = None
497
    view = torch.ops.aten.view.default(add, [2, 8]);  add = None
498
    return view""")
499

500
    def test_val_metadata_mutation(self):
501
        def f(x):
502
            y = x.clone()
503
            y.unsqueeze_(0)
504
            return y
505

506
        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True))
507
        self.assertEqual([
508
            tuple(node.meta['val'].shape)
509
            for node in traced.graph.nodes
510
            if 'val' in node.meta
511
        ], [(3,), (3,), (1, 3)])
512

513
    def test_make_fx_overloads(self):
514
        def f(x):
515
            return x.cos() + torch.randn(x.shape)
516

517
        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
518

519
        self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
520
                            for node in traced.graph.nodes if node.op == 'call_function'))
521

522
    def test_tensor_constants(self):
523
        def f():
524
            val = torch.tensor(float('inf'))
525
            return torch.full((100, 100), val)
526

527
        self._test(f, [])
528

529
    def test_allclose(self):
530
        def f(a, b):
531
            return torch.allclose(a, b)
532

533
        def test_f():
534
            make_fx(f, tracing_mode=self.tracing_mode)(
535
                torch.zeros(3), torch.zeros(3)
536
            )
537

538
        if self.tracing_mode != "real":
539
            self.assertRaises(DataDependentOutputException, test_f)
540
        else:
541
            self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
542

543
    def test_constant_proxy_tensor_mut(self):
544
        def f():
545
            val = torch.tensor(float(1))
546
            val.add_(2)
547
            return torch.full((100, 100), val)
548

549
        g = make_fx(f, tracing_mode=self.tracing_mode)()
550
        self.assertEqual(g(), f())
551
        # In case we mutated shared state in the g graph!
552
        self.assertEqual(g(), f())
553

554
    def test_constant_unbind(self):
555
        def f():
556
            val = torch.tensor([2])
557
            r, = torch.unbind(val, 0)
558
            return r.item()
559

560
        g = make_fx(f, tracing_mode=self.tracing_mode)()
561
        self.assertEqual(g(), f())
562

563
    def test_constant_blowup(self):
564
        def f():
565
            val = torch.tensor([2])
566
            blowup = val.repeat(1000)
567
            return bool(blowup.sum().item() == 2)
568

569
        def test_f():
570
            make_fx(f, tracing_mode=self.tracing_mode)()
571

572
        self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
573

574
    def test_constant_random(self):
575
        def f():
576
            val = torch.tensor([2.0])
577
            val.normal_()
578
            return bool(val.item() == 2.1)
579

580
        def test_f():
581
            make_fx(f, tracing_mode=self.tracing_mode)()
582

583
        self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
584

585
    def test_decomposition_interpreter(self):
586
        def fn(x):
587
            return torch.nn.functional.silu(x)
588

589
        x = torch.rand((4, 4))
590
        fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
591

592
        found_silu = False
593
        for n in fx_module.graph.nodes:
594
            if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
595
                found_silu = True
596

597
        self.assertTrue(found_silu)
598

599
        new_graph = torch.fx.Graph()
600
        silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
601
        DecompositionInterpreter(
602
            fx_module,
603
            new_graph=new_graph,
604
            decomposition_table=silu_decomp_table,
605
        ).run(x)
606

607
        decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
608

609
        for n in decomposed_module.graph.nodes:
610
            self.assertTrue(n.target != torch.ops.aten.silu)
611
            self.assertTrue(n.target != torch.ops.aten.silu.default)
612

613
        self.assertEqual(fx_module(x), decomposed_module(x))
614

615
    def test_make_fx_model_fwd_bwd(self):
616
        class Foo(torch.nn.Module):
617
            def __init__(self):
618
                super().__init__()
619
                self.linear = torch.nn.Linear(5, 5)
620

621
            def forward(self, x):
622
                return self.linear(x).relu()
623

624
        model = Foo()
625

626
        def f(x, params):
627
            out = torch.func.functional_call(model, params, x).sum()
628
            out.backward()
629
            return list(params.values())
630
        input = torch.randn(3, 5, requires_grad=True)
631
        params = dict(model.named_parameters())
632
        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
633
        # fx may change the order of parameters in list, so using set() to compare
634
        self.assertTrue(
635
            torch.allclose(fx_f(input, params)[0], f(input, params)[0])
636
            or
637
            torch.allclose(fx_f(input, params)[0], f(input, params)[1])
638
        )
639
        self.assertTrue(
640
            torch.allclose(fx_f(input, params)[1], f(input, params)[0])
641
            or
642
            torch.allclose(fx_f(input, params)[1], f(input, params)[1])
643
        )
644

645
    def test_make_fx_model_double_param(self):
646
        class Emformer(torch.nn.Module):
647
            def __init__(
648
                self,
649
                input_dim: int = 256,
650
            ) -> None:
651
                super().__init__()
652

653
                self.layer_norm = torch.nn.LayerNorm(input_dim)
654

655
            def forward(mod_self, x):  # noqa: B902
656
                self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
657
                y = mod_self.layer_norm(x)
658
                self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
659
                z = mod_self.layer_norm(y)
660
                return z
661

662

663
        gm = make_fx(Emformer())(torch.randn(16, 1, 256))
664
        ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
665
        self.assertEqual(len(ops), 2)
666

667

668
    def test_make_fx_model_fwd_bwd_wgtupdate(self):
669
        class Foo(torch.nn.Module):
670
            def __init__(self):
671
                super().__init__()
672
                self.linear = torch.nn.Linear(5, 5)
673

674
            def forward(self, x):
675
                return self.linear(x).relu()
676

677
        model = Foo()
678

679
        def f(args, params, buffers):
680
            for p in params.values():
681
                p.grad = None
682
            if not isinstance(args, Iterable):
683
                args = [args]
684
            params_and_buffers = {**params, **buffers}
685
            out = torch.func.functional_call(model, params_and_buffers, args)
686
            out.sum().backward()
687
            return [p - 1e-4 * p.grad for p in params.values()]
688

689
        input = torch.randn(3, 5, requires_grad=True)
690
        params = dict(model.named_parameters())
691
        buffers = dict(model.named_buffers())
692
        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
693
        # fx may change the order of parameters in list, so using set() to compare
694
        # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
695
        self.assertTrue(
696
            torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
697
            or
698
            torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
699
        )
700
        self.assertTrue(
701
            torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
702
            or
703
            torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
704
        )
705

706
    def test_trace_subclasses(self):
707
        def f1(x):
708
            x = UnwrapTensor(x)
709
            y = x * 2
710
            return y
711

712
        def f2(x):
713
            wrapped = UnwrapTensor(x)
714
            y = x * wrapped
715
            return y
716

717
        inp = [torch.randn(5)]
718
        self._test(f1, inp)
719
        self._test(f2, inp)
720

721
    def test_partial_decomp(self):
722
        def f(a, b, c):
723
            x = torch.addmm(a, b, c)
724
            y = torch.addmm(a, b, c, beta=2, alpha=1)
725
            return x + y
726
        inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
727
        fx_g = make_fx(f)(*inps)
728

729
        def addmm(a, b, c, beta=1, alpha=1):
730
            if beta == 1 and alpha == 1:
731
                return NotImplemented
732
            return beta * a + alpha * (b @ c)
733

734
        decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps)
735

736
        self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
737
        self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
738
        self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
739

740
    def test_decomp_of_capture(self):
741
        val = torch.randn(5)
742

743
        def f(x):
744
            return x.t() + val.t()
745

746
        def nop(x):
747
            return x.cos()
748

749
        traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
750
        self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
751

752

753
    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
754
    def test_amp_cache(self):
755
        layer = torch.nn.Conv2d(3, 3, 3).cuda()
756

757
        def f(x, w):
758
            return torch.nn.functional.conv2d(x, w, stride=layer.stride)
759

760
        inp = torch.randn(4, 3, 10, 10, device='cuda')
761
        with torch.autocast('cuda'):
762
            out_graph = make_fx(f)(inp, layer.weight).graph
763
            out_graph2 = make_fx(f)(inp, layer.weight).graph
764

765
        self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
766
        for a, b in zip(out_graph.nodes, out_graph2.nodes):
767
            self.assertEqual(a.op, b.op)
768

769
    def test_strides(self):
770
        def f(x):
771
            self.assertTrue(x.is_contiguous())
772
            self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
773
            x = x.permute(0, 3, 1, 2)
774
            self.assertFalse(x.is_contiguous())
775
            self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
776
            return x
777
        make_fx(f)(torch.randn(2, 3, 4, 5))
778

779
        def f(x):
780
            self.assertTrue(x.is_contiguous())
781
            y = x[:, 1]
782
            self.assertFalse(y.is_contiguous())
783
            y = x[:, ::2]
784
            self.assertFalse(y.is_contiguous())
785
            return x.cos()
786

787
        make_fx(f)(torch.randn(2, 3, 4, 5))
788

789
    def test_pr_86917(self):
790
        # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344
791
        def f(a, b):
792
            return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10)
793

794
        self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
795

796
class TestGenericProxyTensorReal(TestGenericProxyTensor):
797
    tracing_mode = "real"
798

799

800
class TestGenericProxyTensorFake(TestGenericProxyTensor):
801
    tracing_mode = "fake"
802

803

804
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
805
    tracing_mode = "symbolic"
806

807

808
del TestGenericProxyTensor
809

810

811
class TestRealProxyTensor(TestCase):
812
    def test_error_on_data_dependent_ops(self):
813
        def f():
814
            x = torch.randn([])
815
            y = torch.randn([])
816
            assert torch.allclose(x * y, y * x)
817
            z = float(x)
818
            z2 = float(y)
819

820
        # Smoke tests
821
        make_fx(f, _error_on_data_dependent_ops=False)()
822
        make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)()
823

824
class TestFakeProxyTensor(TestCase):
825
    def test_issue82547(self):
826
        x = nn.Parameter(torch.randn(3, 3))
827

828
        def f():
829
            return torch.ops.aten.t.default(x)
830
        self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")())
831

832
        class A(torch.Tensor):
833
            pass
834

835
        x = A(torch.randn(3, 3))
836
        self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")())
837

838
    def test_use_fake_and_tensor(self):
839
        def f(x, y):
840
            z = torch.tensor([2.0, 3.0])
841
            return x + y + z
842

843
        g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
844
        x, y = torch.randn(2), torch.randn(2)
845
        self.assertEqual(g(x, y), f(x, y))
846

847
    def test_free_fake(self):
848
        def f(x):
849
            return torch.add(x, y)
850

851
        with FakeTensorMode() as fake_mode:
852
            y = torch.randn(2)
853
            make_fx(f, tracing_mode="real")(torch.randn(2))
854

855
    def test_fused_adam(self):
856
        # See https://github.com/pytorch/pytorch/issues/99356
857
        params = [torch.randn(10, 10) for _ in range(10)]
858
        grads = [torch.randn(10, 10) for _ in range(10)]
859
        exp_avgs = [torch.randn(10, 10) for _ in range(10)]
860
        exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
861
        max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
862
        state_steps = [torch.tensor(0) for _ in range(10)]
863

864
        def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
865
            (new_params, _, _, _, _) = aten._fused_adam.default(
866
                params,
867
                grads,
868
                exp_avgs,
869
                exp_avg_sqs,
870
                max_exp_avg_sqs,
871
                state_steps,
872
                lr=0.1,
873
                beta1=0.9,
874
                beta2=0.999,
875
                weight_decay=0.01,
876
                eps=1e-8,
877
                amsgrad=False,
878
                maximize=False,
879
            )
880

881
            for p, new_p in zip(params, new_params):
882
                p.copy_(new_p)
883

884
            return params
885

886
        gm = make_fx(fused_adam, tracing_mode='fake')(
887
            params,
888
            grads,
889
            exp_avgs,
890
            exp_avg_sqs,
891
            max_exp_avg_sqs,
892
            state_steps,
893
        )
894
        ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
895
        for n in gm.graph.nodes:
896
            if n.op == "call_function" and n.target in ensure_ops_have_val:
897
                self.assertIn('val', n.meta)
898

899
    def test_alias(self):
900
        def f(x):
901
            return torch.ops.aten.alias(x)
902

903
        r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
904
        # NB: this should not have a detach call
905
        self.assertExpectedInline(r, """\
906
def forward(self, x_1):
907
    alias = torch.ops.aten.alias.default(x_1);  x_1 = None
908
    return alias""")
909

910
    def test_meta(self):
911
        def f(x):
912
            a = x.cos()
913
            b = torch.var_mean(a, dim=0)
914
            c = b * 2
915
            return c
916

917
        out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5))
918
        for n in out.graph.nodes:
919
            if n.op == 'output':
920
                continue
921
            self.assertTrue('val' in n.meta)
922

923
def _get_node(fx_g, cond):
924
    for n in fx_g.graph.nodes:
925
        if cond(n):
926
            return n
927
    raise AssertionError
928

929
def _get_free_symbols(shape_env):
930
    vars = tuple(shape_env.var_to_val.keys())
931
    return len([var for var in vars if var not in shape_env.replacements])
932

933
def _trace(f, *args):
934
    inps = [torch.randn(arg) for arg in args]
935
    return make_fx(f, tracing_mode="symbolic")(*inps)
936

937
# TODO: Need to test the guards themselves specifically as well
938
class TestSymbolicTracing(TestCase):
939
    def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
940
        """
941
        Tests fn traced with trace_inputs against test_inputs
942
        Also returns shape env
943
        """
944
        trace_inputs = [torch.randn(shape) for shape in trace_inputs]
945
        traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
946
        for input in test_inputs:
947
            input = [torch.randn(shape) for shape in input]
948
            rx, ry = traced_f(*input), fn(*input)
949
            if assert_eq:
950
                self.assertEqual(rx, ry)
951
        return traced_f
952

953

954
    def test_debug_interpreter(self):
955
        import torch.library
956
        from torch.library import Library
957

958
        foo = Library("foo", "DEF")  # noqa: TOR901
959
        foo.define("foo(Tensor self) -> Tensor")
960

961
        # Operator where meta and cpu disagree on strides
962
        @torch.library.impl(foo, "foo", "CPU")
963
        def foo_cpu(x):
964
            return x.clone().T
965

966
        @torch.library.impl(foo, "foo", "Meta")
967
        def foo_meta(x):
968
            return x.clone()
969

970
        def f(x):
971
            return torch.ops.foo.foo.default(x)
972

973
        gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2))
974
        from torch._functorch.compilers import DebugInterpreter
975

976
        interp = DebugInterpreter(gm)
977

978
        # input mismatch is caught (indicates guard problem)
979
        self.assertRaisesRegex(
980
            AssertionError, r"3 != 1",
981
            lambda: interp.run(torch.randn(3, 3).T),
982
        )
983

984
        # Catch the incorrect meta
985
        self.assertRaisesRegex(
986
            AssertionError, r"\(3, 1\) != \(1, 3\)",
987
            lambda: interp.run(torch.randn(3, 3))
988
        )
989

990
    def test_int_input(self):
991
        def f(x, y):
992
            return x.view(y)
993

994
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip()
995
        self.assertExpectedInline(r, """\
996
def forward(self, x_1, y_1):
997
    view = torch.ops.aten.view.default(x_1, [y_1]);  x_1 = y_1 = None
998
    return view""")
999

1000
    def test_resize_from_zero(self):
1001
        def f(x, y):
1002
            x.resize_(y.size(0))
1003

1004
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
1005
        self.assertExpectedInline(r, """\
1006
def forward(self, x_1, y_1):
1007
    sym_size_int = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1008
    resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]);  x_1 = sym_size_int = None
1009
    return None""")
1010

1011
    def test_broadcast_shapes(self):
1012
        def f(x, y):
1013
            return torch.functional.broadcast_shapes(x.size(), y.size()[0])
1014

1015
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 1), torch.empty(5)).code).strip()
1016
        self.assertExpectedInline(r, """\
1017
def forward(self, x_1, y_1):
1018
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0);  x_1 = None
1019
    sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1020
    return (sym_size_int, sym_size_int_1)""")
1021

1022
    def test_deduped_shape(self):
1023
        def f(s0, s1, x, y):
1024
            return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1025

1026
        x = torch.empty(3, 1)
1027
        y = torch.empty(5)
1028
        from torch.fx.experimental.symbolic_shapes import ShapeEnv
1029
        shape_env = ShapeEnv()
1030

1031
        with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1032
            x = fake_mode.from_tensor(x)
1033
            y = fake_mode.from_tensor(y)
1034
            r = str(make_fx(f, tracing_mode="real")(x.shape[0], y.shape[0], x, y).code).strip()
1035
            self.assertExpectedInline(r, """\
1036
def forward(self, s0_1, s1_1, x_1, y_1):
1037
    empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False)
1038
    return ((s0_1, s1_1), empty)""")
1039

1040
    def test_non_deduped_shape(self):
1041
        def f(x, y):
1042
            return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1043

1044
        x = torch.empty(3, 1)
1045
        y = torch.empty(5)
1046
        from torch.fx.experimental.symbolic_shapes import ShapeEnv
1047
        shape_env = ShapeEnv()
1048

1049
        with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1050
            x = fake_mode.from_tensor(x)
1051
            y = fake_mode.from_tensor(y)
1052
            r = str(make_fx(f, tracing_mode="real")(x, y).code).strip()
1053
            self.assertExpectedInline(r, """\
1054
def forward(self, x_1, y_1):
1055
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0);  x_1 = None
1056
    empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
1057
    sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1058
    return ((sym_size_int, sym_size_int_1), empty)""")
1059

1060
    def test_unary(self):
1061
        def f(x):
1062
            assert x.shape[0] < 20
1063
            return x.cos()
1064
        test_inputs = []
1065
        test_inputs.append([(2, 5)])
1066
        test_inputs.append([(6, 8)])
1067
        gm = self._test_dynamic(f, [(3, 4)], test_inputs)
1068
        self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
1069
        self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
1070
        self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
1071
        self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
1072

1073
    def test_repeat_interleave(self):
1074
        def f(src_tokens, beam_size_src):
1075
            return src_tokens.repeat_interleave(beam_size_src.size(0), 0)
1076

1077
        prompt_size = 64
1078
        vocab_size = 64
1079
        batch_size = 4
1080
        src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1081
        gm = make_fx(f, tracing_mode="symbolic")(src_tokens, torch.randn(5))
1082
        self.assertEqual(len(gm.shape_env.guards), 0)
1083

1084
    def test_non_symint_size_spec(self):
1085
        # this isn't really a proxy tensor test, but it's the most convenient
1086
        # way to get a fake tensor with symbolic sizes
1087
        def f(x):
1088
            torch._C._non_sym_sizes(x)
1089
            return x + 1
1090

1091
        x = torch.randn(2, 3)
1092
        make_fx(f, tracing_mode="symbolic")(x)
1093

1094
    # https://github.com/pytorch/pytorch/issues/108195
1095
    def test_symbolic_repeat_interleave(self):
1096
        def f(y, x):
1097
            return y.repeat_interleave(x, dim=1)
1098

1099
        y = torch.tensor([[1, 2], [3, 4]])
1100
        x = torch.tensor([2, 3])
1101
        r = str(make_fx(f, tracing_mode="symbolic")(y, x).code).strip()
1102
        self.assertExpectedInline(r, """\
1103
def forward(self, y_1, x_1):
1104
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1);  x_1 = None
1105
    index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave);  y_1 = repeat_interleave = None
1106
    return index_select""")
1107

1108
    def test_cumsum_unbacked(self):
1109
        def f(x):
1110
            y = x.item()
1111
            z = torch.randn((3, y, 3))
1112
            return z.cumsum(0)
1113

1114
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([5])).code).strip()
1115
        self.assertExpectedInline(
1116
            r, """\
1117
def forward(self, x_1):
1118
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1119
    randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1120
    cumsum = torch.ops.aten.cumsum.default(randn, 0);  randn = None
1121
    return cumsum"""  # noqa: B950
1122
        )
1123

1124

1125
    def test_repeat_interleave_unbacked_output_size(self):
1126
        def f(x, y):
1127
            s = x.sum().item()
1128
            return y.repeat_interleave(x, dim=0, output_size=s)
1129

1130
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip()
1131
        self.assertExpectedInline(
1132
            r, """\
1133
def forward(self, x_1, y_1):
1134
    sum_1 = torch.ops.aten.sum.default(x_1)
1135
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1);  sum_1 = None
1136
    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense);  x_1 = _local_scalar_dense = None
1137
    index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave);  y_1 = repeat_interleave = None
1138
    return index_select"""  # noqa: B950
1139
        )
1140

1141
    def test_arange_unbacked_output_size(self):
1142
        def f(x):
1143
            return torch.arange(0, x)
1144

1145
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10)).code).strip()
1146
        self.assertExpectedInline(
1147
            r, """\
1148
def forward(self, x_1):
1149
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1150
    arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1151
    return arange"""  # noqa: B950
1152
        )
1153

1154
    def test_adv_index_batch(self):
1155
        def f(src_tokens):
1156
            bsz, src_len = src_tokens.size()[:2]
1157
            start_step = src_tokens.shape[1]
1158
            beam_size = 1
1159
            generate_size = 64
1160
            max_len = src_len + generate_size
1161
            tokens = torch.zeros(bsz * beam_size, max_len).to(src_tokens).long().fill_(0)
1162
            tokens[:, :start_step] = src_tokens.repeat_interleave(beam_size, 0)
1163
            return tokens
1164

1165
        prompt_size = 64
1166
        vocab_size = 64
1167
        batch_size = 4
1168
        src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1169
        gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
1170
        self.assertEqual(len(gm.shape_env.guards), 0)
1171

1172
    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1173
    def test_cpu_scalar_cuda(self):
1174
        # Extracted from wave2vec2
1175
        def f(a, b):
1176
            return (a * b) @ b
1177

1178
        r = str(
1179
            make_fx(f, tracing_mode="symbolic")(
1180
                torch.tensor(1.0), torch.randn(2, 2, device='cuda')
1181
            ).code
1182
        ).strip()
1183
        self.assertExpectedInline(r, """\
1184
def forward(self, a_1, b_1):
1185
    mul = torch.ops.aten.mul.Tensor(a_1, b_1);  a_1 = None
1186
    mm = torch.ops.aten.mm.default(mul, b_1);  mul = b_1 = None
1187
    return mm""")
1188

1189
    def test_binary_broadcast(self):
1190
        def f(a, b):
1191
            c = a * b
1192
            return c
1193

1194
        test_inputs = []
1195
        test_inputs.append([(1, 5), (3, 1)])
1196
        test_inputs.append([(1, 4), (4, 1)])
1197
        shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env
1198
        assert len(shape_env.guards) == 0
1199

1200
    def test_multiply_shape(self):
1201
        def f(a):
1202
            return torch.empty(a.shape[0] * 2)
1203

1204
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1205
        self.assertExpectedInline(r, """\
1206
def forward(self, a_1):
1207
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0);  a_1 = None
1208
    mul = sym_size_int * 2;  sym_size_int = None
1209
    empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False);  mul = None
1210
    return empty""")
1211

1212
    def test_item(self):
1213
        def f(a):
1214
            r = a.item()
1215
            return r * a
1216

1217
        r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip()
1218
        self.assertExpectedInline(r, """\
1219
def forward(self, a_1):
1220
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1)
1221
    mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense);  a_1 = _local_scalar_dense = None
1222
    return mul""")
1223

1224
    def test_tensor_symfloat(self):
1225
        def f(a):
1226
            r = torch.tensor(a.size(0) ** 2.0)
1227
            assert r.dtype is torch.float
1228
            return r
1229

1230
        gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2))
1231
        r = str(gm.code).strip()
1232
        # NB: this specializes, which is fine, the point is to make sure the
1233
        # dtype inference is correct
1234
        self.assertExpectedInline(r, """\
1235
def forward(self, a_1):
1236
    _tensor_constant0 = self._tensor_constant0
1237
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
1238
    return lift_fresh_copy""")
1239
        self.assertEqual(gm._tensor_constant0, torch.tensor(4.0))
1240

1241
    def test_item_to_constructor(self):
1242
        def f(a):
1243
            r = a.item()
1244
            return torch.empty(r)
1245

1246
        r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
1247
        self.assertExpectedInline(
1248
            r, """\
1249
def forward(self, a_1):
1250
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1);  a_1 = None
1251
    empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1252
    return empty"""  # noqa: B950
1253
        )
1254

1255

1256
    def test_setitem_symint(self):
1257
        # from moco
1258
        # https://github.com/pytorch/pytorch/issues/101939
1259
        def f(x):
1260
            x[0] = x.size(0)
1261
            return x
1262

1263
        r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(10)).code).strip()
1264
        self.assertExpectedInline(
1265
            r, """\
1266
def forward(self, x_1):
1267
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1268
    scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  sym_size_int = None
1269
    select = torch.ops.aten.select.int(x_1, 0, 0)
1270
    copy_ = torch.ops.aten.copy_.default(select, scalar_tensor);  select = scalar_tensor = None
1271
    return x_1"""  # noqa: B950
1272
        )
1273

1274
    def test_dynamic_pointwise_scalar(self):
1275
        def f(gravity, mask):
1276
            gravity[mask, 0] = gravity[mask, 0] * -1
1277

1278
        r = str(make_fx(f, tracing_mode="symbolic")(
1279
            torch.randn((12, 4)),
1280
            torch.randint(0, 2, (12,), dtype=torch.bool)
1281
        ).code).strip()
1282
        self.assertExpectedInline(r, """\
1283
def forward(self, gravity_1, mask_1):
1284
    select = torch.ops.aten.select.int(gravity_1, 1, 0)
1285
    index = torch.ops.aten.index.Tensor(select, [mask_1]);  select = None
1286
    mul = torch.ops.aten.mul.Tensor(index, -1);  index = None
1287
    select_1 = torch.ops.aten.select.int(gravity_1, 1, 0);  gravity_1 = None
1288
    index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul);  select_1 = mask_1 = mul = None
1289
    return None""")
1290

1291
    def test_reflect_r_over_x(self):
1292
        def reflect_R_over_x(R):
1293
            reflect = torch.eye(3, device=R.device)
1294
            reflect[0, 0] = -1
1295
            return reflect @ R @ reflect
1296

1297
        def f(crop_camera, mask):
1298
            crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
1299

1300
        r = str(make_fx(f, tracing_mode="symbolic")(
1301
            torch.randn((12, 3, 3)),
1302
            torch.randint(0, 2, (12,), dtype=torch.bool)
1303
        ).code).strip()
1304
        self.assertExpectedInline(r, """\
1305
def forward(self, crop_camera_1, mask_1):
1306
    index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
1307
    eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
1308
    _tensor_constant0 = self._tensor_constant0
1309
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
1310
    select = torch.ops.aten.select.int(eye, 0, 0)
1311
    select_1 = torch.ops.aten.select.int(select, 0, 0);  select = None
1312
    copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy);  select_1 = lift_fresh_copy = None
1313
    sym_size_int = torch.ops.aten.sym_size.int(index, 0)
1314
    expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
1315
    view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]);  expand = None
1316
    sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
1317
    sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
1318
    expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]);  index = None
1319
    view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  expand_1 = sym_size_int_1 = sym_size_int_2 = None
1320
    bmm = torch.ops.aten.bmm.default(view, view_1);  view = view_1 = None
1321
    view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]);  bmm = None
1322
    mul = sym_size_int * 3
1323
    view_3 = torch.ops.aten.view.default(view_2, [mul, 3]);  view_2 = mul = None
1324
    mm = torch.ops.aten.mm.default(view_3, eye);  view_3 = eye = None
1325
    view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]);  mm = sym_size_int = None
1326
    index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4);  crop_camera_1 = mask_1 = view_4 = None
1327
    return None""")  # noqa: B950
1328

1329
    def test_unbacked_slice(self):
1330
        def f(x, m):
1331
            x = x[m]
1332
            return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
1333

1334
        make_fx(f, tracing_mode="symbolic")(
1335
            torch.randn((12, 3, 3)),
1336
            torch.randint(0, 2, (12,), dtype=torch.bool)
1337
        )
1338

1339
    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
1340
    def test_unbacked_batch_resnet(self):
1341
        mod = torchvision.models.resnet18()
1342

1343
        def f(x, mask, params, buffers):
1344
            for p in itertools.chain([x, mask], params.values(), buffers.values()):
1345
                for s in p.shape:
1346
                    guard_int(s)
1347
            x = x[mask]
1348
            torch._constrain_as_value(x.shape[0], min=1)
1349
            for p in params.values():
1350
                p.grad = None
1351
            return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
1352

1353
        make_fx(f, tracing_mode="symbolic")(
1354
            torch.randn(3, 3, 250, 250),
1355
            torch.randint(0, 2, (3,), dtype=torch.bool),
1356
            dict(mod.named_parameters()),
1357
            dict(mod.named_buffers()),
1358
        )
1359

1360
    def test_boolean_index(self):
1361
        def f(images, handedness, valid):
1362
            images = images[valid]
1363
            handedness = handedness[valid]
1364
            right_hand_mask = handedness == 1
1365
            images[right_hand_mask] = images[right_hand_mask].flip(-1)
1366

1367
        r = str(make_fx(f, tracing_mode="symbolic")(
1368
            torch.randint(0, 256, (512, 1, 96, 96)),
1369
            torch.randint(0, 1, (512,)),
1370
            torch.randint(0, 2, (512,), dtype=torch.bool)
1371
        ).code).strip()
1372
        self.assertExpectedInline(r, """\
1373
def forward(self, images_1, handedness_1, valid_1):
1374
    index = torch.ops.aten.index.Tensor(images_1, [valid_1]);  images_1 = None
1375
    index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]);  handedness_1 = valid_1 = None
1376
    eq = torch.ops.aten.eq.Scalar(index_1, 1);  index_1 = None
1377
    index_2 = torch.ops.aten.index.Tensor(index, [eq])
1378
    flip = torch.ops.aten.flip.default(index_2, [-1]);  index_2 = None
1379
    index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip);  index = eq = flip = None
1380
    return None""")
1381

1382
    def test_neg_shape(self):
1383
        def f(a):
1384
            return torch.empty(-a.shape[0] + 10)
1385

1386
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
1387
        self.assertExpectedInline(r, """\
1388
def forward(self, a_1):
1389
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0);  a_1 = None
1390
    neg = -sym_size_int;  sym_size_int = None
1391
    add = neg + 10;  neg = None
1392
    empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False);  add = None
1393
    return empty""")
1394

1395
    def test_unbacked_unification(self):
1396
        def f(x, y):
1397
            z = torch.zeros(x.item())
1398
            return z + y
1399

1400
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1401
        self.assertExpectedInline(r, """\
1402
def forward(self, x_1, y_1):
1403
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1404
    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1405
    add = torch.ops.aten.add.Tensor(zeros, y_1);  zeros = y_1 = None
1406
    return add""")  # noqa: B950
1407

1408
    def test_view_divisibility_unbacked(self):
1409
        def f(x):
1410
            i0 = x.item()
1411
            r = torch.zeros(i0, 192)
1412
            return r.view(12, -1, 192)
1413
        make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
1414

1415
    def test_unbacked_unify_guard(self):
1416
        def f(x, y):
1417
            z = torch.zeros(x.item())
1418
            torch._check(z.size(0) == y.size(0))  # refines i0 = s0
1419
            if z.size(0) == 4:
1420
                return y * 2
1421
            else:
1422
                return y + 2
1423

1424
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1425
        self.assertExpectedInline(r, """\
1426
def forward(self, x_1, y_1):
1427
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1428
    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1429
    add = torch.ops.aten.add.Tensor(y_1, 2);  y_1 = None
1430
    return add""")  # noqa: B950
1431

1432
    def test_unbacked_unify_guard_transitivity(self):
1433
        def f(x1, x2, y):
1434
            z1 = torch.zeros(x1.item())
1435
            z2 = torch.zeros(x2.item())
1436
            torch._check(z1.size(0) == z2.size(0))  # refines i0 = i1
1437
            torch._check(z2.size(0) == y.size(0))  # refines i0 = s0
1438
            if z1.size(0) == 4:
1439
                return y * 2
1440
            else:
1441
                return y + 2
1442

1443
        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.tensor(10), torch.randn(10)).code).strip()
1444
        self.assertExpectedInline(r, """\
1445
def forward(self, x1_1, x2_1, y_1):
1446
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x1_1);  x1_1 = None
1447
    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1448
    _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(x2_1);  x2_1 = None
1449
    zeros_1 = torch.ops.aten.zeros.default([_local_scalar_dense_1], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense_1 = None
1450
    add = torch.ops.aten.add.Tensor(y_1, 2);  y_1 = None
1451
    return add""")  # noqa: B950
1452

1453
    def test_split_unbacked_sizes(self):
1454
        def f(lengths, values):
1455
            # tolist not directly supported atm
1456
            sizes = [lengths[i].item() for i in range(lengths.size(0))]
1457
            for s in sizes:
1458
                torch._constrain_as_size(s)
1459
            return torch.split(values, sizes)
1460

1461
        r = str(make_fx(f, tracing_mode="symbolic")(
1462
            torch.tensor([2, 3, 4]),
1463
            torch.randn(9)
1464
        ).code).strip()
1465
        self.assertExpectedInline(r, """\
1466
def forward(self, lengths_1, values_1):
1467
    select = torch.ops.aten.select.int(lengths_1, 0, 0)
1468
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select);  select = None
1469
    select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
1470
    _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1);  select_1 = None
1471
    select_2 = torch.ops.aten.select.int(lengths_1, 0, 2);  lengths_1 = None
1472
    _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2);  select_2 = None
1473
    sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense)
1474
    sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1)
1475
    sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2)
1476
    split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]);  values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
1477
    getitem = split_with_sizes[0]
1478
    getitem_1 = split_with_sizes[1]
1479
    getitem_2 = split_with_sizes[2];  split_with_sizes = None
1480
    return (getitem, getitem_1, getitem_2)""")  # noqa: B950
1481

1482
    def test_invalidate_nonzero(self):
1483
        ok = False
1484

1485
        def f(a):
1486
            nonlocal ok
1487
            b = a.clone()
1488
            x = b.nonzero()
1489
            x1 = b.nonzero()
1490
            x2 = b.nonzero()
1491
            assert x1.shape[0] == x2.shape[0]
1492
            ok = True
1493
            b.normal_()
1494
            y = b.nonzero()
1495
            try:
1496
                bool(x1.shape[0] == y.shape[0])
1497
                self.fail("didn't raise exception")
1498
            except GuardOnDataDependentSymNode:
1499
                pass
1500

1501
        make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1502

1503
    def test_sqrt_size(self):
1504
        def f(a):
1505
            return a / a.size(-1) ** 0.5
1506

1507
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1508
        self.assertExpectedInline(r, """\
1509
def forward(self, a_1):
1510
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1511
    pow_1 = sym_size_int ** 0.5;  sym_size_int = None
1512
    div = torch.ops.aten.div.Tensor(a_1, pow_1);  a_1 = pow_1 = None
1513
    return div""")
1514

1515
    def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self):
1516

1517
        class Bar(torch.nn.Module):
1518
            def __init__(self):
1519
                super().__init__()
1520

1521
            def forward(self, x):
1522
                return x + 1
1523

1524
        class Foo(torch.nn.Module):
1525
            def __init__(self):
1526
                super().__init__()
1527
                self.bar = Bar()
1528

1529
            def forward(self, x):
1530
                return x + self.bar(x)
1531

1532
        gm = make_fx(Foo())(torch.randn(4, 4))
1533
        for node in gm.graph.nodes:
1534
            self.assertTrue("nn_module_stack" not in node.meta)
1535

1536
        foo = Foo()
1537

1538
        def functional_call(*args, **kwargs):
1539
            with stateless._reparametrize_module(foo, {}):
1540
                return foo(*args, **kwargs)
1541

1542
        functional_call._orig_mod = foo
1543

1544
        gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4))
1545
        found = False
1546
        for node in gm_with_stack.graph.nodes:
1547
            if "nn_module_stack" in node.meta:
1548
                if len(node.meta["nn_module_stack"]) == 1:
1549
                    self.assertTrue("custom_tracer_preserving_nn_module_stack.<locals>.Foo" in str(node.meta["nn_module_stack"]))
1550
                    found = True
1551
                elif len(node.meta["nn_module_stack"]) == 2:
1552
                    self.assertTrue("preserving_nn_module_stack.<locals>.Bar" in str(node.meta["nn_module_stack"]))
1553
                    found = True
1554
                else:
1555
                    # there can be at most 2 level
1556
                    self.assertTrue(False)
1557

1558
        self.assertTrue(found)
1559

1560
        gm_without_stack = make_fx(functional_call)(torch.randn(4, 4))
1561
        for node in gm_without_stack.graph.nodes:
1562
            self.assertTrue("nn_module_stack" not in node.meta)
1563

1564
    def test_symint_to_tensor(self):
1565
        def f(a):
1566
            return a / a.shape[0]
1567

1568
        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1569
        self.assertExpectedInline(r, """\
1570
def forward(self, a_1):
1571
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1572
    div = torch.ops.aten.div.Tensor(a_1, sym_size_int);  a_1 = sym_size_int = None
1573
    return div""")
1574

1575
        r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
1576
        self.assertExpectedInline(r, """\
1577
def forward(self, a_1):
1578
    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1579
    sym_float = torch.sym_float(sym_size_int);  sym_size_int = None
1580
    div = torch.ops.prims.div.default(a_1, sym_float);  a_1 = sym_float = None
1581
    return div""")
1582

1583
    def test_cat(self):
1584
        def f(a, b):
1585
            val = torch.mul(a, b)
1586
            out = torch.cat([val, val])
1587
            if out.shape[0] * out.shape[1] > 20:
1588
                out = out.cos()
1589
            return out
1590

1591
        test_inputs = []
1592
        test_inputs.append([(1, 5), (6, 1)])
1593
        test_inputs.append([(1, 4), (3, 1)])
1594
        gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
1595
        self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
1596
        self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
1597
        self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
1598

1599
    def test_new_empty(self):
1600
        def f(a, b):
1601
            return a.new_empty(b.shape[0], b.shape[1] * 2)
1602

1603
        self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env
1604

1605
    def test_size_with_tensor(self):
1606
        # I think I messed up writing this test case originally, I think
1607
        # I'm supposed to hit an error case, but the code here works in both
1608
        # eager and tracing
1609
        def f(tensor):
1610
            max_size = torch.tensor([800, 1216], dtype=torch.int64)
1611
            batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
1612
            return tensor.new_empty(batch_shape)
1613

1614
        a = torch.randn(3, 800, 1199)
1615
        f(a)
1616
        make_fx(f, tracing_mode="symbolic")(a)
1617

1618
    def test_fake_tensor_as_size(self):
1619
        def f(x):
1620
            r = torch.zeros([x])
1621
            return r
1622

1623
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.tensor(4))
1624
        self.assertExpectedInline(fx_g.code.strip(), """\
1625
def forward(self, x_1):
1626
    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1627
    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1628
    return zeros""")  # noqa: B950
1629

1630
    def test_expand(self):
1631
        def f(a):
1632
            b = torch.mul(a, a)
1633
            c = b.expand(a.shape)
1634
            return c
1635

1636
        self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
1637
        self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
1638

1639
    def test_metadata(self):
1640
        def f(a, b):
1641
            d = a.new_empty(a.shape[0] + b.shape[0])
1642
            return d
1643
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
1644
        meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
1645
        meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
1646
        self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr)
1647

1648
    def test_metadata_fresh(self):
1649
        def f(x):
1650
            assert x.shape[0] == 3
1651
            return x.cos()
1652

1653
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
1654
        meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
1655
        meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
1656
        self.assertTrue(meta_cos.meta['val'].shape[0] == 3)
1657
        # Checks if the input expr has been updated even though the constraint
1658
        # happened afterwards
1659
        self.assertTrue(meta_inp.meta['val'].shape[0] == 3)
1660

1661
    def test_elementwise_meta_with_sym_numbers(self):
1662
        def f(x, offset, as_sym_float=False):
1663
            x0 = x.size()[0]
1664
            if as_sym_float:
1665
                x0 = torch.sym_float(x0)
1666
            return torch.add(x0, offset)
1667

1668
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False)
1669
        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1670
        self.assertEqual(meta_add.meta['val'].shape, ())
1671
        self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1672

1673
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False)
1674
        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1675
        self.assertEqual(meta_add.meta['val'].shape, ())
1676
        self.assertEqual(meta_add.meta['val'].dtype, torch.int64)
1677

1678
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True)
1679
        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1680
        self.assertEqual(meta_add.meta['val'].shape, ())
1681
        self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1682

1683
    def test_return_symint(self):
1684
        def f(x):
1685
            return x.shape[0], x.cos(), x.shape[0] / 5
1686
        self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1687

1688
        def f(x):
1689
            return x.shape
1690
        self._test_dynamic(f, [(5, 3)], [[(4, 6)]])
1691

1692
    def test_rmethod(self):
1693
        def f(x):
1694
            return x.size(0) + x
1695
        self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1696

1697
    def test_mega_guard(self):
1698
        def f(a, b):
1699
            assert a.shape[0] == b.shape[0] * 2
1700
            return a.cos()
1701
        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
1702
        from torch._dynamo.source import LocalSource
1703
        self.assertExpectedInline(
1704
            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)),  # noqa: B950
1705
            """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]"""  # noqa: B950
1706
        )
1707
        self.assertExpectedInline(
1708
            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)),  # noqa: B950
1709
            """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]"""  # noqa: B950
1710
        )
1711

1712
    def test_guard_upperbound_range_refinement(self):
1713
        def f(a):
1714
            assert a.shape[0] > 5 and a.shape[0] > 12
1715
            return a.cos()
1716
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1717
        self.assertExpectedInline(show_guards(tensor), """13 <= L['a'].size()[0]""")
1718

1719
    def test_guard_lowerbound_range_refinement(self):
1720
        def f(a):
1721
            assert a.shape[0] < 20 and a.shape[0] < 30
1722
            return a.cos()
1723
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1724
        self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] <= 19""")
1725

1726
    def test_guard_upperbound_range_refinement_multivariate(self):
1727
        def f(a):
1728
            assert a.shape[0] > 5 and a.shape[0] > 12
1729
            assert a.shape[1] > 5 and a.shape[1] > a.shape[0]
1730
            return a.cos()
1731
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20)))
1732
        self.assertExpectedInline(show_guards(tensor), """\
1733
L['a'].size()[1] > L['a'].size()[0]
1734
13 <= L['a'].size()[0]
1735
14 <= L['a'].size()[1]""")
1736

1737
    def test_guard_lowerbound_range_refinement_multivariate(self):
1738
        def f(a):
1739
            assert a.shape[0] < 20 and a.shape[0] < 30
1740
            assert a.shape[1] < 30 and a.shape[1] < a.shape[0]
1741
            return a.cos()
1742
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5)))
1743
        self.assertExpectedInline(
1744
            show_guards(tensor),
1745
            """\
1746
L['a'].size()[1] < L['a'].size()[0]
1747
L['a'].size()[0] <= 19
1748
L['a'].size()[1] <= 18""")
1749

1750
    def test_sym_storage_offset(self):
1751
        def f(x, y):
1752
            return x + y
1753

1754
        inp = (torch.randn(8)[3:], torch.randn(5))
1755
        fx_g = make_fx(f, tracing_mode="symbolic")(*inp)
1756
        inp = (torch.randn(8)[3:], torch.randn(5))
1757
        self.assertEqual(fx_g(*inp), f(*inp))
1758

1759
    def _assert_no_guards(self, fx_g, free_symbols):
1760
        assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
1761
        assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
1762

1763
    def test_guards_equal(self):
1764
        def f(a, b):
1765
            return a * b
1766

1767
        # NB: Numbers are carefully chosen to avoid duck shaping from applying
1768

1769
        fx_g = _trace(f, (5, 6), (5, 6))
1770
        self._assert_no_guards(fx_g, 2)
1771

1772
        fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
1773
        self._assert_no_guards(fx_g, 3)
1774

1775
        fx_g = _trace(f, (5, 1), (1, 6))
1776
        self._assert_no_guards(fx_g, 2)
1777

1778
        def f(a, b, c, d):
1779
            a = a + b
1780
            cat = torch.cat([c, d])
1781
            return a + cat
1782

1783
        fx_g = _trace(f, 7, 7, 4, 3)
1784
        self._assert_no_guards(fx_g, 2)
1785

1786
        def f(a, b, c, d, e):
1787
            vals = [a, b, c, d, e]
1788
            x = a
1789
            for idx in range(len(vals) - 1):
1790
                x = torch.cat([x, vals[idx]]) + vals[idx + 1]
1791
            return x
1792

1793
        fx_g = _trace(f, 2, 4, 8, 16, 32)
1794
        self._assert_no_guards(fx_g, 1)
1795

1796
        def f(a, b):
1797
            a = a.view(b.shape[0])
1798
            return a + b.sum()
1799

1800
        fx_g = _trace(f, (4, 2), 8)
1801
        self._assert_no_guards(fx_g, 2)
1802

1803
        fx_g = _trace(f, (4, 2), (8, 5))
1804
        self._assert_no_guards(fx_g, 3)
1805

1806
        fx_g = _trace(f, (2, 3, 4), 24)
1807
        self._assert_no_guards(fx_g, 3)
1808

1809
    def test_nonidentity_transitive_guards(self):
1810
        def f(a, b, c, d, e):
1811
            vals = [a, b, c, d, e]
1812
            cat_vals = []
1813
            for idx in range(len(vals) - 1):
1814
                cat_vals.append(torch.cat([vals[idx], vals[idx]]))
1815
            final_vals = []
1816
            for a, b in reversed(list(zip(cat_vals, vals[1:]))):
1817
                final_vals.append(a + b)
1818
            return final_vals
1819

1820
        fx_g = _trace(f, 2, 4, 8, 16, 32)
1821
        self.assertExpectedInline(show_guards(fx_g), """""")
1822

1823
    @torch.fx.experimental._config.patch(translation_validation=True)
1824
    def test_constant_specialization(self):
1825
        def f(t):
1826
            assert t.shape[0] == 10
1827
            return t
1828

1829
        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10))
1830
        self.assertExpectedInline(show_guards(tensor), """""")
1831

1832

1833
make_fx_failures = {
1834
    # unknown
1835
    xfail('allclose'),
1836
    xfail('equal'),
1837
    # empty
1838
    skip('new_empty'),
1839
    skip('empty_like'),
1840
    skip('empty'),
1841
    skip('empty_permuted'),
1842
    # flaky
1843
    skip('linalg.lstsq', 'grad_oriented'),
1844
    skip('nn.functional.max_unpool1d', '', device_type='cpu'),
1845
    skip('nn.functional.max_unpool2d', '', device_type='cpu'),
1846
    skip('nn.functional.max_unpool3d', '', device_type='cpu'),
1847
    skip('linalg.lstsq'),  # flaky, probably just a precision issue
1848

1849
    # data-dependent control flow
1850
    skip('item'),
1851
    xfail('cov'),
1852
    xfail('nn.functional.gaussian_nll_loss'),
1853
    xfail('tensor_split'),
1854
    xfail('corrcoef'),
1855
    xfail('quantile'),
1856
    xfail('nanquantile'),
1857
    xfail('narrow'),
1858

1859
    # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
1860
    xfail('sparse.sampled_addmm'),
1861
    xfail('sparse.mm', 'reduce'),
1862

1863
    # proxy tensor doesn't support sparse correctly right now
1864
    skip('to_sparse'),
1865
    # segfaults
1866
    skip('block_diag'),
1867

1868
    # AssertionError: Tensor-likes are not close!
1869
    skip('empty_strided', '', device_type='cpu'),
1870
}
1871

1872
fake_tensor_failures = {
1873
    # ASAN failures due to divide by 0
1874
    skip('nn.functional.nll_loss'),
1875
}
1876

1877
symbolic_tensor_failures = {
1878
    xfail('linalg.eig'),
1879
    xfail('linalg.eigvals'),
1880
    xfail('combinations', ''),
1881
    xfail('geqrf', ''),  # aten.geqrf.default - couldn't find symbolic meta function/decomposition
1882
    xfail('histc', ''),  # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because...
1883
    xfail('histogram', ''),  # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
1884
    xfail('histogramdd', ''),  # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
1885
    xfail('kthvalue', ''),  # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
1886
    xfail('nanquantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
1887
    xfail('narrow', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
1888
    xfail('nn.functional.binary_cross_entropy', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decom...
1889
    xfail('nn.functional.cross_entropy', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
1890
    xfail('nn.functional.ctc_loss'),  # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
1891
    xfail('nn.functional.fractional_max_pool2d', ''),  # argument 'size' must be tuple of ints, but found element of t...
1892
    xfail('nn.functional.fractional_max_pool3d', ''),  # argument 'size' must be tuple of ints, but found element of t...
1893
    xfail('quantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
1894
    xfail('resize_as_', ''),  # aten.clone.default - couldn't find symbolic meta function/decomposition
1895
    xfail('unique_consecutive', ''),  # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
1896
    xfail('unique', ''),  # aten._unique2.default - couldn't find symbolic meta function/decomposition
1897

1898
    xfail('max_pool2d_with_indices_backward', ''),  # Expected a value of type 'List[int]' for argument 'kernel_size' but...
1899

1900
    # many complex operators incorrect striding, metadata
1901
    xfail('fft.fft', ''),
1902
    xfail('fft.hfft2', ''),
1903
    xfail('fft.hfft', ''),
1904
    xfail('fft.hfftn', ''),
1905
    xfail('fft.ifft', ''),
1906
    xfail('fft.ihfft2', ''),
1907
    xfail('fft.ihfft', ''),
1908
    xfail('fft.ihfftn', ''),
1909
    xfail('fft.ihfft2', ''),
1910
    xfail('fft.irfft2', ''),
1911
    xfail('fft.irfft', ''),
1912
    xfail('fft.irfftn', ''),
1913
    xfail('fft.rfft2', ''),
1914
    xfail('fft.rfft', ''),
1915
    xfail('fft.rfftn', ''),
1916
    xfail('stft', '')
1917
}
1918
symbolic_tensor_segfaults = {
1919
    skip('nn.functional.batch_norm')  # Segfault??
1920
}
1921

1922
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
1923

1924
inplace_symbolic_tensor_failures = {
1925
    # bugs
1926
    xfail('float_power', ''),  # base given to float_power_ has dtype Float but the operation's result requires dtype Double
1927
    # decomp not implemented
1928
    xfail('unique', ''),
1929
}
1930

1931
out_symbolic_tensor_failures = {
1932
    xfail('_native_batch_norm_legit', ''),
1933
    xfail('angle', ''),
1934
    xfail('argmax', ''),
1935
    xfail('argmin', ''),
1936
    xfail('bmm', ''),
1937
    xfail('fft.fft2', ''),
1938
    xfail('fft.fftn', ''),
1939
    xfail('fft.ifft2', ''),
1940
    xfail('fft.ifftn', ''),
1941
    xfail('gather', ''),
1942
    xfail('linalg.cholesky', ''),
1943
    xfail('linalg.cholesky_ex', ''),
1944
    xfail('linalg.det', ''),
1945
    xfail('linalg.det', 'singular'),
1946
    xfail('linalg.inv', ''),
1947
    xfail('linalg.inv_ex', ''),
1948
    xfail('linalg.pinv', ''),
1949
    xfail('linalg.pinv', 'hermitian'),
1950
    xfail('linalg.svdvals', ''),
1951
    xfail('lu', ''),
1952
    xfail('max', 'reduction_with_dim'),
1953
    xfail('min', 'reduction_with_dim'),
1954
    xfail('nn.functional.avg_pool2d', ''),
1955
    xfail('nn.functional.linear', ''),
1956
    xfail('scatter_add', ''),
1957
    xfail('scatter', ''),
1958
    xfail('take_along_dim', ''),
1959
    xfail('topk', ''),
1960
    xfail('triangular_solve', ''),
1961
    xfail('view_copy', ''),
1962

1963
    # SymIntArrayRef expected to contain only concrete
1964
    xfail('ones', ''),
1965
    xfail('randn', ''),
1966
    xfail('zeros', ''),
1967
}
1968

1969
out_symbolic_tensor_segfaults = {
1970
    skip('nanmean', ''),
1971
}
1972

1973
out_symbolic_tensor_failures.update(out_symbolic_tensor_segfaults)
1974

1975
# Copies inputs to inplace operations to avoid inplace modifications
1976
#   to leaves requiring gradient
1977
def _get_safe_inplace(inplace_variant):
1978
    @functools.wraps(inplace_variant)
1979
    def _fn(t, *args, **kwargs):
1980
        return inplace_variant(t.clone(), *args, **kwargs)
1981

1982
    return _fn
1983

1984
def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, out=False):
1985
    fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op
1986
    sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1987

1988
    # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
1989
    count = 100
1990
    if out:
1991
        count = 5
1992
    for sample_input in itertools.islice(sample_inputs_itr, count):
1993
        if inplace and sample_input.broadcasts_input:
1994
            continue
1995
        args = [sample_input.input] + list(sample_input.args)
1996
        kwargs = sample_input.kwargs
1997
        if out:
1998
            expected = fn(*args, **kwargs)
1999
            kwargs['out'] = expected
2000

2001
        try:
2002
            optests.make_fx_check(fn, args, kwargs, tracing_mode, self.assertEqual,
2003
                                  randomize_data=True)
2004
        except DynamicOutputShapeException:
2005
            self.skipTest("Dynamic output shape operation in trace")
2006

2007

2008
class TestProxyTensorOpInfo(TestCase):
2009
    @ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
2010
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
2011
    def test_make_fx_exhaustive(self, device, dtype, op):
2012
        _test_make_fx_helper(self, device, dtype, op, "real")
2013

2014
    @ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
2015
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
2016
    def test_make_fx_fake_exhaustive(self, device, dtype, op):
2017
        _test_make_fx_helper(self, device, dtype, op, "fake")
2018

2019
    @ops(op_db + custom_op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,))
2020
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
2021
             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
2022
    def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
2023
        _test_make_fx_helper(self, device, dtype, op, "symbolic")
2024

2025
    @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2026
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
2027
             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)
2028
    def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op):
2029
        if not op.get_inplace():
2030
            self.skipTest("No inplace variable for this op")
2031
        _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True)
2032

2033
    @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2034
    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out',
2035
             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | out_symbolic_tensor_failures)
2036
    def test_make_fx_symbolic_exhaustive_out(self, device, dtype, op):
2037
        if not op.supports_out:
2038
            self.skipTest("Op doesn't support out")
2039
        _test_make_fx_helper(self, device, dtype, op, "symbolic", out=True)
2040

2041

2042
only_for = ("cpu")
2043
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
2044

2045

2046
if __name__ == '__main__':
2047
    run_tests()
2048

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.