pytorch

Форк
0
/
test_fx.py 
4592 строки · 162.9 Кб
1
# Owner(s): ["module: fx"]
2

3
import builtins
4
import contextlib
5
import copy
6
import functools
7
import inspect
8
import math
9
import numbers
10
import io
11
import operator
12
import os
13
import pickle
14
import sys
15
import torch
16
import traceback
17
import typing
18
import types
19
import warnings
20
import unittest
21
from math import sqrt
22
from functorch.experimental import control_flow
23
from torch.multiprocessing import Process
24
from torch.testing import FileCheck
25
from torch.testing._internal.common_methods_invocations import op_db
26
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
27
import torch.utils._pytree as pytree
28
import torch.fx._pytree as fx_pytree
29
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
30
from torch.fx.node import Target, Argument, _format_arg
31
from torch.fx.passes import shape_prop
32
from torch.fx.immutable_collections import immutable_dict, immutable_list
33
from torch.fx.experimental.rewriter import RewritingTracer
34
from torch.fx.operator_schemas import get_signature_for_torch_op
35
from copy import deepcopy
36
from collections import namedtuple
37

38
from torch.fx.proxy import TraceError
39
from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY
40
from torch.fx._symbolic_trace import PHBase, PHWithMeta
41
from fx.test_subgraph_rewriter import TestSubgraphRewriter  # noqa: F401
42
from fx.test_dce_pass import TestDCE  # noqa: F401
43
from fx.test_fx_const_fold import TestConstFold  # noqa: F401
44
from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow  # noqa: F401
45
from fx.test_pass_infra import TestPassManager  # noqa: F401
46
from fx.test_common_passes import TestCommonPass  # noqa: F401
47
from fx.test_cse_pass import TestCSEPass  # noqa: F401
48
from fx.test_matcher_utils import TestMatcher  # noqa: F401
49
from fx.test_source_matcher_utils import TestSourceMatcher  # noqa: F401
50

51
from fx.test_gradual_type import AnnotationsTest  # noqa: F401
52
from fx.test_gradual_type import TypeCheckerTest  # noqa: F401
53
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
54
from torch.testing._internal.common_utils import (
55
    IS_FBCODE,
56
    IS_MACOS,
57
    IS_WINDOWS,
58
    find_library_location,
59
    run_tests,
60
    skipIfTorchDynamo,
61
)
62
from torch.testing._internal.jit_utils import JitTestCase
63

64
from fx.named_tup import MyNamedTup
65

66
try:
67
    from torchvision import models as torchvision_models
68
    HAS_TORCHVISION = True
69
except ImportError:
70
    HAS_TORCHVISION = False
71
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
73

74
class SimpleTest(torch.nn.Module):
75
    def forward(self, x):
76
        return torch.relu(x + 3.0)
77

78
def a_non_torch_leaf(a, b):
79
    return a + b
80

81
# Used for test_autowrap_function. Autowrapped functions need to be global
82
def fx_int(x: float) -> int:
83
    return int(x)
84

85
def fx_int_x2(x: float) -> int:
86
    return int(x) * 2
87

88
# used in test_pytree. It's all the way out here because pickling a GraphModule
89
# that uses Point errors out if Point is local to the function
90
Point = namedtuple('Point', ['x', 'y'])
91

92
# Test wrap() passing both a function name as well as a function
93
# directly
94
def a_lifted_leaf(a, b):
95
    return a[0] + a[1] + b
96

97
wrap('a_lifted_leaf')
98
# Test wrapping twice doesn't break anything
99
wrap('a_lifted_leaf')
100

101
def a_lifted_leaf2(a, b):
102
    return a[0] + a[1] + b
103

104
wrap(a_lifted_leaf2)
105

106
wrap('len')
107

108
wrap('getattr')
109

110
def wrapped_named_tup(p1, *, p2):
111
    return p1.x + p2.y
112

113
wrap(wrapped_named_tup)
114

115
@wrap
116
def wrapped_via_decorator(a):
117
    return a + 1
118

119
wrap('wrapped_with_submodule')
120

121
def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
122
    return batchnorm1d(x)
123

124
def my_decorator(f):
125
    @functools.wraps(f)
126
    def wrapper_inside_decorator(*args, **kwargs):
127
        return f(*args, **kwargs)
128
    return wrapper_inside_decorator
129

130
@wrap
131
@my_decorator
132
def wrapped_decorated_fn(x):
133
    return x
134

135
real_wrapped_via_decorator = wrapped_via_decorator
136
real_a_lifed_leaf = a_lifted_leaf
137
real_a_lifed_leaf2 = a_lifted_leaf2
138
_sqrt = sqrt
139

140
wrap('wrapper_fn')
141

142
def wrapper_fn(x):
143
    return torch.foo(x)
144

145
class Pair(NamedTuple):
146
    x : torch.Tensor
147
    y : torch.Tensor
148

149
    def _custom_fx_repr_fn(self) -> str:
150
        return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
151

152
# for testing pytrees
153
class Foo:  # noqa: B209
154
    def __init__(self, a, b):
155
        self.a = a
156
        self.b = b
157

158
class Add(torch.nn.Module):
159
    def forward(self, x):
160
        return x + x
161

162
@torch.fx.has_side_effect
163
@torch.fx.wrap
164
def side_effect_func(x: torch.Tensor):
165
    print(x)
166

167
class TestFX(JitTestCase):
168
    def setUp(self):
169
        super().setUp()
170
        # Checking for mutable operations whil tracing is feature flagged
171
        # Enable it in testing but not by default
172
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
173
        torch.fx.proxy.TracerBase.check_mutable_operations = True
174

175
        if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
176
            lib_file_path = find_library_location('libtorchbind_test.so')
177
            torch.ops.load_library(str(lib_file_path))
178

179
    def tearDown(self):
180
        super().tearDown()
181
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
182

183
    def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
184
        """Check that an nn.Module's results match the GraphModule version
185
        for a given set of args/kwargs.
186
        """
187
        kwargs = kwargs if kwargs else {}
188
        ref_outs = m(*args, **kwargs)
189
        gm = symbolic_trace(m)
190
        gm.graph.lint()
191
        test_outs = gm(*args, **kwargs)
192
        self.assertEqual(ref_outs, test_outs)
193

194
    def test_graph_module(self):
195
        class MySub(torch.nn.Module):
196
            def __init__(self):
197
                super().__init__()
198
                self.w = torch.nn.Parameter(torch.rand(4, 3))
199

200
            def forward(self, x):
201
                return self.w + x
202

203
        class MyModule(torch.nn.Module):
204
            def __init__(self):
205
                super().__init__()
206
                self.lin = torch.nn.Linear(4, 3)
207
                self.sub_mod = MySub()
208
                self.w = torch.nn.Parameter(torch.rand(3))
209

210
            def forward(self, A, B, c):
211
                t = torch.sigmoid(A) + self.lin(c)
212
                return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
213

214
        m = MyModule()
215
        gm = symbolic_trace(m)
216

217
        ms = torch.jit.script(gm)
218

219
        class M2(torch.nn.Module):
220
            def forward(self, A):
221
                m, idx = torch.max(A, 0)
222
                return m + 1, idx + 1
223

224
        m2 = M2()
225
        gm2 = symbolic_trace(m2)
226

227
        class T(torch.nn.Module):
228

229
            def forward(self, A, b=4, *args, c=5, **kwargs):
230
                x = A + 1 + args[0] + kwargs['3']
231
                return x
232

233
        t = T()
234
        symbolic_trace(t)
235

236
        # test for issue described at https://github.com/pytorch/pytorch/issues/63883
237
        class M3(torch.nn.Module):
238
            def forward(self, x):
239
                return torch.relu(x)
240

241
        m3 = M3()
242
        gm3 = symbolic_trace(m3)
243
        new_instance = gm3.__new__(type(gm3))
244
        new_instance.__init__(gm3, gm3.graph)
245

246
        x = torch.randn(5, 3)
247
        torch.testing.assert_close(new_instance(x), torch.relu(x))
248

249
    def test_informative_co_filename(self):
250
        class MyModule(torch.nn.Module):
251
            def forward(self, a):
252
                return a * 2
253

254
        gm = symbolic_trace(MyModule())
255
        self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename)
256

257
    def test_custom_import(self):
258
        graph = torch.fx.Graph()
259
        a = graph.placeholder('x')
260
        b = graph.placeholder('y')
261
        c = graph.call_function(a_non_torch_leaf, (a, b))
262
        d = graph.call_function(torch.sin, (c,))
263
        graph.output(d)
264
        gm = GraphModule(torch.nn.Module(), graph)
265
        x, y = torch.rand(1), torch.rand(1)
266
        self.assertEqual(torch.sin(x + y), gm(x, y))
267

268
    def test_args_kwargs(self):
269
        class T(torch.nn.Module):
270
            def forward(self, *args, **kwargs):
271
                x = args[0] + kwargs['foo']
272
                return x
273

274
        t = T()
275
        self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
276

277
    def test_varargs_concrete(self):
278
        class T(torch.nn.Module):
279
            def forward(self, *args, **kwargs):
280
                x = args[0] + args[1]
281
                return x
282

283
        args = (torch.rand(1), torch.rand(1))
284

285
        t = T()
286
        ref_outs = t(*args)
287
        gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
288
        gm.graph.lint()
289
        test_outs = gm(*args)
290
        self.assertEqual(ref_outs, test_outs)
291

292
    def test_args_kwargs_no_self(self):
293
        class T(torch.nn.Module):
294
            def forward(*args, **kwargs):  # noqa: B902
295
                self = args[0]
296
                return torch.relu(args[1])
297

298
        t = T()
299
        with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
300
            self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
301

302
    def test_fx_shifts(self):
303
        class MyModule(torch.nn.Module):
304
            def forward(self, x):
305
                return x << 3, x >> 3
306

307
        input = torch.LongTensor(10).random_(0, 1024)
308

309
        m = MyModule()
310
        self.checkGraphModule(m, (input,))
311

312
    def test_fx_and_or(self):
313
        class MyModule(torch.nn.Module):
314
            def forward(self, x):
315
                return x & x, x | x
316

317
        input = torch.LongTensor(10).random_(0, 1024)
318

319
        m = MyModule()
320
        self.checkGraphModule(m, (input,))
321

322
    def test_dict(self):
323
        class MyDictMod(torch.nn.Module):
324
            def forward(self, d):
325
                return d['3'].relu(), {'4' : d['3'].neg()}
326

327
        input_dict = {'3': torch.rand(3, 4)}
328
        m = MyDictMod()
329

330
        self.checkGraphModule(m, (input_dict,))
331

332
    def test_matmul_tracing(self):
333
        const = torch.randn(3)
334

335
        def matmul_f(x):
336
            return x @ const
337

338
        mod = symbolic_trace(matmul_f)
339
        inp = torch.randn(3)
340
        self.assertEqual(mod(inp), matmul_f(inp))
341

342
        def rmatmul_f(x):
343
            return const @ x
344

345
        mod = symbolic_trace(rmatmul_f)
346
        inp = torch.randn(3)
347
        self.assertEqual(mod(inp), rmatmul_f(inp))
348

349
    @skipIfNoDynamoSupport
350
    def test_control_flow_tracing(self):
351
        def true(x, y):
352
            return x + y
353

354
        def false(x, y):
355
            return x - y
356

357
        def f(x, y):
358
            x = control_flow.cond(x[0] == 0, true, false, [x, y])
359

360
        with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"):
361
            _ = symbolic_trace(f)
362

363
    def test_disallow_override(self):
364
        # Custom delegate to disallow in-place tensor operations
365
        class NoMutableCallTracer(Tracer):
366
            def create_node(self, kind : str, target : Union[str, Callable],
367
                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
368
                            type_expr : Optional[Any] = None) -> Node:
369
                name = target if isinstance(target, str) else torch.typename(target)
370
                if name[-1] == '_':
371
                    raise RuntimeError('In-place operations are not supported')
372
                return super().create_node(kind, target, args, kwargs, name)
373

374
        # Test method
375
        class MyInplaceMod(torch.nn.Module):
376
            def forward(self, x):
377
                x.add_(3.0)
378
                return x
379

380
        m = MyInplaceMod()
381

382
        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
383
            NoMutableCallTracer().trace(m)
384

385
        # Test free function
386
        class MyInplaceMod2(torch.nn.Module):
387
            def forward(self, x):
388
                torch.log_(x)
389
                return x
390
        m2 = MyInplaceMod2()
391
        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
392
            NoMutableCallTracer().trace(m2)
393

394
        # Test symbolic node as an arg
395
        class MyInplaceMod3(torch.nn.Module):
396
            def forward(self, x):
397
                y = torch.ones(3, 4)
398
                y.add_(x)
399
                return x
400
        m3 = MyInplaceMod3()
401
        with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
402
            NoMutableCallTracer().trace(m3)
403

404
    def test_leaf_module(self):
405
        # Custom delegate to make it so that there are no leaf modules, everything
406
        # should get traced through
407
        class NoLeafModulesTracer(Tracer):
408
            def is_leaf_module(self, m, qualname):
409
                return False
410

411
        class MyReluMod(torch.nn.Module):
412
            def __init__(self):
413
                super().__init__()
414
                self.relu = torch.nn.ReLU()
415

416
            def forward(self, x):
417
                return self.relu(x)
418

419
        mrm = MyReluMod()
420
        sym = NoLeafModulesTracer().trace(mrm)
421
        for node in sym.nodes:
422
            self.assertNotEqual(node.op, 'call_module')
423
        sym.lint()
424

425
    def test_wrap(self):
426
        self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
427

428
        def to_trace(y):
429
            return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
430

431
        m = symbolic_trace(to_trace)
432
        self.assertIn('a_lifted_leaf', m.code)
433
        self.assertEqual(27, m(2))
434
        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
435

436
    def test_wrap_fn_directly(self):
437
        self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
438

439
        def to_trace(y):
440
            return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
441

442
        m = symbolic_trace(to_trace)
443
        self.assertIn('a_lifted_leaf2', m.code)
444
        self.assertEqual(27, m(2))
445
        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
446

447
    def test_wrapped_via_decorator(self):
448
        self.assertEqual(wrapped_via_decorator(0), 1)
449

450
        def to_trace(y):
451
            return wrapped_via_decorator(y)
452

453
        m = symbolic_trace(to_trace)
454
        self.assertIn('wrapped_via_decorator', m.code)
455
        self.assertEqual(m(0), 1)
456
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
457
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
458

459
    def test_wrapped_via_decorator_and_transformed(self):
460
        self.assertEqual(wrapped_via_decorator(0), 1)
461

462
        def to_trace(y):
463
            return wrapped_via_decorator(y)
464

465
        m = symbolic_trace(to_trace)
466
        self.assertIn('wrapped_via_decorator', m.code)
467
        self.assertEqual(m(0), 1)
468
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
469
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
470

471
        transformed = torch.fx.Transformer(m).transform()
472
        self.assertIn('wrapped_via_decorator', transformed.code)
473
        self.assertEqual(transformed(0), 1)
474
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
475
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
476

477
    def test_wrap_with_submodule(self):
478

479
        class M(torch.nn.Module):
480
            def __init__(self):
481
                super().__init__()
482
                self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
483

484
            def forward(self, x: torch.Tensor):
485
                return wrapped_with_submodule(x, self.batchnorm1d)
486

487
        m = symbolic_trace(M())
488

489
        self.assertIn("wrapped_with_submodule", m.code)
490

491
        input = torch.rand(3, 2)
492
        ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
493
        self.assertEqual(ref_batchnorm1d(input), m(input))
494

495
    def test_wrapped_retrace(self):
496
        def to_trace(y):
497
            return wrapped_via_decorator(y)
498

499
        m = symbolic_trace(to_trace)
500
        self.assertIn('wrapped_via_decorator', m.code)
501
        self.assertEqual(m(0), 1)
502

503
        retraced = symbolic_trace(m)
504
        self.assertIn('wrapped_via_decorator', retraced.code)
505
        self.assertEqual(retraced(0), 1)
506

507
    def test_wrap_decorated_function(self):
508
        def to_trace(y):
509
            return wrapped_decorated_fn(y)
510

511
        m = symbolic_trace(to_trace)
512
        self.assertIn('wrapped_decorated_fn', m.code)
513
        self.assertEqual(m(1), 1)
514

515
    def test_graph_edit_with_proxy(self):
516
        class M(torch.nn.Module):
517
            def forward(self, a, b):
518
                return a + b
519
        m = M()
520
        g = symbolic_trace(m).graph
521
        new_g = torch.fx.Graph()
522
        val_map : Dict[Node, Node] = {}
523
        output_val = new_g.graph_copy(g, val_map)
524
        t = Proxy(output_val)
525
        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
526
        new_g.output((t + t).node)
527
        gm = GraphModule(m, new_g)
528
        gm.graph.lint()
529
        self.assertEqual(gm(3, 4), 14)
530

531
    def test_concrete_arg_none_assert(self):
532
        class Foo(torch.nn.Module):
533
            def forward(self, x, val=None):
534
                return x if val is None else x + val
535

536
        f = Foo()
537
        traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None})
538
        with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'):
539
            traced(torch.randn(5), torch.randn(5))
540

541
        x = torch.randn(5)
542
        torch.testing.assert_close(traced(x), f(x))
543

544
    def test_trace_multiple_funcs(self):
545
        class Foo(torch.nn.Module):
546
            def forward(self, x, y):
547
                return x + y
548

549
            def minus_forward(self, x, y):
550
                return x - y
551

552
            def multiply_forward(self, x, y):
553
                return x * y
554

555
        f = Foo()
556
        x, y = torch.randn(5), torch.randn(5)
557

558
        print(torch.__version__)
559

560
        tracer = Tracer()
561
        torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y))
562

563
        tracer.traced_func_name = "minus_forward"
564
        torch.testing.assert_close(
565
            GraphModule(f, tracer.trace(f))(x, y),
566
            f.minus_forward(x, y),
567
        )
568

569
        tracer.traced_func_name = "multiply_forward"
570
        torch.testing.assert_close(
571
            GraphModule(f, tracer.trace(f))(x, y),
572
            f.multiply_forward(x, y),
573
        )
574

575
        tracer.traced_func_name = "add_forward"
576
        with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
577
            tracer.trace(f)
578

579

580
    def test_graph_unique_names(self):
581
        class M(torch.nn.Module):
582
            def forward(self, a, b):
583
                return a + b
584
        m = M()
585
        g = symbolic_trace(m).graph
586
        new_g = torch.fx.Graph()
587
        val_map : Dict[Node, Node] = {}
588
        output_val = new_g.graph_copy(g, val_map)
589
        t = Proxy(output_val)
590
        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
591
        new_g.output((t + t).node)
592
        gm = GraphModule(m, new_g)
593
        seen_names : Set[str] = set()
594
        for node in gm.graph.nodes:
595
            assert node.name not in seen_names
596
            seen_names.add(node.name)
597

598
    def test_stack_traces(self):
599
        class M(torch.nn.Module):
600
            def forward(self, a, b):
601
                return a + b
602

603
        tracer = torch.fx.Tracer()
604
        tracer.record_stack_traces = True
605

606
        graph = tracer.trace(M())
607
        # saving the original list because we will insert new nodes as a part of a test
608
        orig_graph_nodes = list(graph.nodes)
609
        for node in orig_graph_nodes:
610
            if node.op == 'output':
611
                continue
612
            self.assertTrue(node.stack_trace is not None)
613
            assert 'test_fx.py' in node.stack_trace
614

615
            # verify that copying the node does not lose the stack trace
616
            new_node = graph.node_copy(node)
617
            self.assertTrue(new_node.stack_trace is not None)
618
            assert 'test_fx.py' in new_node.stack_trace
619

620
    def test_stack_traces_with_transformer(self):
621
        class M(torch.nn.Module):
622
            def forward(self, a, b):
623
                return a + b
624

625
        tracer = torch.fx.Tracer()
626
        tracer.record_stack_traces = True
627

628
        graph = tracer.trace(M())
629
        gm = GraphModule(tracer.root, graph)
630
        new_gm = Transformer(gm).transform()
631

632
        # nodes after Transformer should still preserve the original node's stack trace
633
        for node in new_gm.graph.nodes:
634
            if node.op in {'placeholder', 'output'}:
635
                continue
636
            self.assertTrue(node.stack_trace is not None)
637
            assert 'test_fx.py' in node.stack_trace
638

639
    def test_lineno_map(self):
640
        class M(torch.nn.Module):
641
            def forward(self, a, b):
642
                a = torch.sin(a)
643
                b = torch.cos(b)
644
                return a + b
645

646
        tracer = torch.fx.Tracer()
647
        graph = tracer.trace(M())
648
        gm = GraphModule(tracer.root, graph)
649
        expected = {1: 2, 2: 3, 3: 4, 4: 5}
650
        self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
651

652
        # test custom codegen
653
        def transform_code(code):
654
            return ["print('hello!')\n", *code]
655
        gm.graph.on_generate_code(lambda _: transform_code)
656
        gm.recompile()
657
        expected = {2: 2, 3: 3, 4: 4, 5: 5}
658
        self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
659

660
    def test_graph_unique_names_manual(self):
661
        graph : torch.fx.Graph = torch.fx.Graph()
662
        a : torch.fx.Node = graph.create_node('placeholder', 'x')
663
        b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
664
        c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
665
        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
666
        graph.output(d)
667
        graph2 = torch.fx.Graph()
668
        val_map : Dict[Node, Node] = {}
669
        graph2.graph_copy(graph, val_map)
670
        seen_names : Set[str] = set()
671
        for node in graph2.nodes:
672
            assert node.name not in seen_names
673
            seen_names.add(node.name)
674

675
    def test_unpack(self):
676
        class M(torch.nn.Module):
677
            def forward(self, a, b):
678
                c, d = a
679
                return c + d + b
680

681
        a = (torch.rand(1), torch.rand(1))
682
        b = torch.rand(1)
683
        m = M()
684
        self.checkGraphModule(m, (a, b))
685

686
    def test_native_callable(self):
687
        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
688
            raise unittest.SkipTest("non-portable load_library call used in test")
689
        # This test exercises the case where we use FX to translate from Python
690
        # code to some native callable object
691
        #
692
        # For the purposes of testing, we use ElementwiseInterpreter defined
693
        # in test_custom_class.cpp.
694
        #
695
        # We test that we can
696
        # 1) Construct a native callable from FX IR
697
        # 2) Construct a drop-in replacement module that delegates to the
698
        #    native callable rather than the original code
699
        # 3) Run both the original code and native callable wrapper with
700
        #    equivalent results
701
        # 4) TorchScript compile the native callable wrapper and confirm
702
        #    equivalent results with the reference
703
        # 5) TorchScript serialize and deserialize the native callable
704
        #    and confirm equivalent results with the reference
705

706
        # We use this simple Module as a reference computation
707
        class MySimpleMod(torch.nn.Module):
708
            def forward(self, x):
709
                return 3.0 * x + x
710

711
        msm = MySimpleMod()
712

713
        # This is what a lowering pass might look like: a function that takes
714
        # a valid nn.Module, symbolically traces it, lowers the Module to some
715
        # representation, and wraps that representation up into another
716
        # nn.Module instance that handles dispatch to the compiled/lowered code.
717
        def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
718
            # ===== Stage 1: Symbolic trace the module =====
719
            mod = symbolic_trace(orig_mod)
720

721
            # ===== Stage 2: Lower GraphModule representation to the C++
722
            #       interpreter's instruction format ======
723
            instructions = []
724
            constant_idx = 0
725
            constants = {}
726
            fn_input_names = []
727

728
            target_to_name = {
729
                operator.add : "add",
730
                operator.mul : "mul"
731
            }
732

733
            output_node : Optional[Node] = None
734
            # For each instruction, create a triple
735
            # (instruction_name : str, inputs : List[str], output : str)
736
            # to feed into the C++ interpreter
737
            for n in mod.graph.nodes:
738
                target, args, out_name = n.target, n.args, n.name
739
                assert len(n.kwargs) == 0, "kwargs currently not supported"
740

741
                if n.op == 'placeholder':
742
                    # Placeholders specify function argument names. Save these
743
                    # for later when we generate the wrapper GraphModule
744
                    fn_input_names.append(target)
745
                elif n.op == 'call_function':
746
                    assert target in target_to_name, "Unsupported call target " + target
747
                    arg_names = []
748
                    for arg in args:
749
                        if not isinstance(arg, Node):
750
                            # Pull out constants. These constants will later be
751
                            # fed to the interpreter C++ object via add_constant()
752
                            arg_name = f'constant_{constant_idx}'
753
                            constants[arg_name] = torch.tensor(
754
                                [arg] if isinstance(arg, numbers.Number) else arg)
755
                            arg_names.append(arg_name)
756
                            constant_idx += 1
757
                        else:
758
                            arg_names.append(arg.name)
759
                    instructions.append((target_to_name[target], arg_names, out_name))
760
                elif n.op == 'output':
761
                    if output_node is not None:
762
                        raise RuntimeError('Multiple output nodes!')
763
                    output_node = n
764
                else:
765
                    raise RuntimeError('Unsupported opcode ' + n.op)
766

767
            interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
768
            # Load constants
769
            for k, v in constants.items():
770
                interpreter.add_constant(k, v)
771
            # Specify names for positional input arguments
772
            interpreter.set_input_names(fn_input_names)
773
            # Load instructions
774
            interpreter.set_instructions(instructions)
775
            # Specify name for single output
776
            assert isinstance(output_node.args[0], torch.fx.Node)
777
            interpreter.set_output_name(output_node.args[0].name)
778

779
            # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
780
            class WrapperModule(torch.nn.Module):
781
                def __init__(self, interpreter):
782
                    super().__init__()
783
                    self.interpreter = interpreter
784

785
            wrapper = WrapperModule(interpreter)
786

787
            # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
788
            # 3) Returns the speficied return value
789

790
            # FIXME: The following code could be greatly simplified by symbolic_trace'ing
791
            # the wrapper with a Tracer that considers the Wrapper instance a root
792
            # module, however, I can't get `__call__` exposed on TorchBind classes
793
            # without it messing up Python `hasattr` for some reason. More digging
794
            # into CPython's implementation of hasattr is probably in order...
795

796
            graph = torch.fx.Graph()
797
            # Add placeholders for fn inputs
798
            placeholder_nodes = []
799
            for name in fn_input_names:
800
                placeholder_nodes.append(graph.create_node('placeholder', name))
801

802
            # Get the interpreter object
803
            interpreter_node = graph.create_node('get_attr', 'interpreter')
804

805
            # Add a node to call the interpreter instance
806
            output_node = graph.create_node(
807
                op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
808

809
            # Register output
810
            graph.output(output_node)
811

812
            graph.lint()
813

814
            # Return final GraphModule!!!
815
            return GraphModule(wrapper, graph)
816

817

818
        # Lower GraphModule to C++ interpreter
819
        lowered = lower_to_elementwise_interpreter(msm)
820

821
        # Compare correctness with original module
822
        x = torch.rand(3, 4)
823
        ref_out = msm(x)
824
        test_out = lowered(x)
825
        torch.testing.assert_close(test_out, ref_out)
826

827
        # Test TorchScript compilation
828
        scripted_lowered = torch.jit.script(lowered)
829
        script_out = scripted_lowered(x)
830
        torch.testing.assert_close(script_out, ref_out)
831

832
        # Test TorchScript ser/de
833
        import_copy = self.getExportImportCopy(scripted_lowered)
834
        imported_out = import_copy(x)
835
        torch.testing.assert_close(imported_out, ref_out)
836

837
    def test_reserved_getattr(self):
838
        """Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
839
        class M(torch.nn.Module):
840
            def forward(self, a):
841
                return a.foo.bar.baz
842

843
        m = M()
844
        m_g = symbolic_trace(m)
845
        m_g.graph.lint()
846
        for node in m_g.graph.nodes:
847
            self.assertTrue(node.name != "getattr")
848

849
    @unittest.skip("Hotfix for SEV remediation")
850
    def test_trace_buffer_slice(self):
851
        bs, d_hid = 10, 23
852

853
        class ExampleCode(torch.nn.Module):
854
            def __init__(self):
855
                super().__init__()
856
                self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
857
                self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
858
                self.lin = torch.nn.Linear(d_hid, d_hid)
859
                self.register_buffer('buffer', torch.randn(bs + 100, d_hid))
860

861
            def forward(self, x):
862
                x = torch.mm(x, self.mm_param)
863
                skip_connection = x
864
                x = torch.relu(x)
865
                x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]]
866
                x = self.lin(x)
867
                x = torch.relu(x)
868
                x = x + skip_connection
869
                x = torch.mm(x, self.mm_param2)
870
                x = self.lin(x)
871
                return x
872

873

874
        ec = ExampleCode()
875

876
        traced = torch.fx.symbolic_trace(ec)
877

878
        x = torch.randn(bs, d_hid)
879
        torch.testing.assert_close(ec(x), traced(x))
880

881

882
    def test_node_tagging(self):
883
        class TaggingTracer(Tracer):
884
            def create_node(self, kind : str, target : Union[str, Callable],
885
                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
886
                            type_expr : Optional[Any] = None) -> Node:
887
                n = super().create_node(kind, target, args, kwargs, name)
888
                n.tag = 'foo'
889
                return n
890

891
        class M(torch.nn.Module):
892
            def forward(self, a, b):
893
                return a + b
894

895
        m = M()
896
        g = TaggingTracer().trace(m)
897
        g.lint()
898
        for n in g.nodes:
899
            self.assertTrue(hasattr(n, 'tag'))
900
            self.assertEqual(n.tag, 'foo')
901

902
    def test_tensor_attribute(self):
903
        class TensorAttribute(torch.nn.Module):
904
            def __init__(self):
905
                super().__init__()
906
                self.tensor = torch.rand(3, 4)
907

908
            def forward(self, x):
909
                return torch.nn.functional.linear(x, self.tensor)
910

911
        ta = TensorAttribute()
912
        traced = symbolic_trace(ta)
913
        traced(torch.rand(4, 4))
914

915
        class WrapperForQualname(torch.nn.Module):
916
            def __init__(self):
917
                super().__init__()
918
                self.ta = TensorAttribute()
919

920
            def forward(self, x):
921
                return torch.nn.functional.linear(x, self.ta.tensor)
922

923
        wfq = WrapperForQualname()
924
        traced2 = symbolic_trace(wfq)
925
        traced2.graph.lint()
926
        traced2(torch.rand(4, 4))
927

928
    def test_tensor_attribute_coalseced(self):
929

930
        def count_attrs(fx_module):
931
            targets = set()
932
            for node in traced.graph.nodes:
933
                if node.op == 'get_attr':
934
                    targets.add(node.target)
935
            return len(targets)
936

937
        val = torch.tensor(5)
938

939
        def f(x):
940
            return x + val + val
941
        traced = symbolic_trace(f)
942
        traced.graph.lint()
943
        self.assertEqual(count_attrs(traced), 1)
944

945
        val2 = torch.tensor(5)
946

947
        def f(x):
948
            val = torch.tensor(5)
949
            return x + val + val2
950

951
        traced = symbolic_trace(f)
952
        traced.graph.lint()
953
        self.assertEqual(count_attrs(traced), 2)
954

955

956
    def test_symbolic_trace_sequential(self):
957
        class Simple(torch.nn.Module):
958
            def forward(self, x):
959
                return torch.neg(x)
960

961
        seq = torch.nn.Sequential(
962
            Simple(),
963
            Simple(),
964
            Simple()
965
        )
966
        traced = symbolic_trace(seq)
967
        traced.graph.lint()
968
        x = torch.rand(3, 4)
969
        self.assertEqual(traced(x), seq(x))
970

971
    def test_tensor_constant(self):
972
        class ConstTensor(torch.nn.Module):
973
            def forward(self, x):
974
                return torch.nn.functional.linear(x, torch.zeros(3, 4))
975

976
        ct = ConstTensor()
977
        traced = symbolic_trace(ct)
978
        traced.graph.lint()
979
        traced(torch.rand(4, 4))
980

981
    def test_pickle_graphmodule(self):
982
        class Nested(torch.nn.Module):
983
            def __init__(self):
984
                super().__init__()
985
                self.st = torch.nn.Linear(4, 4)
986

987
            def forward(self, x):
988
                return self.st(x)
989

990
        n = Nested()
991
        traced = symbolic_trace(n)
992
        traced.graph.lint()
993
        pickled = pickle.dumps(traced)
994
        loaded = pickle.loads(pickled)
995
        loaded.graph.lint()
996
        x = torch.rand(3, 4)
997
        self.assertEqual(loaded(x), traced(x))
998

999
    def test_pickle_custom_import(self):
1000
        graph = torch.fx.Graph()
1001
        a = graph.placeholder('x')
1002
        b = graph.placeholder('y')
1003
        c = graph.call_function(a_non_torch_leaf, (a, b))
1004
        d = graph.call_function(torch.sin, (c,))
1005
        graph.output(d)
1006
        gm = GraphModule(torch.nn.Module(), graph)
1007
        pickled = pickle.dumps(gm)
1008
        loaded = pickle.loads(pickled)
1009
        loaded.graph.lint()
1010
        x, y = torch.rand(1), torch.rand(1)
1011
        self.assertEqual(loaded(x, y), gm(x, y))
1012

1013
    def test_all_input_nodes(self):
1014
        graph : torch.fx.Graph = torch.fx.Graph()
1015
        a : torch.fx.Node = graph.placeholder('x')
1016
        b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
1017
        c : torch.fx.Node = graph.get_attr('y_attr')
1018
        d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
1019
        e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
1020
        graph.output(e)
1021
        graph.lint()
1022

1023
        self.assertEqual(b.all_input_nodes, [a])
1024
        self.assertEqual(c.all_input_nodes, [])
1025
        self.assertEqual(d.all_input_nodes, [b, c])
1026
        self.assertEqual(e.all_input_nodes, [d])
1027

1028
    def test_deepcopy_graphmodule_with_transform(self):
1029
        st = SimpleTest()
1030
        traced = symbolic_trace(st)
1031
        traced.graph.lint()
1032

1033
        def transform(traced):
1034
            new_graph = torch.fx.Graph()
1035
            val_map : Dict[Node, Node] = {}
1036
            output_value = new_graph.graph_copy(traced.graph, val_map)
1037
            relu_out = new_graph.create_node(
1038
                op='call_method', target='neg', args=(output_value,), kwargs={})
1039
            new_graph.output(relu_out)
1040
            return GraphModule(traced, new_graph)
1041
        transformed = transform(traced)
1042
        transformed.graph.lint()
1043
        copied = copy.deepcopy(transformed)
1044
        self.assertNotEqual(id(type(transformed)), id(type(copied)))
1045
        x = torch.randn(3, 4)
1046
        self.assertEqual(copied(x), transformed(x))
1047

1048
    def test_deepcopy_with_submods_params(self):
1049
        class Bar(torch.nn.Module):
1050
            def __init__(self):
1051
                super().__init__()
1052
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1053

1054
            def forward(self, x):
1055
                return torch.relu(x) + self.param
1056

1057
        class Baz(torch.nn.Module):
1058
            def __init__(self):
1059
                super().__init__()
1060
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1061
                self.bar = Bar()
1062

1063
            def forward(self, x):
1064
                return self.bar(x) - self.param
1065

1066
        baz = Baz()
1067
        traced = symbolic_trace(baz)
1068
        traced.graph.lint()
1069
        copied = copy.deepcopy(traced)
1070
        copied.graph.lint()
1071

1072
    def test_deepcopy_graph_with_tracer_cls(self):
1073
        class TestTracer(Tracer):
1074
            def is_leaf_module(self, module, name):
1075
                return True
1076

1077
        g = Graph(tracer_cls=TestTracer)
1078
        x = g.placeholder("x")
1079
        g.output(x)
1080

1081
        h = copy.deepcopy(g)
1082
        self.assertIsNotNone(h._tracer_cls)
1083
        self.assertTrue(g._tracer_cls == h._tracer_cls)
1084

1085
    def test_unpack_list_better_error(self):
1086
        class SomeArgs(torch.nn.Module):
1087
            def forward(self, a, b):
1088
                return torch.rand(3, 4)
1089

1090
        class UnpacksList(torch.nn.Module):
1091
            def __init__(self):
1092
                super().__init__()
1093
                self.sa = SomeArgs()
1094

1095
            def forward(self, x : list):
1096
                return self.sa(*x)
1097

1098
        ul = UnpacksList()
1099
        with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1100
            symbolic_trace(ul)
1101

1102
    def test_unpack_dict_better_error(self):
1103
        class SomeKwargs(torch.nn.Module):
1104
            def forward(self, x=3, y=4):
1105
                return torch.rand(3, 4)
1106

1107
        class UnpacksDict(torch.nn.Module):
1108
            def __init__(self):
1109
                super().__init__()
1110
                self.sk = SomeKwargs()
1111

1112
            def forward(self, x : dict):
1113
                return self.sk(**x)
1114

1115
        ud = UnpacksDict()
1116
        with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1117
            symbolic_trace(ud)
1118

1119
    def test_pretty_print_targets(self):
1120
        # Test that Graph pretty-print prints friendly name for targets
1121
        # in `operator` and `builtins`
1122

1123
        class SomeMod(torch.nn.Module):
1124
            def forward(self, x):
1125
                return torch.add(x.foo + x.bar, 3.0)
1126

1127
        traced = symbolic_trace(SomeMod())
1128
        graph_str = str(traced.graph)
1129
        self.assertIn('builtins.getattr', graph_str)
1130
        self.assertIn('operator.add', graph_str)
1131
        self.assertIn('torch.add', graph_str)
1132

1133
    def test_pretty_print_node(self):
1134
        class M(torch.nn.Module):
1135
            def __init__(self):
1136
                super().__init__()
1137
                self.param: torch.nn.Parameter = torch.nn.Parameter(
1138
                    torch.rand(3, 4))
1139
                self.linear = torch.nn.Linear(4, 5)
1140

1141
            def forward(self, x: torch.Tensor, y: int = 2):
1142
                return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
1143

1144
        traced = symbolic_trace(M())
1145

1146
        all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
1147

1148
        FileCheck().check("x").check("placeholder") \
1149
            .check("y").check("placeholder") \
1150
            .check("getitem").check("call_function") \
1151
            .check("param").check("get_attr") \
1152
            .check("add").check("call_function") \
1153
            .check("linear").check("call_module") \
1154
            .check("clamp").check("call_method") \
1155
            .run(all_formatted)
1156

1157
    def test_script_tensor_constant(self):
1158
        # TorchScript seems to ignore attributes that start with `__`.
1159
        # We used to call anonymous Tensor values `__tensor_constant*`, but
1160
        # they were getting ignored by script. Now they're called
1161
        # `_tensor_constant*`
1162
        class IHaveATensorConstant(torch.nn.Module):
1163
            def forward(self, x):
1164
                return x + torch.rand(3, 4)
1165

1166
        traced = torch.fx.symbolic_trace(IHaveATensorConstant())
1167
        torch.jit.script(traced)
1168

1169
    def test_autowrap_functions(self):
1170
        class AutowrapFnTest(torch.nn.Module):
1171
            def forward(self, x):
1172
                return fx_int(x.shape[0] / 2)
1173

1174
        class AutowrapFnTest2(torch.nn.Module):
1175
            def forward(self, x):
1176
                return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
1177

1178
        # Check function(s) are wrapped
1179
        # `int` would normally throw a TypeError as argument can't be `Proxy`
1180
        tracer = Tracer(autowrap_functions=(fx_int,))
1181
        graph = tracer.trace(AutowrapFnTest())
1182
        traced = GraphModule(tracer.root, graph, 'test')
1183
        tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
1184
        tracer_2.trace(AutowrapFnTest2())
1185

1186
        # Test scriptability
1187
        traced_scripted = torch.jit.script(traced)
1188
        self.assertEqual(traced_scripted(torch.rand(4)), 2)
1189

1190
    def test_tuple_no_subscript(self):
1191
        def foo(x : Tuple):
1192
            return x[0]
1193

1194
        traced = torch.fx.symbolic_trace(foo)
1195
        x = (torch.randn(5, 3),)
1196
        torch.testing.assert_close(traced(x), x[0])
1197

1198
        bio = io.BytesIO()
1199

1200
        torch.save(traced, bio)
1201

1202
        bio.seek(0)
1203

1204
        loaded = torch.load(bio)
1205

1206
        torch.testing.assert_close(loaded(x), x[0])
1207

1208
    def test_torch_fx_len(self):
1209
        class FXLenTest(torch.nn.Module):
1210
            def forward(self, x):
1211
                return len(x)
1212

1213
        traced = symbolic_trace(FXLenTest())
1214
        self.assertEqual(traced(torch.rand(3, 4)), 3)
1215

1216
        # Test scriptability
1217
        scripted = torch.jit.script(FXLenTest())
1218
        self.assertEqual(scripted(torch.rand(3)), 3)
1219

1220
        traced_scripted = torch.jit.script(traced)
1221
        self.assertEqual(traced_scripted(torch.rand(3)), 3)
1222

1223
        # Test non-proxy len
1224
        class FXLenTest2(torch.nn.Module):
1225
            def __init__(self):
1226
                super().__init__()
1227
                self.l = [3, 4, 5]
1228

1229
            def forward(self, x):
1230
                return x + len(self.l)
1231

1232
        traced2 = symbolic_trace(FXLenTest2())
1233
        inp = torch.rand(3, 4)
1234
        self.assertEqual(traced2(inp), inp + 3.0)
1235
        self.assertIs(len, builtins.len)
1236

1237
    def test_torch_fx_getattr(self):
1238
        class FXGetattrTest(torch.nn.Module):
1239
            def forward(self, x):
1240
                return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
1241

1242
        traced = symbolic_trace(FXGetattrTest())
1243
        self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
1244

1245
    def test_sqrt(self):
1246
        class Sqrt1(torch.nn.Module):
1247
            def forward(self, x):
1248
                return sqrt(x.size(0))
1249

1250
        class Sqrt2(torch.nn.Module):
1251
            def forward(self, x):
1252
                return math.sqrt(x.size(0))
1253

1254
        class Sqrt3(torch.nn.Module):
1255
            def forward(self, x):
1256
                return x + math.sqrt(2) + sqrt(2)
1257

1258
        self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
1259
        self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
1260
        self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
1261
        self.assertIs(sqrt, _sqrt)
1262
        self.assertIs(math.sqrt, _sqrt)
1263

1264
    def test_torch_custom_ops(self):
1265
        class M(torch.nn.Module):
1266
            def forward(self, a):
1267
                b = torch.ops.aten.sigmoid(a)
1268
                c = torch.ops.aten.cat([a, b])
1269
                return torch.ops.aten.cat((c, c))
1270
        m = M()
1271
        input = torch.randn(3)
1272
        ref_out = m(input)
1273
        gm = symbolic_trace(m)
1274
        gm.graph.lint()
1275
        out = gm(input)
1276
        self.assertEqual(out, ref_out)
1277

1278
    def test_torch_op_overloads(self):
1279
        class M(torch.nn.Module):
1280
            def forward(self, a):
1281
                b = torch.ops.aten.add.Tensor(a, a)
1282
                return b
1283
        m = M()
1284
        input = torch.randn(3)
1285
        ref_out = m(input)
1286
        gm = symbolic_trace(m)
1287
        gm.graph.lint()
1288
        out = gm(input)
1289
        self.assertEqual(out, ref_out)
1290

1291
        for node in gm.graph.nodes:
1292
            if node.op == 'call_function':
1293
                assert isinstance(node.target, torch._ops.OpOverload)
1294
                assert node.target.__name__ == 'add.Tensor'
1295

1296
    def test_pickle_torch_custom_ops(self):
1297
        class M(torch.nn.Module):
1298
            def forward(self, a):
1299
                b = torch.ops.aten.sigmoid(a)
1300
                c = torch.ops.aten.cat([a, b])
1301
                return torch.ops.aten.cat((c, c))
1302
        m = M()
1303
        input = torch.randn(3)
1304
        ref_out = m(input)
1305
        gm = symbolic_trace(m)
1306
        gm.graph.lint()
1307
        pickled = pickle.dumps(gm)
1308
        loaded = pickle.loads(pickled)
1309
        self.assertEqual(loaded(input), gm(input))
1310

1311
    def test_pretty_print(self):
1312
        st = SimpleTest()
1313
        traced = symbolic_trace(st)
1314
        traced.graph.lint()
1315
        printed = str(traced)
1316
        assert 'SimpleTest()' in printed
1317
        assert 'torch.relu' in printed
1318

1319
    def test_pretty_print_graph(self):
1320
        class KwargPrintTest(torch.nn.Module):
1321
            def forward(self, x):
1322
                return torch.squeeze(x + 3.0, dim=2)
1323
        st = KwargPrintTest()
1324
        traced = symbolic_trace(st)
1325
        traced.graph.lint()
1326
        stringed = str(traced.graph)
1327
        for s in ['args', 'kwargs', 'num_users']:
1328
            assert s in stringed
1329

1330
    def test_custom_proxy_type(self):
1331
        class TensorPair:
1332
            def __init__(self, left, right):
1333
                self.left, self.right = left, right
1334

1335
            def add(self, other):
1336
                l = self.left + other.left
1337
                r = self.right + other.right
1338
                return TensorPair(l, r)
1339

1340
            def mul(self, other):
1341
                l = self.left * other.left
1342
                r = self.right * other.right
1343
                return TensorPair(l, r)
1344

1345
        def use_tensor_pair(x : TensorPair, y : TensorPair):
1346
            s = x.add(y)
1347
            return s.mul(x)
1348

1349
        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1350
        y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1351

1352
        ref_out = use_tensor_pair(x, y)
1353

1354
        traced = symbolic_trace(use_tensor_pair)
1355

1356
        traced_out = traced(x, y)
1357
        self.assertEqual(traced_out.left, ref_out.left)
1358
        self.assertEqual(traced_out.right, ref_out.right)
1359

1360
    def test_custom_proxy_type_literal(self):
1361
        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1362
            def __init__(self, left, right):
1363
                self.left, self.right = left, right
1364

1365
            def add(self, other):
1366
                l = self.left + other.left
1367
                r = self.right + other.right
1368
                return TensorPair(l, r)
1369

1370
            def mul(self, other):
1371
                l = self.left * other.left
1372
                r = self.right * other.right
1373
                return TensorPair(l, r)
1374

1375
        def use_tensor_pair_literal(x : TensorPair):
1376
            s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
1377
            return s.mul(x)
1378

1379
        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1380

1381
        ref_out = use_tensor_pair_literal(x)
1382

1383
        traced = symbolic_trace(use_tensor_pair_literal)
1384

1385
        traced_out = traced(x)
1386
        self.assertEqual(traced_out.left, ref_out.left)
1387
        self.assertEqual(traced_out.right, ref_out.right)
1388

1389
    def test_custom_proxy_dynamic_value(self):
1390
        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1391
            def __init__(self, left, right):
1392
                self.left, self.right = left, right
1393

1394
            def add(self, other):
1395
                l = self.left + other.left
1396
                r = self.right + other.right
1397
                return TensorPair(l, r)
1398

1399
            def mul(self, other):
1400
                l = self.left * other.left
1401
                r = self.right * other.right
1402
                return TensorPair(l, r)
1403

1404
        def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
1405
            s = x.add(TensorPair(y, y))
1406
            return s.mul(x)
1407

1408
        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1409
        y = torch.randn(5, 3)
1410
        ref_out = use_tensor_pair_ctor(x, y)
1411

1412
        traced = symbolic_trace(use_tensor_pair_ctor)
1413

1414
        traced_out = traced(x, y)
1415
        self.assertEqual(traced_out.left, ref_out.left)
1416
        self.assertEqual(traced_out.right, ref_out.right)
1417

1418
    def test_custom_proxy_input_dependent_control_flow(self):
1419
        class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
1420
            def __init__(self, inp):
1421
                if inp.sum() == 0:
1422
                    self.is_zero = True
1423
                    self.tensor = torch.tensor([])
1424
                else:
1425
                    self.is_zero = False
1426
                    self.tensor = inp
1427

1428
            def add(self, other):
1429
                if self.is_zero:
1430
                    return ZeroTensor(other.tensor)
1431
                elif other.is_zero:
1432
                    return self
1433

1434
        def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
1435
            return ZeroTensor(x + y)
1436

1437
        x, y = torch.randn(5, 3), torch.randn(5, 3)
1438

1439
        ref_out = use_zero_tensor(x, y)
1440

1441
        traced = symbolic_trace(use_zero_tensor)
1442

1443
        traced_out = traced(x, y)
1444

1445
        self.assertEqual(traced_out.is_zero, ref_out.is_zero)
1446
        self.assertEqual(traced_out.tensor, ref_out.tensor)
1447

1448
    def test_graph_fns(self):
1449
        g = Graph()
1450
        a = g.placeholder('a')
1451
        b = g.call_module('linear', (a,))
1452
        c = g.get_attr('bias')
1453
        d = g.call_method('add', (b, c))
1454
        e = g.call_function(torch.sin, (d,))
1455
        g.output(e)
1456
        mod = torch.nn.Module()
1457
        mod.linear = torch.nn.Linear(3, 4)
1458
        mod.bias = torch.rand(4)
1459
        gm = GraphModule(mod, g)
1460
        gm.graph.lint()
1461
        input = torch.rand(3)
1462
        r = gm(input)
1463
        ref = torch.sin(mod.linear(input) + mod.bias)
1464
        self.assertEqual(r, ref)
1465

1466
    def test_remove_uses(self):
1467
        g : torch.fx.Graph = Graph()
1468
        x : torch.fx.Node = g.placeholder('x')
1469
        relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1470
        neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1471
        g.output(neg)
1472

1473
        neg.replace_all_uses_with(relu)
1474
        g.erase_node(neg)
1475

1476
        self.assertTrue(neg not in relu.users)
1477

1478
    def test_remove_uses_with_custom_filter(self):
1479
        g : torch.fx.Graph = Graph()
1480
        x : torch.fx.Node = g.placeholder('x')
1481
        relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1482
        neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1483
        g.output(neg)
1484

1485
        neg.replace_all_uses_with(relu, lambda x: x != neg)
1486

1487
        self.assertTrue(neg in relu.users)
1488

1489

1490
    def test_nonetype_annotation(self):
1491
        eb = torch.nn.EmbeddingBag(3, 4)
1492
        symbolic_trace(eb)
1493

1494
    def test_pickle_nonetype_annotation(self):
1495
        eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
1496
        traced = symbolic_trace(eb)
1497
        pickled = pickle.dumps(traced)
1498
        loaded = pickle.loads(pickled)
1499
        loaded.graph.lint()
1500
        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1501
        offsets = torch.LongTensor([0, 4])
1502
        self.assertEqual(loaded(input, offsets), traced(input, offsets))
1503

1504
    def test_return_tuple(self):
1505
        class M(torch.nn.Module):
1506
            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1507
                return (x, x + x)
1508

1509

1510
        original = M()
1511
        traced = symbolic_trace(original)
1512
        self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
1513

1514
    def test_construct_root_dict(self):
1515
        graph : torch.fx.Graph = torch.fx.Graph()
1516
        a : torch.fx.Node = graph.create_node('placeholder', 'x')
1517
        b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1518
        c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1519
        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1520
        graph.output(d)
1521

1522
        linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
1523
        add_param : torch.Tensor = torch.rand(3, 4)
1524
        gm : torch.fx.GraphModule = torch.fx.GraphModule(
1525
            {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
1526
        gm.graph.lint()
1527

1528
        assert 'self.foo.bar.baz' in gm.code
1529

1530
        x : torch.Tensor = torch.rand(3, 3)
1531
        out : torch.Tensor = gm(x)
1532
        ref_out : torch.Tensor = linear_mod(x) + add_param
1533
        self.assertEqual(out, ref_out)
1534

1535
    def test_symbolic_trace_assert(self):
1536

1537
        class AssertsTensorShape(torch.nn.Module):
1538
            def forward(self, x):
1539
                torch._assert(x.shape[1] > 4, "assert_foobar")
1540
                return x
1541

1542
        m = AssertsTensorShape()
1543
        # verify traceability
1544
        traced = symbolic_trace(m)
1545
        # verify assertion on traced model works correctly at runtime
1546
        traced(torch.rand(4, 5))
1547
        with self.assertRaisesRegex(AssertionError, "assert_foobar"):
1548
            traced(torch.rand(4, 3))
1549
        # verify the symbolically traced module is scriptable
1550
        ms = torch.jit.script(m)
1551
        with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
1552
            ms(torch.rand(4, 3))
1553

1554
    def test_fx_create_arg(self):
1555
        class CustomArgObject:
1556
            def __init__(self, x, y):
1557
                self.x = x
1558
                self.y = y
1559

1560
            def __fx_create_arg__(self, tracer: torch.fx.Tracer):
1561
                return tracer.create_node(
1562
                    "call_function",
1563
                    CustomArgObject,
1564
                    args=(
1565
                        tracer.create_arg(self.x),
1566
                        tracer.create_arg(self.y),
1567
                    ),
1568
                    kwargs={},
1569
                )
1570

1571
        class HasCustomArgObjectWhenLeaf(torch.nn.Module):
1572
            def forward(self, o: CustomArgObject):
1573
                # Not normally traceable; good reason to make
1574
                # this module a leaf.
1575
                for x in o.x:
1576
                    o.y += x
1577
                return o.y
1578

1579
        class Root(torch.nn.Module):
1580
            def __init__(self):
1581
                super().__init__()
1582
                self.inner = HasCustomArgObjectWhenLeaf()
1583

1584
            def forward(self, x, y):
1585
                o = CustomArgObject(x, y)
1586
                return self.inner(o)
1587

1588
        class CreateArgTracer(torch.fx.Tracer):
1589
            def is_leaf_module(self, m, module_qualified_name):
1590
                return type(m) is HasCustomArgObjectWhenLeaf
1591

1592
        m = Root()
1593
        graph = CreateArgTracer().trace(m)
1594
        gm = torch.fx.GraphModule(m, graph)
1595
        assert "CustomArgObject(" in gm.code
1596

1597
    def test_trace_fn_constant(self):
1598
        some_constant = torch.rand(3, 4)
1599

1600
        def add_const(x):
1601
            return some_constant + x
1602

1603
        traced = symbolic_trace(add_const)
1604

1605
        input = torch.rand(3, 4)
1606
        self.assertEqual(traced(input), add_const(input))
1607

1608
    def test_copy_no_remap(self):
1609
        traced = symbolic_trace(SimpleTest())
1610
        g = traced.graph
1611
        copied = torch.fx.Graph()
1612
        for node in g.nodes:
1613
            copied.node_copy(node)
1614
        with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
1615
            copied.lint()
1616

1617
    def test_wrong_topo(self):
1618
        graph : torch.fx.Graph = torch.fx.Graph()
1619
        a : torch.fx.Node = graph.create_node('placeholder', 'x')
1620
        b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1621
        c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1622
        d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1623
        graph.output(d)
1624
        nodes = list(graph.nodes)
1625
        nodes[3].append(nodes[2])
1626
        with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
1627
            graph.lint()
1628

1629
    def test_wrong_target_type(self):
1630
        graph : torch.fx.Graph = torch.fx.Graph()
1631
        with self.assertRaises(ValueError):
1632
            n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
1633
                              args=(), kwargs={})
1634

1635
    def test_example_shape_prop(self):
1636
        class TestCase(torch.nn.Module):
1637
            def __init__(self):
1638
                super().__init__()
1639
                self.attr = torch.randn(3, 4)
1640
                self.submod = torch.nn.Linear(4, 4)
1641

1642
            def forward(self, x):
1643
                return torch.neg(self.submod(x.relu() + self.attr))
1644
        tc = TestCase()
1645
        tc_traced = symbolic_trace(tc)
1646
        ref_out = tc_traced(torch.rand(3, 4))
1647
        shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
1648

1649
        # Make sure we're testing all opcodes
1650
        opcodes = set()
1651
        output_shape : Optional[torch.Shape] = None
1652
        output_stride : Optional[Tuple[int]] = None
1653
        for node in tc_traced.graph.nodes:
1654
            opcodes.add(node.op)
1655
            if node.op == 'output':
1656
                output_shape = node.args[0].meta['tensor_meta'].shape
1657
                output_stride = node.args[0].meta['tensor_meta'].stride
1658
        self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
1659
                                   'call_module', 'output'})
1660

1661
        # Test shape propagation and make sure results match actual
1662
        self.assertEqual(output_shape, ref_out.shape)
1663
        self.assertEqual(output_stride, ref_out.stride())
1664

1665
    def test_shape_prop_layout(self):
1666
        class ConvTest(torch.nn.Module):
1667
            def __init__(self):
1668
                super().__init__()
1669
                self.conv_mod = torch.nn.Conv2d(5, 5, 3)
1670

1671
            def forward(self, x):
1672
                return self.conv_mod(x)
1673

1674
        # contiguous layout
1675
        test_mod = ConvTest()
1676
        traced = symbolic_trace(test_mod)
1677
        x = torch.randn(5, 5, 224, 224)
1678
        shape_prop.ShapeProp(traced).propagate(x)
1679

1680
        assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1681
                   for node in traced.graph.nodes)
1682

1683
        x_channels_last = x.contiguous(memory_format=torch.channels_last)
1684
        traced.to(memory_format=torch.channels_last)
1685
        shape_prop.ShapeProp(traced).propagate(x_channels_last)
1686
        for node in traced.graph.nodes:
1687
            # NB: the implementation of conv may not preserve the memory format,
1688
            # unfortunately. The best we can do is just check that the placeholder
1689
            # node is channels-last
1690
            if node.op in {'placeholder'}:
1691
                self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
1692

1693
    def test_shape_prop_aggregate(self):
1694
        class ReturnTwo(torch.nn.Module):
1695
            def forward(self, x):
1696
                return (3, torch.sum(x))
1697

1698
        class UnderTest(torch.nn.Module):
1699
            def __init__(self):
1700
                super().__init__()
1701
                self.rt = ReturnTwo()
1702

1703
            def forward(self, x):
1704
                return self.rt(x)
1705

1706
        ut = UnderTest()
1707

1708
        class RTTracer(torch.fx.Tracer):
1709
            def is_leaf_module(self, m, module_qualified_name):
1710
                return type(m) is ReturnTwo
1711

1712
        graph = RTTracer().trace(ut)
1713
        mod = torch.fx.GraphModule(ut, graph)
1714

1715
        shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
1716

1717
        for node in mod.graph.nodes:
1718
            if node.op == 'call_module':
1719
                assert 'tensor_meta' in node.meta
1720
                tensor_meta = node.meta['tensor_meta']
1721
                assert tensor_meta[0] == 3
1722
                assert tensor_meta[1].shape == torch.Size([])
1723

1724
    def test_shape_prop_layout_3d(self):
1725
        class ConvTest3d(torch.nn.Module):
1726
            def __init__(self):
1727
                super().__init__()
1728
                self.conv_mod = torch.nn.Conv3d(5, 5, 3)
1729

1730
            def forward(self, x):
1731
                return self.conv_mod(x)
1732

1733
        test_mod_3d = ConvTest3d()
1734
        traced_3d = symbolic_trace(test_mod_3d)
1735
        x_3d = torch.randn(5, 5, 224, 224, 15)
1736
        shape_prop.ShapeProp(traced_3d).propagate(x_3d)
1737
        assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1738
                   for node in traced_3d.graph.nodes)
1739

1740
        x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
1741
        traced_3d.to(memory_format=torch.channels_last_3d)
1742
        shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
1743
        for node in traced_3d.graph.nodes:
1744
            # NB: the implementation of conv may not preserve the memory format,
1745
            # unfortunately. The best we can do is just check that the placeholder
1746
            # node is channels-last
1747
            if node.op in {'placeholder'}:
1748
                self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
1749

1750
    def test_nn_module_stack(self):
1751
        class SubModule(torch.nn.Module):
1752
            def __init__(self):
1753
                super().__init__()
1754
                self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
1755

1756
            def forward(self, x):
1757
                return self.conv_mod(x)
1758

1759
        class MyModule(torch.nn.Module):
1760
            def __init__(self):
1761
                super().__init__()
1762
                self.sub_mod = SubModule()
1763

1764
            def forward(self, x):
1765
                return self.sub_mod(x)
1766

1767
        m = MyModule()
1768
        gm = torch.fx.symbolic_trace(m)
1769

1770
        mod_stack = {}
1771
        expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))),
1772
                          ('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))]
1773
        for node in gm.graph.nodes:
1774
            mod_stack = node.meta.get('nn_module_stack', {})
1775
            if mod_stack:
1776
                break
1777
        stack_list = list(mod_stack.items())
1778
        self.assertEqual(stack_list, expected_stack)
1779

1780
    def test_transformer_preserves_nn_module_stack_for_get_attr(self):
1781
        class M(torch.nn.Module):
1782
            def __init__(self):
1783
                super().__init__()
1784
                self.weight = torch.nn.Parameter(torch.ones(1, 1))
1785

1786
            def forward(self, x):
1787
                return self.weight + x
1788

1789
        tracer = torch.fx.Tracer()
1790
        graph = tracer.trace(M())
1791
        gm = GraphModule(tracer.root, graph)
1792
        for node in gm.graph.nodes:
1793
            if node.op == 'get_attr':
1794
                node.meta["nn_module_stack"] = "self"
1795
                node.meta["stack_trace"] = "stack_trace"
1796
                node.meta["source_fn_stack"] = "source_fn_stack"
1797
        new_gm = Transformer(gm).transform()
1798
        for node in new_gm.graph.nodes:
1799
            if node.op == 'get_attr':
1800
                self.assertEqual(node.meta["nn_module_stack"], "self")
1801
                self.assertEqual(node.meta["stack_trace"], "stack_trace")
1802
                self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
1803

1804

1805
    def test_interpreter(self):
1806
        class MyModule(torch.nn.Module):
1807
            def __init__(self):
1808
                super().__init__()
1809
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1810
                self.linear = torch.nn.Linear(4, 5)
1811

1812
            def forward(self, x):
1813
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1814

1815
        m = MyModule()
1816
        gm = torch.fx.symbolic_trace(m)
1817

1818
        interpreter = Interpreter(gm)
1819
        input = torch.randn(3, 4)
1820
        self.assertEqual(interpreter.run(input), gm(input))
1821
        self.assertEqual(interpreter.run(input), m(input))
1822

1823
    def test_interpreter_other_graph(self):
1824
        class MyModule(torch.nn.Module):
1825
            def __init__(self):
1826
                super().__init__()
1827
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1828
                self.linear = torch.nn.Linear(4, 5)
1829

1830
            def forward(self, x):
1831
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1832

1833
        m = MyModule()
1834
        gm = torch.fx.symbolic_trace(m)
1835

1836
        interpreter = Interpreter(gm, graph=gm.graph)
1837
        input = torch.randn(3, 4)
1838
        self.assertEqual(interpreter.run(input), gm(input))
1839
        self.assertEqual(interpreter.run(input), m(input))
1840

1841
    def test_interpreter_run_node_override(self):
1842
        class MyModule(torch.nn.Module):
1843
            def __init__(self):
1844
                super().__init__()
1845
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1846
                self.linear = torch.nn.Linear(4, 5)
1847

1848
            def forward(self, x):
1849
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1850

1851
        m = MyModule()
1852
        gm = torch.fx.symbolic_trace(m)
1853

1854
        class RunNodeInterpreter(Interpreter):
1855
            def __init__(self, module):
1856
                super().__init__(module)
1857

1858
            def run_node(self, n : Node) -> Any:
1859
                result = super().run_node(n)
1860
                n.cached_value = result
1861
                return result
1862

1863
        input = torch.randn(3, 4)
1864
        RunNodeInterpreter(gm).run(input)
1865
        for node in gm.graph.nodes:
1866
            assert hasattr(node, 'cached_value')
1867

1868
    def test_interpreter_onthefly_swap(self):
1869

1870
        def fn(x):
1871
            return torch.sigmoid(x).neg()
1872

1873
        gm = torch.fx.symbolic_trace(fn)
1874

1875
        class NegSigmSwapInterpreter(Interpreter):
1876
            def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1877
                if target == torch.sigmoid:
1878
                    return torch.neg(*args, **kwargs)
1879
                return super().call_function(n)  # noqa: F821
1880

1881
            def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1882
                if target == 'neg':
1883
                    call_self, *args_tail = args
1884
                    return call_self.sigmoid(*args_tail, **kwargs)
1885
                return super().call_method(n)  # noqa: F821
1886

1887
        input = torch.randn(3, 4)
1888
        result = NegSigmSwapInterpreter(gm).run(input)
1889
        self.assertEqual(result, torch.neg(input).sigmoid())
1890

1891
    def test_interpreter_partial_eval(self):
1892
        class MyModule(torch.nn.Module):
1893
            def __init__(self):
1894
                super().__init__()
1895
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1896
                self.linear = torch.nn.Linear(4, 5)
1897

1898
            def forward(self, x):
1899
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1900

1901
        gm = torch.fx.symbolic_trace(MyModule())
1902
        interp = Interpreter(gm)
1903
        env = {}
1904
        for node in gm.graph.nodes:
1905
            if node.op == 'call_module' and node.target == 'linear':
1906
                env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
1907
                break
1908
        assert len(env) == 1
1909
        x = torch.randn(3, 4)
1910
        result = interp.run(x, initial_env=env)
1911
        self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
1912

1913
    def test_interpreter_star_args(self):
1914
        def with_star_args(x, *args):
1915
            return x + args[0]
1916

1917
        gm = torch.fx.symbolic_trace(with_star_args)
1918
        interp = Interpreter(gm)
1919
        result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
1920
        self.assertEqual(result, torch.ones(3, 4) * 2.0)
1921

1922
    @skipIfNoTorchVision
1923
    def test_interpreter_noop_resnet18(self):
1924
        rn18 = torchvision_models.resnet18()
1925
        transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
1926
        inp = torch.randn(5, 3, 224, 224)
1927
        self.assertEqual(transformed(inp), rn18(inp))
1928

1929
    @skipIfNoTorchVision
1930
    def test_interpreter_gc_values(self):
1931
        rn18 = torchvision_models.resnet18()
1932
        interp = Interpreter(symbolic_trace(rn18))
1933
        inp = torch.rand(5, 3, 224, 224)
1934
        out = interp.run(inp)
1935
        env_key_names = {n.name for n in interp.env.keys()}
1936
        self.assertEqual(env_key_names, {'output'})
1937

1938
    def test_interpreter_default_args(self):
1939
        class Model(torch.nn.Module):
1940
            def forward(self, x, y=3.14159):
1941
                return x + y
1942

1943
        model = Model()
1944
        gm = torch.fx.symbolic_trace(model)
1945

1946
        interp = Interpreter(gm)
1947
        x = torch.randn(5, 3)
1948
        out = interp.run(x)
1949
        torch.testing.assert_close(out, x + 3.14159)
1950

1951
    def test_interpreter_not_enough_args(self):
1952
        class Model(torch.nn.Module):
1953
            def forward(self, x, y):
1954
                return x + y
1955

1956
        model = Model()
1957
        gm = torch.fx.symbolic_trace(model)
1958

1959
        interp = Interpreter(gm)
1960
        x = torch.randn(5, 3)
1961
        with self.assertRaisesRegex(RuntimeError,
1962
                                    'Expected positional argument for parameter y, but one was not passed in'):
1963
            out = interp.run(x)
1964

1965
    def test_transformer_noop(self):
1966
        class MyModule(torch.nn.Module):
1967
            def __init__(self):
1968
                super().__init__()
1969
                self.param = torch.nn.Parameter(torch.rand(3, 4))
1970
                self.linear = torch.nn.Linear(4, 5)
1971

1972
            def forward(self, x):
1973
                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1974

1975
        m = MyModule()
1976
        gm = torch.fx.symbolic_trace(m)
1977

1978
        new_gm = Transformer(gm).transform()
1979

1980
        input = torch.randn(3, 4)
1981
        self.assertEqual(new_gm(input), gm(input))
1982

1983
    def test_transformer_op_swap(self):
1984

1985
        def fn(x):
1986
            return torch.sigmoid(x).neg()
1987

1988
        gm = torch.fx.symbolic_trace(fn)
1989

1990
        class NegSigmSwapXformer(Transformer):
1991
            def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1992
                if target == torch.sigmoid:
1993
                    return torch.neg(*args, **kwargs)
1994
                return super().call_function(n)  # noqa: F821
1995

1996
            def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1997
                if target == 'neg':
1998
                    call_self, *args_tail = args
1999
                    return call_self.sigmoid(*args_tail, **kwargs)
2000
                return super().call_method(n)  # noqa: F821
2001

2002
        transformed = NegSigmSwapXformer(gm).transform()
2003
        input = torch.randn(3, 4)
2004
        self.assertEqual(transformed(input), torch.neg(input).sigmoid())
2005

2006
    def test_transformer_multi_outputs(self):
2007
        class MyModule(torch.nn.Module):
2008
            def __init__(self):
2009
                super().__init__()
2010
                self.param = torch.nn.Parameter(torch.rand(3, 4))
2011
                self.linear = torch.nn.Linear(4, 5)
2012

2013
            def forward(self, x):
2014
                x = x + self.param
2015
                out = self.linear(x)
2016
                return x, out
2017

2018
        m = MyModule()
2019
        gm = torch.fx.symbolic_trace(m)
2020

2021
        new_gm = Transformer(gm).transform()
2022

2023
        input = torch.randn(3, 4)
2024
        self.assertEqual(new_gm(input), gm(input))
2025

2026
    def test_fn_type_annotations(self):
2027
        class Foo(torch.nn.Module):
2028
            def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
2029
                return {'a': p.x + p.y + z + i}
2030

2031
        foo_scripted = torch.jit.script(Foo())
2032
        foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2033

2034
        fxed = symbolic_trace(Foo())
2035
        fxed_scripted = torch.jit.script(fxed)
2036
        fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2037

2038
    def test_fn_type_annotation_empty(self):
2039
        def forward(a : List[torch.Tensor]):
2040
            return a[0]
2041
        torch.jit.script(symbolic_trace(forward))
2042

2043
    def test_wrapped_method(self):
2044
        def wrap_with_relu(fn):
2045
            @functools.wraps(fn)
2046
            def wrapper(*args, **kwargs):
2047
                return torch.relu(fn(*args, **kwargs))
2048
            return wrapper
2049

2050
        class Foo(torch.nn.Module):
2051
            @wrap_with_relu
2052
            def forward(self, x, w):
2053
                return torch.matmul(x, w)
2054

2055
        f = Foo()
2056
        traced = symbolic_trace(f)
2057
        x, w = torch.rand(3, 4), torch.rand(4, 4)
2058
        self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
2059

2060
    def test_empty_graph_codegen(self):
2061
        graph = torch.fx.Graph()
2062
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2063
        self.assertEqual(gm(), None)
2064

2065
    def test_sequential(self):
2066
        m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
2067
        gm = torch.fx.symbolic_trace(m)
2068
        gm_copy = copy.deepcopy(gm)
2069

2070
    def test_ctx_mgr(self):
2071
        @contextlib.contextmanager
2072
        def do_nothing():
2073
            yield
2074

2075
        class M(torch.nn.Module):
2076
            @do_nothing()
2077
            def forward(self, x):
2078
                return torch.relu(x)
2079

2080
        m = M()
2081
        self.checkGraphModule(m, (torch.rand(3, 4),))
2082

2083
    def test_typename_print(self):
2084
        graph : torch.fx.Graph = torch.fx.Graph()
2085
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2086
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
2087
                                              type_expr=List[float])
2088
        output : torch.fx.Node = graph.output(b)
2089

2090
        self.assertTrue('typing.List[float]' in str(graph))
2091

2092
    def test_layout(self):
2093
        class M(torch.nn.Module):
2094
            def forward(self, x):
2095
                return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
2096

2097
        traced = symbolic_trace(M())
2098
        x = torch.rand(5, 9, 3, 4)
2099
        self.assertEqual(traced(x), torch.zeros_like(x))
2100

2101
    def test_ellipsis(self):
2102
        class M(torch.nn.Module):
2103
            def forward(self, x, y):
2104
                return x + y[:, 1:10, ...]
2105

2106
        traced = symbolic_trace(M())
2107
        x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
2108
        self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
2109

2110
    def test_inf_nan(self):
2111
        class FooMod(torch.nn.Module):
2112
            def forward(self, x):
2113
                return x + float('inf'), x + float('-inf'), x + float('nan')
2114

2115
        fm = FooMod()
2116
        self.checkGraphModule(fm, (torch.rand(3, 4),))
2117

2118
    def test_inf_nan_kwds(self):
2119
        graph : torch.fx.Graph = torch.fx.Graph()
2120
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2121
        b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
2122
        c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
2123
        graph.output((b, c))
2124

2125
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2126
        x = torch.rand(3, 4)
2127
        self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
2128

2129
    def test_deepcopy_recursion_depth(self):
2130
        depth = sys.getrecursionlimit() + 20
2131

2132
        g = torch.fx.Graph()
2133
        x = g.placeholder('x')
2134
        for i in range(depth):
2135
            x = g.call_function(torch.relu, (x,))
2136
        g.output(x)
2137

2138
        copied_graph = copy.deepcopy(g)
2139

2140
        val_map = {}
2141
        for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2142
            val_map[orig_node] = new_node
2143

2144
        for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2145
            orig_users = set(orig_node.users.keys())
2146
            orig_users_equiv = {val_map[u] for u in orig_users}
2147
            new_users = set(new_node.users.keys())
2148
            self.assertEqual(orig_users_equiv, new_users)
2149

2150
    @skipIfNoTorchVision
2151
    def test_replace_uses(self):
2152
        rn18 = torchvision_models.resnet18()
2153

2154
        class LowerReluTracer(torch.fx.Tracer):
2155
            def is_leaf_module(self, m : torch.nn.Module, qualname : str):
2156
                if isinstance(m, torch.nn.ReLU):
2157
                    return False
2158
                return super().is_leaf_module(m, qualname)
2159

2160
        rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
2161

2162
        to_erase = []
2163
        for node in rn18_traced.graph.nodes:
2164
            if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
2165
                kwargs = node.kwargs.copy()
2166
                # Neg doesn't have in-place
2167
                kwargs.pop('inplace')
2168
                with rn18_traced.graph.inserting_before(node):
2169
                    new_node = rn18_traced.graph.call_function(
2170
                        the_function=torch.neg, args=node.args, kwargs=node.kwargs)
2171
                node.replace_all_uses_with(replace_with=new_node)
2172
                to_erase.append(node)
2173

2174
        for node in to_erase:
2175
            rn18_traced.graph.erase_node(node)
2176

2177

2178
    def test_replace_input(self):
2179
        graph : torch.fx.Graph = torch.fx.Graph()
2180
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2181
        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2182
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2183
        output : torch.fx.Node = graph.output(b)
2184

2185
        b.replace_input_with(x, y)
2186

2187
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2188

2189
        input_x = torch.randn(33, 44)
2190
        input_y = torch.randn(11, 22)
2191
        self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
2192

2193
    def test_insertion_point(self):
2194
        graph : torch.fx.Graph = torch.fx.Graph()
2195
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2196
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2197
        output : torch.fx.Node = graph.output(b)
2198

2199
        with graph.inserting_before(b):
2200
            neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2201
            _, *relu_args = b.args
2202
            b.args = (neg, *relu_args)
2203

2204
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2205

2206
        input = torch.randn(33, 44)
2207
        self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2208

2209
    def test_update_args_api(self):
2210
        graph : torch.fx.Graph = torch.fx.Graph()
2211
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2212
        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2213
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2214
        output : torch.fx.Node = graph.output(b)
2215

2216
        orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2217
        inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2218
        self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2219

2220

2221
        b.update_arg(0, y)
2222
        new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2223
        self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2224

2225
    def test_update_kwargs_api(self):
2226
        graph : torch.fx.Graph = torch.fx.Graph()
2227
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2228
        y : torch.fx.Node = graph.create_node('placeholder', 'y')
2229
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
2230
        output : torch.fx.Node = graph.output(b)
2231

2232
        orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2233
        inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2234
        self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2235

2236

2237
        b.update_kwarg('input', y)
2238
        new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2239
        self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2240

2241
    def test_immutable_list_pytree_ops(self):
2242
        rand_tensor = torch.randn(5, 3)
2243
        l = immutable_list([3, [rand_tensor, 42]])
2244

2245
        flattened, spec = pytree.tree_flatten(l)
2246
        assert flattened == [3, rand_tensor, 42]
2247

2248
        unflattened = pytree.tree_unflatten(flattened, spec)
2249
        assert unflattened == l
2250
        assert isinstance(unflattened, immutable_list)
2251

2252
    def test_immutable_dict_pytree_ops(self):
2253
        rand_tensor = torch.randn(5, 3)
2254
        d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]})
2255

2256
        flattened, spec = pytree.tree_flatten(d)
2257
        assert flattened == [3, rand_tensor, 42]
2258

2259
        unflattened = pytree.tree_unflatten(flattened, spec)
2260
        assert unflattened == d
2261
        assert isinstance(unflattened, immutable_dict)
2262

2263
    def test_move_before(self):
2264
        graph : torch.fx.Graph = torch.fx.Graph()
2265
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2266
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2267
        output : torch.fx.Node = graph.output(b)
2268

2269
        neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2270
        _, *relu_args = b.args
2271
        b.args = (neg, *relu_args)
2272
        b.prepend(neg)
2273

2274
        gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2275

2276
        input = torch.randn(33, 44)
2277
        self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2278

2279
    def test_prepend_self(self):
2280
        graph : torch.fx.Graph = torch.fx.Graph()
2281
        x : torch.fx.Node = graph.create_node('placeholder', 'x')
2282
        b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2283
        output : torch.fx.Node = graph.output(b)
2284

2285
        b.prepend(b)
2286
        x.append(b)
2287
        self.assertEqual(len(graph.nodes), 3)
2288

2289
    def test_erase_node_error(self):
2290
        st = SimpleTest()
2291
        traced = symbolic_trace(st)
2292

2293
        for node in traced.graph.nodes:
2294
            # Test deleting with uses both in another Node and at the output
2295
            if node.target in [operator.add, torch.relu]:
2296
                with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
2297
                    traced.graph.erase_node(node)
2298

2299
    def test_copy_it(self):
2300
        d = immutable_dict([(3, 4), (5, 6)])
2301
        l = immutable_list([(3, 4), (5, 6)])
2302

2303
        self.assertEqual(d, deepcopy(d))
2304
        self.assertEqual(l, deepcopy(l))
2305

2306
    def test_get_torch_func_signature(self):
2307
        for key in dir(torch):
2308
            obj = getattr(torch, key)
2309
            if callable(obj):
2310
                schemas = get_signature_for_torch_op(obj)
2311

2312
    def test_find_uses(self):
2313
        graph = torch.fx.Graph()
2314
        x = torch.fx.Proxy(graph.placeholder('x'))
2315

2316
        y = torch.relu(x)
2317
        z = x + x
2318
        u = torch.neg(x)
2319
        graph.output((y + z + u).node)
2320
        graph.lint()
2321

2322
        users_of_x = x.node.users
2323
        self.assertEqual(len(users_of_x), 3)
2324
        expected_ops = {'relu', 'add', 'neg'}
2325
        for use in users_of_x:
2326
            assert any(use.name.startswith(prefix) for prefix in expected_ops)
2327

2328
    def test_inline_graph(self):
2329
        class InlineInto(torch.nn.Module):
2330
            def forward(self, x):
2331
                return torch.relu(x)
2332

2333
        class ToInline(torch.nn.Module):
2334
            def forward(self, x):
2335
                return torch.neg(x)
2336

2337
        inline_into = symbolic_trace(InlineInto())
2338
        to_inline = symbolic_trace(ToInline())
2339

2340
        combined_graph = torch.fx.Graph()
2341
        output_node = combined_graph.graph_copy(inline_into.graph, {})
2342

2343
        input_node = next(iter(to_inline.graph.nodes))
2344
        assert input_node and input_node.op == 'placeholder'
2345

2346
        val_map = {input_node : output_node}
2347
        output = combined_graph.graph_copy(to_inline.graph, val_map)
2348
        combined_graph.output(output)
2349

2350
        combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
2351

2352
        input = torch.rand(3, 4)
2353
        self.assertEqual(combined_module(input), input.relu().neg())
2354

2355
    def test_multi_insert_point(self):
2356
        graph = torch.fx.Graph()
2357
        x = torch.fx.Proxy(graph.placeholder('x'))
2358
        relu = torch.relu(x)
2359

2360
        with graph.inserting_before(relu.node):
2361
            y = torch.neg(x)
2362
            z = torch.tanh(y)
2363

2364
        graph.output((relu.node, z.node))
2365
        graph.lint()
2366

2367
        expected_ops = ['x', 'neg', 'tanh', 'relu']
2368
        for node, expected in zip(graph.nodes, expected_ops):
2369
            assert expected in node.name
2370

2371
    def test_reassign_args_kwargs_uses(self):
2372
        graph = torch.fx.Graph()
2373
        x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
2374
        z = x + y
2375
        zed = z + z + z
2376
        graph.output(zed.node)
2377
        graph.lint()
2378

2379
        # zed = z + z + z -> zed = z + z + x
2380
        zed.node.args = (zed.node.args[0], x.node)
2381
        self.assertEqual(list(x.node.users.keys()), [z.node, zed.node])
2382

2383
        # z = x + y -> z = y + y
2384
        z.node.args = (y.node, y.node)
2385
        self.assertEqual(list(x.node.users.keys()), [zed.node])
2386

2387
    def test_trace_function(self):
2388
        def foo(x, y):
2389
            return torch.relu(x) + y
2390

2391
        x, y = torch.randn(3, 4), torch.randn(3, 4)
2392
        self.checkGraphModule(foo, (x, y))
2393

2394

2395
    def test_trace_return_dataclass(self):
2396
        """
2397
        Test case for Module that return dataclass
2398
        """
2399
        from dataclasses import dataclass
2400

2401
        @dataclass
2402
        class MyOutput:
2403
            foo: torch.Tensor
2404
            bar: torch.Tensor
2405

2406
        class ModuleReturnDataclass(torch.nn.Module):
2407
            def forward(self, d : torch.Tensor):
2408
                return MyOutput(foo=d + d, bar=d * 3)
2409

2410
        module = ModuleReturnDataclass()
2411
        traced_graph = symbolic_trace(module).graph
2412
        print(traced_graph)
2413

2414
        gm = GraphModule(module, traced_graph)
2415
        x = torch.rand(1)
2416

2417
        self.assertEqual(module(x), gm(x))
2418

2419
    def test_trace_return_dataclass_nested(self):
2420
        """
2421
        Test case for Module that return dataclass
2422
        """
2423
        from dataclasses import dataclass
2424

2425
        @dataclass
2426
        class MyOutput:
2427
            foo: torch.Tensor
2428
            bar: torch.Tensor
2429

2430
        class ModuleReturnDataclass(torch.nn.Module):
2431
            def forward(self, d : torch.Tensor):
2432
                return MyOutput(foo=d + d, bar=d * 3)
2433

2434
        class CallsModule(torch.nn.Module):
2435
            def __init__(self):
2436
                super().__init__()
2437
                self.m = ModuleReturnDataclass()
2438

2439
            def forward(self, x):
2440
                tmp = self.m(x)
2441
                return MyOutput(foo=tmp.foo, bar=tmp.bar)
2442

2443
        module = CallsModule()
2444
        traced_graph = symbolic_trace(module).graph
2445
        print(traced_graph)
2446

2447
        gm = GraphModule(module, traced_graph)
2448
        x = torch.rand(1)
2449

2450
        self.assertEqual(module(x), gm(x))
2451

2452

2453
    def test_trace_return_namedtuple(self):
2454
        """
2455
        Test case for Module that return namedtuple
2456
        """
2457
        class MyOutput(NamedTuple):
2458
            foo: torch.Tensor
2459
            bar: torch.Tensor
2460

2461
        class ModuleReturnNamedTuple(torch.nn.Module):
2462
            def forward(self, d : torch.Tensor):
2463
                return MyOutput(foo=d, bar=d)
2464

2465

2466
        module = ModuleReturnNamedTuple()
2467

2468
        traced_graph = symbolic_trace(module).graph
2469
        print(traced_graph)
2470

2471
        gm = GraphModule(module, traced_graph)
2472
        x = torch.rand(1)
2473

2474
        self.assertEqual(module(x), gm(x))
2475

2476
    def test_trace_dict_int_keys(self):
2477
        class ModWithDictArg(torch.nn.Module):
2478
            def forward(self, d : Dict[int, torch.Tensor]):
2479
                return d[42]
2480

2481
        class CallsModWithDict(torch.nn.Module):
2482
            def __init__(self):
2483
                super().__init__()
2484
                self.m = ModWithDictArg()
2485

2486
            def forward(self, x):
2487
                return self.m({42: x})
2488

2489
        class MyTracer(torch.fx.Tracer):
2490
            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2491
                return isinstance(m, ModWithDictArg)
2492

2493
        traced_graph = MyTracer().trace(CallsModWithDict())
2494

2495
    def test_trace_dict_proxy_keys(self):
2496
        class ModWithDictArg(torch.nn.Module):
2497
            def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
2498
                return d[42]
2499

2500
        class CallsModWithDict(torch.nn.Module):
2501
            def __init__(self):
2502
                super().__init__()
2503
                self.m = ModWithDictArg()
2504

2505
            def forward(self, x):
2506
                return self.m({x: x})
2507

2508
        class MyTracer(torch.fx.Tracer):
2509
            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2510
                return isinstance(m, ModWithDictArg)
2511

2512
        with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
2513
            traced_graph = MyTracer().trace(CallsModWithDict())
2514

2515
    def test_module_deepcopy_edit_nodes(self):
2516
        class Foo(torch.nn.Module):
2517
            def forward(self, x):
2518
                return torch.relu(x)
2519

2520
        traced1 = symbolic_trace(Foo())
2521
        copied = copy.deepcopy(traced1)
2522

2523
        for node in copied.graph.nodes:
2524
            if node.target == torch.relu:
2525
                node.target = torch.neg
2526

2527
        copied.recompile()
2528
        traced1.recompile()
2529

2530
        x = torch.randn(15, 15)
2531
        torch.testing.assert_close(traced1(x), torch.relu(x))
2532
        torch.testing.assert_close(copied(x), torch.neg(x))
2533

2534
    def test_direct_param_use(self):
2535
        class TransposeTest(torch.nn.Module):
2536
            def __init__(self):
2537
                super().__init__()
2538
                self.b = torch.nn.Parameter(torch.rand(4, 3))
2539

2540
            def forward(self, x):
2541
                return self.b
2542

2543
        class Foo(torch.nn.Module):
2544
            def __init__(self):
2545
                super().__init__()
2546
                self.a = TransposeTest()
2547

2548
            def forward(self, x):
2549
                return self.a.b, self.a.b.t(), self.a.b.view(12)
2550

2551
        traced = torch.fx.symbolic_trace(Foo())
2552
        assert all('constant' not in node.target for node in traced.graph.nodes)
2553

2554
    def test_single_default_arg(self):
2555
        class M(torch.nn.Module):
2556
            def forward(self, y=1):
2557
                return y
2558

2559
        m = M()
2560
        self.checkGraphModule(m, ())
2561
        self.checkGraphModule(m, (3,))
2562

2563
    def test_multiple_default_args(self):
2564
        class M(torch.nn.Module):
2565
            def forward(self, y=1, z=2):
2566
                return y + z
2567

2568
        m = M()
2569
        self.checkGraphModule(m, ())
2570
        self.checkGraphModule(m, (3,))
2571
        self.checkGraphModule(m, (3, 4))
2572

2573
    def test_regular_and_default_args(self):
2574
        class M(torch.nn.Module):
2575
            def forward(self, x, y=1):
2576
                return x + y
2577

2578
        m = M()
2579
        self.checkGraphModule(m, (2,))
2580
        self.checkGraphModule(m, (2, 3))
2581

2582
    def test_string_literal_return(self):
2583
        class M(torch.nn.Module):
2584
            def forward(self):
2585
                return "foo"
2586

2587
        m = M()
2588
        self.checkGraphModule(m, ())
2589

2590
    def test_namedtuple_return_qualname(self):
2591
        class NamedTupReturn(torch.nn.Module):
2592
            def forward(self, x):
2593
                return MyNamedTup(x, x)
2594

2595
        traced = symbolic_trace(NamedTupReturn())
2596
        input = torch.rand(3, 4)
2597
        self.assertEqual(traced(input), MyNamedTup(input, input))
2598

2599
    def test_update_args_kwargs_yells_at_you(self):
2600
        symtraced = symbolic_trace(SimpleTest())
2601
        node = next(iter(symtraced.graph.nodes))
2602
        with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
2603
            node.__update_args_kwargs((), {})
2604

2605
    def test_torchbind_class_attribute_in_fx(self):
2606
        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2607
            self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
2608

2609
        class FooBar1234(torch.nn.Module):
2610
            def __init__(self):
2611
                super().__init__()
2612
                self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
2613

2614
            def forward(self):
2615
                return self.f.top()
2616

2617
        m = FooBar1234()
2618
        self.checkGraphModule(m, ())
2619

2620
    def test_torchbind_class_attribute_in_fx_tensor_arg(self):
2621
        if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2622
            self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
2623

2624
        class FooBar2341(torch.nn.Module):
2625
            def __init__(self):
2626
                super().__init__()
2627
                self.f = torch.classes._TorchScriptTesting._ReLUClass()
2628

2629
            def forward(self, x):
2630
                return self.f.run(x)
2631

2632
        m = FooBar2341()
2633

2634
        traced = symbolic_trace(m)
2635
        input = torch.randn(3, 4)
2636
        self.assertEqual(traced(input), m(input))
2637

2638
        self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2639

2640
    def test_script_method_trace(self):
2641
        class Scripted(torch.nn.Module):
2642
            def forward(self, x):
2643
                return torch.relu(x)
2644

2645
        class Holder(torch.nn.Module):
2646
            def __init__(self):
2647
                super().__init__()
2648
                self.s = torch.jit.script(Scripted())
2649

2650
            def forward(self, x):
2651
                return self.s(x)
2652

2653
        h = Holder()
2654
        traced = symbolic_trace(h)
2655
        input = torch.randn(3, 4)
2656
        self.assertEqual(traced(input), h(input))
2657

2658
        self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2659

2660
    def test_namedtuple_return_trace(self):
2661
        class NamedTupReturn(torch.nn.Module):
2662
            def forward(self, x):
2663
                return Pair(x, x)
2664

2665
        traced = symbolic_trace(NamedTupReturn())
2666
        input = torch.rand(3, 4)
2667
        self.assertEqual(traced(input), Pair(input, input))
2668

2669
    def test_named_tuple_inlined(self):
2670
        class NamedTupMod(torch.nn.Module):
2671
            def forward(self, inp):
2672
                return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
2673

2674
        m = NamedTupMod()
2675
        input = torch.rand(3, 4)
2676
        ref = m(input)
2677
        traced = symbolic_trace(m)
2678

2679
        res = traced(input)
2680
        self.assertEqual(ref, res)
2681

2682
        # Check Pair NamedTuple works when inlined into the function call.
2683
        ph = call_func = None
2684
        for node in traced.graph.nodes:
2685
            if node.op == "placeholder":
2686
                ph = node
2687
            elif node.op == "call_function" and node.target == wrapped_named_tup:
2688
                node.update_arg(0, Pair(ph, 1.2))
2689
                node.update_kwarg("p2", Pair(3.4, ph))
2690
                call_func = node
2691
                break
2692
        self.assertTrue(call_func is not None)
2693
        self.assertTrue(isinstance(call_func.args[0], Pair))
2694
        self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
2695
        self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
2696
        self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
2697

2698
        traced.graph.eliminate_dead_code()
2699
        traced.recompile()
2700
        res = traced(input)
2701
        self.assertEqual(ref, res)
2702

2703
    def test_return_type_exists(self):
2704
        class ReturnTypeModule(torch.nn.Module):
2705
            def other(self, x: List[str]) -> List[str]:
2706
                return x
2707

2708
            def forward(self, x: List[str]) -> List[str]:
2709
                return self.other(x)
2710

2711
        traced = symbolic_trace(ReturnTypeModule())
2712
        self.assertIn("-> typing_List[str]", traced._code)
2713
        scripted = torch.jit.script(traced)
2714
        self.assertIn("-> List[str]", scripted.code)
2715

2716
    def getitem_inner(self):
2717
        class GetItemBase(torch.nn.Module):
2718
            def __init__(self):
2719
                super().__init__()
2720
                self.register_buffer('pe', torch.randn(8, 8))
2721

2722
        class GetItem1(GetItemBase):
2723
            def forward(self, x):
2724
                return self.pe[:, :x.size(0)]
2725

2726
        class GetItem2(GetItemBase):
2727
            def forward(self, x):
2728
                return self.pe[x.size(0)]
2729

2730
        class GetItem3(GetItemBase):
2731
            def forward(self, x):
2732
                return self.pe[4]  # fx creates `self._tensor_constant0` here
2733

2734
        self.checkGraphModule(GetItem1(), [torch.zeros(4)])
2735
        self.checkGraphModule(GetItem2(), [torch.zeros(4)])
2736
        self.checkGraphModule(GetItem3(), [torch.zeros(4)])
2737

2738
    @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
2739
                         "Will be checked in test_getitem_subproc")
2740
    def test_getitem(self):
2741
        self.getitem_inner()
2742

2743
    def test_getitem_subproc(self):
2744
        # need to run this test in a subproc to work around:
2745
        #   https://github.com/pytorch/pytorch/issues/50710
2746
        proc = Process(target=run_getitem_target)
2747
        proc.start()
2748
        proc.join()
2749
        self.assertEqual(proc.exitcode, 0)
2750

2751

2752
    def test_user_friendly_call_provenance_with_function(self):
2753
        def fn(x):
2754
            return wrapper_fn(x)
2755

2756
        traced = torch.fx.symbolic_trace(fn)
2757

2758
        with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2759
                                    "being compiled since it was called"
2760
                                    " from 'fn.forward'"):
2761
            scripted = torch.jit.script(traced)
2762

2763
    def test_user_friendly_call_provenance_with_module(self):
2764
        class M(torch.nn.Module):
2765
            def forward(self, x):
2766
                return wrapper_fn(x)
2767

2768
        traced = torch.fx.symbolic_trace(M())
2769

2770
        with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2771
                                    "being compiled since it was called"
2772
                                    " from 'M.forward'"):
2773
            scripted = torch.jit.script(traced)
2774

2775
    def test_snake_case(self):
2776
        class M(torch.nn.Module):
2777
            def __init__(self):
2778
                super().__init__()
2779
                self.activations = torch.nn.ModuleDict([
2780
                    ["snake_case", torch.nn.ReLU()],
2781
                    ["PascalCase", torch.nn.LeakyReLU()],
2782
                    ["ALL_CAPS", torch.nn.PReLU()]
2783
                ])
2784

2785
            def forward(self, x):
2786
                a = self.activations["snake_case"](x)
2787
                b = self.activations["PascalCase"](x)
2788
                c = self.activations["ALL_CAPS"](x)
2789
                return a, b, c
2790

2791
        traced = symbolic_trace(M())
2792

2793
        check = [
2794
            ("activations_snake_case", "activations.snake_case"),
2795
            ("activations_pascal_case", "activations.PascalCase"),
2796
            ("activations_all_caps", "activations.ALL_CAPS")
2797
        ]
2798

2799
        i = 0
2800
        for node in traced.graph.nodes:
2801
            if node.op == "placeholder" or node.op == "output":
2802
                continue
2803
            name = check[i][0]
2804
            target = check[i][1]
2805
            self.assertEqual(name, node.name)
2806
            self.assertEqual(target, node.target)
2807
            i += 1
2808
        self.assertEqual(i, 3)
2809

2810
    def test_no_mutation(self):
2811
        from torch.fx.immutable_collections import immutable_list
2812
        x = immutable_list([3, 4])
2813
        with self.assertRaisesRegex(NotImplementedError, "new_args"):
2814
            x[0] = 4
2815

2816
    def test_partial_trace(self):
2817
        class Foo(torch.nn.Module):
2818
            def forward(self, x, y):
2819
                if y:
2820
                    return 2 * x
2821
                else:
2822
                    return x
2823
        mod = Foo()
2824
        mod_true = symbolic_trace(mod, concrete_args={'y': True})
2825
        mod_false = symbolic_trace(mod, concrete_args={'y': False})
2826
        self.assertEqual(mod_true(3, True), 6)
2827
        print(mod_true.code)
2828
        assert any(i.target == torch._assert for i in mod_true.graph.nodes)
2829
        with self.assertRaises(AssertionError):
2830
            mod_true(3, False)
2831
        self.assertEqual(mod_false(3, False), 3)
2832
        with self.assertRaises(AssertionError):
2833
            mod_false(3, True)
2834

2835
        def f_higher(a, f):
2836
            return f(a)
2837

2838
        nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
2839
        self.assertEqual(nf(3, lambda x: x * 2), 6)
2840

2841
    def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
2842
        class M(torch.nn.Module):
2843
            def __init__(self):
2844
                super().__init__()
2845
                self.W = torch.nn.Parameter(torch.randn(5))
2846

2847
            def forward(self, x):
2848
                return torch.dot(self.W, x)
2849

2850
        traced = torch.fx.symbolic_trace(M())
2851

2852
        out = [n for n in traced.graph.nodes if n.op == "output"][-1]
2853
        with traced.graph.inserting_before(out):
2854
            relu_out = traced.graph.call_method(method_name='relu',
2855
                                                args=(out.args[0],))
2856
        out.args = (relu_out,)
2857

2858
        traced.recompile()
2859

2860
        with self.capture_stderr() as captured:
2861
            with self.assertRaises(TypeError):
2862
                traced(5)
2863

2864
        self.assertRegex(captured[0],
2865
                         r"Call using an FX-traced Module, line .* of the "
2866
                         r"traced Module's generated forward function:")
2867

2868
    def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
2869
        class M(torch.nn.Module):
2870
            def __init__(self):
2871
                super().__init__()
2872
                self.linear = torch.nn.Linear(3, 4)
2873

2874
            def forward(self, x):
2875
                return self.linear(x)
2876

2877
        traced = torch.fx.symbolic_trace(M())
2878

2879
        # Do not change this to `capture_stderr` or another context
2880
        # manager without ensuring that the output is as expected
2881
        try:
2882
            traced(torch.rand(5, 5))
2883
        except RuntimeError:
2884
            captured = traceback.format_exc()
2885

2886
        self.assertNotRegex(captured,
2887
                            r"Call using an FX-traced Module, line .* of the "
2888
                            r"traced Module's generated forward function:")
2889

2890
    def test_graph_module_replicate_for_dp(self):
2891
        class Foo(torch.nn.Module):
2892
            def forward(self, x):
2893
                return torch.relu(x)
2894

2895
        gm = torch.fx.symbolic_trace(Foo())
2896

2897
        x = torch.randn(5, 3)
2898
        out = gm(x)
2899

2900
        replica = gm._replicate_for_data_parallel()
2901
        out_replica = replica(x)
2902

2903
        torch.testing.assert_close(out_replica, out)
2904

2905
    def test_ast_rewriter_rewrites_assert(self):
2906
        class M(torch.nn.Module):
2907
            def forward(self, x: torch.Tensor, y: int, z: int):
2908
                assert y == z
2909
                return torch.add(x, x)
2910

2911
        ast_rewriter = RewritingTracer()
2912
        graph = ast_rewriter.trace(M())
2913
        traced = GraphModule(ast_rewriter.root, graph, "gm")
2914

2915
        traced.graph.lint()
2916

2917
    def test_ast_rewriter_rewrites_assert_with_message(self):
2918
        class M(torch.nn.Module):
2919
            def forward(self, x: torch.Tensor, y: int, z: int):
2920
                assert y == z, "msg"
2921
                return torch.add(x, x)
2922

2923
        ast_rewriter = RewritingTracer()
2924
        graph = ast_rewriter.trace(M())
2925
        traced = GraphModule(ast_rewriter.root, graph, "gm")
2926

2927
        traced.graph.lint()
2928

2929
    def test_throw_out_variant(self):
2930
        def foo(x):
2931
            y = torch.rand_like(x)
2932
            torch.sigmoid(x, out=y)
2933
            return y
2934

2935
        class MyTracer(torch.fx.Tracer):
2936
            check_mutable_operations = True
2937

2938
        tracer = MyTracer()
2939
        with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
2940
            traced_graph = tracer.trace(foo)
2941

2942
    def test_ast_rewriter_reassigns_submodules(self):
2943
        class M(torch.nn.Module):
2944
            def __init__(self):
2945
                super().__init__()
2946
                self.bn = torch.nn.BatchNorm2d(100)
2947

2948
            def forward(self, x: torch.Tensor):
2949
                return torch.add(x, x)
2950

2951
        ast_rewriter = RewritingTracer()
2952
        graph = ast_rewriter.trace(M())
2953
        traced = GraphModule(ast_rewriter.root, graph, "gm")
2954

2955
        traced.graph.lint()
2956

2957
    def test_ast_rewriter_wrap(self):
2958
        self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
2959

2960
        def to_trace(y):
2961
            return (
2962
                a_lifted_leaf((4, y), 3)
2963
                + a_lifted_leaf((3, 4), 5)
2964
                + a_lifted_leaf((y, y), y)
2965
            )
2966

2967
        ast_rewriter = RewritingTracer()
2968
        graph = ast_rewriter.trace(to_trace)
2969
        traced = GraphModule(ast_rewriter.root, graph, "gm")
2970

2971
        self.assertIn("a_lifted_leaf", traced.code)
2972
        self.assertEqual(27, traced(2))
2973
        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
2974

2975
    def test_ast_rewriter_wrap_fn_directly(self):
2976
        self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
2977

2978
        def to_trace(y):
2979
            return (
2980
                a_lifted_leaf2((4, y), 3)
2981
                + a_lifted_leaf2((3, 4), 5)
2982
                + a_lifted_leaf2((y, y), y)
2983
            )
2984

2985
        ast_rewriter = RewritingTracer()
2986
        graph = ast_rewriter.trace(to_trace)
2987
        traced = GraphModule(ast_rewriter.root, graph, "gm")
2988

2989
        self.assertIn("a_lifted_leaf2", traced.code)
2990
        self.assertEqual(27, traced(2))
2991
        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
2992

2993
    def test_profiler_ranges_side_effect(self):
2994
        g = torch.fx.Graph()
2995
        handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',))
2996
        g.call_function(torch.ops.profiler._record_function_exit, (handle,))
2997
        g.output(None)
2998

2999
        found_targets = {}
3000
        for node in g.nodes:
3001
            if node.op == 'call_function':
3002
                found_targets.setdefault(node.target)
3003
        self.assertEqual(
3004
            list(found_targets.keys()),
3005
            [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3006
        )
3007

3008
        g.eliminate_dead_code()
3009
        found_targets = {}
3010
        for node in g.nodes:
3011
            if node.op == 'call_function':
3012
                found_targets.setdefault(node.target)
3013
        self.assertEqual(
3014
            list(found_targets.keys()),
3015
            [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3016
        )
3017

3018
    def test_ast_rewriter_wrapped_via_decorator(self):
3019
        class F(torch.nn.Module):
3020
            def forward(self, x):
3021
                return wrapped_via_decorator(x)
3022

3023
        ast_rewriter = RewritingTracer()
3024
        graph = ast_rewriter.trace(F())
3025
        traced = GraphModule(ast_rewriter.root, graph, "gm")
3026

3027
        self.assertIn("wrapped_via_decorator", traced.code)
3028
        self.assertEqual(traced(0), 1)
3029
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3030
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3031

3032
    def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
3033
        self.assertEqual(wrapped_via_decorator(0), 1)
3034

3035
        def to_trace(y):
3036
            return wrapped_via_decorator(y)
3037

3038
        ast_rewriter = RewritingTracer()
3039
        graph = ast_rewriter.trace(to_trace)
3040
        traced = GraphModule(ast_rewriter.root, graph, "gm")
3041

3042
        self.assertIn("wrapped_via_decorator", traced.code)
3043
        self.assertEqual(traced(0), 1)
3044
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3045
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3046

3047
        transformed = torch.fx.Transformer(traced).transform()
3048
        self.assertIn("wrapped_via_decorator", transformed.code)
3049
        self.assertEqual(transformed(0), 1)
3050
        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3051
        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3052

3053
    def test_ast_rewriter_wrap_with_submodule(self):
3054
        class M(torch.nn.Module):
3055
            def __init__(self):
3056
                super().__init__()
3057
                self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3058

3059
            def forward(self, x: torch.Tensor):
3060
                return wrapped_with_submodule(x, self.batchnorm1d)
3061

3062
        ast_rewriter = RewritingTracer()
3063
        graph = ast_rewriter.trace(M())
3064
        traced = GraphModule(ast_rewriter.root, graph, "gm")
3065

3066
        self.assertIn("wrapped_with_submodule", traced.code)
3067

3068
        input = torch.rand(3, 2)
3069
        ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3070
        self.assertEqual(ref_batchnorm1d(input), traced(input))
3071

3072
    def test_submodule_manipulation_API(self):
3073
        class C(torch.nn.Module):
3074
            def __init__(self):
3075
                super().__init__()
3076
                self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
3077
                self.param = torch.nn.Parameter(torch.rand(2, 3))
3078

3079
            def forward(self, x):
3080
                return self.conv(torch.cat([self.param, x]))
3081

3082
        class B(torch.nn.Module):
3083
            def __init__(self):
3084
                super().__init__()
3085
                self.linear = torch.nn.Linear(100, 200)
3086
                self.register_buffer("buf", torch.randn(2, 3))
3087
                self.net_c = C()
3088

3089
            def forward(self, x):
3090
                return self.linear(torch.cat([self.buf, self.net_c(x)]))
3091

3092
        class A(torch.nn.Module):
3093
            def __init__(self):
3094
                super().__init__()
3095
                self.net_b = B()
3096
                self.param = torch.nn.Parameter(torch.rand(2, 3))
3097

3098
            def forward(self, x):
3099
                return self.net_b(x) + self.param
3100

3101
        a = symbolic_trace(A())
3102

3103
        a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
3104

3105
        conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
3106
        with a.graph.inserting_before(conv):
3107
            with warnings.catch_warnings(record=True) as w:
3108
                dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
3109
                                              args=conv.args)
3110
                self.assertEqual(len(w), 0)
3111

3112
        conv.replace_all_uses_with(dropout)
3113
        a.graph.erase_node(conv)
3114
        a.recompile()
3115

3116
        def module_exists(gm: GraphModule, path: str) -> bool:
3117
            return any(path == name for name, _ in gm.named_modules())
3118

3119
        def parameter_exists(gm: GraphModule, path: str) -> bool:
3120
            return (any(path == name for name, _ in gm.named_parameters())
3121
                    and any(path == name for name in gm.state_dict().keys()))
3122

3123
        def buffer_exists(gm: GraphModule, path: str) -> bool:
3124
            return (any(path == name for name, _ in gm.named_buffers())
3125
                    and any(path == name for name in gm.state_dict().keys()))
3126

3127
        # Test that we added the "dropout" submodule
3128
        self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
3129

3130
        # Test `get_submodule` with an added submodule
3131
        self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
3132

3133
        # Test that the "conv" submodule is still there
3134
        self.assertTrue(module_exists(a, "net_b.net_c.conv"))
3135

3136
        # Test `get_submodule` with an original module
3137
        self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
3138

3139
        # Test that the "conv" node is NOT still there
3140
        conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
3141
        self.assertEqual(conv, [])
3142

3143
        a.delete_submodule("net_b.net_c.conv")
3144

3145
        # Test that the "conv" submodule is now gone
3146
        self.assertFalse(module_exists(a, "net_b.net_c.conv"))
3147

3148
        # Test `get_submodule` with a deleted submodule
3149
        with self.assertRaisesRegex(AttributeError, "has no attribute "
3150
                                    "`conv`"):
3151
            self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
3152

3153
        # Test `get_attr` warnings
3154
        cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
3155

3156
        with a.graph.inserting_before(cat):
3157

3158
            with warnings.catch_warnings(record=True) as w:
3159
                param = a.graph.get_attr(qualified_name="net_b.net_c.param")
3160
                self.assertEqual(len(w), 0)
3161

3162
            with self.assertWarnsRegex(UserWarning, "Attempted to "
3163
                                       "insert a get_attr Node with no "
3164
                                       "underlying reference in the "
3165
                                       "owning GraphModule"):
3166
                bad_param = a.graph.get_attr(qualified_name="net_b.param")
3167
                a.graph.erase_node(bad_param)
3168

3169
        cat.args = (*cat.args, param)
3170

3171
        a.recompile()
3172

3173
        a.graph.lint()
3174

3175
        # Test `get_parameter`
3176
        a.get_parameter("net_b.net_c.param")
3177
        with self.assertRaisesRegex(AttributeError, "is not an "
3178
                                    "nn.Parameter"):
3179
            a.get_parameter("net_b.buf")
3180
        with self.assertRaisesRegex(AttributeError, "has no attribute "
3181
                                    "`param`"):
3182
            a.get_parameter("net_b.param")
3183

3184
        # Test `get_buffer`
3185
        a.get_buffer("net_b.buf")
3186
        with self.assertRaisesRegex(AttributeError, "is not a "
3187
                                    "buffer"):
3188
            a.get_buffer("net_b.net_c.param")
3189
        with self.assertRaisesRegex(AttributeError, "has no attribute "
3190
                                    "`buf`"):
3191
            a.get_buffer("net_b.net_c.buf")
3192

3193
        # Test non-nested attributes
3194
        a.get_submodule("")
3195
        a.get_parameter("param")
3196

3197
        # Insert some unused submodules
3198
        a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
3199
        a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
3200
        a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
3201
        a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
3202

3203
        # Garbage collection
3204
        a.delete_all_unused_submodules()
3205

3206
        # Test that all the unused submodules are gone
3207
        self.assertFalse(module_exists(a, "net_b.embedding"))
3208
        self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
3209
        self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
3210
        self.assertFalse(module_exists(a, "batch_norm_2d"))
3211

3212
        # Test that we didn't delete any unused Parameters or buffers
3213
        self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
3214
        self.assertTrue(buffer_exists(a, "net_b.buf"))
3215

3216
        a.graph.lint()
3217

3218
    def test_delete_unused_submodules_leaf(self):
3219
        class SubModule(torch.nn.Module):
3220
            def __init__(self):
3221
                super().__init__()
3222
                self.linear = torch.nn.Linear(10, 10)
3223
                self.relu = torch.nn.ReLU()
3224

3225
            def forward(self, x):
3226
                x = self.linear(x)
3227
                x = self.relu(x)
3228
                return x
3229

3230
        class Model(torch.nn.Module):
3231
            def __init__(self):
3232
                super().__init__()
3233
                self.submod = SubModule()
3234

3235
            def forward(self, x):
3236
                x = self.submod(x)
3237
                return x
3238

3239
        model = Model()
3240

3241
        class MyCustomTracer(torch.fx.Tracer):
3242
            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
3243
                return module_qualified_name == "submod"
3244

3245
        inputs = torch.randn(1, 10)
3246
        traced_graph = MyCustomTracer().trace(model)
3247
        gm2 = torch.fx.GraphModule(model, traced_graph)
3248
        gm2.delete_all_unused_submodules()
3249
        torch.testing.assert_close(gm2(inputs), model(inputs))
3250

3251
    def test_fx_stateless(self):
3252
        class MockModule(torch.nn.Module):
3253
            def __init__(self):
3254
                super().__init__()
3255
                self.l1 = torch.nn.Linear(1, 1)
3256
                self.register_buffer('buffer', torch.ones(1))
3257

3258
            def forward(self, x):
3259
                return self.l1(x) + self.buffer
3260

3261
        module = MockModule()
3262
        x = torch.rand((1, 1))
3263
        weight = torch.tensor([[1.0]], requires_grad=True)
3264
        bias = torch.tensor([0.0], requires_grad=True)
3265
        buffer = torch.tensor([0.0])
3266
        parameters = {'l1.weight': weight,
3267
                      'l1.bias': bias,
3268
                      'buffer': buffer}
3269
        fx_module = torch.fx.symbolic_trace(module)
3270
        res = torch.func.functional_call(fx_module, parameters, x)
3271
        res.backward()
3272
        self.assertIsNotNone(weight.grad)
3273
        self.assertIsNotNone(bias.grad)
3274
        self.assertIsNone(buffer.grad)
3275
        # Gradient was not calculated for the module stated and buffers
3276
        self.assertIsNone(module.l1.weight.grad)
3277
        self.assertIsNone(module.l1.bias.grad)
3278
        self.assertIsNone(module.buffer.grad)
3279

3280
    def test_tracing_graphmodules_as_leaf_submodules(self):
3281
        class A(torch.nn.Module):
3282
            def forward(self, t):
3283
                return t + t
3284

3285
        class B(torch.nn.Module):
3286
            def __init__(self):
3287
                super(type(self), self).__init__()
3288
                self.calling = False
3289
                self.called = False
3290

3291
            def forward(self, t):
3292
                if self.calling:
3293
                    return t - t
3294
                else:
3295
                    return t + t
3296

3297
            def __call__(self, *args):
3298
                self.called = True
3299
                self.calling = True
3300
                return super(type(self), self).__call__(*args)
3301
                self.calling = False
3302

3303
        class M(torch.nn.Module):
3304
            def __init__(self, a, b):
3305
                super().__init__()
3306
                self.a = a
3307
                self.b = b
3308

3309
            def forward(self, t):
3310
                x = self.a(t)
3311
                y = self.b(t)
3312
                return x + y
3313

3314
        class LeafTracer(Tracer):
3315
            def is_leaf_module(self, module, name):
3316
                return True
3317

3318
        class LeafTracerNotB(Tracer):
3319
            def is_leaf_module(self, module, name):
3320
                return False if "b" in name else True
3321

3322
        # Recompile calls added "for fun", since they
3323
        # chain __call__ wrappers.
3324

3325
        #
3326
        # Test: B as a regular, non-leaf module
3327
        #
3328
        a = symbolic_trace(A())
3329
        a.recompile()
3330
        m = M(a, B())
3331
        graph = LeafTracerNotB().trace(m)
3332
        gm = GraphModule(m, graph)
3333
        gm.recompile()
3334

3335
        # Test graphmodule/submodule a is not inlined.
3336
        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3337
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3338
        self.assertTrue(len(match) == 1)
3339

3340
        # Test submodule b is not treated as leaf.
3341
        self.assertFalse(hasattr(gm, "b"))
3342

3343
        # Test assert custom __call__ on submodule b was honored.
3344
        match = [
3345
            n
3346
            for n in gm.graph.nodes
3347
            if n.op == "call_function" and n.target == operator.sub
3348
        ]
3349
        self.assertTrue(len(match) == 1)
3350

3351
        #
3352
        # Test: B as a regular, leaf module
3353
        # symbolic_trace should only patch torch.nn.Module.__call__,
3354
        # which means B.__call__ should still execute
3355
        #
3356
        a = symbolic_trace(A())
3357
        a.recompile()
3358
        b = B()
3359
        m = M(a, b)
3360
        graph = LeafTracer().trace(m)
3361
        gm = GraphModule(m, graph)
3362
        gm.recompile()
3363

3364
        # Test graphmodule/submodule a is not inlined.
3365
        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3366
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3367
        self.assertTrue(len(match) == 1)
3368

3369
        # Test submodule b is leaf:
3370
        self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3371
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3372
        self.assertTrue(len(match) == 1)
3373

3374
        # Test b.__call__ was run
3375
        self.assertTrue(b.called)
3376
        self.assertTrue(gm.get_submodule("b").called)
3377

3378
        #
3379
        # Test: B as GraphModule leaf
3380
        # __call__ not honored since symbolic_trace directly invokes forward()
3381
        #
3382
        a = symbolic_trace(A())
3383
        a.recompile()
3384
        b = symbolic_trace(B())
3385
        b.recompile()
3386
        m = M(a, b)
3387
        graph = LeafTracer().trace(m)
3388
        gm = GraphModule(m, graph)
3389
        gm.recompile()
3390

3391
        self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3392
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3393
        self.assertTrue(len(match) == 1)
3394

3395
        self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3396
        match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3397
        self.assertTrue(len(match) == 1)
3398

3399
    def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
3400
        class MyModule(torch.nn.Module):
3401
            def __init__(self):
3402
                super().__init__()
3403
                self.register_buffer("my_buff", torch.rand(3, 4))
3404
                self.register_parameter(
3405
                    "my_param", torch.nn.Parameter(torch.rand(3, 4))
3406
                )
3407

3408
            def forward(self, x):
3409
                return x + self.my_buff + self.my_param
3410

3411
        mod = MyModule()
3412
        mod_traced = symbolic_trace(mod)
3413

3414
        # Create new GraphModule based on original, either w/ dict or root module.
3415
        orig_buff = mod_traced.get_buffer("my_buff")
3416
        orig_param = mod_traced.get_parameter("my_param")
3417
        mod_traced_new = GraphModule(
3418
            {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
3419
            mod_traced.graph,
3420
        )
3421

3422
        # Check that both my_buff and my_param are found and the same.
3423
        try:
3424
            new_buff = mod_traced_new.get_buffer("my_buff")
3425
        except Exception:
3426
            self.fail("Did not find my_buff")
3427
        self.assertEqual(orig_buff, new_buff)
3428

3429
        try:
3430
            new_param = mod_traced_new.get_parameter("my_param")
3431
        except Exception:
3432
            self.fail("Did not find my_param")
3433
        self.assertEqual(orig_param, new_param)
3434

3435
        x = torch.rand(3, 4)
3436
        orig_out = mod_traced(x)
3437
        submodules_out = mod_traced_new(x)
3438

3439
        self.assertEqual(orig_out, submodules_out)
3440

3441
    def test_graph_module_init_buffer_param_copied_dict_init(self):
3442
        self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
3443

3444
    def test_graph_module_init_buffer_param_copied_mod_init(self):
3445
        self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
3446

3447
    def test_annotations_with_no_forward_references(self):
3448
        class A:
3449
            def __call__(self, x: torch.Tensor):
3450
                return torch.add(x, x)
3451

3452
        class M(torch.nn.Module):
3453
            def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
3454
                return a(x)
3455

3456
        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3457

3458
    def test_annotations_with_forward_references(self):
3459
        class A:
3460
            def __call__(self, x: torch.Tensor):
3461
                return torch.add(x, x)
3462

3463
        class M(torch.nn.Module):
3464
            def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
3465
                return a(x)
3466

3467
        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3468

3469
    def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
3470
        class A:
3471
            def __call__(self, x: torch.Tensor):
3472
                return torch.add(x, x)
3473

3474
        class M(torch.nn.Module):
3475
            def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
3476
                return a(x[0])
3477

3478
        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3479

3480
    def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
3481
        class A:
3482
            def __call__(self, x: torch.Tensor):
3483
                return torch.add(x, x)
3484

3485
        class M(torch.nn.Module):
3486
            def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
3487
                return a(x)[0]
3488

3489
        self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3490

3491
    @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
3492
                     "`annotations` is not defined in Python <3.7")
3493
    def test_annotation_with_future(self):
3494
        try:
3495
            import fx.test_future    # noqa: F401
3496
        finally:
3497
            del sys.modules["__future__"]
3498

3499
    @unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
3500
    def test_annotations_empty_tuple(self):
3501
        class Foo(torch.nn.Module):
3502
            def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
3503
                return "foo"
3504

3505
        traced = torch.fx.symbolic_trace(Foo())
3506

3507
        x = ()
3508
        y = ("bar", ())
3509

3510
        traced(x, y)
3511

3512
        FileCheck().check("_Tuple[()]")   \
3513
                   .check("typing_Tuple[str,typing_Tuple[()]]") \
3514
                   .run(traced.code)
3515

3516
        scripted = torch.jit.script(traced)
3517

3518
        scripted(x, y)
3519

3520
        FileCheck().check("Tuple[()]")   \
3521
            .check("Tuple[str, Tuple[()]]")    \
3522
            .run(scripted.code)
3523

3524
    @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108")
3525
    @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10")
3526
    def test_assert(self):
3527
        def f(x):
3528
            assert x > 1
3529
            return x + 1
3530
        try:
3531
            torch.fx.proxy.TracerBase.trace_asserts = True
3532
            traced = symbolic_trace(f)
3533
        finally:
3534
            torch.fx.proxy.TracerBase.trace_asserts = False
3535

3536
        self.assertEqual(f(2), traced(2))
3537
        with self.assertRaises(AssertionError):
3538
            traced(0)
3539

3540
    def test_pytree(self):
3541
        # Used to test that you can use your own placeholder class
3542
        class PHTest(PHBase):
3543
            pass
3544

3545
        def f_sum(x):
3546
            return sum(x)
3547

3548
        def f_sum_dict(x):
3549
            out = 0
3550
            for v in x.values():
3551
                out += v
3552
            return out
3553

3554
        def f_dict_list_map(x):
3555
            new_dict = {}
3556
            for k, v in x.items():
3557
                new_dict[k] = [i + 1 for i in v]
3558
            return new_dict
3559

3560
        def f_dict_add(x):
3561
            return x['a'] + sum(x['z'])
3562

3563
        def f_namedtuple_add(x):
3564
            return x.x + x.y
3565

3566
        pytree.register_pytree_node(
3567
            Foo,
3568
            lambda x: ([x.a, x.b], None),
3569
            lambda x, _: Foo(x[0], x[1]),
3570
        )
3571
        fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
3572

3573
        def f_custom(x):
3574
            return x.a + x.b
3575

3576
        def f_custom_dict(x):
3577
            return f_sum_dict(x.a) + x.b
3578

3579
        def f_return_custom(x):
3580
            return Foo(x.b, x.a)
3581

3582
        tests = [
3583
            (f_sum, [PH, PH, PH]),
3584
            (f_sum, []),
3585
            (f_sum, [PHTest(), PHTest(), PHTest()]),
3586
            (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
3587
            (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
3588
            (f_dict_list_map, {5: (PH, PH, PH)}),
3589
            (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
3590
            (f_dict_add, {'a': PH, 'z': []}),
3591
            (f_custom, Foo(PH, PH)),
3592
            (f_custom, Foo(PH, 3)),
3593
            (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
3594
            # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
3595
            (f_namedtuple_add, Point(PH, PH)),
3596
        ]
3597

3598
        def verify_pytree(f, inp):
3599
            val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
3600
            num_flat_args = len([i == PH for i in pytree.tree_leaves(inp)])
3601
            orig_out = f(val)
3602
            nf = symbolic_trace(f, concrete_args={'x': inp})
3603
            self.assertEqual(nf(val), orig_out)
3604

3605
            bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3606
            bare_fx.graph.set_codegen(CodeGen())
3607
            bare_fx.recompile()
3608
            self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
3609

3610
            assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3611
            assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
3612

3613
            nf = symbolic_trace(nf)
3614
            self.assertEqual(nf(val), orig_out)
3615
            assert "tree_flatten_spec" not in nf.code
3616
            assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1
3617

3618
            nf = symbolic_trace(nf, concrete_args={'x': inp})
3619
            self.assertEqual(nf(val), orig_out)
3620
            assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3621
            assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
3622

3623
            pickled = pickle.dumps(nf)
3624
            nf = pickle.loads(pickled)
3625
            self.assertEqual(nf(val), orig_out)
3626

3627
        for f, inp in tests:
3628
            verify_pytree(f, inp)
3629

3630
    def test_pytree_concrete(self):
3631
        def f(b, a):
3632
            if b:
3633
                return a['a']
3634
            else:
3635
                return a['z']
3636

3637
        inp = {'a': {'a': PH, 'z': PH}, 'b': True}
3638
        nf = symbolic_trace(f, concrete_args=inp)
3639
        val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
3640
        self.assertEqual(nf(**val), f(**val))
3641

3642
        nf = symbolic_trace(nf)
3643
        self.assertEqual(nf(**val), f(**val))
3644

3645
    def test_metadata_on_ph(self):
3646
        def f_sum(a: int, b: int) -> int:
3647
            return a + b
3648

3649
        # Due to unflattening of dict, the batch argument
3650
        # will be split into two separate nodes with the names
3651
        # "batch_1" and "batch_2", referring to the keys
3652
        # "f1" and "f2" respectively in the dict.
3653
        def f_dict(a: Dict[str, str]) -> bool:
3654
            return a["f1"] == a["f2"]
3655

3656
        def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]):
3657
            for node in gm.graph.nodes:
3658
                if node.op == "placeholder":
3659
                    self.assertTrue(node.name in arg_names)
3660
                    self.assertTrue(node.ph_key in metadata)
3661

3662
        verify_metadata(
3663
            gm=symbolic_trace(
3664
                f_sum,
3665
                concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")}
3666
            ),
3667
            arg_names=["a_1", "b_1"],
3668
            metadata=["a", "b"]
3669
        )
3670
        verify_metadata(
3671
            gm=symbolic_trace(
3672
                f_dict,
3673
                concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}}
3674
            ),
3675
            arg_names=["a_1", "a_2"],
3676
            metadata=["f1", "f2"]
3677
        )
3678

3679
        # Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag)
3680
        class TaggingTracer(Tracer):
3681
            def create_node(self, kind : str, target : Union[str, Callable],
3682
                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
3683
                            type_expr : Optional[Any] = None) -> Node:
3684
                n = super().create_node(kind, target, args, kwargs, name)
3685
                n.tag = "foo"
3686
                return n
3687

3688
        class PHWithTag(PHBase):
3689
            def __init__(self, tag: str):
3690
                super().__init__()
3691

3692
                self.tag = tag
3693

3694
        g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")})
3695
        for n in g.nodes:
3696
            self.assertTrue(hasattr(n, "tag"))
3697
            # Ensure that tag is still "foo" and not "bar" (from PHWithTag)
3698
            self.assertEqual(n.tag, "foo")
3699

3700
    def test_custom_codegen(self):
3701
        class ListCodeGen(CodeGen):
3702
            def gen_fn_def(self, free_vars, maybe_return_annotation):
3703
                lst_unpack = f"""
3704
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3705
    {', '.join(free_vars)} = args_list"""
3706
                return lst_unpack
3707

3708
            def additional_globals(self):
3709
                return [('List', typing.List)]
3710

3711
            def process_inputs(self, *inputs):
3712
                assert len(inputs) == 1
3713
                return inputs[0]
3714

3715
        def f(a, b):
3716
            return a + b
3717

3718
        nf = symbolic_trace(f)
3719
        vals = [torch.randn(3), torch.randn(3)]
3720
        self.assertEqual(nf(*vals), f(*vals))
3721

3722
        nf.graph.set_codegen(ListCodeGen())
3723
        nf.recompile()
3724

3725
        bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3726
        bare_fx.graph.set_codegen(CodeGen())
3727
        bare_fx.recompile()
3728

3729
        self.assertEqual(nf(vals), f(*vals))
3730
        self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals))
3731

3732
        ts_f = torch.jit.script(nf)
3733
        self.assertEqual(nf(vals), ts_f(vals))
3734

3735
    def test_custom_codegen_with_transformer(self):
3736
        class ListCodeGen(CodeGen):
3737
            def gen_fn_def(self, free_vars, maybe_return_annotation):
3738
                lst_unpack = f"""
3739
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3740
    {', '.join(free_vars)} = args_list"""
3741
                return lst_unpack
3742

3743
            def additional_globals(self):
3744
                return [('List', typing.List)]
3745

3746
            def process_inputs(self, *inputs):
3747
                assert len(inputs) == 1
3748
                return inputs[0]
3749

3750
        def f(a, b):
3751
            return a + b
3752

3753
        nf = symbolic_trace(f)
3754
        vals = [torch.randn(3), torch.randn(3)]
3755
        self.assertEqual(nf(*vals), f(*vals))
3756

3757
        nf.graph.set_codegen(ListCodeGen())
3758
        nf.recompile()
3759
        self.assertEqual(nf(vals), f(*vals))
3760

3761
        transformed_gm = Transformer(nf).transform()
3762
        self.assertEqual(nf(vals), transformed_gm(vals))
3763

3764
    def test_interpreter_with_codegen(self):
3765
        class ListCodeGen(CodeGen):
3766
            def gen_fn_def(self, free_vars, maybe_return_annotation):
3767
                lst_unpack = f"""
3768
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3769
    {', '.join(free_vars)} = args_list"""
3770
                return lst_unpack
3771

3772
            def additional_globals(self):
3773
                return [('List', typing.List)]
3774

3775
            def process_inputs(self, *inputs):
3776
                assert len(inputs) == 1
3777
                return inputs[0]
3778

3779
            def generate_output(self, output_args):
3780
                return f'return list({repr(output_args)})'
3781

3782
            def process_outputs(self, outputs):
3783
                return list(outputs)
3784

3785
        def f(a, b):
3786
            a = a + b
3787
            b = a + b
3788
            return a, b
3789

3790
        nf = symbolic_trace(f)
3791
        vals = [torch.randn(3), torch.randn(3)]
3792
        nf.graph.set_codegen(ListCodeGen())
3793
        nf.recompile()
3794
        self.assertEqual(Interpreter(nf).run(vals), nf(vals))
3795

3796
    def test_imul_code_print(self):
3797
        graph = torch.fx.Graph()
3798
        a = graph.placeholder("a")
3799
        b = graph.placeholder("b")
3800
        graph.call_function(operator.imul, (a, b), {})
3801
        graph.output(a)
3802
        gm = torch.fx.GraphModule({}, graph)
3803
        gm.recompile()
3804
        self.assertEqual(gm(2, 3), 6)
3805
        self.assertIn("a *= b", gm.code)
3806

3807
    def test_deepcopy_tracer(self):
3808
        def fn(x, y):
3809
            return (x + y).relu().sin()
3810

3811
        tracer = Tracer()
3812
        tracer_before = copy.deepcopy(tracer)
3813
        tracer.trace(fn)
3814
        tracer_after = copy.deepcopy(tracer)
3815

3816
        self.assertEqual(str(tracer.graph), str(tracer_after.graph))
3817
        self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph))
3818

3819
    def test_deepcopy_graphmodule(self):
3820
        m = symbolic_trace(SimpleTest())
3821
        m.meta['hello'] = 'world'
3822
        copy_m = copy.deepcopy(m)
3823
        self.assertEqual(copy_m.meta['hello'], 'world')
3824

3825
    def test_deepcopy_no_recursion(self):
3826
        m = symbolic_trace(SimpleTest())
3827
        m.meta['hello'] = m  # circular reference
3828
        copy_m = copy.deepcopy(m)  # finishes
3829
        self.assertEqual(id(copy_m), id(copy_m.meta['hello']))
3830

3831
    def test_enum(self):
3832
        from enum import Enum
3833

3834
        class Foo(Enum):
3835
            A = 1
3836
            B = 2
3837

3838
        def leaf_fn(arr, enum_val):
3839
            # Use the raw enum.
3840
            arr.append(enum_val)
3841
            return arr[-1].value
3842

3843
        def foo(x):
3844
            # Pass the enum as argument.
3845
            return leaf_fn(x, Foo.A)
3846

3847
        traced = torch.fx.symbolic_trace(foo)
3848
        self.assertEqual(foo([]), traced([]))
3849

3850
    def test_insert_arg(self):
3851
        m = symbolic_trace(SimpleTest())
3852
        m.register_buffer("buf", torch.tensor(0))
3853
        output_node = next(iter(reversed(m.graph.nodes)))
3854
        with m.graph.inserting_before(output_node):
3855
            a = m.graph.get_attr("buf")
3856
        r = len(output_node.args)
3857
        output_node.insert_arg(0, a)
3858
        self.assertEqual(len(output_node.args), r + 1)
3859
        self.assertEqual(len(a.users), 1)
3860
        self.assertIs(output_node.args[0], a)
3861
        self.assertIs(next(iter(a.users.keys())), output_node)
3862
        output_node.insert_arg(2, a)
3863
        self.assertEqual(len(output_node.args), r + 2)
3864
        self.assertEqual(len(a.users), 1)
3865
        self.assertIs(output_node.args[2], a)
3866
        self.assertIs(next(iter(a.users.keys())), output_node)
3867
        m.graph.lint()
3868

3869

3870

3871
def run_getitem_target():
3872
    from torch.fx._symbolic_trace import _wrapped_methods_to_patch
3873
    _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
3874
    try:
3875
        TestFX().getitem_inner()
3876
    finally:
3877
        _wrapped_methods_to_patch.pop()
3878

3879

3880
class TestOperatorSignatures(JitTestCase):
3881
    def setUp(self):
3882
        # Checking for mutable operations whil tracing is feature flagged
3883
        # Enable it in testing but not by default
3884
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3885
        torch.fx.proxy.TracerBase.check_mutable_operations = True
3886

3887
    def tearDown(self):
3888
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3889

3890
    @onlyCPU
3891
    @ops(op_db, allowed_dtypes=(torch.float,))
3892
    def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
3893
        if not isinstance(op.op, types.BuiltinFunctionType):
3894
            raise unittest.SkipTest("This path doesn't work on Python functions")
3895
        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
3896
        schemas = get_signature_for_torch_op(op.op)
3897
        if not schemas:
3898
            raise RuntimeError('No Schemas Returned')
3899
        for sample_input in sample_inputs_itr:
3900
            # Iterate through overloads until we hit a match. If we exit this
3901
            # loop via `else`, we haven't found a match
3902
            for schema in schemas:
3903
                try:
3904
                    bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
3905
                    bound_args.apply_defaults()
3906
                    op(*bound_args.args, **bound_args.kwargs)
3907
                    break
3908
                except TypeError as e:
3909
                    pass
3910
            else:
3911
                raise RuntimeError(f'Did not match any schemas for op {op.name}!')
3912

3913

3914
class TestFXAPIBackwardCompatibility(JitTestCase):
3915
    def setUp(self):
3916
        super().setUp()
3917
        self.maxDiff = None
3918

3919
        # Checking for mutable operations whil tracing is feature flagged
3920
        # Enable it in testing but not by default
3921
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3922
        torch.fx.proxy.TracerBase.check_mutable_operations = True
3923

3924
    def tearDown(self):
3925
        super().tearDown()
3926
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3927

3928

3929
    def _fn_to_stable_annotation_str(self, obj):
3930
        """
3931
        Unfortunately we have to serialize function signatures manually since
3932
        serialization for `inspect.Signature` objects is not stable across
3933
        python versions
3934
        """
3935
        fn_name = torch.typename(obj)
3936

3937
        signature = inspect.signature(obj)
3938

3939
        sig_str = f'{fn_name}{signature}'
3940

3941
        arg_strs = []
3942
        for k, v in signature.parameters.items():
3943
            maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
3944
                if v.annotation is not inspect.Signature.empty else ''
3945

3946
            def default_val_str(val):
3947
                if isinstance(val, (tuple, list)):
3948
                    str_pieces = ['(' if isinstance(val, tuple) else '[']
3949
                    str_pieces.append(', '.join(default_val_str(v) for v in val))
3950
                    if isinstance(val, tuple) and len(str_pieces) == 2:
3951
                        str_pieces.append(',')
3952
                    str_pieces.append(')' if isinstance(val, tuple) else ']')
3953
                    return ''.join(str_pieces)
3954

3955
                # Need to fix up some default value strings.
3956
                # First case: modules. Default module `repr` contains the FS path of the module.
3957
                # Don't leak that
3958
                if isinstance(val, types.ModuleType):
3959
                    return f'<module {val.__name__}>'
3960

3961
                # Second case: callables. Callables (such as lambdas) encode their address in
3962
                # their string repr. Don't do that
3963
                if callable(val):
3964
                    return f'<function {val.__name__}>'
3965

3966
                return str(val)
3967

3968
            if v.default is not inspect.Signature.empty:
3969
                default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
3970
                maybe_default = f' = {default_val_str}'
3971
            else:
3972
                maybe_default = ''
3973
            maybe_stars = ''
3974
            if v.kind == inspect.Parameter.VAR_POSITIONAL:
3975
                maybe_stars = '*'
3976
            elif v.kind == inspect.Parameter.VAR_KEYWORD:
3977
                maybe_stars = '**'
3978
            arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
3979

3980
        return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
3981
            if signature.return_annotation is not inspect.Signature.empty else ''
3982

3983
        return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
3984

3985
    def _annotation_type_to_stable_str(self, t, sig_str):
3986
        if t is inspect.Signature.empty:
3987
            return ''
3988

3989
        # Forward ref
3990
        if isinstance(t, str):
3991
            return f"'{t}'"
3992
        if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
3993
            return t.__forward_arg__
3994
        if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
3995
            return t.__forward_arg__
3996

3997
        trivial_mappings = {
3998
            str : 'str',
3999
            int : 'int',
4000
            float: 'float',
4001
            bool: 'bool',
4002
            torch.dtype: 'torch.dtype',
4003
            torch.Tensor: 'torch.Tensor',
4004
            torch.device: 'torch.device',
4005
            torch.memory_format: 'torch.memory_format',
4006
            slice: 'slice',
4007
            torch.nn.Module: 'torch.nn.modules.module.Module',
4008
            torch.fx.Graph : 'torch.fx.graph.Graph',
4009
            torch.fx.Node : 'torch.fx.node.Node',
4010
            torch.fx.Proxy : 'torch.fx.proxy.Proxy',
4011
            torch.fx.node.Target : 'torch.fx.node.Target',
4012
            torch.fx.node.Argument : 'torch.fx.node.Argument',
4013
            torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
4014
            torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
4015
            torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
4016
            Ellipsis : '...',
4017
            typing.Any: 'Any',
4018
            type(None): 'NoneType',
4019
            None: 'None',
4020
            typing.Iterator: 'Iterator',
4021
        }
4022

4023
        mapping = trivial_mappings.get(t, None)
4024
        if mapping:
4025
            return mapping
4026

4027
        # Handle types with contained types
4028
        contained = getattr(t, '__args__', None) or []
4029

4030
        # Callables contain a bare List for arguments
4031
        contained = t if isinstance(t, list) else contained
4032

4033
        # Python 3.8 puts type vars into __args__ for unbound types such as Dict
4034
        if all(isinstance(ct, typing.TypeVar) for ct in contained):
4035
            contained = []
4036

4037
        contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
4038
        contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
4039

4040

4041
        origin = getattr(t, '__origin__', None)
4042
        if origin is None:
4043
            # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
4044
            origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
4045

4046
        if origin in {tuple, typing.Tuple}:
4047
            return f'Tuple{contained_type_str}'
4048
        if origin in {typing.Union}:
4049
            # Annoying hack to detect Optional
4050
            if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
4051
                not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
4052
                return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
4053
            return f'Union{contained_type_str}'
4054
        if origin in {dict, typing.Dict}:
4055
            return f'Dict{contained_type_str}'
4056
        if origin in {list, typing.List}:
4057
            return f'List{contained_type_str}'
4058
        if origin in {type, typing.Type}:
4059
            return f'Type{contained_type_str}'
4060
        if isinstance(t, typing.Callable):
4061
            if len(contained) > 0 and contained[0] is not Ellipsis:
4062
                return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
4063
            else:
4064
                return f'Callable{contained_type_str}'
4065

4066
        raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
4067
                           f'Please add support for this type and confirm with the '
4068
                           f'FX team that your signature change is valid.')
4069

4070

4071
    def test_function_back_compat(self):
4072
        """
4073
        Test backward compatibility for function signatures with
4074
        @compatibility(is_backward_compatible=True). Currently this checks for
4075
        exact signature matches, which may lead to false positives. If this
4076
        becomes too annoying, we can refine this check to actually parse out
4077
        the saved schema strings and check if the change is truly backward-
4078
        incompatible.
4079
        """
4080
        signature_strs = []
4081

4082
        for obj in _BACK_COMPAT_OBJECTS:
4083
            if not isinstance(obj, type):
4084
                signature_strs.append(self._fn_to_stable_annotation_str(obj))
4085

4086
        signature_strs.sort()
4087

4088
        try:
4089
            self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures')
4090
        except AssertionError as e:
4091
            msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
4092
                  f"as backwards-compatible has experienced a signature change. See the " \
4093
                  f"above exception context for more information. If this change was " \
4094
                  f"unintended, please revert it. If it was intended, check with the FX " \
4095
                  f"team to ensure that the proper deprecation protocols have been followed " \
4096
                  f"and subsequently --accept the change."
4097
            raise AssertionError(msg)  # noqa: TRY200
4098

4099
    def test_class_member_back_compat(self):
4100
        """
4101
        Test backward compatibility for members of classes with
4102
        @compatibility(is_backward_compatible=True). Currently this checks for
4103
        exact matches on the publicly visible members of the class.
4104
        """
4105
        class_method_strs = []
4106

4107
        for obj in _BACK_COMPAT_OBJECTS:
4108
            if isinstance(obj, type):
4109
                public_members = [name for name in obj.__dict__ if not name.startswith('_')]
4110
                class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
4111

4112
        class_method_strs.sort()
4113

4114
        try:
4115
            self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
4116
        except AssertionError as e:
4117
            msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
4118
                  f"as backwards-compatible has experienced change in its public members. See the " \
4119
                  f"above exception context for more information. If this change was " \
4120
                  f"unintended, please revert it. If it was intended, check with the FX " \
4121
                  f"team to ensure that the proper deprecation protocols have been followed " \
4122
                  f"and subsequently --accept the change."
4123
            raise AssertionError(msg) from e
4124

4125
    def test_public_api_surface(self):
4126
        non_back_compat_objects = {}
4127

4128
        def check_symbols_have_bc_designation(m, prefix):
4129
            if not m.__name__.startswith('torch.fx'):
4130
                return
4131
            if m.__name__.startswith('torch.fx.experimental'):
4132
                return
4133
            for k, v in m.__dict__.items():
4134
                if v is m:
4135
                    continue
4136
                if k.startswith('_'):
4137
                    continue
4138
                if isinstance(v, types.ModuleType):
4139
                    check_symbols_have_bc_designation(v, prefix + [k])
4140
                elif isinstance(v, (type, types.FunctionType)):
4141
                    if v not in _MARKED_WITH_COMPATIBILITY:
4142
                        non_back_compat_objects.setdefault(v)
4143

4144
        check_symbols_have_bc_designation(torch.fx, ['torch', 'fx'])
4145
        check_symbols_have_bc_designation(torch.fx.passes, ['torch', 'fx', 'passes'])
4146

4147
        non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
4148
        # Only want objects in torch.fx
4149
        non_back_compat_strs = [
4150
            s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
4151
        # Only want objects in public namespaces
4152
        non_back_compat_strs = [
4153
            s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
4154
        non_back_compat_strs.sort()
4155

4156
        if len(non_back_compat_strs) != 0:
4157
            raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
4158
                                 f"backwards-compatibility classification! Please decorate these "
4159
                                 f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
4160
                                 f"BC guarantees.")
4161

4162
    def test_adding_side_effect_function(self):
4163
        class TestModule(torch.nn.Module):
4164
            def forward(self, x):
4165
                side_effect_func(x)
4166
                return x
4167

4168
        gm = torch.fx.symbolic_trace(TestModule())
4169
        self.assertEqual(len(gm.graph.nodes), 3)
4170
        gm.graph.eliminate_dead_code()
4171
        gm.recompile()
4172
        self.assertEqual(len(gm.graph.nodes), 3)
4173
        found = False
4174
        for node in gm.graph.nodes:
4175
            if node.op == 'call_function' and node.target == side_effect_func:
4176
                found = True
4177
        self.assertTrue(found)
4178

4179
    def test_preserve_unused_attr_after_unpickle(self):
4180
        gm = torch.fx.symbolic_trace(Add())
4181
        gm.add_submodule("foo", Add())
4182
        gm.register_buffer("dummy_buffer", torch.empty(1))
4183
        gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1)))
4184
        b = io.BytesIO()
4185
        torch.save(gm, b)
4186
        b.seek(0)
4187
        reload_gm = torch.load(b)
4188
        self.assertTrue(hasattr(reload_gm, "foo"))
4189
        self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
4190
        self.assertTrue(hasattr(reload_gm, "dummy_parameter"))
4191

4192
# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454
4193
@unittest.skipIf(
4194
    sys.version_info >= (3, 12), "Failing on python 3.12+"
4195
)
4196
class TestFunctionalTracing(JitTestCase):
4197
    def setUp(self):
4198
        super().setUp()
4199
        # Checking for mutable operations whil tracing is feature flagged
4200
        # Enable it in testing but not by default
4201
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4202
        torch.fx.proxy.TracerBase.check_mutable_operations = True
4203

4204
    def tearDown(self):
4205
        super().tearDown()
4206
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4207

4208
    IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
4209
                    "has_torch_function_variadic", "handle_torch_function",
4210
                    "boolean_dispatch")
4211
    TO_PATCH = {"has_torch_function": None,
4212
                "has_torch_function_unary": None,
4213
                "has_torch_function_variadic": None}
4214

4215
    BUILT_IN_FUNC = (AssertionError, "")
4216
    PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
4217
    PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4218
    LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
4219
    ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
4220
    CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
4221
    INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
4222
    MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
4223

4224
    UNTRACEABLE_FUNCTIONALS = {
4225
        "adaptive_avg_pool1d": BUILT_IN_FUNC,
4226
        "avg_pool1d": BUILT_IN_FUNC,
4227
        "avg_pool2d": BUILT_IN_FUNC,
4228
        "avg_pool3d": BUILT_IN_FUNC,
4229
        "bilinear": BUILT_IN_FUNC,
4230
        "celu_": BUILT_IN_FUNC,
4231
        "channel_shuffle": BUILT_IN_FUNC,
4232
        "native_channel_shuffle": BUILT_IN_FUNC,
4233
        "conv1d": BUILT_IN_FUNC,
4234
        "conv2d": BUILT_IN_FUNC,
4235
        "conv3d": BUILT_IN_FUNC,
4236
        "conv_tbc": BUILT_IN_FUNC,
4237
        "conv_transpose1d": BUILT_IN_FUNC,
4238
        "conv_transpose2d": BUILT_IN_FUNC,
4239
        "conv_transpose3d": BUILT_IN_FUNC,
4240
        "cosine_similarity": BUILT_IN_FUNC,
4241
        "elu_": BUILT_IN_FUNC,
4242
        "gelu": BUILT_IN_FUNC,
4243
        "hardshrink": BUILT_IN_FUNC,
4244
        "hardtanh_": BUILT_IN_FUNC,
4245
        "leaky_relu_": BUILT_IN_FUNC,
4246
        "linear": BUILT_IN_FUNC,
4247
        "logsigmoid": BUILT_IN_FUNC,
4248
        "one_hot": BUILT_IN_FUNC,
4249
        "pad": ARG_TYPE_MISMATCH,
4250
        "pairwise_distance": BUILT_IN_FUNC,
4251
        "pdist": BUILT_IN_FUNC,
4252
        "pixel_shuffle": BUILT_IN_FUNC,
4253
        "pixel_unshuffle": BUILT_IN_FUNC,
4254
        "prelu": BUILT_IN_FUNC,
4255
        "relu_": BUILT_IN_FUNC,
4256
        "rrelu_": BUILT_IN_FUNC,
4257
        "selu_": BUILT_IN_FUNC,
4258
        "scaled_dot_product_attention": BUILT_IN_FUNC,
4259
        "softplus": BUILT_IN_FUNC,
4260
        "softshrink": BUILT_IN_FUNC,
4261
        "threshold_": BUILT_IN_FUNC,
4262

4263
        "adaptive_avg_pool2d": LEN_ERROR,
4264
        "adaptive_avg_pool3d": LEN_ERROR,
4265
        "adaptive_max_pool2d_with_indices": LEN_ERROR,
4266
        "adaptive_max_pool3d_with_indices": LEN_ERROR,
4267
        "instance_norm": CONTROL_FLOW,
4268

4269
        "adaptive_max_pool1d": PROXY_ITERABLE,
4270
        "adaptive_max_pool2d": PROXY_ITERABLE,
4271
        "adaptive_max_pool3d": PROXY_ITERABLE,
4272
        "fractional_max_pool2d": PROXY_ITERABLE,
4273
        "fractional_max_pool3d": PROXY_ITERABLE,
4274
        "max_pool1d": PROXY_ITERABLE,
4275
        "max_pool2d": PROXY_ITERABLE,
4276
        "max_pool3d": PROXY_ITERABLE,
4277

4278
        "lp_pool2d": PROXY_ITERATED,
4279
        "lp_pool3d": PROXY_ITERATED,
4280
        "max_unpool1d": PROXY_ITERATED,
4281
        "max_unpool2d": PROXY_ITERATED,
4282
        "max_unpool3d": PROXY_ITERATED,
4283
        "fold": PROXY_ITERATED,
4284
        "unfold": PROXY_ITERATED,
4285

4286
        "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4287
        "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4288
        "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4289
        "layer_norm": ARG_TYPE_MISMATCH,
4290
        "lp_pool1d": ARG_TYPE_MISMATCH,
4291

4292
        "affine_grid": CONTROL_FLOW,
4293
        "alpha_dropout": CONTROL_FLOW,
4294
        "batch_norm": CONTROL_FLOW,
4295
        "binary_cross_entropy": CONTROL_FLOW,
4296
        "binary_cross_entropy_with_logits": CONTROL_FLOW,
4297
        "celu": CONTROL_FLOW,
4298
        "cosine_embedding_loss": CONTROL_FLOW,
4299
        "cross_entropy": CONTROL_FLOW,
4300
        "ctc_loss": CONTROL_FLOW,
4301
        "dropout": CONTROL_FLOW,
4302
        "dropout1d": CONTROL_FLOW,
4303
        "dropout2d": CONTROL_FLOW,
4304
        "dropout3d": CONTROL_FLOW,
4305
        "elu": CONTROL_FLOW,
4306
        "embedding": CONTROL_FLOW,
4307
        "embedding_bag": CONTROL_FLOW,
4308
        "feature_alpha_dropout": CONTROL_FLOW,
4309
        "gaussian_nll_loss": CONTROL_FLOW,
4310
        "glu": CONTROL_FLOW,
4311
        "grid_sample": CONTROL_FLOW,
4312
        "group_norm": CONTROL_FLOW,
4313
        "gumbel_softmax": CONTROL_FLOW,
4314
        "hardsigmoid": CONTROL_FLOW,
4315
        "hardswish": CONTROL_FLOW,
4316
        "hardtanh": CONTROL_FLOW,
4317
        "hinge_embedding_loss": CONTROL_FLOW,
4318
        "huber_loss": CONTROL_FLOW,
4319
        "interpolate": CONTROL_FLOW,
4320
        "kl_div": CONTROL_FLOW,
4321
        "l1_loss": CONTROL_FLOW,
4322
        "leaky_relu": CONTROL_FLOW,
4323
        "local_response_norm": CONTROL_FLOW,
4324
        "margin_ranking_loss": CONTROL_FLOW,
4325
        "max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4326
        "max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4327
        "max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4328
        "mse_loss": CONTROL_FLOW,
4329
        "multi_head_attention_forward": CONTROL_FLOW,
4330
        "multi_margin_loss": CONTROL_FLOW,
4331
        "multilabel_margin_loss": CONTROL_FLOW,
4332
        "multilabel_soft_margin_loss": CONTROL_FLOW,
4333
        "nll_loss": CONTROL_FLOW,
4334
        "poisson_nll_loss": CONTROL_FLOW,
4335
        "relu": CONTROL_FLOW,
4336
        "relu6": CONTROL_FLOW,
4337
        "rrelu": CONTROL_FLOW,
4338
        "selu": CONTROL_FLOW,
4339
        "silu": CONTROL_FLOW,
4340
        "mish": CONTROL_FLOW,
4341
        "smooth_l1_loss": CONTROL_FLOW,
4342
        "soft_margin_loss": CONTROL_FLOW,
4343
        "threshold": CONTROL_FLOW,
4344
        "triplet_margin_loss": CONTROL_FLOW,
4345
        "triplet_margin_with_distance_loss": CONTROL_FLOW,
4346
        "upsample": CONTROL_FLOW,
4347

4348
        "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
4349
        "upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
4350
    }
4351

4352
    # List of nn.functionals with Tensor inputs but not with type annotation
4353
    FUNCTIONALS_WITHOUT_ANNOTATION = (
4354
        "adaptive_max_pool1d",
4355
        "adaptive_max_pool2d",
4356
        "adaptive_max_pool3d",
4357
        "fractional_max_pool2d",
4358
        "fractional_max_pool3d",
4359
        "max_pool1d",
4360
        "max_pool2d",
4361
        "max_pool3d",
4362
        "gaussian_nll_loss",
4363
        "upsample",
4364
        "upsample_bilinear",
4365
        "upsample_nearest",
4366
    )
4367

4368
    # Inconsistent behavior between Python 3.8 and other Python versions:
4369
    # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
4370
    # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
4371
    #                 internal exception above
4372
    # Use the following map to override the expected exception for Python 3.8
4373
    UNTRACEABLE_FUNCTIONALS_PY38 = {
4374
        "adaptive_max_pool1d": PROXY_ITERATED,
4375
        "adaptive_max_pool2d": PROXY_ITERATED,
4376
        "adaptive_max_pool3d": PROXY_ITERATED,
4377
        "fractional_max_pool2d": PROXY_ITERATED,
4378
        "fractional_max_pool3d": PROXY_ITERATED,
4379
        "max_pool1d": PROXY_ITERATED,
4380
        "max_pool2d": PROXY_ITERATED,
4381
        "max_pool3d": PROXY_ITERATED,
4382

4383
        "group_norm": CONTROL_FLOW
4384
    }
4385

4386
    @classmethod
4387
    def _get_functional(cls):
4388
        functional_list = []
4389
        for f in dir(torch.nn.functional):
4390
            if not f.islower():
4391
                continue
4392
            # Ignore internal functions
4393
            if f.startswith('_'):
4394
                continue
4395
            # Ignore supporting functions
4396
            if f in cls.IGNORE_FUNCS:
4397
                continue
4398
            fn = getattr(torch.nn.functional, f)
4399
            # Ignore non-callable object like modules
4400
            if not isinstance(fn, Callable):
4401
                continue
4402
            if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
4403
                try:
4404
                    sig = inspect.signature(fn)
4405
                    has_tensor_arg = False
4406
                    for param in sig.parameters.values():
4407
                        if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
4408
                            has_tensor_arg = True
4409
                    if not has_tensor_arg:
4410
                        continue
4411
                # No signature or Object is not supported
4412
                except ValueError:
4413
                    pass
4414
            functional_list.append((f, fn))
4415
        return functional_list
4416

4417
    @classmethod
4418
    def generate_test_func(cls, func_name, fn):
4419

4420
        def functional_test(self):
4421
            if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
4422
                    sys.version_info >= (3, 8) and sys.version_info < (3, 12):
4423
                exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
4424
                with self.assertRaisesRegex(exc, err):
4425
                    symbolic_trace(fn)
4426
            elif func_name in self.UNTRACEABLE_FUNCTIONALS:
4427
                exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
4428
                with self.assertRaisesRegex(exc, err):
4429
                    symbolic_trace(fn)
4430
            else:
4431
                symbolic_trace(fn)
4432
        return functional_test
4433

4434
    @classmethod
4435
    def generate_tests(cls):
4436
        functional_list = cls._get_functional()
4437
        for func_name, fn in functional_list:
4438
            test_name = "test_nn_functional_" + func_name
4439
            functional_test = cls.generate_test_func(func_name, fn)
4440
            setattr(cls, test_name, functional_test)
4441

4442
    @classmethod
4443
    def setUpClass(cls):
4444

4445
        def no(*args, **kwargs):
4446
            return False
4447

4448
        for name in cls.TO_PATCH.keys():
4449
            cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
4450
            setattr(torch.nn.functional, name, no)
4451

4452
    @classmethod
4453
    def tearDownClass(cls):
4454
        for name in cls.TO_PATCH.keys():
4455
            setattr(torch.nn.functional, name, cls.TO_PATCH[name])
4456

4457
TestFunctionalTracing.generate_tests()
4458

4459

4460
instantiate_device_type_tests(TestOperatorSignatures, globals())
4461

4462
@skipIfTorchDynamo("too slow")
4463
@skipIfNoTorchVision
4464
class TestVisionTracing(JitTestCase):
4465
    def setUp(self):
4466
        # Checking for mutable operations while tracing is feature flagged
4467
        # Enable it in testing but not by default
4468
        self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4469
        torch.fx.proxy.TracerBase.check_mutable_operations = True
4470

4471
    def tearDown(self):
4472
        torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4473

4474
    PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4475
    INCONSISTENT_TYPE = (
4476
        RuntimeError,
4477
        r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
4478
    )
4479

4480
    UNTRACEABLE_MODELS = {
4481
        "fasterrcnn_resnet50_fpn": PROXY_ITERATED,
4482
        "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4483
        "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
4484
        "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
4485
        "maskrcnn_resnet50_fpn": PROXY_ITERATED,
4486
        "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4487
        "keypointrcnn_resnet50_fpn": PROXY_ITERATED,
4488
        "retinanet_resnet50_fpn": PROXY_ITERATED,
4489
        "retinanet_resnet50_fpn_v2": PROXY_ITERATED,
4490
        "ssd300_vgg16": PROXY_ITERATED,
4491
        "fcos_resnet50_fpn": PROXY_ITERATED,
4492
        "ssdlite320_mobilenet_v3_large": PROXY_ITERATED,
4493
    }
4494
    UNSCRIPTABLE_MODELS = {
4495
        "googlenet": INCONSISTENT_TYPE,
4496
        "inception_v3": INCONSISTENT_TYPE,
4497
    }
4498

4499
    output_transform = {
4500
        "fcn_resnet50": lambda x: x["out"],
4501
        "fcn_resnet101": lambda x: x["out"],
4502
        "deeplabv3_resnet50": lambda x: x["out"],
4503
        "deeplabv3_resnet101": lambda x: x["out"],
4504
        "deeplabv3_mobilenet_v3_large": lambda x: x["out"],
4505
        "lraspp_mobilenet_v3_large": lambda x: x["out"],
4506
        "fasterrcnn_resnet50_fpn": lambda x: x[1],
4507
        "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4508
        "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
4509
        "maskrcnn_resnet50_fpn": lambda x: x[1],
4510
        "keypointrcnn_resnet50_fpn": lambda x: x[1],
4511
        "retinanet_resnet50_fpn": lambda x: x[1],
4512
    }
4513

4514
    @classmethod
4515
    def generate_test_fn(cls, name, x, kwargs):
4516
        def run_test(self):
4517
            model = torchvision_models.get_model(name, **kwargs)
4518
            model = model.eval()
4519
            if name in self.UNTRACEABLE_MODELS:
4520
                err, exc = self.UNTRACEABLE_MODELS[name]
4521
                with self.assertRaisesRegex(err, exc):
4522
                    graph = symbolic_trace(model)
4523
            else:
4524
                out_transform = self.output_transform.get(name, lambda x: x)
4525
                graph : torch.fx.GraphModule = symbolic_trace(model)
4526
                a = out_transform(model(x))
4527
                b = out_transform(graph(x))
4528
                self.assertEqual(a, b)
4529

4530
                if name in self.UNSCRIPTABLE_MODELS:
4531
                    err, exc = self.UNSCRIPTABLE_MODELS[name]
4532
                    with self.assertRaisesRegex(err, exc):
4533
                        script = torch.jit.script(graph)
4534
                else:
4535
                    script = torch.jit.script(graph)
4536
                    c = out_transform(script(x))
4537
                    self.assertEqual(a, c)
4538

4539
        return run_test
4540

4541
    @classmethod
4542
    def generate_classification_tests(cls):
4543
        for k in torchvision_models.list_models(module=torchvision_models):
4544
            test_name = 'test_torchvision_models_' + k
4545
            x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
4546
            kwargs = dict(num_classes=50)
4547
            model_test = cls.generate_test_fn(k, x, kwargs)
4548
            setattr(cls, test_name, model_test)
4549

4550
    @classmethod
4551
    def generate_segmentation_tests(cls):
4552
        for k in torchvision_models.list_models(module=torchvision_models.segmentation):
4553
            test_name = 'test_torchvision_models_segmentation_' + k
4554
            x = torch.rand(1, 3, 32, 32)
4555
            kwargs = dict(num_classes=10, pretrained_backbone=False)
4556
            model_test = cls.generate_test_fn(k, x, kwargs)
4557
            setattr(cls, test_name, model_test)
4558

4559
    @classmethod
4560
    def generate_detection_tests(cls):
4561
        for k in torchvision_models.list_models(module=torchvision_models.detection):
4562
            test_name = 'test_torchvision_models_detection_' + k
4563
            x = [torch.rand(3, 300, 300)]
4564
            kwargs = dict(num_classes=10, pretrained_backbone=False)
4565
            model_test = cls.generate_test_fn(k, x, kwargs)
4566
            setattr(cls, test_name, model_test)
4567

4568
    @classmethod
4569
    def generate_video_tests(cls):
4570
        for k in torchvision_models.list_models(module=torchvision_models.video):
4571
            test_name = 'test_torchvision_models_video_' + k
4572
            x = (
4573
                torch.rand(1, 3, 4, 112, 112)
4574
                if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"}
4575
                else torch.rand(1, 3, 16, 224, 224)
4576
            )
4577
            kwargs = dict(num_classes=50)
4578
            model_test = cls.generate_test_fn(k, x, kwargs)
4579
            setattr(cls, test_name, model_test)
4580

4581
    @classmethod
4582
    def generate_tests(cls):
4583
        cls.generate_classification_tests()
4584
        cls.generate_detection_tests()
4585
        cls.generate_segmentation_tests()
4586
        cls.generate_video_tests()
4587

4588
if HAS_TORCHVISION:
4589
    TestVisionTracing.generate_tests()
4590

4591
if __name__ == '__main__':
4592
    run_tests()
4593

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.