pytorch

Форк
0
/
test_fx.py 
4661 строка · 166.4 Кб
1
# Owner(s): ["module: fx"]
2

3
import builtins
4
import contextlib
5
import copy
6
import functools
7
import inspect
8
import math
9
import numbers
10
import io
11
import operator
12
import os
13
import pickle
14
import sys
15
import torch
16
import traceback
17
import typing
18
import types
19
import warnings
20
import unittest
21
from math import sqrt
22
from functorch.experimental import control_flow
23
from torch.multiprocessing import Process
24
from torch.testing import FileCheck
25
from torch.testing._internal.common_methods_invocations import op_db
26
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
27
import torch.utils._pytree as pytree
28
import torch.fx._pytree as fx_pytree
29
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
30
from torch.fx.node import Target, Argument, _format_arg
31
from torch.fx.passes import shape_prop
32
from torch.fx.immutable_collections import immutable_dict, immutable_list
33
from torch.fx.experimental.rewriter import RewritingTracer
34
from torch.fx.operator_schemas import get_signature_for_torch_op
35
from copy import deepcopy
36
from collections import namedtuple
37

38
from torch.fx.proxy import TraceError
39
from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY
40
from torch.fx._symbolic_trace import PHBase, PHWithMeta
41
from fx.test_subgraph_rewriter import TestSubgraphRewriter  # noqa: F401
42
from fx.test_dce_pass import TestDCE  # noqa: F401
43
from fx.test_fx_const_fold import TestConstFold  # noqa: F401
44
from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow  # noqa: F401
45
from fx.test_pass_infra import TestPassManager  # noqa: F401
46
from fx.test_common_passes import TestCommonPass  # noqa: F401
47
from fx.test_cse_pass import TestCSEPass  # noqa: F401
48
from fx.test_matcher_utils import TestMatcher  # noqa: F401
49
from fx.test_source_matcher_utils import TestSourceMatcher  # noqa: F401
50

51
from fx.test_gradual_type import AnnotationsTest  # noqa: F401
52
from fx.test_gradual_type import TypeCheckerTest  # noqa: F401
53
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
54
from torch.testing._internal.common_utils import (
55
    IS_FBCODE,
56
    IS_MACOS,
57
    IS_WINDOWS,
58
    find_library_location,
59
    run_tests,
60
    skipIfTorchDynamo,
61
)
62
from torch.testing._internal.jit_utils import JitTestCase
63

64
from fx.named_tup import MyNamedTup
65

66
try:
67
    from torchvision import models as torchvision_models
68
    HAS_TORCHVISION = True
69
except ImportError:
70
    HAS_TORCHVISION = False
71
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
73

74
class SimpleTest(torch.nn.Module):
75
    def forward(self, x):
76
        return torch.relu(x + 3.0)
77

78
def a_non_torch_leaf(a, b):
79
    return a + b
80

81
# Used for test_autowrap_function. Autowrapped functions need to be global
82
def fx_int(x: float) -> int:
83
    return int(x)
84

85
def fx_int_x2(x: float) -> int:
86
    return int(x) * 2
87

88
# used in test_pytree. It's all the way out here because pickling a GraphModule
89
# that uses Point errors out if Point is local to the function
90
Point = namedtuple('Point', ['x', 'y'])
91

92
# Test wrap() passing both a function name as well as a function
93
# directly
94
def a_lifted_leaf(a, b):
95
    return a[0] + a[1] + b
96

97
wrap('a_lifted_leaf')
98
# Test wrapping twice doesn't break anything
99
wrap('a_lifted_leaf')
100

101
def a_lifted_leaf2(a, b):
102
    return a[0] + a[1] + b
103

104
wrap(a_lifted_leaf2)
105

106
wrap('len')
107

108
wrap('getattr')
109

110
def wrapped_named_tup(p1, *, p2):
111
    return p1.x + p2.y
112

113
wrap(wrapped_named_tup)
114

115
@wrap
116
def wrapped_via_decorator(a):
117
    return a + 1
118

119
wrap('wrapped_with_submodule')
120

121
def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
122
    return batchnorm1d(x)
123

124
def my_decorator(f):
125
    @functools.wraps(f)
126
    def wrapper_inside_decorator(*args, **kwargs):
127
        return f(*args, **kwargs)
128
    return wrapper_inside_decorator
129

130
@wrap
131
@my_decorator
132
def wrapped_decorated_fn(x):
133
    return x
134

135
real_wrapped_via_decorator = wrapped_via_decorator
136
real_a_lifed_leaf = a_lifted_leaf
137
real_a_lifed_leaf2 = a_lifted_leaf2
138
_sqrt = sqrt
139

140
wrap('wrapper_fn')
141

142
def wrapper_fn(x):
143
    return torch.foo(x)
144

145
class Pair(NamedTuple):
146
    x : torch.Tensor
147
    y : torch.Tensor
148

149
    def _custom_fx_repr_fn(self) -> str:
150
        return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
151

152
# for testing pytrees
153
class Foo:  # noqa: B209
154
    def __init__(self, a, b):
155
        self.a = a
156
        self.b = b
157

158
class Add(torch.nn.Module):
159
    def forward(self, x):
160
        return x + x
161

162
@torch.fx.has_side_effect
163
@torch.fx.wrap
164
def side_effect_func(x: torch.Tensor):
165
    print(x)
166

167
class TestFX(JitTestCase):
168
    def setUp(self):
169
        super().setUp()
170
        # Checking for mutable operations whil tracing is feature flagged
171
        # Enable it in testing but not by default
172
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
173
        torch.fx.proxy.TracerBase.check_mutable_operations = True
174

175
        if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
176
            lib_file_path = find_library_location('libtorchbind_test.so')
177
            torch.ops.load_library(str(lib_file_path))
178

179
    def tearDown(self):
180
        super().tearDown()
181
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
182

183
    def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
184
        """Check that an nn.Module's results match the GraphModule version
185
        for a given set of args/kwargs.
186
        """
187
        kwargs = kwargs if kwargs else {}
188
        ref_outs = m(*args, **kwargs)
189
        gm = symbolic_trace(m)
190
        gm.graph.lint()
191
        test_outs = gm(*args, **kwargs)
192
        self.assertEqual(ref_outs, test_outs)
193

194
    def test_graph_module(self):
195
        class MySub(torch.nn.Module):
196
            def __init__(self) -> None:
197
                super().__init__()
198
                self.w = torch.nn.Parameter(torch.rand(4, 3))
199

200
            def forward(self, x):
201
                return self.w + x
202

203
        class MyModule(torch.nn.Module):
204
            def __init__(self) -> None:
205
                super().__init__()
206
                self.lin = torch.nn.Linear(4, 3)
207
                self.sub_mod = MySub()
208
                self.w = torch.nn.Parameter(torch.rand(3))
209

210
            def forward(self, A, B, c):
211
                t = torch.sigmoid(A) + self.lin(c)
212
                return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
213

214
        m = MyModule()
215
        gm = symbolic_trace(m)
216

217
        ms = torch.jit.script(gm)
218

219
        class M2(torch.nn.Module):
220
            def forward(self, A):
221
                m, idx = torch.max(A, 0)
222
                return m + 1, idx + 1
223

224
        m2 = M2()
225
        gm2 = symbolic_trace(m2)
226

227
        class T(torch.nn.Module):
228

229
            def forward(self, A, b=4, *args, c=5, **kwargs):
230
                x = A + 1 + args[0] + kwargs['3']
231
                return x
232

233
        t = T()
234
        symbolic_trace(t)
235

236
        # test for issue described at https://github.com/pytorch/pytorch/issues/63883
237
        class M3(torch.nn.Module):
238
            def forward(self, x):
239
                return torch.relu(x)
240

241
        m3 = M3()
242
        gm3 = symbolic_trace(m3)
243
        new_instance = gm3.__new__(type(gm3))
244
        new_instance.__init__(gm3, gm3.graph)
245

246
        x = torch.randn(5, 3)
247
        torch.testing.assert_close(new_instance(x), torch.relu(x))
248

249
    def test_informative_co_filename(self):
250
        class MyModule(torch.nn.Module):
251
            def forward(self, a):
252
                return a * 2
253

254
        gm = symbolic_trace(MyModule())
255
        self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename)
256

257
    def test_custom_import(self):
258
        graph = torch.fx.Graph()
259
        a = graph.placeholder('x')
260
        b = graph.placeholder('y')
261
        c = graph.call_function(a_non_torch_leaf, (a, b))
262
        d = graph.call_function(torch.sin, (c,))
263
        graph.output(d)
264
        gm = GraphModule(torch.nn.Module(), graph)
265
        x, y = torch.rand(1), torch.rand(1)
266
        self.assertEqual(torch.sin(x + y), gm(x, y))
267

268
    def test_args_kwargs(self):
269
        class T(torch.nn.Module):
270
            def forward(self, *args, **kwargs):
271
                x = args[0] + kwargs['foo']
272
                return x
273

274
        t = T()
275
        self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
276

277
    def test_varargs_concrete(self):
278
        class T(torch.nn.Module):
279
            def forward(self, *args, **kwargs):
280
                x = args[0] + args[1]
281
                return x
282

283
        args = (torch.rand(1), torch.rand(1))
284

285
        t = T()
286
        ref_outs = t(*args)
287
        gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
288
        gm.graph.lint()
289
        test_outs = gm(*args)
290
        self.assertEqual(ref_outs, test_outs)
291

292
    def test_args_kwargs_no_self(self):
293
        class T(torch.nn.Module):
294
            def forward(*args, **kwargs):  # noqa: B902
295
                self = args[0]
296
                return torch.relu(args[1])
297

298
        t = T()
299
        with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
300
            self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
301

302
    def test_fx_shifts(self):
303
        class MyModule(torch.nn.Module):
304
            def forward(self, x):
305
                return x << 3, x >> 3
306

307
        input = torch.LongTensor(10).random_(0, 1024)
308

309
        m = MyModule()
310
        self.checkGraphModule(m, (input,))
311

312
    def test_fx_and_or(self):
313
        class MyModule(torch.nn.Module):
314
            def forward(self, x):
315
                return x & x, x | x
316

317
        input = torch.LongTensor(10).random_(0, 1024)
318

319
        m = MyModule()
320
        self.checkGraphModule(m, (input,))
321

322
    def test_dict(self):
323
        class MyDictMod(torch.nn.Module):
324
            def forward(self, d):
325
                return d['3'].relu(), {'4' : d['3'].neg()}
326

327
        input_dict = {'3': torch.rand(3, 4)}
328
        m = MyDictMod()
329

330
        self.checkGraphModule(m, (input_dict,))
331

332
    def test_matmul_tracing(self):
333
        const = torch.randn(3)
334

335
        def matmul_f(x):
336
            return x @ const
337

338
        mod = symbolic_trace(matmul_f)
339
        inp = torch.randn(3)
340
        self.assertEqual(mod(inp), matmul_f(inp))
341

342
        def rmatmul_f(x):
343
            return const @ x
344

345
        mod = symbolic_trace(rmatmul_f)
346
        inp = torch.randn(3)
347
        self.assertEqual(mod(inp), rmatmul_f(inp))
348

349
    @skipIfNoDynamoSupport
350
    def test_control_flow_tracing(self):
351
        def true(x, y):
352
            return x + y
353

354
        def false(x, y):
355
            return x - y
356

357
        def f(x, y):
358
            x = control_flow.cond(x[0] == 0, true, false, [x, y])
359

360
        with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"):
361
            _ = symbolic_trace(f)
362

363
    def test_disallow_override(self):
364
        # Custom delegate to disallow in-place tensor operations
365
        class NoMutableCallTracer(Tracer):
366
            def create_node(self, kind : str, target : Union[str, Callable],
367
                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
368
                            type_expr : Optional[Any] = None) -> Node:
369
                name = target if isinstance(target, str) else torch.typename(target)
370
                if name[-1] == '_':
371
                    raise RuntimeError('In-place operations are not supported')
372
                return super().create_node(kind, target, args, kwargs, name)
373

374
        # Test method
375
        class MyInplaceMod(torch.nn.Module):
376
            def forward(self, x):
377
                x.add_(3.0)
378
                return x
379

380
        m = MyInplaceMod()
381

382
        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
383
            NoMutableCallTracer().trace(m)
384

385
        # Test free function
386
        class MyInplaceMod2(torch.nn.Module):
387
            def forward(self, x):
388
                torch.log_(x)
389
                return x
390
        m2 = MyInplaceMod2()
391
        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
392
            NoMutableCallTracer().trace(m2)
393

394
        # Test symbolic node as an arg
395
        class MyInplaceMod3(torch.nn.Module):
396
            def forward(self, x):
397
                y = torch.ones(3, 4)
398
                y.add_(x)
399
                return x
400
        m3 = MyInplaceMod3()
401
        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
402
            NoMutableCallTracer().trace(m3)
403

404
    def test_leaf_module(self):
405
        # Custom delegate to make it so that there are no leaf modules, everything
406
        # should get traced through
407
        class NoLeafModulesTracer(Tracer):
408
            def is_leaf_module(self, m, qualname):
409
                return False
410

411
        class MyReluMod(torch.nn.Module):
412
            def __init__(self) -> None:
413
                super().__init__()
414
                self.relu = torch.nn.ReLU()
415

416
            def forward(self, x):
417
                return self.relu(x)
418

419
        mrm = MyReluMod()
420
        sym = NoLeafModulesTracer().trace(mrm)
421
        for node in sym.nodes:
422
            self.assertNotEqual(node.op, 'call_module')
423
        sym.lint()
424

425
    def test_wrap(self):
426
        self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
427

428
        def to_trace(y):
429
            return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
430

431
        m = symbolic_trace(to_trace)
432
        self.assertIn('a_lifted_leaf', m.code)
433
        self.assertEqual(27, m(2))
434
        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
435

436
    def test_wrap_fn_directly(self):
437
        self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
438

439
        def to_trace(y):
440
            return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
441

442
        m = symbolic_trace(to_trace)
443
        self.assertIn('a_lifted_leaf2', m.code)
444
        self.assertEqual(27, m(2))
445
        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
446

447
    def test_wrapped_via_decorator(self):
448
        self.assertEqual(wrapped_via_decorator(0), 1)
449

450
        def to_trace(y):
451
            return wrapped_via_decorator(y)
452

453
        m = symbolic_trace(to_trace)
454
        self.assertIn('wrapped_via_decorator', m.code)
455
        self.assertEqual(m(0), 1)
456
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
457
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
458

459
    def test_wrapped_via_decorator_and_transformed(self):
460
        self.assertEqual(wrapped_via_decorator(0), 1)
461

462
        def to_trace(y):
463
            return wrapped_via_decorator(y)
464

465
        m = symbolic_trace(to_trace)
466
        self.assertIn('wrapped_via_decorator', m.code)
467
        self.assertEqual(m(0), 1)
468
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
469
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
470

471
        transformed = torch.fx.Transformer(m).transform()
472
        self.assertIn('wrapped_via_decorator', transformed.code)
473
        self.assertEqual(transformed(0), 1)
474
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
475
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
476

477
    def test_wrap_with_submodule(self):
478

479
        class M(torch.nn.Module):
480
            def __init__(self) -> None:
481
                super().__init__()
482
                self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
483

484
            def forward(self, x: torch.Tensor):
485
                return wrapped_with_submodule(x, self.batchnorm1d)
486

487
        m = symbolic_trace(M())
488

489
        self.assertIn("wrapped_with_submodule", m.code)
490

491
        input = torch.rand(3, 2)
492
        ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
493
        self.assertEqual(ref_batchnorm1d(input), m(input))
494

495
    def test_wrapped_retrace(self):
496
        def to_trace(y):
497
            return wrapped_via_decorator(y)
498

499
        m = symbolic_trace(to_trace)
500
        self.assertIn('wrapped_via_decorator', m.code)
501
        self.assertEqual(m(0), 1)
502

503
        retraced = symbolic_trace(m)
504
        self.assertIn('wrapped_via_decorator', retraced.code)
505
        self.assertEqual(retraced(0), 1)
506

507
    def test_wrap_decorated_function(self):
508
        def to_trace(y):
509
            return wrapped_decorated_fn(y)
510

511
        m = symbolic_trace(to_trace)
512
        self.assertIn('wrapped_decorated_fn', m.code)
513
        self.assertEqual(m(1), 1)
514

515
    def test_graph_edit_with_proxy(self):
516
        class M(torch.nn.Module):
517
            def forward(self, a, b):
518
                return a + b
519
        m = M()
520
        g = symbolic_trace(m).graph
521
        new_g = torch.fx.Graph()
522
        val_map : Dict[Node, Node] = {}
523
        output_val = new_g.graph_copy(g, val_map)
524
        t = Proxy(output_val)
525
        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
526
        new_g.output((t + t).node)
527
        gm = GraphModule(m, new_g)
528
        gm.graph.lint()
529
        self.assertEqual(gm(3, 4), 14)
530

531
    def test_proxy_deepcopy_without_tracer(self):
532
        class MyModule(torch.nn.Module):
533
            def __init__(self):
534
                super().__init__()
535

536
            def forward(self, x):
537
                return 2 * x
538

539
        module = MyModule()
540
        traced = symbolic_trace(module)
541
        node = list(traced.graph.nodes)[-2]
542
        p = torch.fx.Proxy(node, None)
543
        node.proxy = p
544
        p2 = copy.deepcopy(p)
545
        self.assertTrue(isinstance(p2, torch.fx.Proxy))
546
        self.assertEqual(p2.node.name, node.name)
547
        self.assertEqual(p2.node.target, node.target)
548
        self.assertNotEqual(id(p2.node), id(node))
549

550
    def test_proxy_deepcopy_with_tracer(self):
551
        class TestTracer(Tracer):
552
            def __init__(self, name):
553
                super().__init__()
554
                self.name = name
555

556
            def is_leaf_module(self, module, name):
557
                return True
558

559
        class MyModule(torch.nn.Module):
560
            def __init__(self):
561
                super().__init__()
562

563
            def forward(self, x):
564
                return 2 * x
565

566
        module = MyModule()
567
        tracer = TestTracer("mytracer")
568
        traced = symbolic_trace(module)
569
        node = list(traced.graph.nodes)[-2]
570
        p = torch.fx.Proxy(node, tracer)
571
        node.proxy = p
572
        p2 = copy.deepcopy(p)
573
        self.assertTrue(isinstance(p2, torch.fx.Proxy))
574
        self.assertTrue(isinstance(p2.tracer, torch.fx._symbolic_trace.Tracer))
575
        self.assertEqual(p2.tracer.name, "mytracer")
576
        self.assertEqual(p2.node.name, node.name)
577
        self.assertEqual(p2.node.target, node.target)
578
        self.assertNotEqual(id(p2.node), id(node))
579
        self.assertNotEqual(id(p2.tracer), id(tracer))
580

581
    def test_concrete_arg_none_assert(self):
582
        class Foo(torch.nn.Module):
583
            def forward(self, x, val=None):
584
                return x if val is None else x + val
585

586
        f = Foo()
587
        traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None})
588
        with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'):
589
            traced(torch.randn(5), torch.randn(5))
590

591
        x = torch.randn(5)
592
        torch.testing.assert_close(traced(x), f(x))
593

594
    def test_trace_multiple_funcs(self):
595
        class Foo(torch.nn.Module):
596
            def forward(self, x, y):
597
                return x + y
598

599
            def minus_forward(self, x, y):
600
                return x - y
601

602
            def multiply_forward(self, x, y):
603
                return x * y
604

605
        f = Foo()
606
        x, y = torch.randn(5), torch.randn(5)
607

608
        print(torch.__version__)
609

610
        tracer = Tracer()
611
        torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y))
612

613
        tracer.traced_func_name = "minus_forward"
614
        torch.testing.assert_close(
615
            GraphModule(f, tracer.trace(f))(x, y),
616
            f.minus_forward(x, y),
617
        )
618

619
        tracer.traced_func_name = "multiply_forward"
620
        torch.testing.assert_close(
621
            GraphModule(f, tracer.trace(f))(x, y),
622
            f.multiply_forward(x, y),
623
        )
624

625
        tracer.traced_func_name = "add_forward"
626
        with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
627
            tracer.trace(f)
628

629
    def test_graph_unique_names(self):
630
        class M(torch.nn.Module):
631
            def forward(self, a, b):
632
                return a + b
633
        m = M()
634
        g = symbolic_trace(m).graph
635
        new_g = torch.fx.Graph()
636
        val_map : Dict[Node, Node] = {}
637
        output_val = new_g.graph_copy(g, val_map)
638
        t = Proxy(output_val)
639
        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
640
        new_g.output((t + t).node)
641
        gm = GraphModule(m, new_g)
642
        seen_names : Set[str] = set()
643
        for node in gm.graph.nodes:
644
            assert node.name not in seen_names
645
            seen_names.add(node.name)
646

647
    def test_stack_traces(self):
648
        class M(torch.nn.Module):
649
            def forward(self, a, b):
650
                return a + b
651

652
        tracer = torch.fx.Tracer()
653
        tracer.record_stack_traces = True
654

655
        graph = tracer.trace(M())
656
        # saving the original list because we will insert new nodes as a part of a test
657
        orig_graph_nodes = list(graph.nodes)
658
        for node in orig_graph_nodes:
659
            if node.op == 'output':
660
                continue
661
            self.assertTrue(node.stack_trace is not None)
662
            assert 'test_fx.py' in node.stack_trace
663

664
            # verify that copying the node does not lose the stack trace
665
            new_node = graph.node_copy(node)
666
            self.assertTrue(new_node.stack_trace is not None)
667
            assert 'test_fx.py' in new_node.stack_trace
668

669
    def test_stack_traces_with_transformer(self):
670
        class M(torch.nn.Module):
671
            def forward(self, a, b):
672
                return a + b
673

674
        tracer = torch.fx.Tracer()
675
        tracer.record_stack_traces = True
676

677
        graph = tracer.trace(M())
678
        gm = GraphModule(tracer.root, graph)
679
        new_gm = Transformer(gm).transform()
680

681
        # nodes after Transformer should still preserve the original node's stack trace
682
        for node in new_gm.graph.nodes:
683
            if node.op in {'placeholder', 'output'}:
684
                continue
685
            self.assertTrue(node.stack_trace is not None)
686
            assert 'test_fx.py' in node.stack_trace
687

688
    def test_lineno_map(self):
689
        class M(torch.nn.Module):
690
            def forward(self, a, b):
691
                a = torch.sin(a)
692
                b = torch.cos(b)
693
                return a + b
694

695
        tracer = torch.fx.Tracer()
696
        graph = tracer.trace(M())
697
        gm = GraphModule(tracer.root, graph)
698
        expected = {1: 2, 2: 3, 3: 4, 4: 5}
699
        self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
700

701
        # test custom codegen
702
        def transform_code(code):
703
            return ["print('hello!')\n", *code]
704
        gm.graph.on_generate_code(lambda _: transform_code)
705
        gm.recompile()
706
        expected = {2: 2, 3: 3, 4: 4, 5: 5}
707
        self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
708

709
    def test_graph_unique_names_manual(self):
710
        graph : torch.fx.Graph = torch.fx.Graph()
711
        a : torch.fx.Node = graph.create_node('placeholder', 'x')
712
        b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
713
        c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
714
        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
715
        graph.output(d)
716
        graph2 = torch.fx.Graph()
717
        val_map : Dict[Node, Node] = {}
718
        graph2.graph_copy(graph, val_map)
719
        seen_names : Set[str] = set()
720
        for node in graph2.nodes:
721
            assert node.name not in seen_names
722
            seen_names.add(node.name)
723

724
    def test_unpack(self):
725
        class M(torch.nn.Module):
726
            def forward(self, a, b):
727
                c, d = a
728
                return c + d + b
729

730
        a = (torch.rand(1), torch.rand(1))
731
        b = torch.rand(1)
732
        m = M()
733
        self.checkGraphModule(m, (a, b))
734

735
    def test_native_callable(self):
736
        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
737
            raise unittest.SkipTest("non-portable load_library call used in test")
738
        # This test exercises the case where we use FX to translate from Python
739
        # code to some native callable object
740
        #
741
        # For the purposes of testing, we use ElementwiseInterpreter defined
742
        # in test_custom_class.cpp.
743
        #
744
        # We test that we can
745
        # 1) Construct a native callable from FX IR
746
        # 2) Construct a drop-in replacement module that delegates to the
747
        #    native callable rather than the original code
748
        # 3) Run both the original code and native callable wrapper with
749
        #    equivalent results
750
        # 4) TorchScript compile the native callable wrapper and confirm
751
        #    equivalent results with the reference
752
        # 5) TorchScript serialize and deserialize the native callable
753
        #    and confirm equivalent results with the reference
754

755
        # We use this simple Module as a reference computation
756
        class MySimpleMod(torch.nn.Module):
757
            def forward(self, x):
758
                return 3.0 * x + x
759

760
        msm = MySimpleMod()
761

762
        # This is what a lowering pass might look like: a function that takes
763
        # a valid nn.Module, symbolically traces it, lowers the Module to some
764
        # representation, and wraps that representation up into another
765
        # nn.Module instance that handles dispatch to the compiled/lowered code.
766
        def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
767
            # ===== Stage 1: Symbolic trace the module =====
768
            mod = symbolic_trace(orig_mod)
769

770
            # ===== Stage 2: Lower GraphModule representation to the C++
771
            #       interpreter's instruction format ======
772
            instructions = []
773
            constant_idx = 0
774
            constants = {}
775
            fn_input_names = []
776

777
            target_to_name = {
778
                operator.add : "add",
779
                operator.mul : "mul"
780
            }
781

782
            output_node : Optional[Node] = None
783
            # For each instruction, create a triple
784
            # (instruction_name : str, inputs : List[str], output : str)
785
            # to feed into the C++ interpreter
786
            for n in mod.graph.nodes:
787
                target, args, out_name = n.target, n.args, n.name
788
                assert len(n.kwargs) == 0, "kwargs currently not supported"
789

790
                if n.op == 'placeholder':
791
                    # Placeholders specify function argument names. Save these
792
                    # for later when we generate the wrapper GraphModule
793
                    fn_input_names.append(target)
794
                elif n.op == 'call_function':
795
                    assert target in target_to_name, "Unsupported call target " + target
796
                    arg_names = []
797
                    for arg in args:
798
                        if not isinstance(arg, Node):
799
                            # Pull out constants. These constants will later be
800
                            # fed to the interpreter C++ object via add_constant()
801
                            arg_name = f'constant_{constant_idx}'
802
                            constants[arg_name] = torch.tensor(
803
                                [arg] if isinstance(arg, numbers.Number) else arg)
804
                            arg_names.append(arg_name)
805
                            constant_idx += 1
806
                        else:
807
                            arg_names.append(arg.name)
808
                    instructions.append((target_to_name[target], arg_names, out_name))
809
                elif n.op == 'output':
810
                    if output_node is not None:
811
                        raise RuntimeError('Multiple output nodes!')
812
                    output_node = n
813
                else:
814
                    raise RuntimeError('Unsupported opcode ' + n.op)
815

816
            interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
817
            # Load constants
818
            for k, v in constants.items():
819
                interpreter.add_constant(k, v)
820
            # Specify names for positional input arguments
821
            interpreter.set_input_names(fn_input_names)
822
            # Load instructions
823
            interpreter.set_instructions(instructions)
824
            # Specify name for single output
825
            assert isinstance(output_node.args[0], torch.fx.Node)
826
            interpreter.set_output_name(output_node.args[0].name)
827

828
            # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
829
            class WrapperModule(torch.nn.Module):
830
                def __init__(self, interpreter):
831
                    super().__init__()
832
                    self.interpreter = interpreter
833

834
            wrapper = WrapperModule(interpreter)
835

836
            # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
837
            # 3) Returns the speficied return value
838

839
            # FIXME: The following code could be greatly simplified by symbolic_trace'ing
840
            # the wrapper with a Tracer that considers the Wrapper instance a root
841
            # module, however, I can't get `__call__` exposed on TorchBind classes
842
            # without it messing up Python `hasattr` for some reason. More digging
843
            # into CPython's implementation of hasattr is probably in order...
844

845
            graph = torch.fx.Graph()
846
            # Add placeholders for fn inputs
847
            placeholder_nodes = []
848
            for name in fn_input_names:
849
                placeholder_nodes.append(graph.create_node('placeholder', name))
850

851
            # Get the interpreter object
852
            interpreter_node = graph.create_node('get_attr', 'interpreter')
853

854
            # Add a node to call the interpreter instance
855
            output_node = graph.create_node(
856
                op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
857

858
            # Register output
859
            graph.output(output_node)
860

861
            graph.lint()
862

863
            # Return final GraphModule!!!
864
            return GraphModule(wrapper, graph)
865

866
        # Lower GraphModule to C++ interpreter
867
        lowered = lower_to_elementwise_interpreter(msm)
868

869
        # Compare correctness with original module
870
        x = torch.rand(3, 4)
871
        ref_out = msm(x)
872
        test_out = lowered(x)
873
        torch.testing.assert_close(test_out, ref_out)
874

875
        # Test TorchScript compilation
876
        scripted_lowered = torch.jit.script(lowered)
877
        script_out = scripted_lowered(x)
878
        torch.testing.assert_close(script_out, ref_out)
879

880
        # Test TorchScript ser/de
881
        import_copy = self.getExportImportCopy(scripted_lowered)
882
        imported_out = import_copy(x)
883
        torch.testing.assert_close(imported_out, ref_out)
884

885
    def test_reserved_getattr(self):
886
        """Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
887
        class M(torch.nn.Module):
888
            def forward(self, a):
889
                return a.foo.bar.baz
890

891
        m = M()
892
        m_g = symbolic_trace(m)
893
        m_g.graph.lint()
894
        for node in m_g.graph.nodes:
895
            self.assertTrue(node.name != "getattr")
896

897
    @unittest.skip("Hotfix for SEV remediation")
898
    def test_trace_buffer_slice(self):
899
        bs, d_hid = 10, 23
900

901
        class ExampleCode(torch.nn.Module):
902
            def __init__(self) -> None:
903
                super().__init__()
904
                self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
905
                self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
906
                self.lin = torch.nn.Linear(d_hid, d_hid)
907
                self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid))
908

909
            def forward(self, x):
910
                x = torch.mm(x, self.mm_param)
911
                skip_connection = x
912
                x = torch.relu(x)
913
                x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]]
914
                x = self.lin(x)
915
                x = torch.relu(x)
916
                x = x + skip_connection
917
                x = torch.mm(x, self.mm_param2)
918
                x = self.lin(x)
919
                return x
920

921
        ec = ExampleCode()
922

923
        traced = torch.fx.symbolic_trace(ec)
924

925
        x = torch.randn(bs, d_hid)
926
        torch.testing.assert_close(ec(x), traced(x))
927

928
    def test_node_tagging(self):
929
        class TaggingTracer(Tracer):
930
            def create_node(self, kind : str, target : Union[str, Callable],
931
                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
932
                            type_expr : Optional[Any] = None) -> Node:
933
                n = super().create_node(kind, target, args, kwargs, name)
934
                n.tag = 'foo'
935
                return n
936

937
        class M(torch.nn.Module):
938
            def forward(self, a, b):
939
                return a + b
940

941
        m = M()
942
        g = TaggingTracer().trace(m)
943
        g.lint()
944
        for n in g.nodes:
945
            self.assertTrue(hasattr(n, 'tag'))
946
            self.assertEqual(n.tag, 'foo')
947

948
    def test_tensor_attribute(self):
949
        class TensorAttribute(torch.nn.Module):
950
            def __init__(self) -> None:
951
                super().__init__()
952
                self.tensor = torch.rand(3, 4)
953

954
            def forward(self, x):
955
                return torch.nn.functional.linear(x, self.tensor)
956

957
        ta = TensorAttribute()
958
        traced = symbolic_trace(ta)
959
        traced(torch.rand(4, 4))
960

961
        class WrapperForQualname(torch.nn.Module):
962
            def __init__(self) -> None:
963
                super().__init__()
964
                self.ta = TensorAttribute()
965

966
            def forward(self, x):
967
                return torch.nn.functional.linear(x, self.ta.tensor)
968

969
        wfq = WrapperForQualname()
970
        traced2 = symbolic_trace(wfq)
971
        traced2.graph.lint()
972
        traced2(torch.rand(4, 4))
973

974
    def test_tensor_attribute_coalseced(self):
975

976
        def count_attrs(fx_module):
977
            targets = set()
978
            for node in traced.graph.nodes:
979
                if node.op == 'get_attr':
980
                    targets.add(node.target)
981
            return len(targets)
982

983
        val = torch.tensor(5)
984

985
        def f(x):
986
            return x + val + val
987
        traced = symbolic_trace(f)
988
        traced.graph.lint()
989
        self.assertEqual(count_attrs(traced), 1)
990

991
        val2 = torch.tensor(5)
992

993
        def f(x):
994
            val = torch.tensor(5)
995
            return x + val + val2
996

997
        traced = symbolic_trace(f)
998
        traced.graph.lint()
999
        self.assertEqual(count_attrs(traced), 2)
1000

1001
    def test_symbolic_trace_sequential(self):
1002
        class Simple(torch.nn.Module):
1003
            def forward(self, x):
1004
                return torch.neg(x)
1005

1006
        seq = torch.nn.Sequential(
1007
            Simple(),
1008
            Simple(),
1009
            Simple()
1010
        )
1011
        traced = symbolic_trace(seq)
1012
        traced.graph.lint()
1013
        x = torch.rand(3, 4)
1014
        self.assertEqual(traced(x), seq(x))
1015

1016
    def test_tensor_constant(self):
1017
        class ConstTensor(torch.nn.Module):
1018
            def forward(self, x):
1019
                return torch.nn.functional.linear(x, torch.zeros(3, 4))
1020

1021
        ct = ConstTensor()
1022
        traced = symbolic_trace(ct)
1023
        traced.graph.lint()
1024
        traced(torch.rand(4, 4))
1025

1026
    def test_pickle_graphmodule(self):
1027
        class Nested(torch.nn.Module):
1028
            def __init__(self) -> None:
1029
                super().__init__()
1030
                self.st = torch.nn.Linear(4, 4)
1031

1032
            def forward(self, x):
1033
                return self.st(x)
1034

1035
        n = Nested()
1036
        traced = symbolic_trace(n)
1037
        traced.graph.lint()
1038
        pickled = pickle.dumps(traced)
1039
        loaded = pickle.loads(pickled)
1040
        loaded.graph.lint()
1041
        x = torch.rand(3, 4)
1042
        self.assertEqual(loaded(x), traced(x))
1043

1044
    def test_pickle_custom_import(self):
1045
        graph = torch.fx.Graph()
1046
        a = graph.placeholder('x')
1047
        b = graph.placeholder('y')
1048
        c = graph.call_function(a_non_torch_leaf, (a, b))
1049
        d = graph.call_function(torch.sin, (c,))
1050
        graph.output(d)
1051
        gm = GraphModule(torch.nn.Module(), graph)
1052
        pickled = pickle.dumps(gm)
1053
        loaded = pickle.loads(pickled)
1054
        loaded.graph.lint()
1055
        x, y = torch.rand(1), torch.rand(1)
1056
        self.assertEqual(loaded(x, y), gm(x, y))
1057

1058
    def test_all_input_nodes(self):
1059
        graph : torch.fx.Graph = torch.fx.Graph()
1060
        a : torch.fx.Node = graph.placeholder('x')
1061
        b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
1062
        c : torch.fx.Node = graph.get_attr('y_attr')
1063
        d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
1064
        e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
1065
        graph.output(e)
1066
        graph.lint()
1067

1068
        self.assertEqual(b.all_input_nodes, [a])
1069
        self.assertEqual(c.all_input_nodes, [])
1070
        self.assertEqual(d.all_input_nodes, [b, c])
1071
        self.assertEqual(e.all_input_nodes, [d])
1072

1073
    def test_deepcopy_graphmodule_with_transform(self):
1074
        st = SimpleTest()
1075
        traced = symbolic_trace(st)
1076
        traced.graph.lint()
1077

1078
        def transform(traced):
1079
            new_graph = torch.fx.Graph()
1080
            val_map : Dict[Node, Node] = {}
1081
            output_value = new_graph.graph_copy(traced.graph, val_map)
1082
            relu_out = new_graph.create_node(
1083
                op='call_method', target='neg', args=(output_value,), kwargs={})
1084
            new_graph.output(relu_out)
1085
            return GraphModule(traced, new_graph)
1086
        transformed = transform(traced)
1087
        transformed.graph.lint()
1088
        copied = copy.deepcopy(transformed)
1089
        self.assertNotEqual(id(type(transformed)), id(type(copied)))
1090
        x = torch.randn(3, 4)
1091
        self.assertEqual(copied(x), transformed(x))
1092

1093
    def test_deepcopy_with_submods_params(self):
1094
        class Bar(torch.nn.Module):
1095
            def __init__(self) -> None:
1096
                super().__init__()
1097
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1098

1099
            def forward(self, x):
1100
                return torch.relu(x) + self.param
1101

1102
        class Baz(torch.nn.Module):
1103
            def __init__(self) -> None:
1104
                super().__init__()
1105
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1106
                self.bar = Bar()
1107

1108
            def forward(self, x):
1109
                return self.bar(x) - self.param
1110

1111
        baz = Baz()
1112
        traced = symbolic_trace(baz)
1113
        traced.graph.lint()
1114
        copied = copy.deepcopy(traced)
1115
        copied.graph.lint()
1116

1117
    def test_deepcopy_graph_with_tracer_cls(self):
1118
        class TestTracer(Tracer):
1119
            def is_leaf_module(self, module, name):
1120
                return True
1121

1122
        g = Graph(tracer_cls=TestTracer)
1123
        x = g.placeholder("x")
1124
        g.output(x)
1125

1126
        h = copy.deepcopy(g)
1127
        self.assertIsNotNone(h._tracer_cls)
1128
        self.assertTrue(g._tracer_cls == h._tracer_cls)
1129

1130
    def test_unpack_list_better_error(self):
1131
        class SomeArgs(torch.nn.Module):
1132
            def forward(self, a, b):
1133
                return torch.rand(3, 4)
1134

1135
        class UnpacksList(torch.nn.Module):
1136
            def __init__(self) -> None:
1137
                super().__init__()
1138
                self.sa = SomeArgs()
1139

1140
            def forward(self, x : list):
1141
                return self.sa(*x)
1142

1143
        ul = UnpacksList()
1144
        with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1145
            symbolic_trace(ul)
1146

1147
    def test_unpack_dict_better_error(self):
1148
        class SomeKwargs(torch.nn.Module):
1149
            def forward(self, x=3, y=4):
1150
                return torch.rand(3, 4)
1151

1152
        class UnpacksDict(torch.nn.Module):
1153
            def __init__(self) -> None:
1154
                super().__init__()
1155
                self.sk = SomeKwargs()
1156

1157
            def forward(self, x : dict):
1158
                return self.sk(**x)
1159

1160
        ud = UnpacksDict()
1161
        with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1162
            symbolic_trace(ud)
1163

1164
    def test_pretty_print_targets(self):
1165
        # Test that Graph pretty-print prints friendly name for targets
1166
        # in `operator` and `builtins`
1167

1168
        class SomeMod(torch.nn.Module):
1169
            def forward(self, x):
1170
                return torch.add(x.foo + x.bar, 3.0)
1171

1172
        traced = symbolic_trace(SomeMod())
1173
        graph_str = str(traced.graph)
1174
        self.assertIn('builtins.getattr', graph_str)
1175
        self.assertIn('operator.add', graph_str)
1176
        self.assertIn('torch.add', graph_str)
1177

1178
    def test_pretty_print_node(self):
1179
        class M(torch.nn.Module):
1180
            def __init__(self) -> None:
1181
                super().__init__()
1182
                self.param: torch.nn.Parameter = torch.nn.Parameter(
1183
                    torch.rand(3, 4))
1184
                self.linear = torch.nn.Linear(4, 5)
1185

1186
            def forward(self, x: torch.Tensor, y: int = 2):
1187
                return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
1188

1189
        traced = symbolic_trace(M())
1190

1191
        all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
1192

1193
        FileCheck().check("x").check("placeholder") \
1194
            .check("y").check("placeholder") \
1195
            .check("getitem").check("call_function") \
1196
            .check("param").check("get_attr") \
1197
            .check("add").check("call_function") \
1198
            .check("linear").check("call_module") \
1199
            .check("clamp").check("call_method") \
1200
            .run(all_formatted)
1201

1202
    def test_script_tensor_constant(self):
1203
        # TorchScript seems to ignore attributes that start with `__`.
1204
        # We used to call anonymous Tensor values `__tensor_constant*`, but
1205
        # they were getting ignored by script. Now they're called
1206
        # `_tensor_constant*`
1207
        class IHaveATensorConstant(torch.nn.Module):
1208
            def forward(self, x):
1209
                return x + torch.rand(3, 4)
1210

1211
        traced = torch.fx.symbolic_trace(IHaveATensorConstant())
1212
        torch.jit.script(traced)
1213

1214
    def test_autowrap_functions(self):
1215
        class AutowrapFnTest(torch.nn.Module):
1216
            def forward(self, x):
1217
                return fx_int(x.shape[0] / 2)
1218

1219
        class AutowrapFnTest2(torch.nn.Module):
1220
            def forward(self, x):
1221
                return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
1222

1223
        # Check function(s) are wrapped
1224
        # `int` would normally throw a TypeError as argument can't be `Proxy`
1225
        tracer = Tracer(autowrap_functions=(fx_int,))
1226
        graph = tracer.trace(AutowrapFnTest())
1227
        traced = GraphModule(tracer.root, graph, 'test')
1228
        tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
1229
        tracer_2.trace(AutowrapFnTest2())
1230

1231
        # Test scriptability
1232
        traced_scripted = torch.jit.script(traced)
1233
        self.assertEqual(traced_scripted(torch.rand(4)), 2)
1234

1235
    def test_tuple_no_subscript(self):
1236
        def foo(x : Tuple):
1237
            return x[0]
1238

1239
        traced = torch.fx.symbolic_trace(foo)
1240
        x = (torch.randn(5, 3),)
1241
        torch.testing.assert_close(traced(x), x[0])
1242

1243
        bio = io.BytesIO()
1244

1245
        torch.save(traced, bio)
1246

1247
        bio.seek(0)
1248

1249
        # weights_only=False as this loads a GraphModule
1250
        # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
1251
        loaded = torch.load(bio, weights_only=False)
1252

1253
        torch.testing.assert_close(loaded(x), x[0])
1254

1255
    def test_torch_fx_len(self):
1256
        class FXLenTest(torch.nn.Module):
1257
            def forward(self, x):
1258
                return len(x)
1259

1260
        traced = symbolic_trace(FXLenTest())
1261
        self.assertEqual(traced(torch.rand(3, 4)), 3)
1262

1263
        # Test scriptability
1264
        scripted = torch.jit.script(FXLenTest())
1265
        self.assertEqual(scripted(torch.rand(3)), 3)
1266

1267
        traced_scripted = torch.jit.script(traced)
1268
        self.assertEqual(traced_scripted(torch.rand(3)), 3)
1269

1270
        # Test non-proxy len
1271
        class FXLenTest2(torch.nn.Module):
1272
            def __init__(self) -> None:
1273
                super().__init__()
1274
                self.l = [3, 4, 5]
1275

1276
            def forward(self, x):
1277
                return x + len(self.l)
1278

1279
        traced2 = symbolic_trace(FXLenTest2())
1280
        inp = torch.rand(3, 4)
1281
        self.assertEqual(traced2(inp), inp + 3.0)
1282
        self.assertIs(len, builtins.len)
1283

1284
    def test_torch_fx_getattr(self):
1285
        class FXGetattrTest(torch.nn.Module):
1286
            def forward(self, x):
1287
                return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
1288

1289
        traced = symbolic_trace(FXGetattrTest())
1290
        self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
1291

1292
    def test_sqrt(self):
1293
        class Sqrt1(torch.nn.Module):
1294
            def forward(self, x):
1295
                return sqrt(x.size(0))
1296

1297
        class Sqrt2(torch.nn.Module):
1298
            def forward(self, x):
1299
                return math.sqrt(x.size(0))
1300

1301
        class Sqrt3(torch.nn.Module):
1302
            def forward(self, x):
1303
                return x + math.sqrt(2) + sqrt(2)
1304

1305
        self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
1306
        self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
1307
        self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
1308
        self.assertIs(sqrt, _sqrt)
1309
        self.assertIs(math.sqrt, _sqrt)
1310

1311
    def test_torch_custom_ops(self):
1312
        class M(torch.nn.Module):
1313
            def forward(self, a):
1314
                b = torch.ops.aten.sigmoid(a)
1315
                c = torch.ops.aten.cat([a, b])
1316
                return torch.ops.aten.cat((c, c))
1317
        m = M()
1318
        input = torch.randn(3)
1319
        ref_out = m(input)
1320
        gm = symbolic_trace(m)
1321
        gm.graph.lint()
1322
        out = gm(input)
1323
        self.assertEqual(out, ref_out)
1324

1325
    def test_torch_op_overloads(self):
1326
        class M(torch.nn.Module):
1327
            def forward(self, a):
1328
                b = torch.ops.aten.add.Tensor(a, a)
1329
                return b
1330
        m = M()
1331
        input = torch.randn(3)
1332
        ref_out = m(input)
1333
        gm = symbolic_trace(m)
1334
        gm.graph.lint()
1335
        out = gm(input)
1336
        self.assertEqual(out, ref_out)
1337

1338
        for node in gm.graph.nodes:
1339
            if node.op == 'call_function':
1340
                assert isinstance(node.target, torch._ops.OpOverload)
1341
                assert node.target.__name__ == 'add.Tensor'
1342

1343
    def test_pickle_torch_custom_ops(self):
1344
        class M(torch.nn.Module):
1345
            def forward(self, a):
1346
                b = torch.ops.aten.sigmoid(a)
1347
                c = torch.ops.aten.cat([a, b])
1348
                return torch.ops.aten.cat((c, c))
1349
        m = M()
1350
        input = torch.randn(3)
1351
        ref_out = m(input)
1352
        gm = symbolic_trace(m)
1353
        gm.graph.lint()
1354
        pickled = pickle.dumps(gm)
1355
        loaded = pickle.loads(pickled)
1356
        self.assertEqual(loaded(input), gm(input))
1357

1358
    def test_pretty_print(self):
1359
        st = SimpleTest()
1360
        traced = symbolic_trace(st)
1361
        traced.graph.lint()
1362
        printed = str(traced)
1363
        assert 'SimpleTest()' in printed
1364
        assert 'torch.relu' in printed
1365

1366
    def test_pretty_print_graph(self):
1367
        class KwargPrintTest(torch.nn.Module):
1368
            def forward(self, x):
1369
                return torch.squeeze(x + 3.0, dim=2)
1370
        st = KwargPrintTest()
1371
        traced = symbolic_trace(st)
1372
        traced.graph.lint()
1373
        stringed = str(traced.graph)
1374
        for s in ['args', 'kwargs', 'num_users']:
1375
            assert s in stringed
1376

1377
    def test_custom_proxy_type(self):
1378
        class TensorPair:
1379
            def __init__(self, left, right):
1380
                self.left, self.right = left, right
1381

1382
            def add(self, other):
1383
                l = self.left + other.left
1384
                r = self.right + other.right
1385
                return TensorPair(l, r)
1386

1387
            def mul(self, other):
1388
                l = self.left * other.left
1389
                r = self.right * other.right
1390
                return TensorPair(l, r)
1391

1392
        def use_tensor_pair(x : TensorPair, y : TensorPair):
1393
            s = x.add(y)
1394
            return s.mul(x)
1395

1396
        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1397
        y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1398

1399
        ref_out = use_tensor_pair(x, y)
1400

1401
        traced = symbolic_trace(use_tensor_pair)
1402

1403
        traced_out = traced(x, y)
1404
        self.assertEqual(traced_out.left, ref_out.left)
1405
        self.assertEqual(traced_out.right, ref_out.right)
1406

1407
    def test_custom_proxy_type_literal(self):
1408
        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1409
            def __init__(self, left, right):
1410
                self.left, self.right = left, right
1411

1412
            def add(self, other):
1413
                l = self.left + other.left
1414
                r = self.right + other.right
1415
                return TensorPair(l, r)
1416

1417
            def mul(self, other):
1418
                l = self.left * other.left
1419
                r = self.right * other.right
1420
                return TensorPair(l, r)
1421

1422
        def use_tensor_pair_literal(x : TensorPair):
1423
            s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
1424
            return s.mul(x)
1425

1426
        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1427

1428
        ref_out = use_tensor_pair_literal(x)
1429

1430
        traced = symbolic_trace(use_tensor_pair_literal)
1431

1432
        traced_out = traced(x)
1433
        self.assertEqual(traced_out.left, ref_out.left)
1434
        self.assertEqual(traced_out.right, ref_out.right)
1435

1436
    def test_custom_proxy_dynamic_value(self):
1437
        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1438
            def __init__(self, left, right):
1439
                self.left, self.right = left, right
1440

1441
            def add(self, other):
1442
                l = self.left + other.left
1443
                r = self.right + other.right
1444
                return TensorPair(l, r)
1445

1446
            def mul(self, other):
1447
                l = self.left * other.left
1448
                r = self.right * other.right
1449
                return TensorPair(l, r)
1450

1451
        def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
1452
            s = x.add(TensorPair(y, y))
1453
            return s.mul(x)
1454

1455
        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1456
        y = torch.randn(5, 3)
1457
        ref_out = use_tensor_pair_ctor(x, y)
1458

1459
        traced = symbolic_trace(use_tensor_pair_ctor)
1460

1461
        traced_out = traced(x, y)
1462
        self.assertEqual(traced_out.left, ref_out.left)
1463
        self.assertEqual(traced_out.right, ref_out.right)
1464

1465
    def test_custom_proxy_input_dependent_control_flow(self):
1466
        class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
1467
            def __init__(self, inp):
1468
                if inp.sum() == 0:
1469
                    self.is_zero = True
1470
                    self.tensor = torch.tensor([])
1471
                else:
1472
                    self.is_zero = False
1473
                    self.tensor = inp
1474

1475
            def add(self, other):
1476
                if self.is_zero:
1477
                    return ZeroTensor(other.tensor)
1478
                elif other.is_zero:
1479
                    return self
1480

1481
        def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
1482
            return ZeroTensor(x + y)
1483

1484
        x, y = torch.randn(5, 3), torch.randn(5, 3)
1485

1486
        ref_out = use_zero_tensor(x, y)
1487

1488
        traced = symbolic_trace(use_zero_tensor)
1489

1490
        traced_out = traced(x, y)
1491

1492
        self.assertEqual(traced_out.is_zero, ref_out.is_zero)
1493
        self.assertEqual(traced_out.tensor, ref_out.tensor)
1494

1495
    def test_graph_fns(self):
1496
        g = Graph()
1497
        a = g.placeholder('a')
1498
        b = g.call_module('linear', (a,))
1499
        c = g.get_attr('bias')
1500
        d = g.call_method('add', (b, c))
1501
        e = g.call_function(torch.sin, (d,))
1502
        g.output(e)
1503
        mod = torch.nn.Module()
1504
        mod.linear = torch.nn.Linear(3, 4)
1505
        mod.bias = torch.rand(4)
1506
        gm = GraphModule(mod, g)
1507
        gm.graph.lint()
1508
        input = torch.rand(3)
1509
        r = gm(input)
1510
        ref = torch.sin(mod.linear(input) + mod.bias)
1511
        self.assertEqual(r, ref)
1512

1513
    def test_remove_uses(self):
1514
        g : torch.fx.Graph = Graph()
1515
        x : torch.fx.Node = g.placeholder('x')
1516
        relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1517
        neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1518
        g.output(neg)
1519

1520
        neg.replace_all_uses_with(relu)
1521
        g.erase_node(neg)
1522

1523
        self.assertTrue(neg not in relu.users)
1524

1525
    def test_remove_uses_with_custom_filter(self):
1526
        g : torch.fx.Graph = Graph()
1527
        x : torch.fx.Node = g.placeholder('x')
1528
        relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1529
        neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1530
        g.output(neg)
1531

1532
        neg.replace_all_uses_with(relu, lambda x: x != neg)
1533

1534
        self.assertTrue(neg in relu.users)
1535

1536
    def test_nonetype_annotation(self):
1537
        eb = torch.nn.EmbeddingBag(3, 4)
1538
        symbolic_trace(eb)
1539

1540
    def test_pickle_nonetype_annotation(self):
1541
        eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
1542
        traced = symbolic_trace(eb)
1543
        pickled = pickle.dumps(traced)
1544
        loaded = pickle.loads(pickled)
1545
        loaded.graph.lint()
1546
        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1547
        offsets = torch.LongTensor([0, 4])
1548
        self.assertEqual(loaded(input, offsets), traced(input, offsets))
1549

1550
    def test_return_tuple(self):
1551
        class M(torch.nn.Module):
1552
            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1553
                return (x, x + x)
1554

1555
        original = M()
1556
        traced = symbolic_trace(original)
1557
        self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
1558

1559
    def test_construct_root_dict(self):
1560
        graph : torch.fx.Graph = torch.fx.Graph()
1561
        a : torch.fx.Node = graph.create_node('placeholder', 'x')
1562
        b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1563
        c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1564
        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1565
        graph.output(d)
1566

1567
        linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
1568
        add_param : torch.Tensor = torch.rand(3, 4)
1569
        gm : torch.fx.GraphModule = torch.fx.GraphModule(
1570
            {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
1571
        gm.graph.lint()
1572

1573
        assert 'self.foo.bar.baz' in gm.code
1574

1575
        x : torch.Tensor = torch.rand(3, 3)
1576
        out : torch.Tensor = gm(x)
1577
        ref_out : torch.Tensor = linear_mod(x) + add_param
1578
        self.assertEqual(out, ref_out)
1579

1580
    def test_symbolic_trace_assert(self):
1581

1582
        class AssertsTensorShape(torch.nn.Module):
1583
            def forward(self, x):
1584
                torch._assert(x.shape[1] > 4, "assert_foobar")
1585
                return x
1586

1587
        m = AssertsTensorShape()
1588
        # verify traceability
1589
        traced = symbolic_trace(m)
1590
        # verify assertion on traced model works correctly at runtime
1591
        traced(torch.rand(4, 5))
1592
        with self.assertRaisesRegex(AssertionError, "assert_foobar"):
1593
            traced(torch.rand(4, 3))
1594
        # verify the symbolically traced module is scriptable
1595
        ms = torch.jit.script(m)
1596
        with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
1597
            ms(torch.rand(4, 3))
1598

1599
    def test_fx_create_arg(self):
1600
        class CustomArgObject:
1601
            def __init__(self, x, y):
1602
                self.x = x
1603
                self.y = y
1604

1605
            def __fx_create_arg__(self, tracer: torch.fx.Tracer):
1606
                return tracer.create_node(
1607
                    "call_function",
1608
                    CustomArgObject,
1609
                    args=(
1610
                        tracer.create_arg(self.x),
1611
                        tracer.create_arg(self.y),
1612
                    ),
1613
                    kwargs={},
1614
                )
1615

1616
        class HasCustomArgObjectWhenLeaf(torch.nn.Module):
1617
            def forward(self, o: CustomArgObject):
1618
                # Not normally traceable; good reason to make
1619
                # this module a leaf.
1620
                for x in o.x:
1621
                    o.y += x
1622
                return o.y
1623

1624
        class Root(torch.nn.Module):
1625
            def __init__(self) -> None:
1626
                super().__init__()
1627
                self.inner = HasCustomArgObjectWhenLeaf()
1628

1629
            def forward(self, x, y):
1630
                o = CustomArgObject(x, y)
1631
                return self.inner(o)
1632

1633
        class CreateArgTracer(torch.fx.Tracer):
1634
            def is_leaf_module(self, m, module_qualified_name):
1635
                return type(m) is HasCustomArgObjectWhenLeaf
1636

1637
        m = Root()
1638
        graph = CreateArgTracer().trace(m)
1639
        gm = torch.fx.GraphModule(m, graph)
1640
        assert "CustomArgObject(" in gm.code
1641

1642
    def test_trace_fn_constant(self):
1643
        some_constant = torch.rand(3, 4)
1644

1645
        def add_const(x):
1646
            return some_constant + x
1647

1648
        traced = symbolic_trace(add_const)
1649

1650
        input = torch.rand(3, 4)
1651
        self.assertEqual(traced(input), add_const(input))
1652

1653
    def test_copy_no_remap(self):
1654
        traced = symbolic_trace(SimpleTest())
1655
        g = traced.graph
1656
        copied = torch.fx.Graph()
1657
        for node in g.nodes:
1658
            copied.node_copy(node)
1659
        with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
1660
            copied.lint()
1661

1662
    def test_wrong_topo(self):
1663
        graph : torch.fx.Graph = torch.fx.Graph()
1664
        a : torch.fx.Node = graph.create_node('placeholder', 'x')
1665
        b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1666
        c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1667
        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1668
        graph.output(d)
1669
        nodes = list(graph.nodes)
1670
        nodes[3].append(nodes[2])
1671
        with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
1672
            graph.lint()
1673

1674
    def test_wrong_target_type(self):
1675
        graph : torch.fx.Graph = torch.fx.Graph()
1676
        with self.assertRaises(ValueError):
1677
            n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
1678
                              args=(), kwargs={})
1679

1680
    def test_example_shape_prop(self):
1681
        class TestCase(torch.nn.Module):
1682
            def __init__(self) -> None:
1683
                super().__init__()
1684
                self.attr = torch.randn(3, 4)
1685
                self.submod = torch.nn.Linear(4, 4)
1686

1687
            def forward(self, x):
1688
                return torch.neg(self.submod(x.relu() + self.attr))
1689
        tc = TestCase()
1690
        tc_traced = symbolic_trace(tc)
1691
        ref_out = tc_traced(torch.rand(3, 4))
1692
        shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
1693

1694
        # Make sure we're testing all opcodes
1695
        opcodes = set()
1696
        output_shape : Optional[torch.Shape] = None
1697
        output_stride : Optional[Tuple[int]] = None
1698
        for node in tc_traced.graph.nodes:
1699
            opcodes.add(node.op)
1700
            if node.op == 'output':
1701
                output_shape = node.args[0].meta['tensor_meta'].shape
1702
                output_stride = node.args[0].meta['tensor_meta'].stride
1703
        self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
1704
                                   'call_module', 'output'})
1705

1706
        # Test shape propagation and make sure results match actual
1707
        self.assertEqual(output_shape, ref_out.shape)
1708
        self.assertEqual(output_stride, ref_out.stride())
1709

1710
    def test_shape_prop_layout(self):
1711
        class ConvTest(torch.nn.Module):
1712
            def __init__(self) -> None:
1713
                super().__init__()
1714
                self.conv_mod = torch.nn.Conv2d(5, 5, 3)
1715

1716
            def forward(self, x):
1717
                return self.conv_mod(x)
1718

1719
        # contiguous layout
1720
        test_mod = ConvTest()
1721
        traced = symbolic_trace(test_mod)
1722
        x = torch.randn(5, 5, 224, 224)
1723
        shape_prop.ShapeProp(traced).propagate(x)
1724

1725
        assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1726
                   for node in traced.graph.nodes)
1727

1728
        x_channels_last = x.contiguous(memory_format=torch.channels_last)
1729
        traced.to(memory_format=torch.channels_last)
1730
        shape_prop.ShapeProp(traced).propagate(x_channels_last)
1731
        for node in traced.graph.nodes:
1732
            # NB: the implementation of conv may not preserve the memory format,
1733
            # unfortunately. The best we can do is just check that the placeholder
1734
            # node is channels-last
1735
            if node.op in {'placeholder'}:
1736
                self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
1737

1738
    def test_shape_prop_aggregate(self):
1739
        class ReturnTwo(torch.nn.Module):
1740
            def forward(self, x):
1741
                return (3, torch.sum(x))
1742

1743
        class UnderTest(torch.nn.Module):
1744
            def __init__(self) -> None:
1745
                super().__init__()
1746
                self.rt = ReturnTwo()
1747

1748
            def forward(self, x):
1749
                return self.rt(x)
1750

1751
        ut = UnderTest()
1752

1753
        class RTTracer(torch.fx.Tracer):
1754
            def is_leaf_module(self, m, module_qualified_name):
1755
                return type(m) is ReturnTwo
1756

1757
        graph = RTTracer().trace(ut)
1758
        mod = torch.fx.GraphModule(ut, graph)
1759

1760
        shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
1761

1762
        for node in mod.graph.nodes:
1763
            if node.op == 'call_module':
1764
                assert 'tensor_meta' in node.meta
1765
                tensor_meta = node.meta['tensor_meta']
1766
                assert tensor_meta[0] == 3
1767
                assert tensor_meta[1].shape == torch.Size([])
1768

1769
    def test_shape_prop_layout_3d(self):
1770
        class ConvTest3d(torch.nn.Module):
1771
            def __init__(self) -> None:
1772
                super().__init__()
1773
                self.conv_mod = torch.nn.Conv3d(5, 5, 3)
1774

1775
            def forward(self, x):
1776
                return self.conv_mod(x)
1777

1778
        test_mod_3d = ConvTest3d()
1779
        traced_3d = symbolic_trace(test_mod_3d)
1780
        x_3d = torch.randn(5, 5, 224, 224, 15)
1781
        shape_prop.ShapeProp(traced_3d).propagate(x_3d)
1782
        assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1783
                   for node in traced_3d.graph.nodes)
1784

1785
        x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
1786
        traced_3d.to(memory_format=torch.channels_last_3d)
1787
        shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
1788
        for node in traced_3d.graph.nodes:
1789
            # NB: the implementation of conv may not preserve the memory format,
1790
            # unfortunately. The best we can do is just check that the placeholder
1791
            # node is channels-last
1792
            if node.op in {'placeholder'}:
1793
                self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
1794

1795
    def test_nn_module_stack(self):
1796
        class SubModule(torch.nn.Module):
1797
            def __init__(self) -> None:
1798
                super().__init__()
1799
                self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
1800

1801
            def forward(self, x):
1802
                return self.conv_mod(x)
1803

1804
        class MyModule(torch.nn.Module):
1805
            def __init__(self) -> None:
1806
                super().__init__()
1807
                self.sub_mod = SubModule()
1808

1809
            def forward(self, x):
1810
                return self.sub_mod(x)
1811

1812
        m = MyModule()
1813
        gm = torch.fx.symbolic_trace(m)
1814

1815
        mod_stack = {}
1816
        expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))),
1817
                          ('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))]
1818
        for node in gm.graph.nodes:
1819
            mod_stack = node.meta.get('nn_module_stack', {})
1820
            if mod_stack:
1821
                break
1822
        stack_list = list(mod_stack.items())
1823
        self.assertEqual(stack_list, expected_stack)
1824

1825
    def test_transformer_preserves_nn_module_stack_for_get_attr(self):
1826
        class M(torch.nn.Module):
1827
            def __init__(self) -> None:
1828
                super().__init__()
1829
                self.weight = torch.nn.Parameter(torch.ones(1, 1))
1830

1831
            def forward(self, x):
1832
                return self.weight + x
1833

1834
        tracer = torch.fx.Tracer()
1835
        graph = tracer.trace(M())
1836
        gm = GraphModule(tracer.root, graph)
1837
        for node in gm.graph.nodes:
1838
            if node.op == 'get_attr':
1839
                node.meta["nn_module_stack"] = "self"
1840
                node.meta["stack_trace"] = "stack_trace"
1841
                node.meta["source_fn_stack"] = "source_fn_stack"
1842
        new_gm = Transformer(gm).transform()
1843
        for node in new_gm.graph.nodes:
1844
            if node.op == 'get_attr':
1845
                self.assertEqual(node.meta["nn_module_stack"], "self")
1846
                self.assertEqual(node.meta["stack_trace"], "stack_trace")
1847
                self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
1848

1849
    def test_interpreter(self):
1850
        class MyModule(torch.nn.Module):
1851
            def __init__(self) -> None:
1852
                super().__init__()
1853
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1854
                self.linear = torch.nn.Linear(4, 5)
1855

1856
            def forward(self, x):
1857
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1858

1859
        m = MyModule()
1860
        gm = torch.fx.symbolic_trace(m)
1861

1862
        interpreter = Interpreter(gm)
1863
        input = torch.randn(3, 4)
1864
        self.assertEqual(interpreter.run(input), gm(input))
1865
        self.assertEqual(interpreter.run(input), m(input))
1866

1867
    def test_interpreter_other_graph(self):
1868
        class MyModule(torch.nn.Module):
1869
            def __init__(self) -> None:
1870
                super().__init__()
1871
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1872
                self.linear = torch.nn.Linear(4, 5)
1873

1874
            def forward(self, x):
1875
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1876

1877
        m = MyModule()
1878
        gm = torch.fx.symbolic_trace(m)
1879

1880
        interpreter = Interpreter(gm, graph=gm.graph)
1881
        input = torch.randn(3, 4)
1882
        self.assertEqual(interpreter.run(input), gm(input))
1883
        self.assertEqual(interpreter.run(input), m(input))
1884

1885
    def test_interpreter_run_node_override(self):
1886
        class MyModule(torch.nn.Module):
1887
            def __init__(self) -> None:
1888
                super().__init__()
1889
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1890
                self.linear = torch.nn.Linear(4, 5)
1891

1892
            def forward(self, x):
1893
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1894

1895
        m = MyModule()
1896
        gm = torch.fx.symbolic_trace(m)
1897

1898
        class RunNodeInterpreter(Interpreter):
1899
            def __init__(self, module):
1900
                super().__init__(module)
1901

1902
            def run_node(self, n : Node) -> Any:
1903
                result = super().run_node(n)
1904
                n.cached_value = result
1905
                return result
1906

1907
        input = torch.randn(3, 4)
1908
        RunNodeInterpreter(gm).run(input)
1909
        for node in gm.graph.nodes:
1910
            assert hasattr(node, 'cached_value')
1911

1912
    def test_interpreter_onthefly_swap(self):
1913

1914
        def fn(x):
1915
            return torch.sigmoid(x).neg()
1916

1917
        gm = torch.fx.symbolic_trace(fn)
1918

1919
        class NegSigmSwapInterpreter(Interpreter):
1920
            def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1921
                if target == torch.sigmoid:
1922
                    return torch.neg(*args, **kwargs)
1923
                return super().call_function(n)  # noqa: F821
1924

1925
            def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1926
                if target == 'neg':
1927
                    call_self, *args_tail = args
1928
                    return call_self.sigmoid(*args_tail, **kwargs)
1929
                return super().call_method(n)  # noqa: F821
1930

1931
        input = torch.randn(3, 4)
1932
        result = NegSigmSwapInterpreter(gm).run(input)
1933
        self.assertEqual(result, torch.neg(input).sigmoid())
1934

1935
    def test_interpreter_partial_eval(self):
1936
        class MyModule(torch.nn.Module):
1937
            def __init__(self) -> None:
1938
                super().__init__()
1939
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1940
                self.linear = torch.nn.Linear(4, 5)
1941

1942
            def forward(self, x):
1943
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1944

1945
        gm = torch.fx.symbolic_trace(MyModule())
1946
        interp = Interpreter(gm)
1947
        env = {}
1948
        for node in gm.graph.nodes:
1949
            if node.op == 'call_module' and node.target == 'linear':
1950
                env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
1951
                break
1952
        assert len(env) == 1
1953
        x = torch.randn(3, 4)
1954
        result = interp.run(x, initial_env=env)
1955
        self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
1956

1957
    def test_interpreter_star_args(self):
1958
        def with_star_args(x, *args):
1959
            return x + args[0]
1960

1961
        gm = torch.fx.symbolic_trace(with_star_args)
1962
        interp = Interpreter(gm)
1963
        result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
1964
        self.assertEqual(result, torch.ones(3, 4) * 2.0)
1965

1966
    @skipIfNoTorchVision
1967
    def test_interpreter_noop_resnet18(self):
1968
        rn18 = torchvision_models.resnet18()
1969
        transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
1970
        inp = torch.randn(5, 3, 224, 224)
1971
        self.assertEqual(transformed(inp), rn18(inp))
1972

1973
    @skipIfNoTorchVision
1974
    def test_interpreter_gc_values(self):
1975
        rn18 = torchvision_models.resnet18()
1976
        interp = Interpreter(symbolic_trace(rn18))
1977
        inp = torch.rand(5, 3, 224, 224)
1978
        out = interp.run(inp)
1979
        env_key_names = {n.name for n in interp.env.keys()}
1980
        self.assertEqual(env_key_names, {'output'})
1981

1982
    def test_interpreter_default_args(self):
1983
        class Model(torch.nn.Module):
1984
            def forward(self, x, y=3.14159):
1985
                return x + y
1986

1987
        model = Model()
1988
        gm = torch.fx.symbolic_trace(model)
1989

1990
        interp = Interpreter(gm)
1991
        x = torch.randn(5, 3)
1992
        out = interp.run(x)
1993
        torch.testing.assert_close(out, x + 3.14159)
1994

1995
    def test_interpreter_not_enough_args(self):
1996
        class Model(torch.nn.Module):
1997
            def forward(self, x, y):
1998
                return x + y
1999

2000
        model = Model()
2001
        gm = torch.fx.symbolic_trace(model)
2002

2003
        interp = Interpreter(gm)
2004
        x = torch.randn(5, 3)
2005
        with self.assertRaisesRegex(RuntimeError,
2006
                                    'Expected positional argument for parameter y, but one was not passed in'):
2007
            out = interp.run(x)
2008

2009
    def test_transformer_noop(self):
2010
        class MyModule(torch.nn.Module):
2011
            def __init__(self) -> None:
2012
                super().__init__()
2013
                self.param = torch.nn.Parameter(torch.rand(3, 4))
2014
                self.linear = torch.nn.Linear(4, 5)
2015

2016
            def forward(self, x):
2017
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
2018

2019
        m = MyModule()
2020
        gm = torch.fx.symbolic_trace(m)
2021

2022
        new_gm = Transformer(gm).transform()
2023

2024
        input = torch.randn(3, 4)
2025
        self.assertEqual(new_gm(input), gm(input))
2026

2027
    def test_transformer_op_swap(self):
2028

2029
        def fn(x):
2030
            return torch.sigmoid(x).neg()
2031

2032
        gm = torch.fx.symbolic_trace(fn)
2033

2034
        class NegSigmSwapXformer(Transformer):
2035
            def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
2036
                if target == torch.sigmoid:
2037
                    return torch.neg(*args, **kwargs)
2038
                return super().call_function(n)  # noqa: F821
2039

2040
            def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
2041
                if target == 'neg':
2042
                    call_self, *args_tail = args
2043
                    return call_self.sigmoid(*args_tail, **kwargs)
2044
                return super().call_method(n)  # noqa: F821
2045

2046
        transformed = NegSigmSwapXformer(gm).transform()
2047
        input = torch.randn(3, 4)
2048
        self.assertEqual(transformed(input), torch.neg(input).sigmoid())
2049

2050
    def test_transformer_multi_outputs(self):
2051
        class MyModule(torch.nn.Module):
2052
            def __init__(self) -> None:
2053
                super().__init__()
2054
                self.param = torch.nn.Parameter(torch.rand(3, 4))
2055
                self.linear = torch.nn.Linear(4, 5)
2056

2057
            def forward(self, x):
2058
                x = x + self.param
2059
                out = self.linear(x)
2060
                return x, out
2061

2062
        m = MyModule()
2063
        gm = torch.fx.symbolic_trace(m)
2064

2065
        new_gm = Transformer(gm).transform()
2066

2067
        input = torch.randn(3, 4)
2068
        self.assertEqual(new_gm(input), gm(input))
2069

2070
    def test_fn_type_annotations(self):
2071
        class Foo(torch.nn.Module):
2072
            def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
2073
                return {'a': p.x + p.y + z + i}
2074

2075
        foo_scripted = torch.jit.script(Foo())
2076
        foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2077

2078
        fxed = symbolic_trace(Foo())
2079
        fxed_scripted = torch.jit.script(fxed)
2080
        fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2081

2082
    def test_fn_type_annotation_empty(self):
2083
        def forward(a : List[torch.Tensor]):
2084
            return a[0]
2085
        torch.jit.script(symbolic_trace(forward))
2086

2087
    def test_wrapped_method(self):
2088
        def wrap_with_relu(fn):
2089
            @functools.wraps(fn)
2090
            def wrapper(*args, **kwargs):
2091
                return torch.relu(fn(*args, **kwargs))
2092
            return wrapper
2093

2094
        class Foo(torch.nn.Module):
2095
            @wrap_with_relu
2096
            def forward(self, x, w):
2097
                return torch.matmul(x, w)
2098

2099
        f = Foo()
2100
        traced = symbolic_trace(f)
2101
        x, w = torch.rand(3, 4), torch.rand(4, 4)
2102
        self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
2103

2104
    def test_empty_graph_codegen(self):
2105
        graph = torch.fx.Graph()
2106
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2107
        self.assertEqual(gm(), None)
2108

2109
    def test_sequential(self):
2110
        m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
2111
        gm = torch.fx.symbolic_trace(m)
2112
        gm_copy = copy.deepcopy(gm)
2113

2114
    def test_ctx_mgr(self):
2115
        @contextlib.contextmanager
2116
        def do_nothing():
2117
            yield
2118

2119
        class M(torch.nn.Module):
2120
            @do_nothing()
2121
            def forward(self, x):
2122
                return torch.relu(x)
2123

2124
        m = M()
2125
        self.checkGraphModule(m, (torch.rand(3, 4),))
2126

2127
    def test_typename_print(self):
2128
        graph : torch.fx.Graph = torch.fx.Graph()
2129
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2130
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
2131
                                              type_expr=List[float])
2132
        output : torch.fx.Node = graph.output(b)
2133

2134
        self.assertTrue('typing.List[float]' in str(graph))
2135

2136
    def test_layout(self):
2137
        class M(torch.nn.Module):
2138
            def forward(self, x):
2139
                return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
2140

2141
        traced = symbolic_trace(M())
2142
        x = torch.rand(5, 9, 3, 4)
2143
        self.assertEqual(traced(x), torch.zeros_like(x))
2144

2145
    def test_ellipsis(self):
2146
        class M(torch.nn.Module):
2147
            def forward(self, x, y):
2148
                return x + y[:, 1:10, ...]
2149

2150
        traced = symbolic_trace(M())
2151
        x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
2152
        self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
2153

2154
    def test_inf_nan(self):
2155
        class FooMod(torch.nn.Module):
2156
            def forward(self, x):
2157
                return x + float('inf'), x + float('-inf'), x + float('nan')
2158

2159
        fm = FooMod()
2160
        self.checkGraphModule(fm, (torch.rand(3, 4),))
2161

2162
    def test_inf_nan_kwds(self):
2163
        graph : torch.fx.Graph = torch.fx.Graph()
2164
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2165
        b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
2166
        c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
2167
        graph.output((b, c))
2168

2169
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2170
        x = torch.rand(3, 4)
2171
        self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
2172

2173
    def test_deepcopy_recursion_depth(self):
2174
        depth = sys.getrecursionlimit() + 20
2175

2176
        g = torch.fx.Graph()
2177
        x = g.placeholder('x')
2178
        for i in range(depth):
2179
            x = g.call_function(torch.relu, (x,))
2180
        g.output(x)
2181

2182
        copied_graph = copy.deepcopy(g)
2183

2184
        val_map = {}
2185
        for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2186
            val_map[orig_node] = new_node
2187

2188
        for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2189
            orig_users = set(orig_node.users.keys())
2190
            orig_users_equiv = {val_map[u] for u in orig_users}
2191
            new_users = set(new_node.users.keys())
2192
            self.assertEqual(orig_users_equiv, new_users)
2193

2194
    @skipIfNoTorchVision
2195
    def test_replace_uses(self):
2196
        rn18 = torchvision_models.resnet18()
2197

2198
        class LowerReluTracer(torch.fx.Tracer):
2199
            def is_leaf_module(self, m : torch.nn.Module, qualname : str):
2200
                if isinstance(m, torch.nn.ReLU):
2201
                    return False
2202
                return super().is_leaf_module(m, qualname)
2203

2204
        rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
2205

2206
        to_erase = []
2207
        for node in rn18_traced.graph.nodes:
2208
            if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
2209
                kwargs = node.kwargs.copy()
2210
                # Neg doesn't have in-place
2211
                kwargs.pop('inplace')
2212
                with rn18_traced.graph.inserting_before(node):
2213
                    new_node = rn18_traced.graph.call_function(
2214
                        the_function=torch.neg, args=node.args, kwargs=node.kwargs)
2215
                node.replace_all_uses_with(replace_with=new_node)
2216
                to_erase.append(node)
2217

2218
        for node in to_erase:
2219
            rn18_traced.graph.erase_node(node)
2220

2221
    def test_replace_input(self):
2222
        graph : torch.fx.Graph = torch.fx.Graph()
2223
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2224
        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2225
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2226
        output : torch.fx.Node = graph.output(b)
2227

2228
        b.replace_input_with(x, y)
2229

2230
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2231

2232
        input_x = torch.randn(33, 44)
2233
        input_y = torch.randn(11, 22)
2234
        self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
2235

2236
    def test_insertion_point(self):
2237
        graph : torch.fx.Graph = torch.fx.Graph()
2238
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2239
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2240
        output : torch.fx.Node = graph.output(b)
2241

2242
        with graph.inserting_before(b):
2243
            neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2244
            _, *relu_args = b.args
2245
            b.args = (neg, *relu_args)
2246

2247
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2248

2249
        input = torch.randn(33, 44)
2250
        self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2251

2252
    def test_update_args_api(self):
2253
        graph : torch.fx.Graph = torch.fx.Graph()
2254
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2255
        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2256
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2257
        output : torch.fx.Node = graph.output(b)
2258

2259
        orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2260
        inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2261
        self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2262

2263
        b.update_arg(0, y)
2264
        new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2265
        self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2266

2267
    def test_update_kwargs_api(self):
2268
        graph : torch.fx.Graph = torch.fx.Graph()
2269
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2270
        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2271
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
2272
        output : torch.fx.Node = graph.output(b)
2273

2274
        orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2275
        inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2276
        self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2277

2278
        b.update_kwarg('input', y)
2279
        new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2280
        self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2281

2282
    def test_immutable_list_pytree_ops(self):
2283
        rand_tensor = torch.randn(5, 3)
2284
        l = immutable_list([3, [rand_tensor, 42]])
2285

2286
        flattened, spec = pytree.tree_flatten(l)
2287
        assert flattened == [3, rand_tensor, 42]
2288

2289
        unflattened = pytree.tree_unflatten(flattened, spec)
2290
        assert unflattened == l
2291
        assert isinstance(unflattened, immutable_list)
2292

2293
    def test_immutable_dict_pytree_ops(self):
2294
        rand_tensor = torch.randn(5, 3)
2295
        d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]})
2296

2297
        flattened, spec = pytree.tree_flatten(d)
2298
        assert flattened == [3, rand_tensor, 42]
2299

2300
        unflattened = pytree.tree_unflatten(flattened, spec)
2301
        assert unflattened == d
2302
        assert isinstance(unflattened, immutable_dict)
2303

2304
    def test_move_before(self):
2305
        graph : torch.fx.Graph = torch.fx.Graph()
2306
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2307
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2308
        output : torch.fx.Node = graph.output(b)
2309

2310
        neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2311
        _, *relu_args = b.args
2312
        b.args = (neg, *relu_args)
2313
        b.prepend(neg)
2314

2315
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2316

2317
        input = torch.randn(33, 44)
2318
        self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2319

2320
    def test_prepend_self(self):
2321
        graph : torch.fx.Graph = torch.fx.Graph()
2322
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2323
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2324
        output : torch.fx.Node = graph.output(b)
2325

2326
        b.prepend(b)
2327
        x.append(b)
2328
        self.assertEqual(len(graph.nodes), 3)
2329

2330
    def test_erase_node_error(self):
2331
        st = SimpleTest()
2332
        traced = symbolic_trace(st)
2333

2334
        for node in traced.graph.nodes:
2335
            # Test deleting with uses both in another Node and at the output
2336
            if node.target in [operator.add, torch.relu]:
2337
                with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
2338
                    traced.graph.erase_node(node)
2339

2340
    def test_copy_it(self):
2341
        d = immutable_dict([(3, 4), (5, 6)])
2342
        l = immutable_list([(3, 4), (5, 6)])
2343

2344
        self.assertEqual(d, deepcopy(d))
2345
        self.assertEqual(l, deepcopy(l))
2346

2347
    def test_get_torch_func_signature(self):
2348
        for key in dir(torch):
2349
            obj = getattr(torch, key)
2350
            if callable(obj):
2351
                schemas = get_signature_for_torch_op(obj)
2352

2353
    def test_find_uses(self):
2354
        graph = torch.fx.Graph()
2355
        x = torch.fx.Proxy(graph.placeholder('x'))
2356

2357
        y = torch.relu(x)
2358
        z = x + x
2359
        u = torch.neg(x)
2360
        graph.output((y + z + u).node)
2361
        graph.lint()
2362

2363
        users_of_x = x.node.users
2364
        self.assertEqual(len(users_of_x), 3)
2365
        expected_ops = {'relu', 'add', 'neg'}
2366
        for use in users_of_x:
2367
            assert any(use.name.startswith(prefix) for prefix in expected_ops)
2368

2369
    def test_inline_graph(self):
2370
        class InlineInto(torch.nn.Module):
2371
            def forward(self, x):
2372
                return torch.relu(x)
2373

2374
        class ToInline(torch.nn.Module):
2375
            def forward(self, x):
2376
                return torch.neg(x)
2377

2378
        inline_into = symbolic_trace(InlineInto())
2379
        to_inline = symbolic_trace(ToInline())
2380

2381
        combined_graph = torch.fx.Graph()
2382
        output_node = combined_graph.graph_copy(inline_into.graph, {})
2383

2384
        input_node = next(iter(to_inline.graph.nodes))
2385
        assert input_node and input_node.op == 'placeholder'
2386

2387
        val_map = {input_node : output_node}
2388
        output = combined_graph.graph_copy(to_inline.graph, val_map)
2389
        combined_graph.output(output)
2390

2391
        combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
2392

2393
        input = torch.rand(3, 4)
2394
        self.assertEqual(combined_module(input), input.relu().neg())
2395

2396
    def test_multi_insert_point(self):
2397
        graph = torch.fx.Graph()
2398
        x = torch.fx.Proxy(graph.placeholder('x'))
2399
        relu = torch.relu(x)
2400

2401
        with graph.inserting_before(relu.node):
2402
            y = torch.neg(x)
2403
            z = torch.tanh(y)
2404

2405
        graph.output((relu.node, z.node))
2406
        graph.lint()
2407

2408
        expected_ops = ['x', 'neg', 'tanh', 'relu']
2409
        for node, expected in zip(graph.nodes, expected_ops):
2410
            assert expected in node.name
2411

2412
    def test_reassign_args_kwargs_uses(self):
2413
        graph = torch.fx.Graph()
2414
        x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
2415
        z = x + y
2416
        zed = z + z + z
2417
        graph.output(zed.node)
2418
        graph.lint()
2419

2420
        # zed = z + z + z -> zed = z + z + x
2421
        zed.node.args = (zed.node.args[0], x.node)
2422
        self.assertEqual(list(x.node.users.keys()), [z.node, zed.node])
2423

2424
        # z = x + y -> z = y + y
2425
        z.node.args = (y.node, y.node)
2426
        self.assertEqual(list(x.node.users.keys()), [zed.node])
2427

2428
    def test_trace_function(self):
2429
        def foo(x, y):
2430
            return torch.relu(x) + y
2431

2432
        x, y = torch.randn(3, 4), torch.randn(3, 4)
2433
        self.checkGraphModule(foo, (x, y))
2434

2435
    def test_trace_return_dataclass(self):
2436
        """
2437
        Test case for Module that return dataclass
2438
        """
2439
        from dataclasses import dataclass
2440

2441
        @dataclass
2442
        class MyOutput:
2443
            foo: torch.Tensor
2444
            bar: torch.Tensor
2445

2446
        class ModuleReturnDataclass(torch.nn.Module):
2447
            def forward(self, d : torch.Tensor):
2448
                return MyOutput(foo=d + d, bar=d * 3)
2449

2450
        module = ModuleReturnDataclass()
2451
        traced_graph = symbolic_trace(module).graph
2452
        print(traced_graph)
2453

2454
        gm = GraphModule(module, traced_graph)
2455
        x = torch.rand(1)
2456

2457
        self.assertEqual(module(x), gm(x))
2458

2459
    def test_trace_return_dataclass_nested(self):
2460
        """
2461
        Test case for Module that return dataclass
2462
        """
2463
        from dataclasses import dataclass
2464

2465
        @dataclass
2466
        class MyOutput:
2467
            foo: torch.Tensor
2468
            bar: torch.Tensor
2469

2470
        class ModuleReturnDataclass(torch.nn.Module):
2471
            def forward(self, d : torch.Tensor):
2472
                return MyOutput(foo=d + d, bar=d * 3)
2473

2474
        class CallsModule(torch.nn.Module):
2475
            def __init__(self) -> None:
2476
                super().__init__()
2477
                self.m = ModuleReturnDataclass()
2478

2479
            def forward(self, x):
2480
                tmp = self.m(x)
2481
                return MyOutput(foo=tmp.foo, bar=tmp.bar)
2482

2483
        module = CallsModule()
2484
        traced_graph = symbolic_trace(module).graph
2485
        print(traced_graph)
2486

2487
        gm = GraphModule(module, traced_graph)
2488
        x = torch.rand(1)
2489

2490
        self.assertEqual(module(x), gm(x))
2491

2492
    def test_trace_return_namedtuple(self):
2493
        """
2494
        Test case for Module that return namedtuple
2495
        """
2496
        class MyOutput(NamedTuple):
2497
            foo: torch.Tensor
2498
            bar: torch.Tensor
2499

2500
        class ModuleReturnNamedTuple(torch.nn.Module):
2501
            def forward(self, d : torch.Tensor):
2502
                return MyOutput(foo=d, bar=d)
2503

2504
        module = ModuleReturnNamedTuple()
2505

2506
        traced_graph = symbolic_trace(module).graph
2507
        print(traced_graph)
2508

2509
        gm = GraphModule(module, traced_graph)
2510
        x = torch.rand(1)
2511

2512
        self.assertEqual(module(x), gm(x))
2513

2514
    def test_trace_dict_int_keys(self):
2515
        class ModWithDictArg(torch.nn.Module):
2516
            def forward(self, d : Dict[int, torch.Tensor]):
2517
                return d[42]
2518

2519
        class CallsModWithDict(torch.nn.Module):
2520
            def __init__(self) -> None:
2521
                super().__init__()
2522
                self.m = ModWithDictArg()
2523

2524
            def forward(self, x):
2525
                return self.m({42: x})
2526

2527
        class MyTracer(torch.fx.Tracer):
2528
            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2529
                return isinstance(m, ModWithDictArg)
2530

2531
        traced_graph = MyTracer().trace(CallsModWithDict())
2532

2533
    def test_trace_dict_proxy_keys(self):
2534
        class ModWithDictArg(torch.nn.Module):
2535
            def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
2536
                return d[42]
2537

2538
        class CallsModWithDict(torch.nn.Module):
2539
            def __init__(self) -> None:
2540
                super().__init__()
2541
                self.m = ModWithDictArg()
2542

2543
            def forward(self, x):
2544
                return self.m({x: x})
2545

2546
        class MyTracer(torch.fx.Tracer):
2547
            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2548
                return isinstance(m, ModWithDictArg)
2549

2550
        with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
2551
            traced_graph = MyTracer().trace(CallsModWithDict())
2552

2553
    def test_module_deepcopy_edit_nodes(self):
2554
        class Foo(torch.nn.Module):
2555
            def forward(self, x):
2556
                return torch.relu(x)
2557

2558
        traced1 = symbolic_trace(Foo())
2559
        copied = copy.deepcopy(traced1)
2560

2561
        for node in copied.graph.nodes:
2562
            if node.target == torch.relu:
2563
                node.target = torch.neg
2564

2565
        copied.recompile()
2566
        traced1.recompile()
2567

2568
        x = torch.randn(15, 15)
2569
        torch.testing.assert_close(traced1(x), torch.relu(x))
2570
        torch.testing.assert_close(copied(x), torch.neg(x))
2571

2572
    def test_direct_param_use(self):
2573
        class TransposeTest(torch.nn.Module):
2574
            def __init__(self) -> None:
2575
                super().__init__()
2576
                self.b = torch.nn.Parameter(torch.rand(4, 3))
2577

2578
            def forward(self, x):
2579
                return self.b
2580

2581
        class Foo(torch.nn.Module):
2582
            def __init__(self) -> None:
2583
                super().__init__()
2584
                self.a = TransposeTest()
2585

2586
            def forward(self, x):
2587
                return self.a.b, self.a.b.t(), self.a.b.view(12)
2588

2589
        traced = torch.fx.symbolic_trace(Foo())
2590
        assert all('constant' not in node.target for node in traced.graph.nodes)
2591

2592
    def test_single_default_arg(self):
2593
        class M(torch.nn.Module):
2594
            def forward(self, y=1):
2595
                return y
2596

2597
        m = M()
2598
        self.checkGraphModule(m, ())
2599
        self.checkGraphModule(m, (3,))
2600

2601
    def test_multiple_default_args(self):
2602
        class M(torch.nn.Module):
2603
            def forward(self, y=1, z=2):
2604
                return y + z
2605

2606
        m = M()
2607
        self.checkGraphModule(m, ())
2608
        self.checkGraphModule(m, (3,))
2609
        self.checkGraphModule(m, (3, 4))
2610

2611
    def test_regular_and_default_args(self):
2612
        class M(torch.nn.Module):
2613
            def forward(self, x, y=1):
2614
                return x + y
2615

2616
        m = M()
2617
        self.checkGraphModule(m, (2,))
2618
        self.checkGraphModule(m, (2, 3))
2619

2620
    def test_string_literal_return(self):
2621
        class M(torch.nn.Module):
2622
            def forward(self):
2623
                return "foo"
2624

2625
        m = M()
2626
        self.checkGraphModule(m, ())
2627

2628
    def test_namedtuple_return_qualname(self):
2629
        class NamedTupReturn(torch.nn.Module):
2630
            def forward(self, x):
2631
                return MyNamedTup(x, x)
2632

2633
        traced = symbolic_trace(NamedTupReturn())
2634
        input = torch.rand(3, 4)
2635
        self.assertEqual(traced(input), MyNamedTup(input, input))
2636

2637
    def test_update_args_kwargs_yells_at_you(self):
2638
        symtraced = symbolic_trace(SimpleTest())
2639
        node = next(iter(symtraced.graph.nodes))
2640
        with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
2641
            node.__update_args_kwargs((), {})
2642

2643
    def test_torchbind_class_attribute_in_fx(self):
2644
        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2645
            self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
2646

2647
        class FooBar1234(torch.nn.Module):
2648
            def __init__(self) -> None:
2649
                super().__init__()
2650
                self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
2651

2652
            def forward(self):
2653
                return self.f.top()
2654

2655
        m = FooBar1234()
2656
        self.checkGraphModule(m, ())
2657

2658
    def test_torchbind_class_attribute_in_fx_tensor_arg(self):
2659
        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2660
            self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
2661

2662
        class FooBar2341(torch.nn.Module):
2663
            def __init__(self) -> None:
2664
                super().__init__()
2665
                self.f = torch.classes._TorchScriptTesting._ReLUClass()
2666

2667
            def forward(self, x):
2668
                return self.f.run(x)
2669

2670
        m = FooBar2341()
2671

2672
        traced = symbolic_trace(m)
2673
        input = torch.randn(3, 4)
2674
        self.assertEqual(traced(input), m(input))
2675

2676
        self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2677

2678
    def test_script_method_trace(self):
2679
        class Scripted(torch.nn.Module):
2680
            def forward(self, x):
2681
                return torch.relu(x)
2682

2683
        class Holder(torch.nn.Module):
2684
            def __init__(self) -> None:
2685
                super().__init__()
2686
                self.s = torch.jit.script(Scripted())
2687

2688
            def forward(self, x):
2689
                return self.s(x)
2690

2691
        h = Holder()
2692
        traced = symbolic_trace(h)
2693
        input = torch.randn(3, 4)
2694
        self.assertEqual(traced(input), h(input))
2695

2696
        self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2697

2698
    def test_namedtuple_return_trace(self):
2699
        class NamedTupReturn(torch.nn.Module):
2700
            def forward(self, x):
2701
                return Pair(x, x)
2702

2703
        traced = symbolic_trace(NamedTupReturn())
2704
        input = torch.rand(3, 4)
2705
        self.assertEqual(traced(input), Pair(input, input))
2706

2707
    def test_named_tuple_inlined(self):
2708
        class NamedTupMod(torch.nn.Module):
2709
            def forward(self, inp):
2710
                return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
2711

2712
        m = NamedTupMod()
2713
        input = torch.rand(3, 4)
2714
        ref = m(input)
2715
        traced = symbolic_trace(m)
2716

2717
        res = traced(input)
2718
        self.assertEqual(ref, res)
2719

2720
        # Check Pair NamedTuple works when inlined into the function call.
2721
        ph = call_func = None
2722
        for node in traced.graph.nodes:
2723
            if node.op == "placeholder":
2724
                ph = node
2725
            elif node.op == "call_function" and node.target == wrapped_named_tup:
2726
                node.update_arg(0, Pair(ph, 1.2))
2727
                node.update_kwarg("p2", Pair(3.4, ph))
2728
                call_func = node
2729
                break
2730
        self.assertTrue(call_func is not None)
2731
        self.assertTrue(isinstance(call_func.args[0], Pair))
2732
        self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
2733
        self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
2734
        self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
2735

2736
        traced.graph.eliminate_dead_code()
2737
        traced.recompile()
2738
        res = traced(input)
2739
        self.assertEqual(ref, res)
2740

2741
    def test_return_type_exists(self):
2742
        class ReturnTypeModule(torch.nn.Module):
2743
            def other(self, x: List[str]) -> List[str]:
2744
                return x
2745

2746
            def forward(self, x: List[str]) -> List[str]:
2747
                return self.other(x)
2748

2749
        traced = symbolic_trace(ReturnTypeModule())
2750
        self.assertIn("-> typing_List[str]", traced._code)
2751
        scripted = torch.jit.script(traced)
2752
        self.assertIn("-> List[str]", scripted.code)
2753

2754
    def getitem_inner(self):
2755
        class GetItemBase(torch.nn.Module):
2756
            def __init__(self) -> None:
2757
                super().__init__()
2758
                self.pe = torch.nn.Buffer(torch.randn(8, 8))
2759

2760
        class GetItem1(GetItemBase):
2761
            def forward(self, x):
2762
                return self.pe[:, :x.size(0)]
2763

2764
        class GetItem2(GetItemBase):
2765
            def forward(self, x):
2766
                return self.pe[x.size(0)]
2767

2768
        class GetItem3(GetItemBase):
2769
            def forward(self, x):
2770
                return self.pe[4]  # fx creates `self._tensor_constant0` here
2771

2772
        self.checkGraphModule(GetItem1(), [torch.zeros(4)])
2773
        self.checkGraphModule(GetItem2(), [torch.zeros(4)])
2774
        self.checkGraphModule(GetItem3(), [torch.zeros(4)])
2775

2776
    @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
2777
                         "Will be checked in test_getitem_subproc")
2778
    def test_getitem(self):
2779
        self.getitem_inner()
2780

2781
    def test_getitem_subproc(self):
2782
        # need to run this test in a subproc to work around:
2783
        #   https://github.com/pytorch/pytorch/issues/50710
2784
        proc = Process(target=run_getitem_target)
2785
        proc.start()
2786
        proc.join()
2787
        self.assertEqual(proc.exitcode, 0)
2788

2789
    def test_user_friendly_call_provenance_with_function(self):
2790
        def fn(x):
2791
            return wrapper_fn(x)
2792

2793
        traced = torch.fx.symbolic_trace(fn)
2794

2795
        with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2796
                                    "being compiled since it was called"
2797
                                    " from 'fn.forward'"):
2798
            scripted = torch.jit.script(traced)
2799

2800
    def test_user_friendly_call_provenance_with_module(self):
2801
        class M(torch.nn.Module):
2802
            def forward(self, x):
2803
                return wrapper_fn(x)
2804

2805
        traced = torch.fx.symbolic_trace(M())
2806

2807
        with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2808
                                    "being compiled since it was called"
2809
                                    " from 'M.forward'"):
2810
            scripted = torch.jit.script(traced)
2811

2812
    def test_snake_case(self):
2813
        class M(torch.nn.Module):
2814
            def __init__(self) -> None:
2815
                super().__init__()
2816
                self.activations = torch.nn.ModuleDict([
2817
                    ["snake_case", torch.nn.ReLU()],
2818
                    ["PascalCase", torch.nn.LeakyReLU()],
2819
                    ["ALL_CAPS", torch.nn.PReLU()]
2820
                ])
2821

2822
            def forward(self, x):
2823
                a = self.activations["snake_case"](x)
2824
                b = self.activations["PascalCase"](x)
2825
                c = self.activations["ALL_CAPS"](x)
2826
                return a, b, c
2827

2828
        traced = symbolic_trace(M())
2829

2830
        check = [
2831
            ("activations_snake_case", "activations.snake_case"),
2832
            ("activations_pascal_case", "activations.PascalCase"),
2833
            ("activations_all_caps", "activations.ALL_CAPS")
2834
        ]
2835

2836
        i = 0
2837
        for node in traced.graph.nodes:
2838
            if node.op == "placeholder" or node.op == "output":
2839
                continue
2840
            name = check[i][0]
2841
            target = check[i][1]
2842
            self.assertEqual(name, node.name)
2843
            self.assertEqual(target, node.target)
2844
            i += 1
2845
        self.assertEqual(i, 3)
2846

2847
    def test_no_mutation(self):
2848
        from torch.fx.immutable_collections import immutable_list
2849
        x = immutable_list([3, 4])
2850
        with self.assertRaisesRegex(NotImplementedError, "new_args"):
2851
            x[0] = 4
2852

2853
    def test_partial_trace(self):
2854
        class Foo(torch.nn.Module):
2855
            def forward(self, x, y):
2856
                if y:
2857
                    return 2 * x
2858
                else:
2859
                    return x
2860
        mod = Foo()
2861
        mod_true = symbolic_trace(mod, concrete_args={'y': True})
2862
        mod_false = symbolic_trace(mod, concrete_args={'y': False})
2863
        self.assertEqual(mod_true(3, True), 6)
2864
        print(mod_true.code)
2865
        assert any(i.target == torch._assert for i in mod_true.graph.nodes)
2866
        with self.assertRaises(AssertionError):
2867
            mod_true(3, False)
2868
        self.assertEqual(mod_false(3, False), 3)
2869
        with self.assertRaises(AssertionError):
2870
            mod_false(3, True)
2871

2872
        def f_higher(a, f):
2873
            return f(a)
2874

2875
        nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
2876
        self.assertEqual(nf(3, lambda x: x * 2), 6)
2877

2878
    def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
2879
        class M(torch.nn.Module):
2880
            def __init__(self) -> None:
2881
                super().__init__()
2882
                self.W = torch.nn.Parameter(torch.randn(5))
2883

2884
            def forward(self, x):
2885
                return torch.dot(self.W, x)
2886

2887
        traced = torch.fx.symbolic_trace(M())
2888

2889
        out = [n for n in traced.graph.nodes if n.op == "output"][-1]
2890
        with traced.graph.inserting_before(out):
2891
            relu_out = traced.graph.call_method(method_name='relu',
2892
                                                args=(out.args[0],))
2893
        out.args = (relu_out,)
2894

2895
        traced.recompile()
2896

2897
        with self.capture_stderr() as captured:
2898
            with self.assertRaises(TypeError):
2899
                traced(5)
2900

2901
        self.assertRegex(captured[0],
2902
                         r"Call using an FX-traced Module, line .* of the "
2903
                         r"traced Module's generated forward function:")
2904

2905
    def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
2906
        class M(torch.nn.Module):
2907
            def __init__(self) -> None:
2908
                super().__init__()
2909
                self.linear = torch.nn.Linear(3, 4)
2910

2911
            def forward(self, x):
2912
                return self.linear(x)
2913

2914
        traced = torch.fx.symbolic_trace(M())
2915

2916
        # Do not change this to `capture_stderr` or another context
2917
        # manager without ensuring that the output is as expected
2918
        try:
2919
            traced(torch.rand(5, 5))
2920
        except RuntimeError:
2921
            captured = traceback.format_exc()
2922

2923
        self.assertNotRegex(captured,
2924
                            r"Call using an FX-traced Module, line .* of the "
2925
                            r"traced Module's generated forward function:")
2926

2927
    def test_graph_module_replicate_for_dp(self):
2928
        class Foo(torch.nn.Module):
2929
            def forward(self, x):
2930
                return torch.relu(x)
2931

2932
        gm = torch.fx.symbolic_trace(Foo())
2933

2934
        x = torch.randn(5, 3)
2935
        out = gm(x)
2936

2937
        replica = gm._replicate_for_data_parallel()
2938
        out_replica = replica(x)
2939

2940
        torch.testing.assert_close(out_replica, out)
2941

2942
    def test_ast_rewriter_rewrites_assert(self):
2943
        class M(torch.nn.Module):
2944
            def forward(self, x: torch.Tensor, y: int, z: int):
2945
                assert y == z
2946
                return torch.add(x, x)
2947

2948
        ast_rewriter = RewritingTracer()
2949
        graph = ast_rewriter.trace(M())
2950
        traced = GraphModule(ast_rewriter.root, graph, "gm")
2951

2952
        traced.graph.lint()
2953

2954
    def test_ast_rewriter_rewrites_assert_with_message(self):
2955
        class M(torch.nn.Module):
2956
            def forward(self, x: torch.Tensor, y: int, z: int):
2957
                assert y == z, "msg"
2958
                return torch.add(x, x)
2959

2960
        ast_rewriter = RewritingTracer()
2961
        graph = ast_rewriter.trace(M())
2962
        traced = GraphModule(ast_rewriter.root, graph, "gm")
2963

2964
        traced.graph.lint()
2965

2966
    def test_throw_out_variant(self):
2967
        def foo(x):
2968
            y = torch.rand_like(x)
2969
            torch.sigmoid(x, out=y)
2970
            return y
2971

2972
        class MyTracer(torch.fx.Tracer):
2973
            check_mutable_operations = True
2974

2975
        tracer = MyTracer()
2976
        with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
2977
            traced_graph = tracer.trace(foo)
2978

2979
    def test_ast_rewriter_reassigns_submodules(self):
2980
        class M(torch.nn.Module):
2981
            def __init__(self) -> None:
2982
                super().__init__()
2983
                self.bn = torch.nn.BatchNorm2d(100)
2984

2985
            def forward(self, x: torch.Tensor):
2986
                return torch.add(x, x)
2987

2988
        ast_rewriter = RewritingTracer()
2989
        graph = ast_rewriter.trace(M())
2990
        traced = GraphModule(ast_rewriter.root, graph, "gm")
2991

2992
        traced.graph.lint()
2993

2994
    def test_ast_rewriter_wrap(self):
2995
        self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
2996

2997
        def to_trace(y):
2998
            return (
2999
                a_lifted_leaf((4, y), 3)
3000
                + a_lifted_leaf((3, 4), 5)
3001
                + a_lifted_leaf((y, y), y)
3002
            )
3003

3004
        ast_rewriter = RewritingTracer()
3005
        graph = ast_rewriter.trace(to_trace)
3006
        traced = GraphModule(ast_rewriter.root, graph, "gm")
3007

3008
        self.assertIn("a_lifted_leaf", traced.code)
3009
        self.assertEqual(27, traced(2))
3010
        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
3011

3012
    def test_ast_rewriter_wrap_fn_directly(self):
3013
        self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
3014

3015
        def to_trace(y):
3016
            return (
3017
                a_lifted_leaf2((4, y), 3)
3018
                + a_lifted_leaf2((3, 4), 5)
3019
                + a_lifted_leaf2((y, y), y)
3020
            )
3021

3022
        ast_rewriter = RewritingTracer()
3023
        graph = ast_rewriter.trace(to_trace)
3024
        traced = GraphModule(ast_rewriter.root, graph, "gm")
3025

3026
        self.assertIn("a_lifted_leaf2", traced.code)
3027
        self.assertEqual(27, traced(2))
3028
        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
3029

3030
    def test_profiler_ranges_side_effect(self):
3031
        g = torch.fx.Graph()
3032
        handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',))
3033
        g.call_function(torch.ops.profiler._record_function_exit, (handle,))
3034
        g.output(None)
3035

3036
        found_targets = {}
3037
        for node in g.nodes:
3038
            if node.op == 'call_function':
3039
                found_targets.setdefault(node.target)
3040
        self.assertEqual(
3041
            list(found_targets.keys()),
3042
            [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3043
        )
3044

3045
        g.eliminate_dead_code()
3046
        found_targets = {}
3047
        for node in g.nodes:
3048
            if node.op == 'call_function':
3049
                found_targets.setdefault(node.target)
3050
        self.assertEqual(
3051
            list(found_targets.keys()),
3052
            [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3053
        )
3054

3055
    def test_ast_rewriter_wrapped_via_decorator(self):
3056
        class F(torch.nn.Module):
3057
            def forward(self, x):
3058
                return wrapped_via_decorator(x)
3059

3060
        ast_rewriter = RewritingTracer()
3061
        graph = ast_rewriter.trace(F())
3062
        traced = GraphModule(ast_rewriter.root, graph, "gm")
3063

3064
        self.assertIn("wrapped_via_decorator", traced.code)
3065
        self.assertEqual(traced(0), 1)
3066
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3067
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3068

3069
    def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
3070
        self.assertEqual(wrapped_via_decorator(0), 1)
3071

3072
        def to_trace(y):
3073
            return wrapped_via_decorator(y)
3074

3075
        ast_rewriter = RewritingTracer()
3076
        graph = ast_rewriter.trace(to_trace)
3077
        traced = GraphModule(ast_rewriter.root, graph, "gm")
3078

3079
        self.assertIn("wrapped_via_decorator", traced.code)
3080
        self.assertEqual(traced(0), 1)
3081
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3082
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3083

3084
        transformed = torch.fx.Transformer(traced).transform()
3085
        self.assertIn("wrapped_via_decorator", transformed.code)
3086
        self.assertEqual(transformed(0), 1)
3087
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3088
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3089

3090
    def test_ast_rewriter_wrap_with_submodule(self):
3091
        class M(torch.nn.Module):
3092
            def __init__(self) -> None:
3093
                super().__init__()
3094
                self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3095

3096
            def forward(self, x: torch.Tensor):
3097
                return wrapped_with_submodule(x, self.batchnorm1d)
3098

3099
        ast_rewriter = RewritingTracer()
3100
        graph = ast_rewriter.trace(M())
3101
        traced = GraphModule(ast_rewriter.root, graph, "gm")
3102

3103
        self.assertIn("wrapped_with_submodule", traced.code)
3104

3105
        input = torch.rand(3, 2)
3106
        ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3107
        self.assertEqual(ref_batchnorm1d(input), traced(input))
3108

3109
    def test_submodule_manipulation_API(self):
3110
        class C(torch.nn.Module):
3111
            def __init__(self) -> None:
3112
                super().__init__()
3113
                self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
3114
                self.param = torch.nn.Parameter(torch.rand(2, 3))
3115

3116
            def forward(self, x):
3117
                return self.conv(torch.cat([self.param, x]))
3118

3119
        class B(torch.nn.Module):
3120
            def __init__(self) -> None:
3121
                super().__init__()
3122
                self.linear = torch.nn.Linear(100, 200)
3123
                self.buf = torch.nn.Buffer(torch.randn(2, 3))
3124
                self.net_c = C()
3125

3126
            def forward(self, x):
3127
                return self.linear(torch.cat([self.buf, self.net_c(x)]))
3128

3129
        class A(torch.nn.Module):
3130
            def __init__(self) -> None:
3131
                super().__init__()
3132
                self.net_b = B()
3133
                self.param = torch.nn.Parameter(torch.rand(2, 3))
3134

3135
            def forward(self, x):
3136
                return self.net_b(x) + self.param
3137

3138
        a = symbolic_trace(A())
3139

3140
        a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
3141

3142
        conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
3143
        with a.graph.inserting_before(conv):
3144
            with warnings.catch_warnings(record=True) as w:
3145
                dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
3146
                                              args=conv.args)
3147
                self.assertEqual(len(w), 0)
3148

3149
        conv.replace_all_uses_with(dropout)
3150
        a.graph.erase_node(conv)
3151
        a.recompile()
3152

3153
        def module_exists(gm: GraphModule, path: str) -> bool:
3154
            return any(path == name for name, _ in gm.named_modules())
3155

3156
        def parameter_exists(gm: GraphModule, path: str) -> bool:
3157
            return (any(path == name for name, _ in gm.named_parameters())
3158
                    and any(path == name for name in gm.state_dict().keys()))
3159

3160
        def buffer_exists(gm: GraphModule, path: str) -> bool:
3161
            return (any(path == name for name, _ in gm.named_buffers())
3162
                    and any(path == name for name in gm.state_dict().keys()))
3163

3164
        # Test that we added the "dropout" submodule
3165
        self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
3166

3167
        # Test `get_submodule` with an added submodule
3168
        self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
3169

3170
        # Test that the "conv" submodule is still there
3171
        self.assertTrue(module_exists(a, "net_b.net_c.conv"))
3172

3173
        # Test `get_submodule` with an original module
3174
        self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
3175

3176
        # Test that the "conv" node is NOT still there
3177
        conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
3178
        self.assertEqual(conv, [])
3179

3180
        a.delete_submodule("net_b.net_c.conv")
3181

3182
        # Test that the "conv" submodule is now gone
3183
        self.assertFalse(module_exists(a, "net_b.net_c.conv"))
3184

3185
        # Test `get_submodule` with a deleted submodule
3186
        with self.assertRaisesRegex(AttributeError, "has no attribute "
3187
                                    "`conv`"):
3188
            self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
3189

3190
        # Test `get_attr` warnings
3191
        cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
3192

3193
        with a.graph.inserting_before(cat):
3194

3195
            with warnings.catch_warnings(record=True) as w:
3196
                param = a.graph.get_attr(qualified_name="net_b.net_c.param")
3197
                self.assertEqual(len(w), 0)
3198

3199
            with self.assertWarnsRegex(UserWarning, "Attempted to "
3200
                                       "insert a get_attr Node with no "
3201
                                       "underlying reference in the "
3202
                                       "owning GraphModule"):
3203
                bad_param = a.graph.get_attr(qualified_name="net_b.param")
3204
                a.graph.erase_node(bad_param)
3205

3206
        cat.args = (*cat.args, param)
3207

3208
        a.recompile()
3209

3210
        a.graph.lint()
3211

3212
        # Test `get_parameter`
3213
        a.get_parameter("net_b.net_c.param")
3214
        with self.assertRaisesRegex(AttributeError, "is not an "
3215
                                    "nn.Parameter"):
3216
            a.get_parameter("net_b.buf")
3217
        with self.assertRaisesRegex(AttributeError, "has no attribute "
3218
                                    "`param`"):
3219
            a.get_parameter("net_b.param")
3220

3221
        # Test `get_buffer`
3222
        a.get_buffer("net_b.buf")
3223
        with self.assertRaisesRegex(AttributeError, "is not a "
3224
                                    "buffer"):
3225
            a.get_buffer("net_b.net_c.param")
3226
        with self.assertRaisesRegex(AttributeError, "has no attribute "
3227
                                    "`buf`"):
3228
            a.get_buffer("net_b.net_c.buf")
3229

3230
        # Test non-nested attributes
3231
        a.get_submodule("")
3232
        a.get_parameter("param")
3233

3234
        # Insert some unused submodules
3235
        a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
3236
        a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
3237
        a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
3238
        a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
3239

3240
        # Garbage collection
3241
        a.delete_all_unused_submodules()
3242

3243
        # Test that all the unused submodules are gone
3244
        self.assertFalse(module_exists(a, "net_b.embedding"))
3245
        self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
3246
        self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
3247
        self.assertFalse(module_exists(a, "batch_norm_2d"))
3248

3249
        # Test that we didn't delete any unused Parameters or buffers
3250
        self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
3251
        self.assertTrue(buffer_exists(a, "net_b.buf"))
3252

3253
        a.graph.lint()
3254

3255
    def test_delete_unused_submodules_leaf(self):
3256
        class SubModule(torch.nn.Module):
3257
            def __init__(self) -> None:
3258
                super().__init__()
3259
                self.linear = torch.nn.Linear(10, 10)
3260
                self.relu = torch.nn.ReLU()
3261

3262
            def forward(self, x):
3263
                x = self.linear(x)
3264
                x = self.relu(x)
3265
                return x
3266

3267
        class Model(torch.nn.Module):
3268
            def __init__(self) -> None:
3269
                super().__init__()
3270
                self.submod = SubModule()
3271

3272
            def forward(self, x):
3273
                x = self.submod(x)
3274
                return x
3275

3276
        model = Model()
3277

3278
        class MyCustomTracer(torch.fx.Tracer):
3279
            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
3280
                return module_qualified_name == "submod"
3281

3282
        inputs = torch.randn(1, 10)
3283
        traced_graph = MyCustomTracer().trace(model)
3284
        gm2 = torch.fx.GraphModule(model, traced_graph)
3285
        gm2.delete_all_unused_submodules()
3286
        torch.testing.assert_close(gm2(inputs), model(inputs))
3287

3288
    def test_fx_stateless(self):
3289
        class MockModule(torch.nn.Module):
3290
            def __init__(self) -> None:
3291
                super().__init__()
3292
                self.l1 = torch.nn.Linear(1, 1)
3293
                self.buffer = torch.nn.Buffer(torch.ones(1))
3294

3295
            def forward(self, x):
3296
                return self.l1(x) + self.buffer
3297

3298
        module = MockModule()
3299
        x = torch.rand((1, 1))
3300
        weight = torch.tensor([[1.0]], requires_grad=True)
3301
        bias = torch.tensor([0.0], requires_grad=True)
3302
        buffer = torch.tensor([0.0])
3303
        parameters = {'l1.weight': weight,
3304
                      'l1.bias': bias,
3305
                      'buffer': buffer}
3306
        fx_module = torch.fx.symbolic_trace(module)
3307
        res = torch.func.functional_call(fx_module, parameters, x)
3308
        res.backward()
3309
        self.assertIsNotNone(weight.grad)
3310
        self.assertIsNotNone(bias.grad)
3311
        self.assertIsNone(buffer.grad)
3312
        # Gradient was not calculated for the module stated and buffers
3313
        self.assertIsNone(module.l1.weight.grad)
3314
        self.assertIsNone(module.l1.bias.grad)
3315
        self.assertIsNone(module.buffer.grad)
3316

3317
    def test_tracing_graphmodules_as_leaf_submodules(self):
3318
        class A(torch.nn.Module):
3319
            def forward(self, t):
3320
                return t + t
3321

3322
        class B(torch.nn.Module):
3323
            def __init__(self) -> None:
3324
                super(type(self), self).__init__()
3325
                self.calling = False
3326
                self.called = False
3327

3328
            def forward(self, t):
3329
                if self.calling:
3330
                    return t - t
3331
                else:
3332
                    return t + t
3333

3334
            def __call__(self, *args):
3335
                self.called = True
3336
                self.calling = True
3337
                return super(type(self), self).__call__(*args)
3338
                self.calling = False
3339

3340
        class M(torch.nn.Module):
3341
            def __init__(self, a, b):
3342
                super().__init__()
3343
                self.a = a
3344
                self.b = b
3345

3346
            def forward(self, t):
3347
                x = self.a(t)
3348
                y = self.b(t)
3349
                return x + y
3350

3351
        class LeafTracer(Tracer):
3352
            def is_leaf_module(self, module, name):
3353
                return True
3354

3355
        class LeafTracerNotB(Tracer):
3356
            def is_leaf_module(self, module, name):
3357
                return False if "b" in name else True
3358

3359
        # Recompile calls added "for fun", since they
3360
        # chain __call__ wrappers.
3361

3362
        #
3363
        # Test: B as a regular, non-leaf module
3364
        #
3365
        a = symbolic_trace(A())
3366
        a.recompile()
3367
        m = M(a, B())
3368
        graph = LeafTracerNotB().trace(m)
3369
        gm = GraphModule(m, graph)
3370
        gm.recompile()
3371

3372
        # Test graphmodule/submodule a is not inlined.
3373
        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3374
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3375
        self.assertTrue(len(match) == 1)
3376

3377
        # Test submodule b is not treated as leaf.
3378
        self.assertFalse(hasattr(gm, "b"))
3379

3380
        # Test assert custom __call__ on submodule b was honored.
3381
        match = [
3382
            n
3383
            for n in gm.graph.nodes
3384
            if n.op == "call_function" and n.target == operator.sub
3385
        ]
3386
        self.assertTrue(len(match) == 1)
3387

3388
        #
3389
        # Test: B as a regular, leaf module
3390
        # symbolic_trace should only patch torch.nn.Module.__call__,
3391
        # which means B.__call__ should still execute
3392
        #
3393
        a = symbolic_trace(A())
3394
        a.recompile()
3395
        b = B()
3396
        m = M(a, b)
3397
        graph = LeafTracer().trace(m)
3398
        gm = GraphModule(m, graph)
3399
        gm.recompile()
3400

3401
        # Test graphmodule/submodule a is not inlined.
3402
        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3403
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3404
        self.assertTrue(len(match) == 1)
3405

3406
        # Test submodule b is leaf:
3407
        self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3408
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3409
        self.assertTrue(len(match) == 1)
3410

3411
        # Test b.__call__ was run
3412
        self.assertTrue(b.called)
3413
        self.assertTrue(gm.get_submodule("b").called)
3414

3415
        #
3416
        # Test: B as GraphModule leaf
3417
        # __call__ not honored since symbolic_trace directly invokes forward()
3418
        #
3419
        a = symbolic_trace(A())
3420
        a.recompile()
3421
        b = symbolic_trace(B())
3422
        b.recompile()
3423
        m = M(a, b)
3424
        graph = LeafTracer().trace(m)
3425
        gm = GraphModule(m, graph)
3426
        gm.recompile()
3427

3428
        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3429
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3430
        self.assertTrue(len(match) == 1)
3431

3432
        self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3433
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3434
        self.assertTrue(len(match) == 1)
3435

3436
    def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
3437
        class MyModule(torch.nn.Module):
3438
            def __init__(self) -> None:
3439
                super().__init__()
3440
                self.my_buff = torch.nn.Buffer(torch.rand(3, 4))
3441
                self.register_parameter(
3442
                    "my_param", torch.nn.Parameter(torch.rand(3, 4))
3443
                )
3444

3445
            def forward(self, x):
3446
                return x + self.my_buff + self.my_param
3447

3448
        mod = MyModule()
3449
        mod_traced = symbolic_trace(mod)
3450

3451
        # Create new GraphModule based on original, either w/ dict or root module.
3452
        orig_buff = mod_traced.get_buffer("my_buff")
3453
        orig_param = mod_traced.get_parameter("my_param")
3454
        mod_traced_new = GraphModule(
3455
            {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
3456
            mod_traced.graph,
3457
        )
3458

3459
        # Check that both my_buff and my_param are found and the same.
3460
        try:
3461
            new_buff = mod_traced_new.get_buffer("my_buff")
3462
        except Exception:
3463
            self.fail("Did not find my_buff")
3464
        self.assertEqual(orig_buff, new_buff)
3465

3466
        try:
3467
            new_param = mod_traced_new.get_parameter("my_param")
3468
        except Exception:
3469
            self.fail("Did not find my_param")
3470
        self.assertEqual(orig_param, new_param)
3471

3472
        x = torch.rand(3, 4)
3473
        orig_out = mod_traced(x)
3474
        submodules_out = mod_traced_new(x)
3475

3476
        self.assertEqual(orig_out, submodules_out)
3477

3478
    def test_graph_module_init_buffer_param_copied_dict_init(self):
3479
        self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
3480

3481
    def test_graph_module_init_buffer_param_copied_mod_init(self):
3482
        self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
3483

3484
    def test_annotations_with_no_forward_references(self):
3485
        class A:
3486
            def __call__(self, x: torch.Tensor):
3487
                return torch.add(x, x)
3488

3489
        class M(torch.nn.Module):
3490
            def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
3491
                return a(x)
3492

3493
        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3494

3495
    def test_annotations_with_forward_references(self):
3496
        class A:
3497
            def __call__(self, x: torch.Tensor):
3498
                return torch.add(x, x)
3499

3500
        class M(torch.nn.Module):
3501
            def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
3502
                return a(x)
3503

3504
        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3505

3506
    def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
3507
        class A:
3508
            def __call__(self, x: torch.Tensor):
3509
                return torch.add(x, x)
3510

3511
        class M(torch.nn.Module):
3512
            def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
3513
                return a(x[0])
3514

3515
        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3516

3517
    def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
3518
        class A:
3519
            def __call__(self, x: torch.Tensor):
3520
                return torch.add(x, x)
3521

3522
        class M(torch.nn.Module):
3523
            def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
3524
                return a(x)[0]
3525

3526
        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3527

3528
    @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
3529
                     "`annotations` is not defined in Python <3.7")
3530
    def test_annotation_with_future(self):
3531
        try:
3532
            import fx.test_future    # noqa: F401
3533
        finally:
3534
            del sys.modules["__future__"]
3535

3536
    @unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
3537
    def test_annotations_empty_tuple(self):
3538
        class Foo(torch.nn.Module):
3539
            def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
3540
                return "foo"
3541

3542
        traced = torch.fx.symbolic_trace(Foo())
3543

3544
        x = ()
3545
        y = ("bar", ())
3546

3547
        traced(x, y)
3548

3549
        FileCheck().check("_Tuple[()]")   \
3550
                   .check("typing_Tuple[str,typing_Tuple[()]]") \
3551
                   .run(traced.code)
3552

3553
        scripted = torch.jit.script(traced)
3554

3555
        scripted(x, y)
3556

3557
        FileCheck().check("Tuple[()]")   \
3558
            .check("Tuple[str, Tuple[()]]")    \
3559
            .run(scripted.code)
3560

3561
    @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108")
3562
    @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10")
3563
    def test_assert(self):
3564
        def f(x):
3565
            assert x > 1
3566
            return x + 1
3567
        try:
3568
            torch.fx.proxy.TracerBase.trace_asserts = True
3569
            traced = symbolic_trace(f)
3570
        finally:
3571
            torch.fx.proxy.TracerBase.trace_asserts = False
3572

3573
        self.assertEqual(f(2), traced(2))
3574
        with self.assertRaises(AssertionError):
3575
            traced(0)
3576

3577
    def test_pytree(self):
3578
        # Used to test that you can use your own placeholder class
3579
        class PHTest(PHBase):
3580
            pass
3581

3582
        def f_sum(x):
3583
            return sum(x)
3584

3585
        def f_sum_dict(x):
3586
            out = 0
3587
            for v in x.values():
3588
                out += v
3589
            return out
3590

3591
        def f_dict_list_map(x):
3592
            new_dict = {}
3593
            for k, v in x.items():
3594
                new_dict[k] = [i + 1 for i in v]
3595
            return new_dict
3596

3597
        def f_dict_add(x):
3598
            return x['a'] + sum(x['z'])
3599

3600
        def f_namedtuple_add(x):
3601
            return x.x + x.y
3602

3603
        pytree.register_pytree_node(
3604
            Foo,
3605
            lambda x: ([x.a, x.b], None),
3606
            lambda x, _: Foo(x[0], x[1]),
3607
        )
3608
        fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
3609

3610
        def f_custom(x):
3611
            return x.a + x.b
3612

3613
        def f_custom_dict(x):
3614
            return f_sum_dict(x.a) + x.b
3615

3616
        def f_return_custom(x):
3617
            return Foo(x.b, x.a)
3618

3619
        tests = [
3620
            (f_sum, [PH, PH, PH]),
3621
            (f_sum, []),
3622
            (f_sum, [PHTest(), PHTest(), PHTest()]),
3623
            (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
3624
            (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
3625
            (f_dict_list_map, {5: (PH, PH, PH)}),
3626
            (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
3627
            (f_dict_add, {'a': PH, 'z': []}),
3628
            (f_custom, Foo(PH, PH)),
3629
            (f_custom, Foo(PH, 3)),
3630
            (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
3631
            # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
3632
            (f_namedtuple_add, Point(PH, PH)),
3633
        ]
3634

3635
        def verify_pytree(f, inp):
3636
            val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
3637
            num_flat_args = len(pytree.tree_leaves(inp))
3638
            orig_out = f(val)
3639
            nf = symbolic_trace(f, concrete_args={'x': inp})
3640
            self.assertEqual(nf(val), orig_out)
3641

3642
            bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3643
            bare_fx.graph.set_codegen(CodeGen())
3644
            bare_fx.recompile()
3645
            self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
3646

3647
            assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3648
            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
3649

3650
            nf = symbolic_trace(nf)
3651
            self.assertEqual(nf(val), orig_out)
3652
            assert "tree_flatten_spec" not in nf.code
3653
            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1
3654

3655
            nf = symbolic_trace(nf, concrete_args={'x': inp})
3656
            self.assertEqual(nf(val), orig_out)
3657
            assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3658
            assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
3659

3660
            pickled = pickle.dumps(nf)
3661
            nf = pickle.loads(pickled)
3662
            self.assertEqual(nf(val), orig_out)
3663

3664
        for f, inp in tests:
3665
            verify_pytree(f, inp)
3666

3667
    def test_pytree_concrete(self):
3668
        def f(b, a):
3669
            if b:
3670
                return a['a']
3671
            else:
3672
                return a['z']
3673

3674
        inp = {'a': {'a': PH, 'z': PH}, 'b': True}
3675
        nf = symbolic_trace(f, concrete_args=inp)
3676
        val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
3677
        self.assertEqual(nf(**val), f(**val))
3678

3679
        nf = symbolic_trace(nf)
3680
        self.assertEqual(nf(**val), f(**val))
3681

3682
    def test_metadata_on_ph(self):
3683
        def f_sum(a: int, b: int) -> int:
3684
            return a + b
3685

3686
        # Due to unflattening of dict, the batch argument
3687
        # will be split into two separate nodes with the names
3688
        # "batch_1" and "batch_2", referring to the keys
3689
        # "f1" and "f2" respectively in the dict.
3690
        def f_dict(a: Dict[str, str]) -> bool:
3691
            return a["f1"] == a["f2"]
3692

3693
        def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]):
3694
            for node in gm.graph.nodes:
3695
                if node.op == "placeholder":
3696
                    self.assertTrue(node.name in arg_names)
3697
                    self.assertTrue(node.ph_key in metadata)
3698

3699
        verify_metadata(
3700
            gm=symbolic_trace(
3701
                f_sum,
3702
                concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")}
3703
            ),
3704
            arg_names=["a_1", "b_1"],
3705
            metadata=["a", "b"]
3706
        )
3707
        verify_metadata(
3708
            gm=symbolic_trace(
3709
                f_dict,
3710
                concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}}
3711
            ),
3712
            arg_names=["a_1", "a_2"],
3713
            metadata=["f1", "f2"]
3714
        )
3715

3716
        # Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag)
3717
        class TaggingTracer(Tracer):
3718
            def create_node(self, kind : str, target : Union[str, Callable],
3719
                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
3720
                            type_expr : Optional[Any] = None) -> Node:
3721
                n = super().create_node(kind, target, args, kwargs, name)
3722
                n.tag = "foo"
3723
                return n
3724

3725
        class PHWithTag(PHBase):
3726
            def __init__(self, tag: str):
3727
                super().__init__()
3728

3729
                self.tag = tag
3730

3731
        g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")})
3732
        for n in g.nodes:
3733
            self.assertTrue(hasattr(n, "tag"))
3734
            # Ensure that tag is still "foo" and not "bar" (from PHWithTag)
3735
            self.assertEqual(n.tag, "foo")
3736

3737
    def test_custom_codegen(self):
3738
        class ListCodeGen(CodeGen):
3739
            def gen_fn_def(self, free_vars, maybe_return_annotation):
3740
                lst_unpack = f"""
3741
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3742
    {', '.join(free_vars)} = args_list"""
3743
                return lst_unpack
3744

3745
            def additional_globals(self):
3746
                return [('List', typing.List)]
3747

3748
            def process_inputs(self, *inputs):
3749
                assert len(inputs) == 1
3750
                return inputs[0]
3751

3752
        def f(a, b):
3753
            return a + b
3754

3755
        nf = symbolic_trace(f)
3756
        vals = [torch.randn(3), torch.randn(3)]
3757
        self.assertEqual(nf(*vals), f(*vals))
3758

3759
        nf.graph.set_codegen(ListCodeGen())
3760
        nf.recompile()
3761

3762
        bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3763
        bare_fx.graph.set_codegen(CodeGen())
3764
        bare_fx.recompile()
3765

3766
        self.assertEqual(nf(vals), f(*vals))
3767
        self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals))
3768

3769
        ts_f = torch.jit.script(nf)
3770
        self.assertEqual(nf(vals), ts_f(vals))
3771

3772
    def test_custom_codegen_with_transformer(self):
3773
        class ListCodeGen(CodeGen):
3774
            def gen_fn_def(self, free_vars, maybe_return_annotation):
3775
                lst_unpack = f"""
3776
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3777
    {', '.join(free_vars)} = args_list"""
3778
                return lst_unpack
3779

3780
            def additional_globals(self):
3781
                return [('List', typing.List)]
3782

3783
            def process_inputs(self, *inputs):
3784
                assert len(inputs) == 1
3785
                return inputs[0]
3786

3787
        def f(a, b):
3788
            return a + b
3789

3790
        nf = symbolic_trace(f)
3791
        vals = [torch.randn(3), torch.randn(3)]
3792
        self.assertEqual(nf(*vals), f(*vals))
3793

3794
        nf.graph.set_codegen(ListCodeGen())
3795
        nf.recompile()
3796
        self.assertEqual(nf(vals), f(*vals))
3797

3798
        transformed_gm = Transformer(nf).transform()
3799
        self.assertEqual(nf(vals), transformed_gm(vals))
3800

3801
    def test_interpreter_with_codegen(self):
3802
        class ListCodeGen(CodeGen):
3803
            def gen_fn_def(self, free_vars, maybe_return_annotation):
3804
                lst_unpack = f"""
3805
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3806
    {', '.join(free_vars)} = args_list"""
3807
                return lst_unpack
3808

3809
            def additional_globals(self):
3810
                return [('List', typing.List)]
3811

3812
            def process_inputs(self, *inputs):
3813
                assert len(inputs) == 1
3814
                return inputs[0]
3815

3816
            def generate_output(self, output_args):
3817
                return f'return list({repr(output_args)})'
3818

3819
            def process_outputs(self, outputs):
3820
                return list(outputs)
3821

3822
        def f(a, b):
3823
            a = a + b
3824
            b = a + b
3825
            return a, b
3826

3827
        nf = symbolic_trace(f)
3828
        vals = [torch.randn(3), torch.randn(3)]
3829
        nf.graph.set_codegen(ListCodeGen())
3830
        nf.recompile()
3831
        self.assertEqual(Interpreter(nf).run(vals), nf(vals))
3832

3833
    def test_imul_code_print(self):
3834
        graph = torch.fx.Graph()
3835
        a = graph.placeholder("a")
3836
        b = graph.placeholder("b")
3837
        graph.call_function(operator.imul, (a, b), {})
3838
        graph.output(a)
3839
        gm = torch.fx.GraphModule({}, graph)
3840
        gm.recompile()
3841
        self.assertEqual(gm(2, 3), 6)
3842
        self.assertIn("a *= b", gm.code)
3843

3844
    def test_deepcopy_tracer(self):
3845
        def fn(x, y):
3846
            return (x + y).relu().sin()
3847

3848
        tracer = Tracer()
3849
        tracer_before = copy.deepcopy(tracer)
3850
        tracer.trace(fn)
3851
        tracer_after = copy.deepcopy(tracer)
3852

3853
        self.assertEqual(str(tracer.graph), str(tracer_after.graph))
3854
        self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph))
3855

3856
    def test_deepcopy_graphmodule(self):
3857
        m = symbolic_trace(SimpleTest())
3858
        m.meta['hello'] = 'world'
3859
        copy_m = copy.deepcopy(m)
3860
        self.assertEqual(copy_m.meta['hello'], 'world')
3861

3862
    def test_deepcopy_no_recursion(self):
3863
        m = symbolic_trace(SimpleTest())
3864
        m.meta['hello'] = m  # circular reference
3865
        copy_m = copy.deepcopy(m)  # finishes
3866
        self.assertEqual(id(copy_m), id(copy_m.meta['hello']))
3867

3868
    def test_enum(self):
3869
        from enum import Enum
3870

3871
        class Foo(Enum):
3872
            A = 1
3873
            B = 2
3874

3875
        def leaf_fn(arr, enum_val):
3876
            # Use the raw enum.
3877
            arr.append(enum_val)
3878
            return arr[-1].value
3879

3880
        def foo(x):
3881
            # Pass the enum as argument.
3882
            return leaf_fn(x, Foo.A)
3883

3884
        traced = torch.fx.symbolic_trace(foo)
3885
        self.assertEqual(foo([]), traced([]))
3886

3887
    def test_insert_arg(self):
3888
        m = symbolic_trace(SimpleTest())
3889
        m.buf = torch.nn.Buffer(torch.tensor(0))
3890
        output_node = next(iter(reversed(m.graph.nodes)))
3891
        with m.graph.inserting_before(output_node):
3892
            a = m.graph.get_attr("buf")
3893
        r = len(output_node.args)
3894
        output_node.insert_arg(0, a)
3895
        self.assertEqual(len(output_node.args), r + 1)
3896
        self.assertEqual(len(a.users), 1)
3897
        self.assertIs(output_node.args[0], a)
3898
        self.assertIs(next(iter(a.users.keys())), output_node)
3899
        output_node.insert_arg(2, a)
3900
        self.assertEqual(len(output_node.args), r + 2)
3901
        self.assertEqual(len(a.users), 1)
3902
        self.assertIs(output_node.args[2], a)
3903
        self.assertIs(next(iter(a.users.keys())), output_node)
3904
        m.graph.lint()
3905

3906
    def test_delete_unused_values(self):
3907
        from torch.fx.experimental.proxy_tensor import make_fx
3908

3909
        # disable mutable checking temporarily
3910
        orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3911
        torch.fx.proxy.TracerBase.check_mutable_operations = False
3912

3913
        def fn(a, b, c, d):
3914
            x = a + b
3915
            y = c + d
3916
            y.copy_(x)
3917
            x = torch.relu(x)
3918
            return x
3919

3920
        a, b, c, d = (torch.randn(2, 4, requires_grad=False) for _ in range(4))
3921
        fx_fn = make_fx(fn)(a, b, c, d)
3922
        print(fx_fn)
3923

3924
        fx_fn.graph.eliminate_dead_code()
3925
        py_code = fx_fn.recompile()
3926
        self.assertTrue("copy_ = torch.ops.aten.copy_.default" in py_code.src)
3927
        self.assertTrue("copy_ = None" in py_code.src)
3928

3929
        # recorver mutable checking flag
3930
        torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
3931

3932
def run_getitem_target():
3933
    from torch.fx._symbolic_trace import _wrapped_methods_to_patch
3934
    _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
3935
    try:
3936
        TestFX().getitem_inner()
3937
    finally:
3938
        _wrapped_methods_to_patch.pop()
3939

3940

3941
class TestOperatorSignatures(JitTestCase):
3942
    def setUp(self):
3943
        # Checking for mutable operations whil tracing is feature flagged
3944
        # Enable it in testing but not by default
3945
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3946
        torch.fx.proxy.TracerBase.check_mutable_operations = True
3947

3948
    def tearDown(self):
3949
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3950

3951
    @onlyCPU
3952
    @ops(op_db, allowed_dtypes=(torch.float,))
3953
    def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
3954
        if not isinstance(op.op, types.BuiltinFunctionType):
3955
            raise unittest.SkipTest("This path doesn't work on Python functions")
3956
        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
3957
        schemas = get_signature_for_torch_op(op.op)
3958
        if not schemas:
3959
            raise RuntimeError('No Schemas Returned')
3960
        for sample_input in sample_inputs_itr:
3961
            # Iterate through overloads until we hit a match. If we exit this
3962
            # loop via `else`, we haven't found a match
3963
            for schema in schemas:
3964
                try:
3965
                    bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
3966
                    bound_args.apply_defaults()
3967
                    op(*bound_args.args, **bound_args.kwargs)
3968
                    break
3969
                except TypeError as e:
3970
                    pass
3971
            else:
3972
                raise RuntimeError(f'Did not match any schemas for op {op.name}!')
3973

3974

3975
class TestFXAPIBackwardCompatibility(JitTestCase):
3976
    def setUp(self):
3977
        super().setUp()
3978
        self.maxDiff = None
3979

3980
        # Checking for mutable operations whil tracing is feature flagged
3981
        # Enable it in testing but not by default
3982
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3983
        torch.fx.proxy.TracerBase.check_mutable_operations = True
3984

3985
    def tearDown(self):
3986
        super().tearDown()
3987
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3988

3989

3990
    def _fn_to_stable_annotation_str(self, obj):
3991
        """
3992
        Unfortunately we have to serialize function signatures manually since
3993
        serialization for `inspect.Signature` objects is not stable across
3994
        python versions
3995
        """
3996
        fn_name = torch.typename(obj)
3997

3998
        signature = inspect.signature(obj)
3999

4000
        sig_str = f'{fn_name}{signature}'
4001

4002
        arg_strs = []
4003
        for k, v in signature.parameters.items():
4004
            maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
4005
                if v.annotation is not inspect.Signature.empty else ''
4006

4007
            def default_val_str(val):
4008
                if isinstance(val, (tuple, list)):
4009
                    str_pieces = ['(' if isinstance(val, tuple) else '[']
4010
                    str_pieces.append(', '.join(default_val_str(v) for v in val))
4011
                    if isinstance(val, tuple) and len(str_pieces) == 2:
4012
                        str_pieces.append(',')
4013
                    str_pieces.append(')' if isinstance(val, tuple) else ']')
4014
                    return ''.join(str_pieces)
4015

4016
                # Need to fix up some default value strings.
4017
                # First case: modules. Default module `repr` contains the FS path of the module.
4018
                # Don't leak that
4019
                if isinstance(val, types.ModuleType):
4020
                    return f'<module {val.__name__}>'
4021

4022
                # Second case: callables. Callables (such as lambdas) encode their address in
4023
                # their string repr. Don't do that
4024
                if callable(val):
4025
                    return f'<function {val.__name__}>'
4026

4027
                return str(val)
4028

4029
            if v.default is not inspect.Signature.empty:
4030
                default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
4031
                maybe_default = f' = {default_val_str}'
4032
            else:
4033
                maybe_default = ''
4034
            maybe_stars = ''
4035
            if v.kind == inspect.Parameter.VAR_POSITIONAL:
4036
                maybe_stars = '*'
4037
            elif v.kind == inspect.Parameter.VAR_KEYWORD:
4038
                maybe_stars = '**'
4039
            arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
4040

4041
        return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
4042
            if signature.return_annotation is not inspect.Signature.empty else ''
4043

4044
        return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
4045

4046
    def _annotation_type_to_stable_str(self, t, sig_str):
4047
        if t is inspect.Signature.empty:
4048
            return ''
4049

4050
        # Forward ref
4051
        if isinstance(t, str):
4052
            return f"'{t}'"
4053
        if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
4054
            return t.__forward_arg__
4055
        if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
4056
            return t.__forward_arg__
4057

4058
        trivial_mappings = {
4059
            str : 'str',
4060
            int : 'int',
4061
            float: 'float',
4062
            bool: 'bool',
4063
            torch.dtype: 'torch.dtype',
4064
            torch.Tensor: 'torch.Tensor',
4065
            torch.device: 'torch.device',
4066
            torch.memory_format: 'torch.memory_format',
4067
            slice: 'slice',
4068
            torch.nn.Module: 'torch.nn.modules.module.Module',
4069
            torch.fx.Graph : 'torch.fx.graph.Graph',
4070
            torch.fx.Node : 'torch.fx.node.Node',
4071
            torch.fx.Proxy : 'torch.fx.proxy.Proxy',
4072
            torch.fx.node.Target : 'torch.fx.node.Target',
4073
            torch.fx.node.Argument : 'torch.fx.node.Argument',
4074
            torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
4075
            torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
4076
            torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
4077
            Ellipsis : '...',
4078
            typing.Any: 'Any',
4079
            type(None): 'NoneType',
4080
            None: 'None',
4081
            typing.Iterator: 'Iterator',
4082
        }
4083

4084
        mapping = trivial_mappings.get(t, None)
4085
        if mapping:
4086
            return mapping
4087

4088
        # Handle types with contained types
4089
        contained = getattr(t, '__args__', None) or []
4090

4091
        # Callables contain a bare List for arguments
4092
        contained = t if isinstance(t, list) else contained
4093

4094
        # Python 3.8 puts type vars into __args__ for unbound types such as Dict
4095
        if all(isinstance(ct, typing.TypeVar) for ct in contained):
4096
            contained = []
4097

4098
        contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
4099
        contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
4100

4101

4102
        origin = getattr(t, '__origin__', None)
4103
        if origin is None:
4104
            # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
4105
            origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
4106

4107
        if origin in {tuple, typing.Tuple}:
4108
            return f'Tuple{contained_type_str}'
4109
        if origin in {typing.Union}:
4110
            # Annoying hack to detect Optional
4111
            if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
4112
                not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
4113
                return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
4114
            return f'Union{contained_type_str}'
4115
        if origin in {dict, typing.Dict}:
4116
            return f'Dict{contained_type_str}'
4117
        if origin in {list, typing.List}:
4118
            return f'List{contained_type_str}'
4119
        if origin in {type, typing.Type}:
4120
            return f'Type{contained_type_str}'
4121
        if isinstance(t, typing.Callable):
4122
            if len(contained) > 0 and contained[0] is not Ellipsis:
4123
                return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
4124
            else:
4125
                return f'Callable{contained_type_str}'
4126

4127
        raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
4128
                           f'Please add support for this type and confirm with the '
4129
                           f'FX team that your signature change is valid.')
4130

4131

4132
    def test_function_back_compat(self):
4133
        """
4134
        Test backward compatibility for function signatures with
4135
        @compatibility(is_backward_compatible=True). Currently this checks for
4136
        exact signature matches, which may lead to false positives. If this
4137
        becomes too annoying, we can refine this check to actually parse out
4138
        the saved schema strings and check if the change is truly backward-
4139
        incompatible.
4140
        """
4141
        signature_strs = []
4142

4143
        for obj in _BACK_COMPAT_OBJECTS:
4144
            if not isinstance(obj, type):
4145
                signature_strs.append(self._fn_to_stable_annotation_str(obj))
4146

4147
        signature_strs.sort()
4148

4149
        try:
4150
            self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures')
4151
        except AssertionError as e:
4152
            msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
4153
                  f"as backwards-compatible has experienced a signature change. See the " \
4154
                  f"above exception context for more information. If this change was " \
4155
                  f"unintended, please revert it. If it was intended, check with the FX " \
4156
                  f"team to ensure that the proper deprecation protocols have been followed " \
4157
                  f"and subsequently --accept the change."
4158
            raise AssertionError(msg)  # noqa: B904
4159

4160
    def test_class_member_back_compat(self):
4161
        """
4162
        Test backward compatibility for members of classes with
4163
        @compatibility(is_backward_compatible=True). Currently this checks for
4164
        exact matches on the publicly visible members of the class.
4165
        """
4166
        class_method_strs = []
4167

4168
        for obj in _BACK_COMPAT_OBJECTS:
4169
            if isinstance(obj, type):
4170
                public_members = [name for name in obj.__dict__ if not name.startswith('_')]
4171
                class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
4172

4173
        class_method_strs.sort()
4174

4175
        try:
4176
            self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
4177
        except AssertionError as e:
4178
            msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
4179
                  f"as backwards-compatible has experienced change in its public members. See the " \
4180
                  f"above exception context for more information. If this change was " \
4181
                  f"unintended, please revert it. If it was intended, check with the FX " \
4182
                  f"team to ensure that the proper deprecation protocols have been followed " \
4183
                  f"and subsequently --accept the change."
4184
            raise AssertionError(msg) from e
4185

4186
    def test_public_api_surface(self):
4187
        non_back_compat_objects = {}
4188

4189
        def check_symbols_have_bc_designation(m, seen):
4190
            if not m.__name__.startswith('torch.fx'):
4191
                return
4192
            if m.__name__.startswith('torch.fx.experimental'):
4193
                return
4194
            # It's really common for inner functions to point to random modules
4195
            # - make sure we don't recurse into modules we've already checked.
4196
            seen.add(m.__name__)
4197
            for k, v in m.__dict__.items():
4198
                if hasattr(v, '__name__') and v.__name__ in seen:
4199
                    continue
4200
                if v is m:
4201
                    continue
4202
                if k.startswith('_'):
4203
                    continue
4204
                if isinstance(v, types.ModuleType):
4205
                    check_symbols_have_bc_designation(v, seen)
4206
                elif isinstance(v, (type, types.FunctionType)):
4207
                    if v not in _MARKED_WITH_COMPATIBILITY:
4208
                        non_back_compat_objects.setdefault(v)
4209

4210
        check_symbols_have_bc_designation(torch.fx, set())
4211
        check_symbols_have_bc_designation(torch.fx.passes, set())
4212

4213
        non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
4214
        # Only want objects in torch.fx
4215
        non_back_compat_strs = [
4216
            s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
4217
        # Only want objects in public namespaces
4218
        non_back_compat_strs = [
4219
            s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
4220
        non_back_compat_strs.sort()
4221

4222
        if len(non_back_compat_strs) != 0:
4223
            raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
4224
                                 f"backwards-compatibility classification! Please decorate these "
4225
                                 f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
4226
                                 f"BC guarantees.")
4227

4228
    def test_adding_side_effect_function(self):
4229
        class TestModule(torch.nn.Module):
4230
            def forward(self, x):
4231
                side_effect_func(x)
4232
                return x
4233

4234
        gm = torch.fx.symbolic_trace(TestModule())
4235
        self.assertEqual(len(gm.graph.nodes), 3)
4236
        gm.graph.eliminate_dead_code()
4237
        gm.recompile()
4238
        self.assertEqual(len(gm.graph.nodes), 3)
4239
        found = False
4240
        for node in gm.graph.nodes:
4241
            if node.op == 'call_function' and node.target == side_effect_func:
4242
                found = True
4243
        self.assertTrue(found)
4244

4245
    def test_preserve_unused_attr_after_unpickle(self):
4246
        gm = torch.fx.symbolic_trace(Add())
4247
        gm.add_submodule("foo", Add())
4248
        gm.dummy_buffer = torch.nn.Buffer(torch.empty(1))
4249
        gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1)))
4250
        b = io.BytesIO()
4251
        torch.save(gm, b)
4252
        b.seek(0)
4253
        # weights_only=False as this loads a GraphModule
4254
        # GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
4255
        reload_gm = torch.load(b, weights_only=False)
4256
        self.assertTrue(hasattr(reload_gm, "foo"))
4257
        self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
4258
        self.assertTrue(hasattr(reload_gm, "dummy_parameter"))
4259

4260
# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454
4261
@unittest.skipIf(
4262
    sys.version_info >= (3, 12), "Failing on python 3.12+"
4263
)
4264
class TestFunctionalTracing(JitTestCase):
4265
    def setUp(self):
4266
        super().setUp()
4267
        # Checking for mutable operations whil tracing is feature flagged
4268
        # Enable it in testing but not by default
4269
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4270
        torch.fx.proxy.TracerBase.check_mutable_operations = True
4271

4272
    def tearDown(self):
4273
        super().tearDown()
4274
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4275

4276
    IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
4277
                    "has_torch_function_variadic", "handle_torch_function",
4278
                    "boolean_dispatch")
4279
    TO_PATCH = {"has_torch_function": None,
4280
                "has_torch_function_unary": None,
4281
                "has_torch_function_variadic": None}
4282

4283
    BUILT_IN_FUNC = (AssertionError, "")
4284
    PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
4285
    PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4286
    LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
4287
    ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
4288
    CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
4289
    INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
4290
    MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
4291

4292
    UNTRACEABLE_FUNCTIONALS = {
4293
        "adaptive_avg_pool1d": BUILT_IN_FUNC,
4294
        "avg_pool1d": BUILT_IN_FUNC,
4295
        "avg_pool2d": BUILT_IN_FUNC,
4296
        "avg_pool3d": BUILT_IN_FUNC,
4297
        "bilinear": BUILT_IN_FUNC,
4298
        "celu_": BUILT_IN_FUNC,
4299
        "channel_shuffle": BUILT_IN_FUNC,
4300
        "native_channel_shuffle": BUILT_IN_FUNC,
4301
        "conv1d": BUILT_IN_FUNC,
4302
        "conv2d": BUILT_IN_FUNC,
4303
        "conv3d": BUILT_IN_FUNC,
4304
        "conv_tbc": BUILT_IN_FUNC,
4305
        "conv_transpose1d": BUILT_IN_FUNC,
4306
        "conv_transpose2d": BUILT_IN_FUNC,
4307
        "conv_transpose3d": BUILT_IN_FUNC,
4308
        "cosine_similarity": BUILT_IN_FUNC,
4309
        "elu_": BUILT_IN_FUNC,
4310
        "gelu": BUILT_IN_FUNC,
4311
        "hardshrink": BUILT_IN_FUNC,
4312
        "hardtanh_": BUILT_IN_FUNC,
4313
        "leaky_relu_": BUILT_IN_FUNC,
4314
        "linear": BUILT_IN_FUNC,
4315
        "logsigmoid": BUILT_IN_FUNC,
4316
        "one_hot": BUILT_IN_FUNC,
4317
        "pad": ARG_TYPE_MISMATCH,
4318
        "pairwise_distance": BUILT_IN_FUNC,
4319
        "pdist": BUILT_IN_FUNC,
4320
        "pixel_shuffle": BUILT_IN_FUNC,
4321
        "pixel_unshuffle": BUILT_IN_FUNC,
4322
        "prelu": BUILT_IN_FUNC,
4323
        "relu_": BUILT_IN_FUNC,
4324
        "rrelu_": BUILT_IN_FUNC,
4325
        "selu_": BUILT_IN_FUNC,
4326
        "scaled_dot_product_attention": BUILT_IN_FUNC,
4327
        "softplus": BUILT_IN_FUNC,
4328
        "softshrink": BUILT_IN_FUNC,
4329
        "threshold_": BUILT_IN_FUNC,
4330

4331
        "adaptive_avg_pool2d": LEN_ERROR,
4332
        "adaptive_avg_pool3d": LEN_ERROR,
4333
        "adaptive_max_pool2d_with_indices": LEN_ERROR,
4334
        "adaptive_max_pool3d_with_indices": LEN_ERROR,
4335
        "instance_norm": CONTROL_FLOW,
4336

4337
        "adaptive_max_pool1d": PROXY_ITERABLE,
4338
        "adaptive_max_pool2d": PROXY_ITERABLE,
4339
        "adaptive_max_pool3d": PROXY_ITERABLE,
4340
        "fractional_max_pool2d": PROXY_ITERABLE,
4341
        "fractional_max_pool3d": PROXY_ITERABLE,
4342
        "max_pool1d": PROXY_ITERABLE,
4343
        "max_pool2d": PROXY_ITERABLE,
4344
        "max_pool3d": PROXY_ITERABLE,
4345

4346
        "lp_pool2d": PROXY_ITERATED,
4347
        "lp_pool3d": PROXY_ITERATED,
4348
        "max_unpool1d": PROXY_ITERATED,
4349
        "max_unpool2d": PROXY_ITERATED,
4350
        "max_unpool3d": PROXY_ITERATED,
4351
        "fold": PROXY_ITERATED,
4352
        "unfold": PROXY_ITERATED,
4353

4354
        "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4355
        "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4356
        "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4357
        "layer_norm": ARG_TYPE_MISMATCH,
4358
        "rms_norm": ARG_TYPE_MISMATCH,
4359
        "lp_pool1d": ARG_TYPE_MISMATCH,
4360

4361
        "affine_grid": CONTROL_FLOW,
4362
        "alpha_dropout": CONTROL_FLOW,
4363
        "batch_norm": CONTROL_FLOW,
4364
        "binary_cross_entropy": CONTROL_FLOW,
4365
        "binary_cross_entropy_with_logits": CONTROL_FLOW,
4366
        "celu": CONTROL_FLOW,
4367
        "cosine_embedding_loss": CONTROL_FLOW,
4368
        "cross_entropy": CONTROL_FLOW,
4369
        "ctc_loss": CONTROL_FLOW,
4370
        "dropout": CONTROL_FLOW,
4371
        "dropout1d": CONTROL_FLOW,
4372
        "dropout2d": CONTROL_FLOW,
4373
        "dropout3d": CONTROL_FLOW,
4374
        "elu": CONTROL_FLOW,
4375
        "embedding": CONTROL_FLOW,
4376
        "embedding_bag": CONTROL_FLOW,
4377
        "feature_alpha_dropout": CONTROL_FLOW,
4378
        "gaussian_nll_loss": CONTROL_FLOW,
4379
        "glu": CONTROL_FLOW,
4380
        "grid_sample": CONTROL_FLOW,
4381
        "group_norm": CONTROL_FLOW,
4382
        "gumbel_softmax": CONTROL_FLOW,
4383
        "hardsigmoid": CONTROL_FLOW,
4384
        "hardswish": CONTROL_FLOW,
4385
        "hardtanh": CONTROL_FLOW,
4386
        "hinge_embedding_loss": CONTROL_FLOW,
4387
        "huber_loss": CONTROL_FLOW,
4388
        "interpolate": CONTROL_FLOW,
4389
        "kl_div": CONTROL_FLOW,
4390
        "l1_loss": CONTROL_FLOW,
4391
        "leaky_relu": CONTROL_FLOW,
4392
        "local_response_norm": CONTROL_FLOW,
4393
        "margin_ranking_loss": CONTROL_FLOW,
4394
        "max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4395
        "max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4396
        "max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4397
        "mse_loss": CONTROL_FLOW,
4398
        "multi_head_attention_forward": CONTROL_FLOW,
4399
        "multi_margin_loss": CONTROL_FLOW,
4400
        "multilabel_margin_loss": CONTROL_FLOW,
4401
        "multilabel_soft_margin_loss": CONTROL_FLOW,
4402
        "nll_loss": CONTROL_FLOW,
4403
        "poisson_nll_loss": CONTROL_FLOW,
4404
        "relu": CONTROL_FLOW,
4405
        "relu6": CONTROL_FLOW,
4406
        "rrelu": CONTROL_FLOW,
4407
        "selu": CONTROL_FLOW,
4408
        "silu": CONTROL_FLOW,
4409
        "mish": CONTROL_FLOW,
4410
        "smooth_l1_loss": CONTROL_FLOW,
4411
        "soft_margin_loss": CONTROL_FLOW,
4412
        "threshold": CONTROL_FLOW,
4413
        "triplet_margin_loss": CONTROL_FLOW,
4414
        "triplet_margin_with_distance_loss": CONTROL_FLOW,
4415
        "upsample": CONTROL_FLOW,
4416

4417
        "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
4418
        "upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
4419
    }
4420

4421
    # List of nn.functionals with Tensor inputs but not with type annotation
4422
    FUNCTIONALS_WITHOUT_ANNOTATION = (
4423
        "adaptive_max_pool1d",
4424
        "adaptive_max_pool2d",
4425
        "adaptive_max_pool3d",
4426
        "fractional_max_pool2d",
4427
        "fractional_max_pool3d",
4428
        "max_pool1d",
4429
        "max_pool2d",
4430
        "max_pool3d",
4431
        "gaussian_nll_loss",
4432
        "upsample",
4433
        "upsample_bilinear",
4434
        "upsample_nearest",
4435
    )
4436

4437
    # Inconsistent behavior between Python 3.8 and other Python versions:
4438
    # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
4439
    # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
4440
    #                 internal exception above
4441
    # Use the following map to override the expected exception for Python 3.8
4442
    UNTRACEABLE_FUNCTIONALS_PY38 = {
4443
        "adaptive_max_pool1d": PROXY_ITERATED,
4444
        "adaptive_max_pool2d": PROXY_ITERATED,
4445
        "adaptive_max_pool3d": PROXY_ITERATED,
4446
        "fractional_max_pool2d": PROXY_ITERATED,
4447
        "fractional_max_pool3d": PROXY_ITERATED,
4448
        "max_pool1d": PROXY_ITERATED,
4449
        "max_pool2d": PROXY_ITERATED,
4450
        "max_pool3d": PROXY_ITERATED,
4451

4452
        "group_norm": CONTROL_FLOW
4453
    }
4454

4455
    @classmethod
4456
    def _get_functional(cls):
4457
        functional_list = []
4458
        for f in dir(torch.nn.functional):
4459
            if not f.islower():
4460
                continue
4461
            # Ignore internal functions
4462
            if f.startswith('_'):
4463
                continue
4464
            # Ignore supporting functions
4465
            if f in cls.IGNORE_FUNCS:
4466
                continue
4467
            fn = getattr(torch.nn.functional, f)
4468
            # Ignore non-callable object like modules
4469
            if not isinstance(fn, Callable):
4470
                continue
4471
            if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
4472
                try:
4473
                    sig = inspect.signature(fn)
4474
                    has_tensor_arg = False
4475
                    for param in sig.parameters.values():
4476
                        if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
4477
                            has_tensor_arg = True
4478
                    if not has_tensor_arg:
4479
                        continue
4480
                # No signature or Object is not supported
4481
                except ValueError:
4482
                    pass
4483
            functional_list.append((f, fn))
4484
        return functional_list
4485

4486
    @classmethod
4487
    def generate_test_func(cls, func_name, fn):
4488

4489
        def functional_test(self):
4490
            if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
4491
                    sys.version_info >= (3, 8) and sys.version_info < (3, 12):
4492
                exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
4493
                with self.assertRaisesRegex(exc, err):
4494
                    symbolic_trace(fn)
4495
            elif func_name in self.UNTRACEABLE_FUNCTIONALS:
4496
                exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
4497
                with self.assertRaisesRegex(exc, err):
4498
                    symbolic_trace(fn)
4499
            else:
4500
                symbolic_trace(fn)
4501
        return functional_test
4502

4503
    @classmethod
4504
    def generate_tests(cls):
4505
        functional_list = cls._get_functional()
4506
        for func_name, fn in functional_list:
4507
            test_name = "test_nn_functional_" + func_name
4508
            functional_test = cls.generate_test_func(func_name, fn)
4509
            setattr(cls, test_name, functional_test)
4510

4511
    @classmethod
4512
    def setUpClass(cls):
4513

4514
        def no(*args, **kwargs):
4515
            return False
4516

4517
        for name in cls.TO_PATCH.keys():
4518
            cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
4519
            setattr(torch.nn.functional, name, no)
4520

4521
    @classmethod
4522
    def tearDownClass(cls):
4523
        for name in cls.TO_PATCH.keys():
4524
            setattr(torch.nn.functional, name, cls.TO_PATCH[name])
4525

4526
TestFunctionalTracing.generate_tests()
4527

4528

4529
instantiate_device_type_tests(TestOperatorSignatures, globals())
4530

4531
@skipIfTorchDynamo("too slow")
4532
@skipIfNoTorchVision
4533
class TestVisionTracing(JitTestCase):
4534
    def setUp(self):
4535
        # Checking for mutable operations while tracing is feature flagged
4536
        # Enable it in testing but not by default
4537
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4538
        torch.fx.proxy.TracerBase.check_mutable_operations = True
4539

4540
    def tearDown(self):
4541
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4542

4543
    PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4544
    INCONSISTENT_TYPE = (
4545
        RuntimeError,
4546
        r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
4547
    )
4548

4549
    UNTRACEABLE_MODELS = {
4550
        "fasterrcnn_resnet50_fpn": PROXY_ITERATED,
4551
        "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4552
        "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
4553
        "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
4554
        "maskrcnn_resnet50_fpn": PROXY_ITERATED,
4555
        "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4556
        "keypointrcnn_resnet50_fpn": PROXY_ITERATED,
4557
        "retinanet_resnet50_fpn": PROXY_ITERATED,
4558
        "retinanet_resnet50_fpn_v2": PROXY_ITERATED,
4559
        "ssd300_vgg16": PROXY_ITERATED,
4560
        "fcos_resnet50_fpn": PROXY_ITERATED,
4561
        "ssdlite320_mobilenet_v3_large": PROXY_ITERATED,
4562
    }
4563
    UNSCRIPTABLE_MODELS = {
4564
        "googlenet": INCONSISTENT_TYPE,
4565
        "inception_v3": INCONSISTENT_TYPE,
4566
    }
4567

4568
    output_transform = {
4569
        "fcn_resnet50": lambda x: x["out"],
4570
        "fcn_resnet101": lambda x: x["out"],
4571
        "deeplabv3_resnet50": lambda x: x["out"],
4572
        "deeplabv3_resnet101": lambda x: x["out"],
4573
        "deeplabv3_mobilenet_v3_large": lambda x: x["out"],
4574
        "lraspp_mobilenet_v3_large": lambda x: x["out"],
4575
        "fasterrcnn_resnet50_fpn": lambda x: x[1],
4576
        "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4577
        "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
4578
        "maskrcnn_resnet50_fpn": lambda x: x[1],
4579
        "keypointrcnn_resnet50_fpn": lambda x: x[1],
4580
        "retinanet_resnet50_fpn": lambda x: x[1],
4581
    }
4582

4583
    @classmethod
4584
    def generate_test_fn(cls, name, x, kwargs):
4585
        def run_test(self):
4586
            model = torchvision_models.get_model(name, **kwargs)
4587
            model = model.eval()
4588
            if name in self.UNTRACEABLE_MODELS:
4589
                err, exc = self.UNTRACEABLE_MODELS[name]
4590
                with self.assertRaisesRegex(err, exc):
4591
                    graph = symbolic_trace(model)
4592
            else:
4593
                out_transform = self.output_transform.get(name, lambda x: x)
4594
                graph : torch.fx.GraphModule = symbolic_trace(model)
4595
                a = out_transform(model(x))
4596
                b = out_transform(graph(x))
4597
                self.assertEqual(a, b)
4598

4599
                if name in self.UNSCRIPTABLE_MODELS:
4600
                    err, exc = self.UNSCRIPTABLE_MODELS[name]
4601
                    with self.assertRaisesRegex(err, exc):
4602
                        script = torch.jit.script(graph)
4603
                else:
4604
                    script = torch.jit.script(graph)
4605
                    c = out_transform(script(x))
4606
                    self.assertEqual(a, c)
4607

4608
        return run_test
4609

4610
    @classmethod
4611
    def generate_classification_tests(cls):
4612
        for k in torchvision_models.list_models(module=torchvision_models):
4613
            test_name = 'test_torchvision_models_' + k
4614
            x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
4615
            kwargs = dict(num_classes=50)
4616
            model_test = cls.generate_test_fn(k, x, kwargs)
4617
            setattr(cls, test_name, model_test)
4618

4619
    @classmethod
4620
    def generate_segmentation_tests(cls):
4621
        for k in torchvision_models.list_models(module=torchvision_models.segmentation):
4622
            test_name = 'test_torchvision_models_segmentation_' + k
4623
            x = torch.rand(1, 3, 32, 32)
4624
            kwargs = dict(num_classes=10, pretrained_backbone=False)
4625
            model_test = cls.generate_test_fn(k, x, kwargs)
4626
            setattr(cls, test_name, model_test)
4627

4628
    @classmethod
4629
    def generate_detection_tests(cls):
4630
        for k in torchvision_models.list_models(module=torchvision_models.detection):
4631
            test_name = 'test_torchvision_models_detection_' + k
4632
            x = [torch.rand(3, 300, 300)]
4633
            kwargs = dict(num_classes=10, pretrained_backbone=False)
4634
            model_test = cls.generate_test_fn(k, x, kwargs)
4635
            setattr(cls, test_name, model_test)
4636

4637
    @classmethod
4638
    def generate_video_tests(cls):
4639
        for k in torchvision_models.list_models(module=torchvision_models.video):
4640
            test_name = 'test_torchvision_models_video_' + k
4641
            x = (
4642
                torch.rand(1, 3, 4, 112, 112)
4643
                if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"}
4644
                else torch.rand(1, 3, 16, 224, 224)
4645
            )
4646
            kwargs = dict(num_classes=50)
4647
            model_test = cls.generate_test_fn(k, x, kwargs)
4648
            setattr(cls, test_name, model_test)
4649

4650
    @classmethod
4651
    def generate_tests(cls):
4652
        cls.generate_classification_tests()
4653
        cls.generate_detection_tests()
4654
        cls.generate_segmentation_tests()
4655
        cls.generate_video_tests()
4656

4657
if HAS_TORCHVISION:
4658
    TestVisionTracing.generate_tests()
4659

4660
if __name__ == '__main__':
4661
    run_tests()
4662

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.