1
# Owner(s): ["module: fx"]
22
from functorch.experimental import control_flow
23
from torch.multiprocessing import Process
24
from torch.testing import FileCheck
25
from torch.testing._internal.common_methods_invocations import op_db
26
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
27
import torch.utils._pytree as pytree
28
import torch.fx._pytree as fx_pytree
29
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
30
from torch.fx.node import Target, Argument, _format_arg
31
from torch.fx.passes import shape_prop
32
from torch.fx.immutable_collections import immutable_dict, immutable_list
33
from torch.fx.experimental.rewriter import RewritingTracer
34
from torch.fx.operator_schemas import get_signature_for_torch_op
35
from copy import deepcopy
36
from collections import namedtuple
38
from torch.fx.proxy import TraceError
39
from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY
40
from torch.fx._symbolic_trace import PHBase, PHWithMeta
41
from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401
42
from fx.test_dce_pass import TestDCE # noqa: F401
43
from fx.test_fx_const_fold import TestConstFold # noqa: F401
44
from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow # noqa: F401
45
from fx.test_pass_infra import TestPassManager # noqa: F401
46
from fx.test_common_passes import TestCommonPass # noqa: F401
47
from fx.test_cse_pass import TestCSEPass # noqa: F401
48
from fx.test_matcher_utils import TestMatcher # noqa: F401
49
from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401
51
from fx.test_gradual_type import AnnotationsTest # noqa: F401
52
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
53
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
54
from torch.testing._internal.common_utils import (
58
find_library_location,
62
from torch.testing._internal.jit_utils import JitTestCase
64
from fx.named_tup import MyNamedTup
67
from torchvision import models as torchvision_models
68
HAS_TORCHVISION = True
70
HAS_TORCHVISION = False
71
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
74
class SimpleTest(torch.nn.Module):
76
return torch.relu(x + 3.0)
78
def a_non_torch_leaf(a, b):
81
# Used for test_autowrap_function. Autowrapped functions need to be global
82
def fx_int(x: float) -> int:
85
def fx_int_x2(x: float) -> int:
88
# used in test_pytree. It's all the way out here because pickling a GraphModule
89
# that uses Point errors out if Point is local to the function
90
Point = namedtuple('Point', ['x', 'y'])
92
# Test wrap() passing both a function name as well as a function
94
def a_lifted_leaf(a, b):
95
return a[0] + a[1] + b
98
# Test wrapping twice doesn't break anything
101
def a_lifted_leaf2(a, b):
102
return a[0] + a[1] + b
110
def wrapped_named_tup(p1, *, p2):
113
wrap(wrapped_named_tup)
116
def wrapped_via_decorator(a):
119
wrap('wrapped_with_submodule')
121
def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
122
return batchnorm1d(x)
126
def wrapper_inside_decorator(*args, **kwargs):
127
return f(*args, **kwargs)
128
return wrapper_inside_decorator
132
def wrapped_decorated_fn(x):
135
real_wrapped_via_decorator = wrapped_via_decorator
136
real_a_lifed_leaf = a_lifted_leaf
137
real_a_lifed_leaf2 = a_lifted_leaf2
145
class Pair(NamedTuple):
149
def _custom_fx_repr_fn(self) -> str:
150
return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
153
class Foo: # noqa: B209
154
def __init__(self, a, b):
158
class Add(torch.nn.Module):
159
def forward(self, x):
162
@torch.fx.has_side_effect
164
def side_effect_func(x: torch.Tensor):
167
class TestFX(JitTestCase):
170
# Checking for mutable operations whil tracing is feature flagged
171
# Enable it in testing but not by default
172
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
173
torch.fx.proxy.TracerBase.check_mutable_operations = True
175
if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
176
lib_file_path = find_library_location('libtorchbind_test.so')
177
torch.ops.load_library(str(lib_file_path))
181
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
183
def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
184
"""Check that an nn.Module's results match the GraphModule version
185
for a given set of args/kwargs.
187
kwargs = kwargs if kwargs else {}
188
ref_outs = m(*args, **kwargs)
189
gm = symbolic_trace(m)
191
test_outs = gm(*args, **kwargs)
192
self.assertEqual(ref_outs, test_outs)
194
def test_graph_module(self):
195
class MySub(torch.nn.Module):
196
def __init__(self) -> None:
198
self.w = torch.nn.Parameter(torch.rand(4, 3))
200
def forward(self, x):
203
class MyModule(torch.nn.Module):
204
def __init__(self) -> None:
206
self.lin = torch.nn.Linear(4, 3)
207
self.sub_mod = MySub()
208
self.w = torch.nn.Parameter(torch.rand(3))
210
def forward(self, A, B, c):
211
t = torch.sigmoid(A) + self.lin(c)
212
return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
215
gm = symbolic_trace(m)
217
ms = torch.jit.script(gm)
219
class M2(torch.nn.Module):
220
def forward(self, A):
221
m, idx = torch.max(A, 0)
222
return m + 1, idx + 1
225
gm2 = symbolic_trace(m2)
227
class T(torch.nn.Module):
229
def forward(self, A, b=4, *args, c=5, **kwargs):
230
x = A + 1 + args[0] + kwargs['3']
236
# test for issue described at https://github.com/pytorch/pytorch/issues/63883
237
class M3(torch.nn.Module):
238
def forward(self, x):
242
gm3 = symbolic_trace(m3)
243
new_instance = gm3.__new__(type(gm3))
244
new_instance.__init__(gm3, gm3.graph)
246
x = torch.randn(5, 3)
247
torch.testing.assert_close(new_instance(x), torch.relu(x))
249
def test_informative_co_filename(self):
250
class MyModule(torch.nn.Module):
251
def forward(self, a):
254
gm = symbolic_trace(MyModule())
255
self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename)
257
def test_custom_import(self):
258
graph = torch.fx.Graph()
259
a = graph.placeholder('x')
260
b = graph.placeholder('y')
261
c = graph.call_function(a_non_torch_leaf, (a, b))
262
d = graph.call_function(torch.sin, (c,))
264
gm = GraphModule(torch.nn.Module(), graph)
265
x, y = torch.rand(1), torch.rand(1)
266
self.assertEqual(torch.sin(x + y), gm(x, y))
268
def test_args_kwargs(self):
269
class T(torch.nn.Module):
270
def forward(self, *args, **kwargs):
271
x = args[0] + kwargs['foo']
275
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
277
def test_varargs_concrete(self):
278
class T(torch.nn.Module):
279
def forward(self, *args, **kwargs):
280
x = args[0] + args[1]
283
args = (torch.rand(1), torch.rand(1))
287
gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
289
test_outs = gm(*args)
290
self.assertEqual(ref_outs, test_outs)
292
def test_args_kwargs_no_self(self):
293
class T(torch.nn.Module):
294
def forward(*args, **kwargs): # noqa: B902
296
return torch.relu(args[1])
299
with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
300
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
302
def test_fx_shifts(self):
303
class MyModule(torch.nn.Module):
304
def forward(self, x):
305
return x << 3, x >> 3
307
input = torch.LongTensor(10).random_(0, 1024)
310
self.checkGraphModule(m, (input,))
312
def test_fx_and_or(self):
313
class MyModule(torch.nn.Module):
314
def forward(self, x):
317
input = torch.LongTensor(10).random_(0, 1024)
320
self.checkGraphModule(m, (input,))
323
class MyDictMod(torch.nn.Module):
324
def forward(self, d):
325
return d['3'].relu(), {'4' : d['3'].neg()}
327
input_dict = {'3': torch.rand(3, 4)}
330
self.checkGraphModule(m, (input_dict,))
332
def test_matmul_tracing(self):
333
const = torch.randn(3)
338
mod = symbolic_trace(matmul_f)
340
self.assertEqual(mod(inp), matmul_f(inp))
345
mod = symbolic_trace(rmatmul_f)
347
self.assertEqual(mod(inp), rmatmul_f(inp))
349
@skipIfNoDynamoSupport
350
def test_control_flow_tracing(self):
358
x = control_flow.cond(x[0] == 0, true, false, [x, y])
360
with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"):
361
_ = symbolic_trace(f)
363
def test_disallow_override(self):
364
# Custom delegate to disallow in-place tensor operations
365
class NoMutableCallTracer(Tracer):
366
def create_node(self, kind : str, target : Union[str, Callable],
367
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
368
type_expr : Optional[Any] = None) -> Node:
369
name = target if isinstance(target, str) else torch.typename(target)
371
raise RuntimeError('In-place operations are not supported')
372
return super().create_node(kind, target, args, kwargs, name)
375
class MyInplaceMod(torch.nn.Module):
376
def forward(self, x):
382
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
383
NoMutableCallTracer().trace(m)
386
class MyInplaceMod2(torch.nn.Module):
387
def forward(self, x):
391
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
392
NoMutableCallTracer().trace(m2)
394
# Test symbolic node as an arg
395
class MyInplaceMod3(torch.nn.Module):
396
def forward(self, x):
401
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
402
NoMutableCallTracer().trace(m3)
404
def test_leaf_module(self):
405
# Custom delegate to make it so that there are no leaf modules, everything
406
# should get traced through
407
class NoLeafModulesTracer(Tracer):
408
def is_leaf_module(self, m, qualname):
411
class MyReluMod(torch.nn.Module):
412
def __init__(self) -> None:
414
self.relu = torch.nn.ReLU()
416
def forward(self, x):
420
sym = NoLeafModulesTracer().trace(mrm)
421
for node in sym.nodes:
422
self.assertNotEqual(node.op, 'call_module')
426
self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
429
return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
431
m = symbolic_trace(to_trace)
432
self.assertIn('a_lifted_leaf', m.code)
433
self.assertEqual(27, m(2))
434
self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
436
def test_wrap_fn_directly(self):
437
self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
440
return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
442
m = symbolic_trace(to_trace)
443
self.assertIn('a_lifted_leaf2', m.code)
444
self.assertEqual(27, m(2))
445
self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
447
def test_wrapped_via_decorator(self):
448
self.assertEqual(wrapped_via_decorator(0), 1)
451
return wrapped_via_decorator(y)
453
m = symbolic_trace(to_trace)
454
self.assertIn('wrapped_via_decorator', m.code)
455
self.assertEqual(m(0), 1)
456
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
457
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
459
def test_wrapped_via_decorator_and_transformed(self):
460
self.assertEqual(wrapped_via_decorator(0), 1)
463
return wrapped_via_decorator(y)
465
m = symbolic_trace(to_trace)
466
self.assertIn('wrapped_via_decorator', m.code)
467
self.assertEqual(m(0), 1)
468
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
469
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
471
transformed = torch.fx.Transformer(m).transform()
472
self.assertIn('wrapped_via_decorator', transformed.code)
473
self.assertEqual(transformed(0), 1)
474
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
475
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
477
def test_wrap_with_submodule(self):
479
class M(torch.nn.Module):
480
def __init__(self) -> None:
482
self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
484
def forward(self, x: torch.Tensor):
485
return wrapped_with_submodule(x, self.batchnorm1d)
487
m = symbolic_trace(M())
489
self.assertIn("wrapped_with_submodule", m.code)
491
input = torch.rand(3, 2)
492
ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
493
self.assertEqual(ref_batchnorm1d(input), m(input))
495
def test_wrapped_retrace(self):
497
return wrapped_via_decorator(y)
499
m = symbolic_trace(to_trace)
500
self.assertIn('wrapped_via_decorator', m.code)
501
self.assertEqual(m(0), 1)
503
retraced = symbolic_trace(m)
504
self.assertIn('wrapped_via_decorator', retraced.code)
505
self.assertEqual(retraced(0), 1)
507
def test_wrap_decorated_function(self):
509
return wrapped_decorated_fn(y)
511
m = symbolic_trace(to_trace)
512
self.assertIn('wrapped_decorated_fn', m.code)
513
self.assertEqual(m(1), 1)
515
def test_graph_edit_with_proxy(self):
516
class M(torch.nn.Module):
517
def forward(self, a, b):
520
g = symbolic_trace(m).graph
521
new_g = torch.fx.Graph()
522
val_map : Dict[Node, Node] = {}
523
output_val = new_g.graph_copy(g, val_map)
524
t = Proxy(output_val)
525
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
526
new_g.output((t + t).node)
527
gm = GraphModule(m, new_g)
529
self.assertEqual(gm(3, 4), 14)
531
def test_proxy_deepcopy_without_tracer(self):
532
class MyModule(torch.nn.Module):
536
def forward(self, x):
540
traced = symbolic_trace(module)
541
node = list(traced.graph.nodes)[-2]
542
p = torch.fx.Proxy(node, None)
544
p2 = copy.deepcopy(p)
545
self.assertTrue(isinstance(p2, torch.fx.Proxy))
546
self.assertEqual(p2.node.name, node.name)
547
self.assertEqual(p2.node.target, node.target)
548
self.assertNotEqual(id(p2.node), id(node))
550
def test_proxy_deepcopy_with_tracer(self):
551
class TestTracer(Tracer):
552
def __init__(self, name):
556
def is_leaf_module(self, module, name):
559
class MyModule(torch.nn.Module):
563
def forward(self, x):
567
tracer = TestTracer("mytracer")
568
traced = symbolic_trace(module)
569
node = list(traced.graph.nodes)[-2]
570
p = torch.fx.Proxy(node, tracer)
572
p2 = copy.deepcopy(p)
573
self.assertTrue(isinstance(p2, torch.fx.Proxy))
574
self.assertTrue(isinstance(p2.tracer, torch.fx._symbolic_trace.Tracer))
575
self.assertEqual(p2.tracer.name, "mytracer")
576
self.assertEqual(p2.node.name, node.name)
577
self.assertEqual(p2.node.target, node.target)
578
self.assertNotEqual(id(p2.node), id(node))
579
self.assertNotEqual(id(p2.tracer), id(tracer))
581
def test_concrete_arg_none_assert(self):
582
class Foo(torch.nn.Module):
583
def forward(self, x, val=None):
584
return x if val is None else x + val
587
traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None})
588
with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'):
589
traced(torch.randn(5), torch.randn(5))
592
torch.testing.assert_close(traced(x), f(x))
594
def test_trace_multiple_funcs(self):
595
class Foo(torch.nn.Module):
596
def forward(self, x, y):
599
def minus_forward(self, x, y):
602
def multiply_forward(self, x, y):
606
x, y = torch.randn(5), torch.randn(5)
608
print(torch.__version__)
611
torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y))
613
tracer.traced_func_name = "minus_forward"
614
torch.testing.assert_close(
615
GraphModule(f, tracer.trace(f))(x, y),
616
f.minus_forward(x, y),
619
tracer.traced_func_name = "multiply_forward"
620
torch.testing.assert_close(
621
GraphModule(f, tracer.trace(f))(x, y),
622
f.multiply_forward(x, y),
625
tracer.traced_func_name = "add_forward"
626
with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
629
def test_graph_unique_names(self):
630
class M(torch.nn.Module):
631
def forward(self, a, b):
634
g = symbolic_trace(m).graph
635
new_g = torch.fx.Graph()
636
val_map : Dict[Node, Node] = {}
637
output_val = new_g.graph_copy(g, val_map)
638
t = Proxy(output_val)
639
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
640
new_g.output((t + t).node)
641
gm = GraphModule(m, new_g)
642
seen_names : Set[str] = set()
643
for node in gm.graph.nodes:
644
assert node.name not in seen_names
645
seen_names.add(node.name)
647
def test_stack_traces(self):
648
class M(torch.nn.Module):
649
def forward(self, a, b):
652
tracer = torch.fx.Tracer()
653
tracer.record_stack_traces = True
655
graph = tracer.trace(M())
656
# saving the original list because we will insert new nodes as a part of a test
657
orig_graph_nodes = list(graph.nodes)
658
for node in orig_graph_nodes:
659
if node.op == 'output':
661
self.assertTrue(node.stack_trace is not None)
662
assert 'test_fx.py' in node.stack_trace
664
# verify that copying the node does not lose the stack trace
665
new_node = graph.node_copy(node)
666
self.assertTrue(new_node.stack_trace is not None)
667
assert 'test_fx.py' in new_node.stack_trace
669
def test_stack_traces_with_transformer(self):
670
class M(torch.nn.Module):
671
def forward(self, a, b):
674
tracer = torch.fx.Tracer()
675
tracer.record_stack_traces = True
677
graph = tracer.trace(M())
678
gm = GraphModule(tracer.root, graph)
679
new_gm = Transformer(gm).transform()
681
# nodes after Transformer should still preserve the original node's stack trace
682
for node in new_gm.graph.nodes:
683
if node.op in {'placeholder', 'output'}:
685
self.assertTrue(node.stack_trace is not None)
686
assert 'test_fx.py' in node.stack_trace
688
def test_lineno_map(self):
689
class M(torch.nn.Module):
690
def forward(self, a, b):
695
tracer = torch.fx.Tracer()
696
graph = tracer.trace(M())
697
gm = GraphModule(tracer.root, graph)
698
expected = {1: 2, 2: 3, 3: 4, 4: 5}
699
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
701
# test custom codegen
702
def transform_code(code):
703
return ["print('hello!')\n", *code]
704
gm.graph.on_generate_code(lambda _: transform_code)
706
expected = {2: 2, 3: 3, 4: 4, 5: 5}
707
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
709
def test_graph_unique_names_manual(self):
710
graph : torch.fx.Graph = torch.fx.Graph()
711
a : torch.fx.Node = graph.create_node('placeholder', 'x')
712
b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
713
c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
714
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
716
graph2 = torch.fx.Graph()
717
val_map : Dict[Node, Node] = {}
718
graph2.graph_copy(graph, val_map)
719
seen_names : Set[str] = set()
720
for node in graph2.nodes:
721
assert node.name not in seen_names
722
seen_names.add(node.name)
724
def test_unpack(self):
725
class M(torch.nn.Module):
726
def forward(self, a, b):
730
a = (torch.rand(1), torch.rand(1))
733
self.checkGraphModule(m, (a, b))
735
def test_native_callable(self):
736
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
737
raise unittest.SkipTest("non-portable load_library call used in test")
738
# This test exercises the case where we use FX to translate from Python
739
# code to some native callable object
741
# For the purposes of testing, we use ElementwiseInterpreter defined
742
# in test_custom_class.cpp.
744
# We test that we can
745
# 1) Construct a native callable from FX IR
746
# 2) Construct a drop-in replacement module that delegates to the
747
# native callable rather than the original code
748
# 3) Run both the original code and native callable wrapper with
750
# 4) TorchScript compile the native callable wrapper and confirm
751
# equivalent results with the reference
752
# 5) TorchScript serialize and deserialize the native callable
753
# and confirm equivalent results with the reference
755
# We use this simple Module as a reference computation
756
class MySimpleMod(torch.nn.Module):
757
def forward(self, x):
762
# This is what a lowering pass might look like: a function that takes
763
# a valid nn.Module, symbolically traces it, lowers the Module to some
764
# representation, and wraps that representation up into another
765
# nn.Module instance that handles dispatch to the compiled/lowered code.
766
def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
767
# ===== Stage 1: Symbolic trace the module =====
768
mod = symbolic_trace(orig_mod)
770
# ===== Stage 2: Lower GraphModule representation to the C++
771
# interpreter's instruction format ======
778
operator.add : "add",
782
output_node : Optional[Node] = None
783
# For each instruction, create a triple
784
# (instruction_name : str, inputs : List[str], output : str)
785
# to feed into the C++ interpreter
786
for n in mod.graph.nodes:
787
target, args, out_name = n.target, n.args, n.name
788
assert len(n.kwargs) == 0, "kwargs currently not supported"
790
if n.op == 'placeholder':
791
# Placeholders specify function argument names. Save these
792
# for later when we generate the wrapper GraphModule
793
fn_input_names.append(target)
794
elif n.op == 'call_function':
795
assert target in target_to_name, "Unsupported call target " + target
798
if not isinstance(arg, Node):
799
# Pull out constants. These constants will later be
800
# fed to the interpreter C++ object via add_constant()
801
arg_name = f'constant_{constant_idx}'
802
constants[arg_name] = torch.tensor(
803
[arg] if isinstance(arg, numbers.Number) else arg)
804
arg_names.append(arg_name)
807
arg_names.append(arg.name)
808
instructions.append((target_to_name[target], arg_names, out_name))
809
elif n.op == 'output':
810
if output_node is not None:
811
raise RuntimeError('Multiple output nodes!')
814
raise RuntimeError('Unsupported opcode ' + n.op)
816
interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
818
for k, v in constants.items():
819
interpreter.add_constant(k, v)
820
# Specify names for positional input arguments
821
interpreter.set_input_names(fn_input_names)
823
interpreter.set_instructions(instructions)
824
# Specify name for single output
825
assert isinstance(output_node.args[0], torch.fx.Node)
826
interpreter.set_output_name(output_node.args[0].name)
828
# ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
829
class WrapperModule(torch.nn.Module):
830
def __init__(self, interpreter):
832
self.interpreter = interpreter
834
wrapper = WrapperModule(interpreter)
836
# Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
837
# 3) Returns the speficied return value
839
# FIXME: The following code could be greatly simplified by symbolic_trace'ing
840
# the wrapper with a Tracer that considers the Wrapper instance a root
841
# module, however, I can't get `__call__` exposed on TorchBind classes
842
# without it messing up Python `hasattr` for some reason. More digging
843
# into CPython's implementation of hasattr is probably in order...
845
graph = torch.fx.Graph()
846
# Add placeholders for fn inputs
847
placeholder_nodes = []
848
for name in fn_input_names:
849
placeholder_nodes.append(graph.create_node('placeholder', name))
851
# Get the interpreter object
852
interpreter_node = graph.create_node('get_attr', 'interpreter')
854
# Add a node to call the interpreter instance
855
output_node = graph.create_node(
856
op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
859
graph.output(output_node)
863
# Return final GraphModule!!!
864
return GraphModule(wrapper, graph)
866
# Lower GraphModule to C++ interpreter
867
lowered = lower_to_elementwise_interpreter(msm)
869
# Compare correctness with original module
872
test_out = lowered(x)
873
torch.testing.assert_close(test_out, ref_out)
875
# Test TorchScript compilation
876
scripted_lowered = torch.jit.script(lowered)
877
script_out = scripted_lowered(x)
878
torch.testing.assert_close(script_out, ref_out)
880
# Test TorchScript ser/de
881
import_copy = self.getExportImportCopy(scripted_lowered)
882
imported_out = import_copy(x)
883
torch.testing.assert_close(imported_out, ref_out)
885
def test_reserved_getattr(self):
886
"""Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
887
class M(torch.nn.Module):
888
def forward(self, a):
892
m_g = symbolic_trace(m)
894
for node in m_g.graph.nodes:
895
self.assertTrue(node.name != "getattr")
897
@unittest.skip("Hotfix for SEV remediation")
898
def test_trace_buffer_slice(self):
901
class ExampleCode(torch.nn.Module):
902
def __init__(self) -> None:
904
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
905
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
906
self.lin = torch.nn.Linear(d_hid, d_hid)
907
self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid))
909
def forward(self, x):
910
x = torch.mm(x, self.mm_param)
913
x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]]
916
x = x + skip_connection
917
x = torch.mm(x, self.mm_param2)
923
traced = torch.fx.symbolic_trace(ec)
925
x = torch.randn(bs, d_hid)
926
torch.testing.assert_close(ec(x), traced(x))
928
def test_node_tagging(self):
929
class TaggingTracer(Tracer):
930
def create_node(self, kind : str, target : Union[str, Callable],
931
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
932
type_expr : Optional[Any] = None) -> Node:
933
n = super().create_node(kind, target, args, kwargs, name)
937
class M(torch.nn.Module):
938
def forward(self, a, b):
942
g = TaggingTracer().trace(m)
945
self.assertTrue(hasattr(n, 'tag'))
946
self.assertEqual(n.tag, 'foo')
948
def test_tensor_attribute(self):
949
class TensorAttribute(torch.nn.Module):
950
def __init__(self) -> None:
952
self.tensor = torch.rand(3, 4)
954
def forward(self, x):
955
return torch.nn.functional.linear(x, self.tensor)
957
ta = TensorAttribute()
958
traced = symbolic_trace(ta)
959
traced(torch.rand(4, 4))
961
class WrapperForQualname(torch.nn.Module):
962
def __init__(self) -> None:
964
self.ta = TensorAttribute()
966
def forward(self, x):
967
return torch.nn.functional.linear(x, self.ta.tensor)
969
wfq = WrapperForQualname()
970
traced2 = symbolic_trace(wfq)
972
traced2(torch.rand(4, 4))
974
def test_tensor_attribute_coalseced(self):
976
def count_attrs(fx_module):
978
for node in traced.graph.nodes:
979
if node.op == 'get_attr':
980
targets.add(node.target)
983
val = torch.tensor(5)
987
traced = symbolic_trace(f)
989
self.assertEqual(count_attrs(traced), 1)
991
val2 = torch.tensor(5)
994
val = torch.tensor(5)
995
return x + val + val2
997
traced = symbolic_trace(f)
999
self.assertEqual(count_attrs(traced), 2)
1001
def test_symbolic_trace_sequential(self):
1002
class Simple(torch.nn.Module):
1003
def forward(self, x):
1006
seq = torch.nn.Sequential(
1011
traced = symbolic_trace(seq)
1013
x = torch.rand(3, 4)
1014
self.assertEqual(traced(x), seq(x))
1016
def test_tensor_constant(self):
1017
class ConstTensor(torch.nn.Module):
1018
def forward(self, x):
1019
return torch.nn.functional.linear(x, torch.zeros(3, 4))
1022
traced = symbolic_trace(ct)
1024
traced(torch.rand(4, 4))
1026
def test_pickle_graphmodule(self):
1027
class Nested(torch.nn.Module):
1028
def __init__(self) -> None:
1030
self.st = torch.nn.Linear(4, 4)
1032
def forward(self, x):
1036
traced = symbolic_trace(n)
1038
pickled = pickle.dumps(traced)
1039
loaded = pickle.loads(pickled)
1041
x = torch.rand(3, 4)
1042
self.assertEqual(loaded(x), traced(x))
1044
def test_pickle_custom_import(self):
1045
graph = torch.fx.Graph()
1046
a = graph.placeholder('x')
1047
b = graph.placeholder('y')
1048
c = graph.call_function(a_non_torch_leaf, (a, b))
1049
d = graph.call_function(torch.sin, (c,))
1051
gm = GraphModule(torch.nn.Module(), graph)
1052
pickled = pickle.dumps(gm)
1053
loaded = pickle.loads(pickled)
1055
x, y = torch.rand(1), torch.rand(1)
1056
self.assertEqual(loaded(x, y), gm(x, y))
1058
def test_all_input_nodes(self):
1059
graph : torch.fx.Graph = torch.fx.Graph()
1060
a : torch.fx.Node = graph.placeholder('x')
1061
b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
1062
c : torch.fx.Node = graph.get_attr('y_attr')
1063
d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
1064
e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
1068
self.assertEqual(b.all_input_nodes, [a])
1069
self.assertEqual(c.all_input_nodes, [])
1070
self.assertEqual(d.all_input_nodes, [b, c])
1071
self.assertEqual(e.all_input_nodes, [d])
1073
def test_deepcopy_graphmodule_with_transform(self):
1075
traced = symbolic_trace(st)
1078
def transform(traced):
1079
new_graph = torch.fx.Graph()
1080
val_map : Dict[Node, Node] = {}
1081
output_value = new_graph.graph_copy(traced.graph, val_map)
1082
relu_out = new_graph.create_node(
1083
op='call_method', target='neg', args=(output_value,), kwargs={})
1084
new_graph.output(relu_out)
1085
return GraphModule(traced, new_graph)
1086
transformed = transform(traced)
1087
transformed.graph.lint()
1088
copied = copy.deepcopy(transformed)
1089
self.assertNotEqual(id(type(transformed)), id(type(copied)))
1090
x = torch.randn(3, 4)
1091
self.assertEqual(copied(x), transformed(x))
1093
def test_deepcopy_with_submods_params(self):
1094
class Bar(torch.nn.Module):
1095
def __init__(self) -> None:
1097
self.param = torch.nn.Parameter(torch.rand(3, 4))
1099
def forward(self, x):
1100
return torch.relu(x) + self.param
1102
class Baz(torch.nn.Module):
1103
def __init__(self) -> None:
1105
self.param = torch.nn.Parameter(torch.rand(3, 4))
1108
def forward(self, x):
1109
return self.bar(x) - self.param
1112
traced = symbolic_trace(baz)
1114
copied = copy.deepcopy(traced)
1117
def test_deepcopy_graph_with_tracer_cls(self):
1118
class TestTracer(Tracer):
1119
def is_leaf_module(self, module, name):
1122
g = Graph(tracer_cls=TestTracer)
1123
x = g.placeholder("x")
1126
h = copy.deepcopy(g)
1127
self.assertIsNotNone(h._tracer_cls)
1128
self.assertTrue(g._tracer_cls == h._tracer_cls)
1130
def test_unpack_list_better_error(self):
1131
class SomeArgs(torch.nn.Module):
1132
def forward(self, a, b):
1133
return torch.rand(3, 4)
1135
class UnpacksList(torch.nn.Module):
1136
def __init__(self) -> None:
1138
self.sa = SomeArgs()
1140
def forward(self, x : list):
1144
with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1147
def test_unpack_dict_better_error(self):
1148
class SomeKwargs(torch.nn.Module):
1149
def forward(self, x=3, y=4):
1150
return torch.rand(3, 4)
1152
class UnpacksDict(torch.nn.Module):
1153
def __init__(self) -> None:
1155
self.sk = SomeKwargs()
1157
def forward(self, x : dict):
1161
with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1164
def test_pretty_print_targets(self):
1165
# Test that Graph pretty-print prints friendly name for targets
1166
# in `operator` and `builtins`
1168
class SomeMod(torch.nn.Module):
1169
def forward(self, x):
1170
return torch.add(x.foo + x.bar, 3.0)
1172
traced = symbolic_trace(SomeMod())
1173
graph_str = str(traced.graph)
1174
self.assertIn('builtins.getattr', graph_str)
1175
self.assertIn('operator.add', graph_str)
1176
self.assertIn('torch.add', graph_str)
1178
def test_pretty_print_node(self):
1179
class M(torch.nn.Module):
1180
def __init__(self) -> None:
1182
self.param: torch.nn.Parameter = torch.nn.Parameter(
1184
self.linear = torch.nn.Linear(4, 5)
1186
def forward(self, x: torch.Tensor, y: int = 2):
1187
return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
1189
traced = symbolic_trace(M())
1191
all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
1193
FileCheck().check("x").check("placeholder") \
1194
.check("y").check("placeholder") \
1195
.check("getitem").check("call_function") \
1196
.check("param").check("get_attr") \
1197
.check("add").check("call_function") \
1198
.check("linear").check("call_module") \
1199
.check("clamp").check("call_method") \
1202
def test_script_tensor_constant(self):
1203
# TorchScript seems to ignore attributes that start with `__`.
1204
# We used to call anonymous Tensor values `__tensor_constant*`, but
1205
# they were getting ignored by script. Now they're called
1206
# `_tensor_constant*`
1207
class IHaveATensorConstant(torch.nn.Module):
1208
def forward(self, x):
1209
return x + torch.rand(3, 4)
1211
traced = torch.fx.symbolic_trace(IHaveATensorConstant())
1212
torch.jit.script(traced)
1214
def test_autowrap_functions(self):
1215
class AutowrapFnTest(torch.nn.Module):
1216
def forward(self, x):
1217
return fx_int(x.shape[0] / 2)
1219
class AutowrapFnTest2(torch.nn.Module):
1220
def forward(self, x):
1221
return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
1223
# Check function(s) are wrapped
1224
# `int` would normally throw a TypeError as argument can't be `Proxy`
1225
tracer = Tracer(autowrap_functions=(fx_int,))
1226
graph = tracer.trace(AutowrapFnTest())
1227
traced = GraphModule(tracer.root, graph, 'test')
1228
tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
1229
tracer_2.trace(AutowrapFnTest2())
1231
# Test scriptability
1232
traced_scripted = torch.jit.script(traced)
1233
self.assertEqual(traced_scripted(torch.rand(4)), 2)
1235
def test_tuple_no_subscript(self):
1239
traced = torch.fx.symbolic_trace(foo)
1240
x = (torch.randn(5, 3),)
1241
torch.testing.assert_close(traced(x), x[0])
1245
torch.save(traced, bio)
1249
# weights_only=False as this loads a GraphModule
1250
# GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
1251
loaded = torch.load(bio, weights_only=False)
1253
torch.testing.assert_close(loaded(x), x[0])
1255
def test_torch_fx_len(self):
1256
class FXLenTest(torch.nn.Module):
1257
def forward(self, x):
1260
traced = symbolic_trace(FXLenTest())
1261
self.assertEqual(traced(torch.rand(3, 4)), 3)
1263
# Test scriptability
1264
scripted = torch.jit.script(FXLenTest())
1265
self.assertEqual(scripted(torch.rand(3)), 3)
1267
traced_scripted = torch.jit.script(traced)
1268
self.assertEqual(traced_scripted(torch.rand(3)), 3)
1270
# Test non-proxy len
1271
class FXLenTest2(torch.nn.Module):
1272
def __init__(self) -> None:
1276
def forward(self, x):
1277
return x + len(self.l)
1279
traced2 = symbolic_trace(FXLenTest2())
1280
inp = torch.rand(3, 4)
1281
self.assertEqual(traced2(inp), inp + 3.0)
1282
self.assertIs(len, builtins.len)
1284
def test_torch_fx_getattr(self):
1285
class FXGetattrTest(torch.nn.Module):
1286
def forward(self, x):
1287
return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
1289
traced = symbolic_trace(FXGetattrTest())
1290
self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
1292
def test_sqrt(self):
1293
class Sqrt1(torch.nn.Module):
1294
def forward(self, x):
1295
return sqrt(x.size(0))
1297
class Sqrt2(torch.nn.Module):
1298
def forward(self, x):
1299
return math.sqrt(x.size(0))
1301
class Sqrt3(torch.nn.Module):
1302
def forward(self, x):
1303
return x + math.sqrt(2) + sqrt(2)
1305
self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
1306
self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
1307
self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
1308
self.assertIs(sqrt, _sqrt)
1309
self.assertIs(math.sqrt, _sqrt)
1311
def test_torch_custom_ops(self):
1312
class M(torch.nn.Module):
1313
def forward(self, a):
1314
b = torch.ops.aten.sigmoid(a)
1315
c = torch.ops.aten.cat([a, b])
1316
return torch.ops.aten.cat((c, c))
1318
input = torch.randn(3)
1320
gm = symbolic_trace(m)
1323
self.assertEqual(out, ref_out)
1325
def test_torch_op_overloads(self):
1326
class M(torch.nn.Module):
1327
def forward(self, a):
1328
b = torch.ops.aten.add.Tensor(a, a)
1331
input = torch.randn(3)
1333
gm = symbolic_trace(m)
1336
self.assertEqual(out, ref_out)
1338
for node in gm.graph.nodes:
1339
if node.op == 'call_function':
1340
assert isinstance(node.target, torch._ops.OpOverload)
1341
assert node.target.__name__ == 'add.Tensor'
1343
def test_pickle_torch_custom_ops(self):
1344
class M(torch.nn.Module):
1345
def forward(self, a):
1346
b = torch.ops.aten.sigmoid(a)
1347
c = torch.ops.aten.cat([a, b])
1348
return torch.ops.aten.cat((c, c))
1350
input = torch.randn(3)
1352
gm = symbolic_trace(m)
1354
pickled = pickle.dumps(gm)
1355
loaded = pickle.loads(pickled)
1356
self.assertEqual(loaded(input), gm(input))
1358
def test_pretty_print(self):
1360
traced = symbolic_trace(st)
1362
printed = str(traced)
1363
assert 'SimpleTest()' in printed
1364
assert 'torch.relu' in printed
1366
def test_pretty_print_graph(self):
1367
class KwargPrintTest(torch.nn.Module):
1368
def forward(self, x):
1369
return torch.squeeze(x + 3.0, dim=2)
1370
st = KwargPrintTest()
1371
traced = symbolic_trace(st)
1373
stringed = str(traced.graph)
1374
for s in ['args', 'kwargs', 'num_users']:
1375
assert s in stringed
1377
def test_custom_proxy_type(self):
1379
def __init__(self, left, right):
1380
self.left, self.right = left, right
1382
def add(self, other):
1383
l = self.left + other.left
1384
r = self.right + other.right
1385
return TensorPair(l, r)
1387
def mul(self, other):
1388
l = self.left * other.left
1389
r = self.right * other.right
1390
return TensorPair(l, r)
1392
def use_tensor_pair(x : TensorPair, y : TensorPair):
1396
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1397
y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1399
ref_out = use_tensor_pair(x, y)
1401
traced = symbolic_trace(use_tensor_pair)
1403
traced_out = traced(x, y)
1404
self.assertEqual(traced_out.left, ref_out.left)
1405
self.assertEqual(traced_out.right, ref_out.right)
1407
def test_custom_proxy_type_literal(self):
1408
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1409
def __init__(self, left, right):
1410
self.left, self.right = left, right
1412
def add(self, other):
1413
l = self.left + other.left
1414
r = self.right + other.right
1415
return TensorPair(l, r)
1417
def mul(self, other):
1418
l = self.left * other.left
1419
r = self.right * other.right
1420
return TensorPair(l, r)
1422
def use_tensor_pair_literal(x : TensorPair):
1423
s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
1426
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1428
ref_out = use_tensor_pair_literal(x)
1430
traced = symbolic_trace(use_tensor_pair_literal)
1432
traced_out = traced(x)
1433
self.assertEqual(traced_out.left, ref_out.left)
1434
self.assertEqual(traced_out.right, ref_out.right)
1436
def test_custom_proxy_dynamic_value(self):
1437
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1438
def __init__(self, left, right):
1439
self.left, self.right = left, right
1441
def add(self, other):
1442
l = self.left + other.left
1443
r = self.right + other.right
1444
return TensorPair(l, r)
1446
def mul(self, other):
1447
l = self.left * other.left
1448
r = self.right * other.right
1449
return TensorPair(l, r)
1451
def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
1452
s = x.add(TensorPair(y, y))
1455
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1456
y = torch.randn(5, 3)
1457
ref_out = use_tensor_pair_ctor(x, y)
1459
traced = symbolic_trace(use_tensor_pair_ctor)
1461
traced_out = traced(x, y)
1462
self.assertEqual(traced_out.left, ref_out.left)
1463
self.assertEqual(traced_out.right, ref_out.right)
1465
def test_custom_proxy_input_dependent_control_flow(self):
1466
class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
1467
def __init__(self, inp):
1470
self.tensor = torch.tensor([])
1472
self.is_zero = False
1475
def add(self, other):
1477
return ZeroTensor(other.tensor)
1481
def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
1482
return ZeroTensor(x + y)
1484
x, y = torch.randn(5, 3), torch.randn(5, 3)
1486
ref_out = use_zero_tensor(x, y)
1488
traced = symbolic_trace(use_zero_tensor)
1490
traced_out = traced(x, y)
1492
self.assertEqual(traced_out.is_zero, ref_out.is_zero)
1493
self.assertEqual(traced_out.tensor, ref_out.tensor)
1495
def test_graph_fns(self):
1497
a = g.placeholder('a')
1498
b = g.call_module('linear', (a,))
1499
c = g.get_attr('bias')
1500
d = g.call_method('add', (b, c))
1501
e = g.call_function(torch.sin, (d,))
1503
mod = torch.nn.Module()
1504
mod.linear = torch.nn.Linear(3, 4)
1505
mod.bias = torch.rand(4)
1506
gm = GraphModule(mod, g)
1508
input = torch.rand(3)
1510
ref = torch.sin(mod.linear(input) + mod.bias)
1511
self.assertEqual(r, ref)
1513
def test_remove_uses(self):
1514
g : torch.fx.Graph = Graph()
1515
x : torch.fx.Node = g.placeholder('x')
1516
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1517
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1520
neg.replace_all_uses_with(relu)
1523
self.assertTrue(neg not in relu.users)
1525
def test_remove_uses_with_custom_filter(self):
1526
g : torch.fx.Graph = Graph()
1527
x : torch.fx.Node = g.placeholder('x')
1528
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1529
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1532
neg.replace_all_uses_with(relu, lambda x: x != neg)
1534
self.assertTrue(neg in relu.users)
1536
def test_nonetype_annotation(self):
1537
eb = torch.nn.EmbeddingBag(3, 4)
1540
def test_pickle_nonetype_annotation(self):
1541
eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
1542
traced = symbolic_trace(eb)
1543
pickled = pickle.dumps(traced)
1544
loaded = pickle.loads(pickled)
1546
input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1547
offsets = torch.LongTensor([0, 4])
1548
self.assertEqual(loaded(input, offsets), traced(input, offsets))
1550
def test_return_tuple(self):
1551
class M(torch.nn.Module):
1552
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1556
traced = symbolic_trace(original)
1557
self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
1559
def test_construct_root_dict(self):
1560
graph : torch.fx.Graph = torch.fx.Graph()
1561
a : torch.fx.Node = graph.create_node('placeholder', 'x')
1562
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1563
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1564
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1567
linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
1568
add_param : torch.Tensor = torch.rand(3, 4)
1569
gm : torch.fx.GraphModule = torch.fx.GraphModule(
1570
{'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
1573
assert 'self.foo.bar.baz' in gm.code
1575
x : torch.Tensor = torch.rand(3, 3)
1576
out : torch.Tensor = gm(x)
1577
ref_out : torch.Tensor = linear_mod(x) + add_param
1578
self.assertEqual(out, ref_out)
1580
def test_symbolic_trace_assert(self):
1582
class AssertsTensorShape(torch.nn.Module):
1583
def forward(self, x):
1584
torch._assert(x.shape[1] > 4, "assert_foobar")
1587
m = AssertsTensorShape()
1588
# verify traceability
1589
traced = symbolic_trace(m)
1590
# verify assertion on traced model works correctly at runtime
1591
traced(torch.rand(4, 5))
1592
with self.assertRaisesRegex(AssertionError, "assert_foobar"):
1593
traced(torch.rand(4, 3))
1594
# verify the symbolically traced module is scriptable
1595
ms = torch.jit.script(m)
1596
with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
1597
ms(torch.rand(4, 3))
1599
def test_fx_create_arg(self):
1600
class CustomArgObject:
1601
def __init__(self, x, y):
1605
def __fx_create_arg__(self, tracer: torch.fx.Tracer):
1606
return tracer.create_node(
1610
tracer.create_arg(self.x),
1611
tracer.create_arg(self.y),
1616
class HasCustomArgObjectWhenLeaf(torch.nn.Module):
1617
def forward(self, o: CustomArgObject):
1618
# Not normally traceable; good reason to make
1619
# this module a leaf.
1624
class Root(torch.nn.Module):
1625
def __init__(self) -> None:
1627
self.inner = HasCustomArgObjectWhenLeaf()
1629
def forward(self, x, y):
1630
o = CustomArgObject(x, y)
1631
return self.inner(o)
1633
class CreateArgTracer(torch.fx.Tracer):
1634
def is_leaf_module(self, m, module_qualified_name):
1635
return type(m) is HasCustomArgObjectWhenLeaf
1638
graph = CreateArgTracer().trace(m)
1639
gm = torch.fx.GraphModule(m, graph)
1640
assert "CustomArgObject(" in gm.code
1642
def test_trace_fn_constant(self):
1643
some_constant = torch.rand(3, 4)
1646
return some_constant + x
1648
traced = symbolic_trace(add_const)
1650
input = torch.rand(3, 4)
1651
self.assertEqual(traced(input), add_const(input))
1653
def test_copy_no_remap(self):
1654
traced = symbolic_trace(SimpleTest())
1656
copied = torch.fx.Graph()
1657
for node in g.nodes:
1658
copied.node_copy(node)
1659
with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
1662
def test_wrong_topo(self):
1663
graph : torch.fx.Graph = torch.fx.Graph()
1664
a : torch.fx.Node = graph.create_node('placeholder', 'x')
1665
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1666
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1667
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1669
nodes = list(graph.nodes)
1670
nodes[3].append(nodes[2])
1671
with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
1674
def test_wrong_target_type(self):
1675
graph : torch.fx.Graph = torch.fx.Graph()
1676
with self.assertRaises(ValueError):
1677
n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
1680
def test_example_shape_prop(self):
1681
class TestCase(torch.nn.Module):
1682
def __init__(self) -> None:
1684
self.attr = torch.randn(3, 4)
1685
self.submod = torch.nn.Linear(4, 4)
1687
def forward(self, x):
1688
return torch.neg(self.submod(x.relu() + self.attr))
1690
tc_traced = symbolic_trace(tc)
1691
ref_out = tc_traced(torch.rand(3, 4))
1692
shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
1694
# Make sure we're testing all opcodes
1696
output_shape : Optional[torch.Shape] = None
1697
output_stride : Optional[Tuple[int]] = None
1698
for node in tc_traced.graph.nodes:
1699
opcodes.add(node.op)
1700
if node.op == 'output':
1701
output_shape = node.args[0].meta['tensor_meta'].shape
1702
output_stride = node.args[0].meta['tensor_meta'].stride
1703
self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
1704
'call_module', 'output'})
1706
# Test shape propagation and make sure results match actual
1707
self.assertEqual(output_shape, ref_out.shape)
1708
self.assertEqual(output_stride, ref_out.stride())
1710
def test_shape_prop_layout(self):
1711
class ConvTest(torch.nn.Module):
1712
def __init__(self) -> None:
1714
self.conv_mod = torch.nn.Conv2d(5, 5, 3)
1716
def forward(self, x):
1717
return self.conv_mod(x)
1720
test_mod = ConvTest()
1721
traced = symbolic_trace(test_mod)
1722
x = torch.randn(5, 5, 224, 224)
1723
shape_prop.ShapeProp(traced).propagate(x)
1725
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1726
for node in traced.graph.nodes)
1728
x_channels_last = x.contiguous(memory_format=torch.channels_last)
1729
traced.to(memory_format=torch.channels_last)
1730
shape_prop.ShapeProp(traced).propagate(x_channels_last)
1731
for node in traced.graph.nodes:
1732
# NB: the implementation of conv may not preserve the memory format,
1733
# unfortunately. The best we can do is just check that the placeholder
1734
# node is channels-last
1735
if node.op in {'placeholder'}:
1736
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
1738
def test_shape_prop_aggregate(self):
1739
class ReturnTwo(torch.nn.Module):
1740
def forward(self, x):
1741
return (3, torch.sum(x))
1743
class UnderTest(torch.nn.Module):
1744
def __init__(self) -> None:
1746
self.rt = ReturnTwo()
1748
def forward(self, x):
1753
class RTTracer(torch.fx.Tracer):
1754
def is_leaf_module(self, m, module_qualified_name):
1755
return type(m) is ReturnTwo
1757
graph = RTTracer().trace(ut)
1758
mod = torch.fx.GraphModule(ut, graph)
1760
shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
1762
for node in mod.graph.nodes:
1763
if node.op == 'call_module':
1764
assert 'tensor_meta' in node.meta
1765
tensor_meta = node.meta['tensor_meta']
1766
assert tensor_meta[0] == 3
1767
assert tensor_meta[1].shape == torch.Size([])
1769
def test_shape_prop_layout_3d(self):
1770
class ConvTest3d(torch.nn.Module):
1771
def __init__(self) -> None:
1773
self.conv_mod = torch.nn.Conv3d(5, 5, 3)
1775
def forward(self, x):
1776
return self.conv_mod(x)
1778
test_mod_3d = ConvTest3d()
1779
traced_3d = symbolic_trace(test_mod_3d)
1780
x_3d = torch.randn(5, 5, 224, 224, 15)
1781
shape_prop.ShapeProp(traced_3d).propagate(x_3d)
1782
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1783
for node in traced_3d.graph.nodes)
1785
x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
1786
traced_3d.to(memory_format=torch.channels_last_3d)
1787
shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
1788
for node in traced_3d.graph.nodes:
1789
# NB: the implementation of conv may not preserve the memory format,
1790
# unfortunately. The best we can do is just check that the placeholder
1791
# node is channels-last
1792
if node.op in {'placeholder'}:
1793
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
1795
def test_nn_module_stack(self):
1796
class SubModule(torch.nn.Module):
1797
def __init__(self) -> None:
1799
self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
1801
def forward(self, x):
1802
return self.conv_mod(x)
1804
class MyModule(torch.nn.Module):
1805
def __init__(self) -> None:
1807
self.sub_mod = SubModule()
1809
def forward(self, x):
1810
return self.sub_mod(x)
1813
gm = torch.fx.symbolic_trace(m)
1816
expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))),
1817
('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))]
1818
for node in gm.graph.nodes:
1819
mod_stack = node.meta.get('nn_module_stack', {})
1822
stack_list = list(mod_stack.items())
1823
self.assertEqual(stack_list, expected_stack)
1825
def test_transformer_preserves_nn_module_stack_for_get_attr(self):
1826
class M(torch.nn.Module):
1827
def __init__(self) -> None:
1829
self.weight = torch.nn.Parameter(torch.ones(1, 1))
1831
def forward(self, x):
1832
return self.weight + x
1834
tracer = torch.fx.Tracer()
1835
graph = tracer.trace(M())
1836
gm = GraphModule(tracer.root, graph)
1837
for node in gm.graph.nodes:
1838
if node.op == 'get_attr':
1839
node.meta["nn_module_stack"] = "self"
1840
node.meta["stack_trace"] = "stack_trace"
1841
node.meta["source_fn_stack"] = "source_fn_stack"
1842
new_gm = Transformer(gm).transform()
1843
for node in new_gm.graph.nodes:
1844
if node.op == 'get_attr':
1845
self.assertEqual(node.meta["nn_module_stack"], "self")
1846
self.assertEqual(node.meta["stack_trace"], "stack_trace")
1847
self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
1849
def test_interpreter(self):
1850
class MyModule(torch.nn.Module):
1851
def __init__(self) -> None:
1853
self.param = torch.nn.Parameter(torch.rand(3, 4))
1854
self.linear = torch.nn.Linear(4, 5)
1856
def forward(self, x):
1857
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1860
gm = torch.fx.symbolic_trace(m)
1862
interpreter = Interpreter(gm)
1863
input = torch.randn(3, 4)
1864
self.assertEqual(interpreter.run(input), gm(input))
1865
self.assertEqual(interpreter.run(input), m(input))
1867
def test_interpreter_other_graph(self):
1868
class MyModule(torch.nn.Module):
1869
def __init__(self) -> None:
1871
self.param = torch.nn.Parameter(torch.rand(3, 4))
1872
self.linear = torch.nn.Linear(4, 5)
1874
def forward(self, x):
1875
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1878
gm = torch.fx.symbolic_trace(m)
1880
interpreter = Interpreter(gm, graph=gm.graph)
1881
input = torch.randn(3, 4)
1882
self.assertEqual(interpreter.run(input), gm(input))
1883
self.assertEqual(interpreter.run(input), m(input))
1885
def test_interpreter_run_node_override(self):
1886
class MyModule(torch.nn.Module):
1887
def __init__(self) -> None:
1889
self.param = torch.nn.Parameter(torch.rand(3, 4))
1890
self.linear = torch.nn.Linear(4, 5)
1892
def forward(self, x):
1893
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1896
gm = torch.fx.symbolic_trace(m)
1898
class RunNodeInterpreter(Interpreter):
1899
def __init__(self, module):
1900
super().__init__(module)
1902
def run_node(self, n : Node) -> Any:
1903
result = super().run_node(n)
1904
n.cached_value = result
1907
input = torch.randn(3, 4)
1908
RunNodeInterpreter(gm).run(input)
1909
for node in gm.graph.nodes:
1910
assert hasattr(node, 'cached_value')
1912
def test_interpreter_onthefly_swap(self):
1915
return torch.sigmoid(x).neg()
1917
gm = torch.fx.symbolic_trace(fn)
1919
class NegSigmSwapInterpreter(Interpreter):
1920
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1921
if target == torch.sigmoid:
1922
return torch.neg(*args, **kwargs)
1923
return super().call_function(n) # noqa: F821
1925
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1927
call_self, *args_tail = args
1928
return call_self.sigmoid(*args_tail, **kwargs)
1929
return super().call_method(n) # noqa: F821
1931
input = torch.randn(3, 4)
1932
result = NegSigmSwapInterpreter(gm).run(input)
1933
self.assertEqual(result, torch.neg(input).sigmoid())
1935
def test_interpreter_partial_eval(self):
1936
class MyModule(torch.nn.Module):
1937
def __init__(self) -> None:
1939
self.param = torch.nn.Parameter(torch.rand(3, 4))
1940
self.linear = torch.nn.Linear(4, 5)
1942
def forward(self, x):
1943
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1945
gm = torch.fx.symbolic_trace(MyModule())
1946
interp = Interpreter(gm)
1948
for node in gm.graph.nodes:
1949
if node.op == 'call_module' and node.target == 'linear':
1950
env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
1952
assert len(env) == 1
1953
x = torch.randn(3, 4)
1954
result = interp.run(x, initial_env=env)
1955
self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
1957
def test_interpreter_star_args(self):
1958
def with_star_args(x, *args):
1961
gm = torch.fx.symbolic_trace(with_star_args)
1962
interp = Interpreter(gm)
1963
result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
1964
self.assertEqual(result, torch.ones(3, 4) * 2.0)
1966
@skipIfNoTorchVision
1967
def test_interpreter_noop_resnet18(self):
1968
rn18 = torchvision_models.resnet18()
1969
transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
1970
inp = torch.randn(5, 3, 224, 224)
1971
self.assertEqual(transformed(inp), rn18(inp))
1973
@skipIfNoTorchVision
1974
def test_interpreter_gc_values(self):
1975
rn18 = torchvision_models.resnet18()
1976
interp = Interpreter(symbolic_trace(rn18))
1977
inp = torch.rand(5, 3, 224, 224)
1978
out = interp.run(inp)
1979
env_key_names = {n.name for n in interp.env.keys()}
1980
self.assertEqual(env_key_names, {'output'})
1982
def test_interpreter_default_args(self):
1983
class Model(torch.nn.Module):
1984
def forward(self, x, y=3.14159):
1988
gm = torch.fx.symbolic_trace(model)
1990
interp = Interpreter(gm)
1991
x = torch.randn(5, 3)
1993
torch.testing.assert_close(out, x + 3.14159)
1995
def test_interpreter_not_enough_args(self):
1996
class Model(torch.nn.Module):
1997
def forward(self, x, y):
2001
gm = torch.fx.symbolic_trace(model)
2003
interp = Interpreter(gm)
2004
x = torch.randn(5, 3)
2005
with self.assertRaisesRegex(RuntimeError,
2006
'Expected positional argument for parameter y, but one was not passed in'):
2009
def test_transformer_noop(self):
2010
class MyModule(torch.nn.Module):
2011
def __init__(self) -> None:
2013
self.param = torch.nn.Parameter(torch.rand(3, 4))
2014
self.linear = torch.nn.Linear(4, 5)
2016
def forward(self, x):
2017
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
2020
gm = torch.fx.symbolic_trace(m)
2022
new_gm = Transformer(gm).transform()
2024
input = torch.randn(3, 4)
2025
self.assertEqual(new_gm(input), gm(input))
2027
def test_transformer_op_swap(self):
2030
return torch.sigmoid(x).neg()
2032
gm = torch.fx.symbolic_trace(fn)
2034
class NegSigmSwapXformer(Transformer):
2035
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
2036
if target == torch.sigmoid:
2037
return torch.neg(*args, **kwargs)
2038
return super().call_function(n) # noqa: F821
2040
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
2042
call_self, *args_tail = args
2043
return call_self.sigmoid(*args_tail, **kwargs)
2044
return super().call_method(n) # noqa: F821
2046
transformed = NegSigmSwapXformer(gm).transform()
2047
input = torch.randn(3, 4)
2048
self.assertEqual(transformed(input), torch.neg(input).sigmoid())
2050
def test_transformer_multi_outputs(self):
2051
class MyModule(torch.nn.Module):
2052
def __init__(self) -> None:
2054
self.param = torch.nn.Parameter(torch.rand(3, 4))
2055
self.linear = torch.nn.Linear(4, 5)
2057
def forward(self, x):
2059
out = self.linear(x)
2063
gm = torch.fx.symbolic_trace(m)
2065
new_gm = Transformer(gm).transform()
2067
input = torch.randn(3, 4)
2068
self.assertEqual(new_gm(input), gm(input))
2070
def test_fn_type_annotations(self):
2071
class Foo(torch.nn.Module):
2072
def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
2073
return {'a': p.x + p.y + z + i}
2075
foo_scripted = torch.jit.script(Foo())
2076
foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2078
fxed = symbolic_trace(Foo())
2079
fxed_scripted = torch.jit.script(fxed)
2080
fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2082
def test_fn_type_annotation_empty(self):
2083
def forward(a : List[torch.Tensor]):
2085
torch.jit.script(symbolic_trace(forward))
2087
def test_wrapped_method(self):
2088
def wrap_with_relu(fn):
2089
@functools.wraps(fn)
2090
def wrapper(*args, **kwargs):
2091
return torch.relu(fn(*args, **kwargs))
2094
class Foo(torch.nn.Module):
2096
def forward(self, x, w):
2097
return torch.matmul(x, w)
2100
traced = symbolic_trace(f)
2101
x, w = torch.rand(3, 4), torch.rand(4, 4)
2102
self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
2104
def test_empty_graph_codegen(self):
2105
graph = torch.fx.Graph()
2106
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2107
self.assertEqual(gm(), None)
2109
def test_sequential(self):
2110
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
2111
gm = torch.fx.symbolic_trace(m)
2112
gm_copy = copy.deepcopy(gm)
2114
def test_ctx_mgr(self):
2115
@contextlib.contextmanager
2119
class M(torch.nn.Module):
2121
def forward(self, x):
2122
return torch.relu(x)
2125
self.checkGraphModule(m, (torch.rand(3, 4),))
2127
def test_typename_print(self):
2128
graph : torch.fx.Graph = torch.fx.Graph()
2129
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2130
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
2131
type_expr=List[float])
2132
output : torch.fx.Node = graph.output(b)
2134
self.assertTrue('typing.List[float]' in str(graph))
2136
def test_layout(self):
2137
class M(torch.nn.Module):
2138
def forward(self, x):
2139
return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
2141
traced = symbolic_trace(M())
2142
x = torch.rand(5, 9, 3, 4)
2143
self.assertEqual(traced(x), torch.zeros_like(x))
2145
def test_ellipsis(self):
2146
class M(torch.nn.Module):
2147
def forward(self, x, y):
2148
return x + y[:, 1:10, ...]
2150
traced = symbolic_trace(M())
2151
x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
2152
self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
2154
def test_inf_nan(self):
2155
class FooMod(torch.nn.Module):
2156
def forward(self, x):
2157
return x + float('inf'), x + float('-inf'), x + float('nan')
2160
self.checkGraphModule(fm, (torch.rand(3, 4),))
2162
def test_inf_nan_kwds(self):
2163
graph : torch.fx.Graph = torch.fx.Graph()
2164
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2165
b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
2166
c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
2167
graph.output((b, c))
2169
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2170
x = torch.rand(3, 4)
2171
self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
2173
def test_deepcopy_recursion_depth(self):
2174
depth = sys.getrecursionlimit() + 20
2176
g = torch.fx.Graph()
2177
x = g.placeholder('x')
2178
for i in range(depth):
2179
x = g.call_function(torch.relu, (x,))
2182
copied_graph = copy.deepcopy(g)
2185
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2186
val_map[orig_node] = new_node
2188
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2189
orig_users = set(orig_node.users.keys())
2190
orig_users_equiv = {val_map[u] for u in orig_users}
2191
new_users = set(new_node.users.keys())
2192
self.assertEqual(orig_users_equiv, new_users)
2194
@skipIfNoTorchVision
2195
def test_replace_uses(self):
2196
rn18 = torchvision_models.resnet18()
2198
class LowerReluTracer(torch.fx.Tracer):
2199
def is_leaf_module(self, m : torch.nn.Module, qualname : str):
2200
if isinstance(m, torch.nn.ReLU):
2202
return super().is_leaf_module(m, qualname)
2204
rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
2207
for node in rn18_traced.graph.nodes:
2208
if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
2209
kwargs = node.kwargs.copy()
2210
# Neg doesn't have in-place
2211
kwargs.pop('inplace')
2212
with rn18_traced.graph.inserting_before(node):
2213
new_node = rn18_traced.graph.call_function(
2214
the_function=torch.neg, args=node.args, kwargs=node.kwargs)
2215
node.replace_all_uses_with(replace_with=new_node)
2216
to_erase.append(node)
2218
for node in to_erase:
2219
rn18_traced.graph.erase_node(node)
2221
def test_replace_input(self):
2222
graph : torch.fx.Graph = torch.fx.Graph()
2223
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2224
y : torch.fx.Node = graph.create_node('placeholder', 'y')
2225
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2226
output : torch.fx.Node = graph.output(b)
2228
b.replace_input_with(x, y)
2230
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2232
input_x = torch.randn(33, 44)
2233
input_y = torch.randn(11, 22)
2234
self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
2236
def test_insertion_point(self):
2237
graph : torch.fx.Graph = torch.fx.Graph()
2238
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2239
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2240
output : torch.fx.Node = graph.output(b)
2242
with graph.inserting_before(b):
2243
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2244
_, *relu_args = b.args
2245
b.args = (neg, *relu_args)
2247
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2249
input = torch.randn(33, 44)
2250
self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2252
def test_update_args_api(self):
2253
graph : torch.fx.Graph = torch.fx.Graph()
2254
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2255
y : torch.fx.Node = graph.create_node('placeholder', 'y')
2256
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2257
output : torch.fx.Node = graph.output(b)
2259
orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2260
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2261
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2264
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2265
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2267
def test_update_kwargs_api(self):
2268
graph : torch.fx.Graph = torch.fx.Graph()
2269
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2270
y : torch.fx.Node = graph.create_node('placeholder', 'y')
2271
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
2272
output : torch.fx.Node = graph.output(b)
2274
orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2275
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2276
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2278
b.update_kwarg('input', y)
2279
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2280
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2282
def test_immutable_list_pytree_ops(self):
2283
rand_tensor = torch.randn(5, 3)
2284
l = immutable_list([3, [rand_tensor, 42]])
2286
flattened, spec = pytree.tree_flatten(l)
2287
assert flattened == [3, rand_tensor, 42]
2289
unflattened = pytree.tree_unflatten(flattened, spec)
2290
assert unflattened == l
2291
assert isinstance(unflattened, immutable_list)
2293
def test_immutable_dict_pytree_ops(self):
2294
rand_tensor = torch.randn(5, 3)
2295
d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]})
2297
flattened, spec = pytree.tree_flatten(d)
2298
assert flattened == [3, rand_tensor, 42]
2300
unflattened = pytree.tree_unflatten(flattened, spec)
2301
assert unflattened == d
2302
assert isinstance(unflattened, immutable_dict)
2304
def test_move_before(self):
2305
graph : torch.fx.Graph = torch.fx.Graph()
2306
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2307
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2308
output : torch.fx.Node = graph.output(b)
2310
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2311
_, *relu_args = b.args
2312
b.args = (neg, *relu_args)
2315
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2317
input = torch.randn(33, 44)
2318
self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2320
def test_prepend_self(self):
2321
graph : torch.fx.Graph = torch.fx.Graph()
2322
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2323
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2324
output : torch.fx.Node = graph.output(b)
2328
self.assertEqual(len(graph.nodes), 3)
2330
def test_erase_node_error(self):
2332
traced = symbolic_trace(st)
2334
for node in traced.graph.nodes:
2335
# Test deleting with uses both in another Node and at the output
2336
if node.target in [operator.add, torch.relu]:
2337
with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
2338
traced.graph.erase_node(node)
2340
def test_copy_it(self):
2341
d = immutable_dict([(3, 4), (5, 6)])
2342
l = immutable_list([(3, 4), (5, 6)])
2344
self.assertEqual(d, deepcopy(d))
2345
self.assertEqual(l, deepcopy(l))
2347
def test_get_torch_func_signature(self):
2348
for key in dir(torch):
2349
obj = getattr(torch, key)
2351
schemas = get_signature_for_torch_op(obj)
2353
def test_find_uses(self):
2354
graph = torch.fx.Graph()
2355
x = torch.fx.Proxy(graph.placeholder('x'))
2360
graph.output((y + z + u).node)
2363
users_of_x = x.node.users
2364
self.assertEqual(len(users_of_x), 3)
2365
expected_ops = {'relu', 'add', 'neg'}
2366
for use in users_of_x:
2367
assert any(use.name.startswith(prefix) for prefix in expected_ops)
2369
def test_inline_graph(self):
2370
class InlineInto(torch.nn.Module):
2371
def forward(self, x):
2372
return torch.relu(x)
2374
class ToInline(torch.nn.Module):
2375
def forward(self, x):
2378
inline_into = symbolic_trace(InlineInto())
2379
to_inline = symbolic_trace(ToInline())
2381
combined_graph = torch.fx.Graph()
2382
output_node = combined_graph.graph_copy(inline_into.graph, {})
2384
input_node = next(iter(to_inline.graph.nodes))
2385
assert input_node and input_node.op == 'placeholder'
2387
val_map = {input_node : output_node}
2388
output = combined_graph.graph_copy(to_inline.graph, val_map)
2389
combined_graph.output(output)
2391
combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
2393
input = torch.rand(3, 4)
2394
self.assertEqual(combined_module(input), input.relu().neg())
2396
def test_multi_insert_point(self):
2397
graph = torch.fx.Graph()
2398
x = torch.fx.Proxy(graph.placeholder('x'))
2399
relu = torch.relu(x)
2401
with graph.inserting_before(relu.node):
2405
graph.output((relu.node, z.node))
2408
expected_ops = ['x', 'neg', 'tanh', 'relu']
2409
for node, expected in zip(graph.nodes, expected_ops):
2410
assert expected in node.name
2412
def test_reassign_args_kwargs_uses(self):
2413
graph = torch.fx.Graph()
2414
x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
2417
graph.output(zed.node)
2420
# zed = z + z + z -> zed = z + z + x
2421
zed.node.args = (zed.node.args[0], x.node)
2422
self.assertEqual(list(x.node.users.keys()), [z.node, zed.node])
2424
# z = x + y -> z = y + y
2425
z.node.args = (y.node, y.node)
2426
self.assertEqual(list(x.node.users.keys()), [zed.node])
2428
def test_trace_function(self):
2430
return torch.relu(x) + y
2432
x, y = torch.randn(3, 4), torch.randn(3, 4)
2433
self.checkGraphModule(foo, (x, y))
2435
def test_trace_return_dataclass(self):
2437
Test case for Module that return dataclass
2439
from dataclasses import dataclass
2446
class ModuleReturnDataclass(torch.nn.Module):
2447
def forward(self, d : torch.Tensor):
2448
return MyOutput(foo=d + d, bar=d * 3)
2450
module = ModuleReturnDataclass()
2451
traced_graph = symbolic_trace(module).graph
2454
gm = GraphModule(module, traced_graph)
2457
self.assertEqual(module(x), gm(x))
2459
def test_trace_return_dataclass_nested(self):
2461
Test case for Module that return dataclass
2463
from dataclasses import dataclass
2470
class ModuleReturnDataclass(torch.nn.Module):
2471
def forward(self, d : torch.Tensor):
2472
return MyOutput(foo=d + d, bar=d * 3)
2474
class CallsModule(torch.nn.Module):
2475
def __init__(self) -> None:
2477
self.m = ModuleReturnDataclass()
2479
def forward(self, x):
2481
return MyOutput(foo=tmp.foo, bar=tmp.bar)
2483
module = CallsModule()
2484
traced_graph = symbolic_trace(module).graph
2487
gm = GraphModule(module, traced_graph)
2490
self.assertEqual(module(x), gm(x))
2492
def test_trace_return_namedtuple(self):
2494
Test case for Module that return namedtuple
2496
class MyOutput(NamedTuple):
2500
class ModuleReturnNamedTuple(torch.nn.Module):
2501
def forward(self, d : torch.Tensor):
2502
return MyOutput(foo=d, bar=d)
2504
module = ModuleReturnNamedTuple()
2506
traced_graph = symbolic_trace(module).graph
2509
gm = GraphModule(module, traced_graph)
2512
self.assertEqual(module(x), gm(x))
2514
def test_trace_dict_int_keys(self):
2515
class ModWithDictArg(torch.nn.Module):
2516
def forward(self, d : Dict[int, torch.Tensor]):
2519
class CallsModWithDict(torch.nn.Module):
2520
def __init__(self) -> None:
2522
self.m = ModWithDictArg()
2524
def forward(self, x):
2525
return self.m({42: x})
2527
class MyTracer(torch.fx.Tracer):
2528
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2529
return isinstance(m, ModWithDictArg)
2531
traced_graph = MyTracer().trace(CallsModWithDict())
2533
def test_trace_dict_proxy_keys(self):
2534
class ModWithDictArg(torch.nn.Module):
2535
def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
2538
class CallsModWithDict(torch.nn.Module):
2539
def __init__(self) -> None:
2541
self.m = ModWithDictArg()
2543
def forward(self, x):
2544
return self.m({x: x})
2546
class MyTracer(torch.fx.Tracer):
2547
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2548
return isinstance(m, ModWithDictArg)
2550
with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
2551
traced_graph = MyTracer().trace(CallsModWithDict())
2553
def test_module_deepcopy_edit_nodes(self):
2554
class Foo(torch.nn.Module):
2555
def forward(self, x):
2556
return torch.relu(x)
2558
traced1 = symbolic_trace(Foo())
2559
copied = copy.deepcopy(traced1)
2561
for node in copied.graph.nodes:
2562
if node.target == torch.relu:
2563
node.target = torch.neg
2568
x = torch.randn(15, 15)
2569
torch.testing.assert_close(traced1(x), torch.relu(x))
2570
torch.testing.assert_close(copied(x), torch.neg(x))
2572
def test_direct_param_use(self):
2573
class TransposeTest(torch.nn.Module):
2574
def __init__(self) -> None:
2576
self.b = torch.nn.Parameter(torch.rand(4, 3))
2578
def forward(self, x):
2581
class Foo(torch.nn.Module):
2582
def __init__(self) -> None:
2584
self.a = TransposeTest()
2586
def forward(self, x):
2587
return self.a.b, self.a.b.t(), self.a.b.view(12)
2589
traced = torch.fx.symbolic_trace(Foo())
2590
assert all('constant' not in node.target for node in traced.graph.nodes)
2592
def test_single_default_arg(self):
2593
class M(torch.nn.Module):
2594
def forward(self, y=1):
2598
self.checkGraphModule(m, ())
2599
self.checkGraphModule(m, (3,))
2601
def test_multiple_default_args(self):
2602
class M(torch.nn.Module):
2603
def forward(self, y=1, z=2):
2607
self.checkGraphModule(m, ())
2608
self.checkGraphModule(m, (3,))
2609
self.checkGraphModule(m, (3, 4))
2611
def test_regular_and_default_args(self):
2612
class M(torch.nn.Module):
2613
def forward(self, x, y=1):
2617
self.checkGraphModule(m, (2,))
2618
self.checkGraphModule(m, (2, 3))
2620
def test_string_literal_return(self):
2621
class M(torch.nn.Module):
2626
self.checkGraphModule(m, ())
2628
def test_namedtuple_return_qualname(self):
2629
class NamedTupReturn(torch.nn.Module):
2630
def forward(self, x):
2631
return MyNamedTup(x, x)
2633
traced = symbolic_trace(NamedTupReturn())
2634
input = torch.rand(3, 4)
2635
self.assertEqual(traced(input), MyNamedTup(input, input))
2637
def test_update_args_kwargs_yells_at_you(self):
2638
symtraced = symbolic_trace(SimpleTest())
2639
node = next(iter(symtraced.graph.nodes))
2640
with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
2641
node.__update_args_kwargs((), {})
2643
def test_torchbind_class_attribute_in_fx(self):
2644
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2645
self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
2647
class FooBar1234(torch.nn.Module):
2648
def __init__(self) -> None:
2650
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
2656
self.checkGraphModule(m, ())
2658
def test_torchbind_class_attribute_in_fx_tensor_arg(self):
2659
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2660
self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
2662
class FooBar2341(torch.nn.Module):
2663
def __init__(self) -> None:
2665
self.f = torch.classes._TorchScriptTesting._ReLUClass()
2667
def forward(self, x):
2668
return self.f.run(x)
2672
traced = symbolic_trace(m)
2673
input = torch.randn(3, 4)
2674
self.assertEqual(traced(input), m(input))
2676
self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2678
def test_script_method_trace(self):
2679
class Scripted(torch.nn.Module):
2680
def forward(self, x):
2681
return torch.relu(x)
2683
class Holder(torch.nn.Module):
2684
def __init__(self) -> None:
2686
self.s = torch.jit.script(Scripted())
2688
def forward(self, x):
2692
traced = symbolic_trace(h)
2693
input = torch.randn(3, 4)
2694
self.assertEqual(traced(input), h(input))
2696
self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2698
def test_namedtuple_return_trace(self):
2699
class NamedTupReturn(torch.nn.Module):
2700
def forward(self, x):
2703
traced = symbolic_trace(NamedTupReturn())
2704
input = torch.rand(3, 4)
2705
self.assertEqual(traced(input), Pair(input, input))
2707
def test_named_tuple_inlined(self):
2708
class NamedTupMod(torch.nn.Module):
2709
def forward(self, inp):
2710
return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
2713
input = torch.rand(3, 4)
2715
traced = symbolic_trace(m)
2718
self.assertEqual(ref, res)
2720
# Check Pair NamedTuple works when inlined into the function call.
2721
ph = call_func = None
2722
for node in traced.graph.nodes:
2723
if node.op == "placeholder":
2725
elif node.op == "call_function" and node.target == wrapped_named_tup:
2726
node.update_arg(0, Pair(ph, 1.2))
2727
node.update_kwarg("p2", Pair(3.4, ph))
2730
self.assertTrue(call_func is not None)
2731
self.assertTrue(isinstance(call_func.args[0], Pair))
2732
self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
2733
self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
2734
self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
2736
traced.graph.eliminate_dead_code()
2739
self.assertEqual(ref, res)
2741
def test_return_type_exists(self):
2742
class ReturnTypeModule(torch.nn.Module):
2743
def other(self, x: List[str]) -> List[str]:
2746
def forward(self, x: List[str]) -> List[str]:
2747
return self.other(x)
2749
traced = symbolic_trace(ReturnTypeModule())
2750
self.assertIn("-> typing_List[str]", traced._code)
2751
scripted = torch.jit.script(traced)
2752
self.assertIn("-> List[str]", scripted.code)
2754
def getitem_inner(self):
2755
class GetItemBase(torch.nn.Module):
2756
def __init__(self) -> None:
2758
self.pe = torch.nn.Buffer(torch.randn(8, 8))
2760
class GetItem1(GetItemBase):
2761
def forward(self, x):
2762
return self.pe[:, :x.size(0)]
2764
class GetItem2(GetItemBase):
2765
def forward(self, x):
2766
return self.pe[x.size(0)]
2768
class GetItem3(GetItemBase):
2769
def forward(self, x):
2770
return self.pe[4] # fx creates `self._tensor_constant0` here
2772
self.checkGraphModule(GetItem1(), [torch.zeros(4)])
2773
self.checkGraphModule(GetItem2(), [torch.zeros(4)])
2774
self.checkGraphModule(GetItem3(), [torch.zeros(4)])
2776
@unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
2777
"Will be checked in test_getitem_subproc")
2778
def test_getitem(self):
2779
self.getitem_inner()
2781
def test_getitem_subproc(self):
2782
# need to run this test in a subproc to work around:
2783
# https://github.com/pytorch/pytorch/issues/50710
2784
proc = Process(target=run_getitem_target)
2787
self.assertEqual(proc.exitcode, 0)
2789
def test_user_friendly_call_provenance_with_function(self):
2791
return wrapper_fn(x)
2793
traced = torch.fx.symbolic_trace(fn)
2795
with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2796
"being compiled since it was called"
2797
" from 'fn.forward'"):
2798
scripted = torch.jit.script(traced)
2800
def test_user_friendly_call_provenance_with_module(self):
2801
class M(torch.nn.Module):
2802
def forward(self, x):
2803
return wrapper_fn(x)
2805
traced = torch.fx.symbolic_trace(M())
2807
with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2808
"being compiled since it was called"
2809
" from 'M.forward'"):
2810
scripted = torch.jit.script(traced)
2812
def test_snake_case(self):
2813
class M(torch.nn.Module):
2814
def __init__(self) -> None:
2816
self.activations = torch.nn.ModuleDict([
2817
["snake_case", torch.nn.ReLU()],
2818
["PascalCase", torch.nn.LeakyReLU()],
2819
["ALL_CAPS", torch.nn.PReLU()]
2822
def forward(self, x):
2823
a = self.activations["snake_case"](x)
2824
b = self.activations["PascalCase"](x)
2825
c = self.activations["ALL_CAPS"](x)
2828
traced = symbolic_trace(M())
2831
("activations_snake_case", "activations.snake_case"),
2832
("activations_pascal_case", "activations.PascalCase"),
2833
("activations_all_caps", "activations.ALL_CAPS")
2837
for node in traced.graph.nodes:
2838
if node.op == "placeholder" or node.op == "output":
2841
target = check[i][1]
2842
self.assertEqual(name, node.name)
2843
self.assertEqual(target, node.target)
2845
self.assertEqual(i, 3)
2847
def test_no_mutation(self):
2848
from torch.fx.immutable_collections import immutable_list
2849
x = immutable_list([3, 4])
2850
with self.assertRaisesRegex(NotImplementedError, "new_args"):
2853
def test_partial_trace(self):
2854
class Foo(torch.nn.Module):
2855
def forward(self, x, y):
2861
mod_true = symbolic_trace(mod, concrete_args={'y': True})
2862
mod_false = symbolic_trace(mod, concrete_args={'y': False})
2863
self.assertEqual(mod_true(3, True), 6)
2864
print(mod_true.code)
2865
assert any(i.target == torch._assert for i in mod_true.graph.nodes)
2866
with self.assertRaises(AssertionError):
2868
self.assertEqual(mod_false(3, False), 3)
2869
with self.assertRaises(AssertionError):
2875
nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
2876
self.assertEqual(nf(3, lambda x: x * 2), 6)
2878
def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
2879
class M(torch.nn.Module):
2880
def __init__(self) -> None:
2882
self.W = torch.nn.Parameter(torch.randn(5))
2884
def forward(self, x):
2885
return torch.dot(self.W, x)
2887
traced = torch.fx.symbolic_trace(M())
2889
out = [n for n in traced.graph.nodes if n.op == "output"][-1]
2890
with traced.graph.inserting_before(out):
2891
relu_out = traced.graph.call_method(method_name='relu',
2892
args=(out.args[0],))
2893
out.args = (relu_out,)
2897
with self.capture_stderr() as captured:
2898
with self.assertRaises(TypeError):
2901
self.assertRegex(captured[0],
2902
r"Call using an FX-traced Module, line .* of the "
2903
r"traced Module's generated forward function:")
2905
def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
2906
class M(torch.nn.Module):
2907
def __init__(self) -> None:
2909
self.linear = torch.nn.Linear(3, 4)
2911
def forward(self, x):
2912
return self.linear(x)
2914
traced = torch.fx.symbolic_trace(M())
2916
# Do not change this to `capture_stderr` or another context
2917
# manager without ensuring that the output is as expected
2919
traced(torch.rand(5, 5))
2920
except RuntimeError:
2921
captured = traceback.format_exc()
2923
self.assertNotRegex(captured,
2924
r"Call using an FX-traced Module, line .* of the "
2925
r"traced Module's generated forward function:")
2927
def test_graph_module_replicate_for_dp(self):
2928
class Foo(torch.nn.Module):
2929
def forward(self, x):
2930
return torch.relu(x)
2932
gm = torch.fx.symbolic_trace(Foo())
2934
x = torch.randn(5, 3)
2937
replica = gm._replicate_for_data_parallel()
2938
out_replica = replica(x)
2940
torch.testing.assert_close(out_replica, out)
2942
def test_ast_rewriter_rewrites_assert(self):
2943
class M(torch.nn.Module):
2944
def forward(self, x: torch.Tensor, y: int, z: int):
2946
return torch.add(x, x)
2948
ast_rewriter = RewritingTracer()
2949
graph = ast_rewriter.trace(M())
2950
traced = GraphModule(ast_rewriter.root, graph, "gm")
2954
def test_ast_rewriter_rewrites_assert_with_message(self):
2955
class M(torch.nn.Module):
2956
def forward(self, x: torch.Tensor, y: int, z: int):
2957
assert y == z, "msg"
2958
return torch.add(x, x)
2960
ast_rewriter = RewritingTracer()
2961
graph = ast_rewriter.trace(M())
2962
traced = GraphModule(ast_rewriter.root, graph, "gm")
2966
def test_throw_out_variant(self):
2968
y = torch.rand_like(x)
2969
torch.sigmoid(x, out=y)
2972
class MyTracer(torch.fx.Tracer):
2973
check_mutable_operations = True
2976
with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
2977
traced_graph = tracer.trace(foo)
2979
def test_ast_rewriter_reassigns_submodules(self):
2980
class M(torch.nn.Module):
2981
def __init__(self) -> None:
2983
self.bn = torch.nn.BatchNorm2d(100)
2985
def forward(self, x: torch.Tensor):
2986
return torch.add(x, x)
2988
ast_rewriter = RewritingTracer()
2989
graph = ast_rewriter.trace(M())
2990
traced = GraphModule(ast_rewriter.root, graph, "gm")
2994
def test_ast_rewriter_wrap(self):
2995
self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
2999
a_lifted_leaf((4, y), 3)
3000
+ a_lifted_leaf((3, 4), 5)
3001
+ a_lifted_leaf((y, y), y)
3004
ast_rewriter = RewritingTracer()
3005
graph = ast_rewriter.trace(to_trace)
3006
traced = GraphModule(ast_rewriter.root, graph, "gm")
3008
self.assertIn("a_lifted_leaf", traced.code)
3009
self.assertEqual(27, traced(2))
3010
self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
3012
def test_ast_rewriter_wrap_fn_directly(self):
3013
self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
3017
a_lifted_leaf2((4, y), 3)
3018
+ a_lifted_leaf2((3, 4), 5)
3019
+ a_lifted_leaf2((y, y), y)
3022
ast_rewriter = RewritingTracer()
3023
graph = ast_rewriter.trace(to_trace)
3024
traced = GraphModule(ast_rewriter.root, graph, "gm")
3026
self.assertIn("a_lifted_leaf2", traced.code)
3027
self.assertEqual(27, traced(2))
3028
self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
3030
def test_profiler_ranges_side_effect(self):
3031
g = torch.fx.Graph()
3032
handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',))
3033
g.call_function(torch.ops.profiler._record_function_exit, (handle,))
3037
for node in g.nodes:
3038
if node.op == 'call_function':
3039
found_targets.setdefault(node.target)
3041
list(found_targets.keys()),
3042
[torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3045
g.eliminate_dead_code()
3047
for node in g.nodes:
3048
if node.op == 'call_function':
3049
found_targets.setdefault(node.target)
3051
list(found_targets.keys()),
3052
[torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3055
def test_ast_rewriter_wrapped_via_decorator(self):
3056
class F(torch.nn.Module):
3057
def forward(self, x):
3058
return wrapped_via_decorator(x)
3060
ast_rewriter = RewritingTracer()
3061
graph = ast_rewriter.trace(F())
3062
traced = GraphModule(ast_rewriter.root, graph, "gm")
3064
self.assertIn("wrapped_via_decorator", traced.code)
3065
self.assertEqual(traced(0), 1)
3066
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3067
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3069
def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
3070
self.assertEqual(wrapped_via_decorator(0), 1)
3073
return wrapped_via_decorator(y)
3075
ast_rewriter = RewritingTracer()
3076
graph = ast_rewriter.trace(to_trace)
3077
traced = GraphModule(ast_rewriter.root, graph, "gm")
3079
self.assertIn("wrapped_via_decorator", traced.code)
3080
self.assertEqual(traced(0), 1)
3081
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3082
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3084
transformed = torch.fx.Transformer(traced).transform()
3085
self.assertIn("wrapped_via_decorator", transformed.code)
3086
self.assertEqual(transformed(0), 1)
3087
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3088
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3090
def test_ast_rewriter_wrap_with_submodule(self):
3091
class M(torch.nn.Module):
3092
def __init__(self) -> None:
3094
self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3096
def forward(self, x: torch.Tensor):
3097
return wrapped_with_submodule(x, self.batchnorm1d)
3099
ast_rewriter = RewritingTracer()
3100
graph = ast_rewriter.trace(M())
3101
traced = GraphModule(ast_rewriter.root, graph, "gm")
3103
self.assertIn("wrapped_with_submodule", traced.code)
3105
input = torch.rand(3, 2)
3106
ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3107
self.assertEqual(ref_batchnorm1d(input), traced(input))
3109
def test_submodule_manipulation_API(self):
3110
class C(torch.nn.Module):
3111
def __init__(self) -> None:
3113
self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
3114
self.param = torch.nn.Parameter(torch.rand(2, 3))
3116
def forward(self, x):
3117
return self.conv(torch.cat([self.param, x]))
3119
class B(torch.nn.Module):
3120
def __init__(self) -> None:
3122
self.linear = torch.nn.Linear(100, 200)
3123
self.buf = torch.nn.Buffer(torch.randn(2, 3))
3126
def forward(self, x):
3127
return self.linear(torch.cat([self.buf, self.net_c(x)]))
3129
class A(torch.nn.Module):
3130
def __init__(self) -> None:
3133
self.param = torch.nn.Parameter(torch.rand(2, 3))
3135
def forward(self, x):
3136
return self.net_b(x) + self.param
3138
a = symbolic_trace(A())
3140
a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
3142
conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
3143
with a.graph.inserting_before(conv):
3144
with warnings.catch_warnings(record=True) as w:
3145
dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
3147
self.assertEqual(len(w), 0)
3149
conv.replace_all_uses_with(dropout)
3150
a.graph.erase_node(conv)
3153
def module_exists(gm: GraphModule, path: str) -> bool:
3154
return any(path == name for name, _ in gm.named_modules())
3156
def parameter_exists(gm: GraphModule, path: str) -> bool:
3157
return (any(path == name for name, _ in gm.named_parameters())
3158
and any(path == name for name in gm.state_dict().keys()))
3160
def buffer_exists(gm: GraphModule, path: str) -> bool:
3161
return (any(path == name for name, _ in gm.named_buffers())
3162
and any(path == name for name in gm.state_dict().keys()))
3164
# Test that we added the "dropout" submodule
3165
self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
3167
# Test `get_submodule` with an added submodule
3168
self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
3170
# Test that the "conv" submodule is still there
3171
self.assertTrue(module_exists(a, "net_b.net_c.conv"))
3173
# Test `get_submodule` with an original module
3174
self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
3176
# Test that the "conv" node is NOT still there
3177
conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
3178
self.assertEqual(conv, [])
3180
a.delete_submodule("net_b.net_c.conv")
3182
# Test that the "conv" submodule is now gone
3183
self.assertFalse(module_exists(a, "net_b.net_c.conv"))
3185
# Test `get_submodule` with a deleted submodule
3186
with self.assertRaisesRegex(AttributeError, "has no attribute "
3188
self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
3190
# Test `get_attr` warnings
3191
cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
3193
with a.graph.inserting_before(cat):
3195
with warnings.catch_warnings(record=True) as w:
3196
param = a.graph.get_attr(qualified_name="net_b.net_c.param")
3197
self.assertEqual(len(w), 0)
3199
with self.assertWarnsRegex(UserWarning, "Attempted to "
3200
"insert a get_attr Node with no "
3201
"underlying reference in the "
3202
"owning GraphModule"):
3203
bad_param = a.graph.get_attr(qualified_name="net_b.param")
3204
a.graph.erase_node(bad_param)
3206
cat.args = (*cat.args, param)
3212
# Test `get_parameter`
3213
a.get_parameter("net_b.net_c.param")
3214
with self.assertRaisesRegex(AttributeError, "is not an "
3216
a.get_parameter("net_b.buf")
3217
with self.assertRaisesRegex(AttributeError, "has no attribute "
3219
a.get_parameter("net_b.param")
3222
a.get_buffer("net_b.buf")
3223
with self.assertRaisesRegex(AttributeError, "is not a "
3225
a.get_buffer("net_b.net_c.param")
3226
with self.assertRaisesRegex(AttributeError, "has no attribute "
3228
a.get_buffer("net_b.net_c.buf")
3230
# Test non-nested attributes
3232
a.get_parameter("param")
3234
# Insert some unused submodules
3235
a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
3236
a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
3237
a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
3238
a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
3240
# Garbage collection
3241
a.delete_all_unused_submodules()
3243
# Test that all the unused submodules are gone
3244
self.assertFalse(module_exists(a, "net_b.embedding"))
3245
self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
3246
self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
3247
self.assertFalse(module_exists(a, "batch_norm_2d"))
3249
# Test that we didn't delete any unused Parameters or buffers
3250
self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
3251
self.assertTrue(buffer_exists(a, "net_b.buf"))
3255
def test_delete_unused_submodules_leaf(self):
3256
class SubModule(torch.nn.Module):
3257
def __init__(self) -> None:
3259
self.linear = torch.nn.Linear(10, 10)
3260
self.relu = torch.nn.ReLU()
3262
def forward(self, x):
3267
class Model(torch.nn.Module):
3268
def __init__(self) -> None:
3270
self.submod = SubModule()
3272
def forward(self, x):
3278
class MyCustomTracer(torch.fx.Tracer):
3279
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
3280
return module_qualified_name == "submod"
3282
inputs = torch.randn(1, 10)
3283
traced_graph = MyCustomTracer().trace(model)
3284
gm2 = torch.fx.GraphModule(model, traced_graph)
3285
gm2.delete_all_unused_submodules()
3286
torch.testing.assert_close(gm2(inputs), model(inputs))
3288
def test_fx_stateless(self):
3289
class MockModule(torch.nn.Module):
3290
def __init__(self) -> None:
3292
self.l1 = torch.nn.Linear(1, 1)
3293
self.buffer = torch.nn.Buffer(torch.ones(1))
3295
def forward(self, x):
3296
return self.l1(x) + self.buffer
3298
module = MockModule()
3299
x = torch.rand((1, 1))
3300
weight = torch.tensor([[1.0]], requires_grad=True)
3301
bias = torch.tensor([0.0], requires_grad=True)
3302
buffer = torch.tensor([0.0])
3303
parameters = {'l1.weight': weight,
3306
fx_module = torch.fx.symbolic_trace(module)
3307
res = torch.func.functional_call(fx_module, parameters, x)
3309
self.assertIsNotNone(weight.grad)
3310
self.assertIsNotNone(bias.grad)
3311
self.assertIsNone(buffer.grad)
3312
# Gradient was not calculated for the module stated and buffers
3313
self.assertIsNone(module.l1.weight.grad)
3314
self.assertIsNone(module.l1.bias.grad)
3315
self.assertIsNone(module.buffer.grad)
3317
def test_tracing_graphmodules_as_leaf_submodules(self):
3318
class A(torch.nn.Module):
3319
def forward(self, t):
3322
class B(torch.nn.Module):
3323
def __init__(self) -> None:
3324
super(type(self), self).__init__()
3325
self.calling = False
3328
def forward(self, t):
3334
def __call__(self, *args):
3337
return super(type(self), self).__call__(*args)
3338
self.calling = False
3340
class M(torch.nn.Module):
3341
def __init__(self, a, b):
3346
def forward(self, t):
3351
class LeafTracer(Tracer):
3352
def is_leaf_module(self, module, name):
3355
class LeafTracerNotB(Tracer):
3356
def is_leaf_module(self, module, name):
3357
return False if "b" in name else True
3359
# Recompile calls added "for fun", since they
3360
# chain __call__ wrappers.
3363
# Test: B as a regular, non-leaf module
3365
a = symbolic_trace(A())
3368
graph = LeafTracerNotB().trace(m)
3369
gm = GraphModule(m, graph)
3372
# Test graphmodule/submodule a is not inlined.
3373
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3374
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3375
self.assertTrue(len(match) == 1)
3377
# Test submodule b is not treated as leaf.
3378
self.assertFalse(hasattr(gm, "b"))
3380
# Test assert custom __call__ on submodule b was honored.
3383
for n in gm.graph.nodes
3384
if n.op == "call_function" and n.target == operator.sub
3386
self.assertTrue(len(match) == 1)
3389
# Test: B as a regular, leaf module
3390
# symbolic_trace should only patch torch.nn.Module.__call__,
3391
# which means B.__call__ should still execute
3393
a = symbolic_trace(A())
3397
graph = LeafTracer().trace(m)
3398
gm = GraphModule(m, graph)
3401
# Test graphmodule/submodule a is not inlined.
3402
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3403
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3404
self.assertTrue(len(match) == 1)
3406
# Test submodule b is leaf:
3407
self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3408
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3409
self.assertTrue(len(match) == 1)
3411
# Test b.__call__ was run
3412
self.assertTrue(b.called)
3413
self.assertTrue(gm.get_submodule("b").called)
3416
# Test: B as GraphModule leaf
3417
# __call__ not honored since symbolic_trace directly invokes forward()
3419
a = symbolic_trace(A())
3421
b = symbolic_trace(B())
3424
graph = LeafTracer().trace(m)
3425
gm = GraphModule(m, graph)
3428
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3429
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3430
self.assertTrue(len(match) == 1)
3432
self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3433
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3434
self.assertTrue(len(match) == 1)
3436
def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
3437
class MyModule(torch.nn.Module):
3438
def __init__(self) -> None:
3440
self.my_buff = torch.nn.Buffer(torch.rand(3, 4))
3441
self.register_parameter(
3442
"my_param", torch.nn.Parameter(torch.rand(3, 4))
3445
def forward(self, x):
3446
return x + self.my_buff + self.my_param
3449
mod_traced = symbolic_trace(mod)
3451
# Create new GraphModule based on original, either w/ dict or root module.
3452
orig_buff = mod_traced.get_buffer("my_buff")
3453
orig_param = mod_traced.get_parameter("my_param")
3454
mod_traced_new = GraphModule(
3455
{"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
3459
# Check that both my_buff and my_param are found and the same.
3461
new_buff = mod_traced_new.get_buffer("my_buff")
3463
self.fail("Did not find my_buff")
3464
self.assertEqual(orig_buff, new_buff)
3467
new_param = mod_traced_new.get_parameter("my_param")
3469
self.fail("Did not find my_param")
3470
self.assertEqual(orig_param, new_param)
3472
x = torch.rand(3, 4)
3473
orig_out = mod_traced(x)
3474
submodules_out = mod_traced_new(x)
3476
self.assertEqual(orig_out, submodules_out)
3478
def test_graph_module_init_buffer_param_copied_dict_init(self):
3479
self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
3481
def test_graph_module_init_buffer_param_copied_mod_init(self):
3482
self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
3484
def test_annotations_with_no_forward_references(self):
3486
def __call__(self, x: torch.Tensor):
3487
return torch.add(x, x)
3489
class M(torch.nn.Module):
3490
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
3493
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3495
def test_annotations_with_forward_references(self):
3497
def __call__(self, x: torch.Tensor):
3498
return torch.add(x, x)
3500
class M(torch.nn.Module):
3501
def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
3504
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3506
def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
3508
def __call__(self, x: torch.Tensor):
3509
return torch.add(x, x)
3511
class M(torch.nn.Module):
3512
def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
3515
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3517
def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
3519
def __call__(self, x: torch.Tensor):
3520
return torch.add(x, x)
3522
class M(torch.nn.Module):
3523
def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
3526
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3528
@unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
3529
"`annotations` is not defined in Python <3.7")
3530
def test_annotation_with_future(self):
3532
import fx.test_future # noqa: F401
3534
del sys.modules["__future__"]
3536
@unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
3537
def test_annotations_empty_tuple(self):
3538
class Foo(torch.nn.Module):
3539
def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
3542
traced = torch.fx.symbolic_trace(Foo())
3549
FileCheck().check("_Tuple[()]") \
3550
.check("typing_Tuple[str,typing_Tuple[()]]") \
3553
scripted = torch.jit.script(traced)
3557
FileCheck().check("Tuple[()]") \
3558
.check("Tuple[str, Tuple[()]]") \
3561
@unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108")
3562
@unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10")
3563
def test_assert(self):
3568
torch.fx.proxy.TracerBase.trace_asserts = True
3569
traced = symbolic_trace(f)
3571
torch.fx.proxy.TracerBase.trace_asserts = False
3573
self.assertEqual(f(2), traced(2))
3574
with self.assertRaises(AssertionError):
3577
def test_pytree(self):
3578
# Used to test that you can use your own placeholder class
3579
class PHTest(PHBase):
3587
for v in x.values():
3591
def f_dict_list_map(x):
3593
for k, v in x.items():
3594
new_dict[k] = [i + 1 for i in v]
3598
return x['a'] + sum(x['z'])
3600
def f_namedtuple_add(x):
3603
pytree.register_pytree_node(
3605
lambda x: ([x.a, x.b], None),
3606
lambda x, _: Foo(x[0], x[1]),
3608
fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
3613
def f_custom_dict(x):
3614
return f_sum_dict(x.a) + x.b
3616
def f_return_custom(x):
3617
return Foo(x.b, x.a)
3620
(f_sum, [PH, PH, PH]),
3622
(f_sum, [PHTest(), PHTest(), PHTest()]),
3623
(f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
3624
(f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
3625
(f_dict_list_map, {5: (PH, PH, PH)}),
3626
(f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
3627
(f_dict_add, {'a': PH, 'z': []}),
3628
(f_custom, Foo(PH, PH)),
3629
(f_custom, Foo(PH, 3)),
3630
(f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
3631
# (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
3632
(f_namedtuple_add, Point(PH, PH)),
3635
def verify_pytree(f, inp):
3636
val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
3637
num_flat_args = len(pytree.tree_leaves(inp))
3639
nf = symbolic_trace(f, concrete_args={'x': inp})
3640
self.assertEqual(nf(val), orig_out)
3642
bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3643
bare_fx.graph.set_codegen(CodeGen())
3645
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
3647
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3648
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
3650
nf = symbolic_trace(nf)
3651
self.assertEqual(nf(val), orig_out)
3652
assert "tree_flatten_spec" not in nf.code
3653
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1
3655
nf = symbolic_trace(nf, concrete_args={'x': inp})
3656
self.assertEqual(nf(val), orig_out)
3657
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3658
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
3660
pickled = pickle.dumps(nf)
3661
nf = pickle.loads(pickled)
3662
self.assertEqual(nf(val), orig_out)
3664
for f, inp in tests:
3665
verify_pytree(f, inp)
3667
def test_pytree_concrete(self):
3674
inp = {'a': {'a': PH, 'z': PH}, 'b': True}
3675
nf = symbolic_trace(f, concrete_args=inp)
3676
val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
3677
self.assertEqual(nf(**val), f(**val))
3679
nf = symbolic_trace(nf)
3680
self.assertEqual(nf(**val), f(**val))
3682
def test_metadata_on_ph(self):
3683
def f_sum(a: int, b: int) -> int:
3686
# Due to unflattening of dict, the batch argument
3687
# will be split into two separate nodes with the names
3688
# "batch_1" and "batch_2", referring to the keys
3689
# "f1" and "f2" respectively in the dict.
3690
def f_dict(a: Dict[str, str]) -> bool:
3691
return a["f1"] == a["f2"]
3693
def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]):
3694
for node in gm.graph.nodes:
3695
if node.op == "placeholder":
3696
self.assertTrue(node.name in arg_names)
3697
self.assertTrue(node.ph_key in metadata)
3702
concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")}
3704
arg_names=["a_1", "b_1"],
3710
concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}}
3712
arg_names=["a_1", "a_2"],
3713
metadata=["f1", "f2"]
3716
# Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag)
3717
class TaggingTracer(Tracer):
3718
def create_node(self, kind : str, target : Union[str, Callable],
3719
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
3720
type_expr : Optional[Any] = None) -> Node:
3721
n = super().create_node(kind, target, args, kwargs, name)
3725
class PHWithTag(PHBase):
3726
def __init__(self, tag: str):
3731
g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")})
3733
self.assertTrue(hasattr(n, "tag"))
3734
# Ensure that tag is still "foo" and not "bar" (from PHWithTag)
3735
self.assertEqual(n.tag, "foo")
3737
def test_custom_codegen(self):
3738
class ListCodeGen(CodeGen):
3739
def gen_fn_def(self, free_vars, maybe_return_annotation):
3741
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3742
{', '.join(free_vars)} = args_list"""
3745
def additional_globals(self):
3746
return [('List', typing.List)]
3748
def process_inputs(self, *inputs):
3749
assert len(inputs) == 1
3755
nf = symbolic_trace(f)
3756
vals = [torch.randn(3), torch.randn(3)]
3757
self.assertEqual(nf(*vals), f(*vals))
3759
nf.graph.set_codegen(ListCodeGen())
3762
bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3763
bare_fx.graph.set_codegen(CodeGen())
3766
self.assertEqual(nf(vals), f(*vals))
3767
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals))
3769
ts_f = torch.jit.script(nf)
3770
self.assertEqual(nf(vals), ts_f(vals))
3772
def test_custom_codegen_with_transformer(self):
3773
class ListCodeGen(CodeGen):
3774
def gen_fn_def(self, free_vars, maybe_return_annotation):
3776
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3777
{', '.join(free_vars)} = args_list"""
3780
def additional_globals(self):
3781
return [('List', typing.List)]
3783
def process_inputs(self, *inputs):
3784
assert len(inputs) == 1
3790
nf = symbolic_trace(f)
3791
vals = [torch.randn(3), torch.randn(3)]
3792
self.assertEqual(nf(*vals), f(*vals))
3794
nf.graph.set_codegen(ListCodeGen())
3796
self.assertEqual(nf(vals), f(*vals))
3798
transformed_gm = Transformer(nf).transform()
3799
self.assertEqual(nf(vals), transformed_gm(vals))
3801
def test_interpreter_with_codegen(self):
3802
class ListCodeGen(CodeGen):
3803
def gen_fn_def(self, free_vars, maybe_return_annotation):
3805
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3806
{', '.join(free_vars)} = args_list"""
3809
def additional_globals(self):
3810
return [('List', typing.List)]
3812
def process_inputs(self, *inputs):
3813
assert len(inputs) == 1
3816
def generate_output(self, output_args):
3817
return f'return list({repr(output_args)})'
3819
def process_outputs(self, outputs):
3820
return list(outputs)
3827
nf = symbolic_trace(f)
3828
vals = [torch.randn(3), torch.randn(3)]
3829
nf.graph.set_codegen(ListCodeGen())
3831
self.assertEqual(Interpreter(nf).run(vals), nf(vals))
3833
def test_imul_code_print(self):
3834
graph = torch.fx.Graph()
3835
a = graph.placeholder("a")
3836
b = graph.placeholder("b")
3837
graph.call_function(operator.imul, (a, b), {})
3839
gm = torch.fx.GraphModule({}, graph)
3841
self.assertEqual(gm(2, 3), 6)
3842
self.assertIn("a *= b", gm.code)
3844
def test_deepcopy_tracer(self):
3846
return (x + y).relu().sin()
3849
tracer_before = copy.deepcopy(tracer)
3851
tracer_after = copy.deepcopy(tracer)
3853
self.assertEqual(str(tracer.graph), str(tracer_after.graph))
3854
self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph))
3856
def test_deepcopy_graphmodule(self):
3857
m = symbolic_trace(SimpleTest())
3858
m.meta['hello'] = 'world'
3859
copy_m = copy.deepcopy(m)
3860
self.assertEqual(copy_m.meta['hello'], 'world')
3862
def test_deepcopy_no_recursion(self):
3863
m = symbolic_trace(SimpleTest())
3864
m.meta['hello'] = m # circular reference
3865
copy_m = copy.deepcopy(m) # finishes
3866
self.assertEqual(id(copy_m), id(copy_m.meta['hello']))
3868
def test_enum(self):
3869
from enum import Enum
3875
def leaf_fn(arr, enum_val):
3877
arr.append(enum_val)
3878
return arr[-1].value
3881
# Pass the enum as argument.
3882
return leaf_fn(x, Foo.A)
3884
traced = torch.fx.symbolic_trace(foo)
3885
self.assertEqual(foo([]), traced([]))
3887
def test_insert_arg(self):
3888
m = symbolic_trace(SimpleTest())
3889
m.buf = torch.nn.Buffer(torch.tensor(0))
3890
output_node = next(iter(reversed(m.graph.nodes)))
3891
with m.graph.inserting_before(output_node):
3892
a = m.graph.get_attr("buf")
3893
r = len(output_node.args)
3894
output_node.insert_arg(0, a)
3895
self.assertEqual(len(output_node.args), r + 1)
3896
self.assertEqual(len(a.users), 1)
3897
self.assertIs(output_node.args[0], a)
3898
self.assertIs(next(iter(a.users.keys())), output_node)
3899
output_node.insert_arg(2, a)
3900
self.assertEqual(len(output_node.args), r + 2)
3901
self.assertEqual(len(a.users), 1)
3902
self.assertIs(output_node.args[2], a)
3903
self.assertIs(next(iter(a.users.keys())), output_node)
3906
def test_delete_unused_values(self):
3907
from torch.fx.experimental.proxy_tensor import make_fx
3909
# disable mutable checking temporarily
3910
orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3911
torch.fx.proxy.TracerBase.check_mutable_operations = False
3920
a, b, c, d = (torch.randn(2, 4, requires_grad=False) for _ in range(4))
3921
fx_fn = make_fx(fn)(a, b, c, d)
3924
fx_fn.graph.eliminate_dead_code()
3925
py_code = fx_fn.recompile()
3926
self.assertTrue("copy_ = torch.ops.aten.copy_.default" in py_code.src)
3927
self.assertTrue("copy_ = None" in py_code.src)
3929
# recorver mutable checking flag
3930
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
3932
def run_getitem_target():
3933
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
3934
_wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
3936
TestFX().getitem_inner()
3938
_wrapped_methods_to_patch.pop()
3941
class TestOperatorSignatures(JitTestCase):
3943
# Checking for mutable operations whil tracing is feature flagged
3944
# Enable it in testing but not by default
3945
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3946
torch.fx.proxy.TracerBase.check_mutable_operations = True
3949
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3952
@ops(op_db, allowed_dtypes=(torch.float,))
3953
def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
3954
if not isinstance(op.op, types.BuiltinFunctionType):
3955
raise unittest.SkipTest("This path doesn't work on Python functions")
3956
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
3957
schemas = get_signature_for_torch_op(op.op)
3959
raise RuntimeError('No Schemas Returned')
3960
for sample_input in sample_inputs_itr:
3961
# Iterate through overloads until we hit a match. If we exit this
3962
# loop via `else`, we haven't found a match
3963
for schema in schemas:
3965
bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
3966
bound_args.apply_defaults()
3967
op(*bound_args.args, **bound_args.kwargs)
3969
except TypeError as e:
3972
raise RuntimeError(f'Did not match any schemas for op {op.name}!')
3975
class TestFXAPIBackwardCompatibility(JitTestCase):
3980
# Checking for mutable operations whil tracing is feature flagged
3981
# Enable it in testing but not by default
3982
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3983
torch.fx.proxy.TracerBase.check_mutable_operations = True
3987
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3990
def _fn_to_stable_annotation_str(self, obj):
3992
Unfortunately we have to serialize function signatures manually since
3993
serialization for `inspect.Signature` objects is not stable across
3996
fn_name = torch.typename(obj)
3998
signature = inspect.signature(obj)
4000
sig_str = f'{fn_name}{signature}'
4003
for k, v in signature.parameters.items():
4004
maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
4005
if v.annotation is not inspect.Signature.empty else ''
4007
def default_val_str(val):
4008
if isinstance(val, (tuple, list)):
4009
str_pieces = ['(' if isinstance(val, tuple) else '[']
4010
str_pieces.append(', '.join(default_val_str(v) for v in val))
4011
if isinstance(val, tuple) and len(str_pieces) == 2:
4012
str_pieces.append(',')
4013
str_pieces.append(')' if isinstance(val, tuple) else ']')
4014
return ''.join(str_pieces)
4016
# Need to fix up some default value strings.
4017
# First case: modules. Default module `repr` contains the FS path of the module.
4019
if isinstance(val, types.ModuleType):
4020
return f'<module {val.__name__}>'
4022
# Second case: callables. Callables (such as lambdas) encode their address in
4023
# their string repr. Don't do that
4025
return f'<function {val.__name__}>'
4029
if v.default is not inspect.Signature.empty:
4030
default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
4031
maybe_default = f' = {default_val_str}'
4035
if v.kind == inspect.Parameter.VAR_POSITIONAL:
4037
elif v.kind == inspect.Parameter.VAR_KEYWORD:
4039
arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
4041
return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
4042
if signature.return_annotation is not inspect.Signature.empty else ''
4044
return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
4046
def _annotation_type_to_stable_str(self, t, sig_str):
4047
if t is inspect.Signature.empty:
4051
if isinstance(t, str):
4053
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
4054
return t.__forward_arg__
4055
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
4056
return t.__forward_arg__
4058
trivial_mappings = {
4063
torch.dtype: 'torch.dtype',
4064
torch.Tensor: 'torch.Tensor',
4065
torch.device: 'torch.device',
4066
torch.memory_format: 'torch.memory_format',
4068
torch.nn.Module: 'torch.nn.modules.module.Module',
4069
torch.fx.Graph : 'torch.fx.graph.Graph',
4070
torch.fx.Node : 'torch.fx.node.Node',
4071
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
4072
torch.fx.node.Target : 'torch.fx.node.Target',
4073
torch.fx.node.Argument : 'torch.fx.node.Argument',
4074
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
4075
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
4076
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
4079
type(None): 'NoneType',
4081
typing.Iterator: 'Iterator',
4084
mapping = trivial_mappings.get(t, None)
4088
# Handle types with contained types
4089
contained = getattr(t, '__args__', None) or []
4091
# Callables contain a bare List for arguments
4092
contained = t if isinstance(t, list) else contained
4094
# Python 3.8 puts type vars into __args__ for unbound types such as Dict
4095
if all(isinstance(ct, typing.TypeVar) for ct in contained):
4098
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
4099
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
4102
origin = getattr(t, '__origin__', None)
4104
# Unbound types don't have `__origin__` in some Python versions, so fix that up here.
4105
origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
4107
if origin in {tuple, typing.Tuple}:
4108
return f'Tuple{contained_type_str}'
4109
if origin in {typing.Union}:
4110
# Annoying hack to detect Optional
4111
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
4112
not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
4113
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
4114
return f'Union{contained_type_str}'
4115
if origin in {dict, typing.Dict}:
4116
return f'Dict{contained_type_str}'
4117
if origin in {list, typing.List}:
4118
return f'List{contained_type_str}'
4119
if origin in {type, typing.Type}:
4120
return f'Type{contained_type_str}'
4121
if isinstance(t, typing.Callable):
4122
if len(contained) > 0 and contained[0] is not Ellipsis:
4123
return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
4125
return f'Callable{contained_type_str}'
4127
raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
4128
f'Please add support for this type and confirm with the '
4129
f'FX team that your signature change is valid.')
4132
def test_function_back_compat(self):
4134
Test backward compatibility for function signatures with
4135
@compatibility(is_backward_compatible=True). Currently this checks for
4136
exact signature matches, which may lead to false positives. If this
4137
becomes too annoying, we can refine this check to actually parse out
4138
the saved schema strings and check if the change is truly backward-
4143
for obj in _BACK_COMPAT_OBJECTS:
4144
if not isinstance(obj, type):
4145
signature_strs.append(self._fn_to_stable_annotation_str(obj))
4147
signature_strs.sort()
4150
self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures')
4151
except AssertionError as e:
4152
msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
4153
f"as backwards-compatible has experienced a signature change. See the " \
4154
f"above exception context for more information. If this change was " \
4155
f"unintended, please revert it. If it was intended, check with the FX " \
4156
f"team to ensure that the proper deprecation protocols have been followed " \
4157
f"and subsequently --accept the change."
4158
raise AssertionError(msg) # noqa: B904
4160
def test_class_member_back_compat(self):
4162
Test backward compatibility for members of classes with
4163
@compatibility(is_backward_compatible=True). Currently this checks for
4164
exact matches on the publicly visible members of the class.
4166
class_method_strs = []
4168
for obj in _BACK_COMPAT_OBJECTS:
4169
if isinstance(obj, type):
4170
public_members = [name for name in obj.__dict__ if not name.startswith('_')]
4171
class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
4173
class_method_strs.sort()
4176
self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
4177
except AssertionError as e:
4178
msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
4179
f"as backwards-compatible has experienced change in its public members. See the " \
4180
f"above exception context for more information. If this change was " \
4181
f"unintended, please revert it. If it was intended, check with the FX " \
4182
f"team to ensure that the proper deprecation protocols have been followed " \
4183
f"and subsequently --accept the change."
4184
raise AssertionError(msg) from e
4186
def test_public_api_surface(self):
4187
non_back_compat_objects = {}
4189
def check_symbols_have_bc_designation(m, seen):
4190
if not m.__name__.startswith('torch.fx'):
4192
if m.__name__.startswith('torch.fx.experimental'):
4194
# It's really common for inner functions to point to random modules
4195
# - make sure we don't recurse into modules we've already checked.
4196
seen.add(m.__name__)
4197
for k, v in m.__dict__.items():
4198
if hasattr(v, '__name__') and v.__name__ in seen:
4202
if k.startswith('_'):
4204
if isinstance(v, types.ModuleType):
4205
check_symbols_have_bc_designation(v, seen)
4206
elif isinstance(v, (type, types.FunctionType)):
4207
if v not in _MARKED_WITH_COMPATIBILITY:
4208
non_back_compat_objects.setdefault(v)
4210
check_symbols_have_bc_designation(torch.fx, set())
4211
check_symbols_have_bc_designation(torch.fx.passes, set())
4213
non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
4214
# Only want objects in torch.fx
4215
non_back_compat_strs = [
4216
s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
4217
# Only want objects in public namespaces
4218
non_back_compat_strs = [
4219
s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
4220
non_back_compat_strs.sort()
4222
if len(non_back_compat_strs) != 0:
4223
raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
4224
f"backwards-compatibility classification! Please decorate these "
4225
f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
4228
def test_adding_side_effect_function(self):
4229
class TestModule(torch.nn.Module):
4230
def forward(self, x):
4234
gm = torch.fx.symbolic_trace(TestModule())
4235
self.assertEqual(len(gm.graph.nodes), 3)
4236
gm.graph.eliminate_dead_code()
4238
self.assertEqual(len(gm.graph.nodes), 3)
4240
for node in gm.graph.nodes:
4241
if node.op == 'call_function' and node.target == side_effect_func:
4243
self.assertTrue(found)
4245
def test_preserve_unused_attr_after_unpickle(self):
4246
gm = torch.fx.symbolic_trace(Add())
4247
gm.add_submodule("foo", Add())
4248
gm.dummy_buffer = torch.nn.Buffer(torch.empty(1))
4249
gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1)))
4253
# weights_only=False as this loads a GraphModule
4254
# GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
4255
reload_gm = torch.load(b, weights_only=False)
4256
self.assertTrue(hasattr(reload_gm, "foo"))
4257
self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
4258
self.assertTrue(hasattr(reload_gm, "dummy_parameter"))
4260
# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454
4262
sys.version_info >= (3, 12), "Failing on python 3.12+"
4264
class TestFunctionalTracing(JitTestCase):
4267
# Checking for mutable operations whil tracing is feature flagged
4268
# Enable it in testing but not by default
4269
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4270
torch.fx.proxy.TracerBase.check_mutable_operations = True
4274
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4276
IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
4277
"has_torch_function_variadic", "handle_torch_function",
4279
TO_PATCH = {"has_torch_function": None,
4280
"has_torch_function_unary": None,
4281
"has_torch_function_variadic": None}
4283
BUILT_IN_FUNC = (AssertionError, "")
4284
PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
4285
PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4286
LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
4287
ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
4288
CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
4289
INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
4290
MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
4292
UNTRACEABLE_FUNCTIONALS = {
4293
"adaptive_avg_pool1d": BUILT_IN_FUNC,
4294
"avg_pool1d": BUILT_IN_FUNC,
4295
"avg_pool2d": BUILT_IN_FUNC,
4296
"avg_pool3d": BUILT_IN_FUNC,
4297
"bilinear": BUILT_IN_FUNC,
4298
"celu_": BUILT_IN_FUNC,
4299
"channel_shuffle": BUILT_IN_FUNC,
4300
"native_channel_shuffle": BUILT_IN_FUNC,
4301
"conv1d": BUILT_IN_FUNC,
4302
"conv2d": BUILT_IN_FUNC,
4303
"conv3d": BUILT_IN_FUNC,
4304
"conv_tbc": BUILT_IN_FUNC,
4305
"conv_transpose1d": BUILT_IN_FUNC,
4306
"conv_transpose2d": BUILT_IN_FUNC,
4307
"conv_transpose3d": BUILT_IN_FUNC,
4308
"cosine_similarity": BUILT_IN_FUNC,
4309
"elu_": BUILT_IN_FUNC,
4310
"gelu": BUILT_IN_FUNC,
4311
"hardshrink": BUILT_IN_FUNC,
4312
"hardtanh_": BUILT_IN_FUNC,
4313
"leaky_relu_": BUILT_IN_FUNC,
4314
"linear": BUILT_IN_FUNC,
4315
"logsigmoid": BUILT_IN_FUNC,
4316
"one_hot": BUILT_IN_FUNC,
4317
"pad": ARG_TYPE_MISMATCH,
4318
"pairwise_distance": BUILT_IN_FUNC,
4319
"pdist": BUILT_IN_FUNC,
4320
"pixel_shuffle": BUILT_IN_FUNC,
4321
"pixel_unshuffle": BUILT_IN_FUNC,
4322
"prelu": BUILT_IN_FUNC,
4323
"relu_": BUILT_IN_FUNC,
4324
"rrelu_": BUILT_IN_FUNC,
4325
"selu_": BUILT_IN_FUNC,
4326
"scaled_dot_product_attention": BUILT_IN_FUNC,
4327
"softplus": BUILT_IN_FUNC,
4328
"softshrink": BUILT_IN_FUNC,
4329
"threshold_": BUILT_IN_FUNC,
4331
"adaptive_avg_pool2d": LEN_ERROR,
4332
"adaptive_avg_pool3d": LEN_ERROR,
4333
"adaptive_max_pool2d_with_indices": LEN_ERROR,
4334
"adaptive_max_pool3d_with_indices": LEN_ERROR,
4335
"instance_norm": CONTROL_FLOW,
4337
"adaptive_max_pool1d": PROXY_ITERABLE,
4338
"adaptive_max_pool2d": PROXY_ITERABLE,
4339
"adaptive_max_pool3d": PROXY_ITERABLE,
4340
"fractional_max_pool2d": PROXY_ITERABLE,
4341
"fractional_max_pool3d": PROXY_ITERABLE,
4342
"max_pool1d": PROXY_ITERABLE,
4343
"max_pool2d": PROXY_ITERABLE,
4344
"max_pool3d": PROXY_ITERABLE,
4346
"lp_pool2d": PROXY_ITERATED,
4347
"lp_pool3d": PROXY_ITERATED,
4348
"max_unpool1d": PROXY_ITERATED,
4349
"max_unpool2d": PROXY_ITERATED,
4350
"max_unpool3d": PROXY_ITERATED,
4351
"fold": PROXY_ITERATED,
4352
"unfold": PROXY_ITERATED,
4354
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4355
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4356
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4357
"layer_norm": ARG_TYPE_MISMATCH,
4358
"rms_norm": ARG_TYPE_MISMATCH,
4359
"lp_pool1d": ARG_TYPE_MISMATCH,
4361
"affine_grid": CONTROL_FLOW,
4362
"alpha_dropout": CONTROL_FLOW,
4363
"batch_norm": CONTROL_FLOW,
4364
"binary_cross_entropy": CONTROL_FLOW,
4365
"binary_cross_entropy_with_logits": CONTROL_FLOW,
4366
"celu": CONTROL_FLOW,
4367
"cosine_embedding_loss": CONTROL_FLOW,
4368
"cross_entropy": CONTROL_FLOW,
4369
"ctc_loss": CONTROL_FLOW,
4370
"dropout": CONTROL_FLOW,
4371
"dropout1d": CONTROL_FLOW,
4372
"dropout2d": CONTROL_FLOW,
4373
"dropout3d": CONTROL_FLOW,
4374
"elu": CONTROL_FLOW,
4375
"embedding": CONTROL_FLOW,
4376
"embedding_bag": CONTROL_FLOW,
4377
"feature_alpha_dropout": CONTROL_FLOW,
4378
"gaussian_nll_loss": CONTROL_FLOW,
4379
"glu": CONTROL_FLOW,
4380
"grid_sample": CONTROL_FLOW,
4381
"group_norm": CONTROL_FLOW,
4382
"gumbel_softmax": CONTROL_FLOW,
4383
"hardsigmoid": CONTROL_FLOW,
4384
"hardswish": CONTROL_FLOW,
4385
"hardtanh": CONTROL_FLOW,
4386
"hinge_embedding_loss": CONTROL_FLOW,
4387
"huber_loss": CONTROL_FLOW,
4388
"interpolate": CONTROL_FLOW,
4389
"kl_div": CONTROL_FLOW,
4390
"l1_loss": CONTROL_FLOW,
4391
"leaky_relu": CONTROL_FLOW,
4392
"local_response_norm": CONTROL_FLOW,
4393
"margin_ranking_loss": CONTROL_FLOW,
4394
"max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4395
"max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4396
"max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4397
"mse_loss": CONTROL_FLOW,
4398
"multi_head_attention_forward": CONTROL_FLOW,
4399
"multi_margin_loss": CONTROL_FLOW,
4400
"multilabel_margin_loss": CONTROL_FLOW,
4401
"multilabel_soft_margin_loss": CONTROL_FLOW,
4402
"nll_loss": CONTROL_FLOW,
4403
"poisson_nll_loss": CONTROL_FLOW,
4404
"relu": CONTROL_FLOW,
4405
"relu6": CONTROL_FLOW,
4406
"rrelu": CONTROL_FLOW,
4407
"selu": CONTROL_FLOW,
4408
"silu": CONTROL_FLOW,
4409
"mish": CONTROL_FLOW,
4410
"smooth_l1_loss": CONTROL_FLOW,
4411
"soft_margin_loss": CONTROL_FLOW,
4412
"threshold": CONTROL_FLOW,
4413
"triplet_margin_loss": CONTROL_FLOW,
4414
"triplet_margin_with_distance_loss": CONTROL_FLOW,
4415
"upsample": CONTROL_FLOW,
4417
"upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
4418
"upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
4421
# List of nn.functionals with Tensor inputs but not with type annotation
4422
FUNCTIONALS_WITHOUT_ANNOTATION = (
4423
"adaptive_max_pool1d",
4424
"adaptive_max_pool2d",
4425
"adaptive_max_pool3d",
4426
"fractional_max_pool2d",
4427
"fractional_max_pool3d",
4431
"gaussian_nll_loss",
4433
"upsample_bilinear",
4437
# Inconsistent behavior between Python 3.8 and other Python versions:
4438
# - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
4439
# - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
4440
# internal exception above
4441
# Use the following map to override the expected exception for Python 3.8
4442
UNTRACEABLE_FUNCTIONALS_PY38 = {
4443
"adaptive_max_pool1d": PROXY_ITERATED,
4444
"adaptive_max_pool2d": PROXY_ITERATED,
4445
"adaptive_max_pool3d": PROXY_ITERATED,
4446
"fractional_max_pool2d": PROXY_ITERATED,
4447
"fractional_max_pool3d": PROXY_ITERATED,
4448
"max_pool1d": PROXY_ITERATED,
4449
"max_pool2d": PROXY_ITERATED,
4450
"max_pool3d": PROXY_ITERATED,
4452
"group_norm": CONTROL_FLOW
4456
def _get_functional(cls):
4457
functional_list = []
4458
for f in dir(torch.nn.functional):
4461
# Ignore internal functions
4462
if f.startswith('_'):
4464
# Ignore supporting functions
4465
if f in cls.IGNORE_FUNCS:
4467
fn = getattr(torch.nn.functional, f)
4468
# Ignore non-callable object like modules
4469
if not isinstance(fn, Callable):
4471
if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
4473
sig = inspect.signature(fn)
4474
has_tensor_arg = False
4475
for param in sig.parameters.values():
4476
if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
4477
has_tensor_arg = True
4478
if not has_tensor_arg:
4480
# No signature or Object is not supported
4483
functional_list.append((f, fn))
4484
return functional_list
4487
def generate_test_func(cls, func_name, fn):
4489
def functional_test(self):
4490
if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
4491
sys.version_info >= (3, 8) and sys.version_info < (3, 12):
4492
exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
4493
with self.assertRaisesRegex(exc, err):
4495
elif func_name in self.UNTRACEABLE_FUNCTIONALS:
4496
exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
4497
with self.assertRaisesRegex(exc, err):
4501
return functional_test
4504
def generate_tests(cls):
4505
functional_list = cls._get_functional()
4506
for func_name, fn in functional_list:
4507
test_name = "test_nn_functional_" + func_name
4508
functional_test = cls.generate_test_func(func_name, fn)
4509
setattr(cls, test_name, functional_test)
4512
def setUpClass(cls):
4514
def no(*args, **kwargs):
4517
for name in cls.TO_PATCH.keys():
4518
cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
4519
setattr(torch.nn.functional, name, no)
4522
def tearDownClass(cls):
4523
for name in cls.TO_PATCH.keys():
4524
setattr(torch.nn.functional, name, cls.TO_PATCH[name])
4526
TestFunctionalTracing.generate_tests()
4529
instantiate_device_type_tests(TestOperatorSignatures, globals())
4531
@skipIfTorchDynamo("too slow")
4533
class TestVisionTracing(JitTestCase):
4535
# Checking for mutable operations while tracing is feature flagged
4536
# Enable it in testing but not by default
4537
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4538
torch.fx.proxy.TracerBase.check_mutable_operations = True
4541
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4543
PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4544
INCONSISTENT_TYPE = (
4546
r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
4549
UNTRACEABLE_MODELS = {
4550
"fasterrcnn_resnet50_fpn": PROXY_ITERATED,
4551
"fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4552
"fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
4553
"fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
4554
"maskrcnn_resnet50_fpn": PROXY_ITERATED,
4555
"maskrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4556
"keypointrcnn_resnet50_fpn": PROXY_ITERATED,
4557
"retinanet_resnet50_fpn": PROXY_ITERATED,
4558
"retinanet_resnet50_fpn_v2": PROXY_ITERATED,
4559
"ssd300_vgg16": PROXY_ITERATED,
4560
"fcos_resnet50_fpn": PROXY_ITERATED,
4561
"ssdlite320_mobilenet_v3_large": PROXY_ITERATED,
4563
UNSCRIPTABLE_MODELS = {
4564
"googlenet": INCONSISTENT_TYPE,
4565
"inception_v3": INCONSISTENT_TYPE,
4568
output_transform = {
4569
"fcn_resnet50": lambda x: x["out"],
4570
"fcn_resnet101": lambda x: x["out"],
4571
"deeplabv3_resnet50": lambda x: x["out"],
4572
"deeplabv3_resnet101": lambda x: x["out"],
4573
"deeplabv3_mobilenet_v3_large": lambda x: x["out"],
4574
"lraspp_mobilenet_v3_large": lambda x: x["out"],
4575
"fasterrcnn_resnet50_fpn": lambda x: x[1],
4576
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4577
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
4578
"maskrcnn_resnet50_fpn": lambda x: x[1],
4579
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4580
"retinanet_resnet50_fpn": lambda x: x[1],
4584
def generate_test_fn(cls, name, x, kwargs):
4586
model = torchvision_models.get_model(name, **kwargs)
4587
model = model.eval()
4588
if name in self.UNTRACEABLE_MODELS:
4589
err, exc = self.UNTRACEABLE_MODELS[name]
4590
with self.assertRaisesRegex(err, exc):
4591
graph = symbolic_trace(model)
4593
out_transform = self.output_transform.get(name, lambda x: x)
4594
graph : torch.fx.GraphModule = symbolic_trace(model)
4595
a = out_transform(model(x))
4596
b = out_transform(graph(x))
4597
self.assertEqual(a, b)
4599
if name in self.UNSCRIPTABLE_MODELS:
4600
err, exc = self.UNSCRIPTABLE_MODELS[name]
4601
with self.assertRaisesRegex(err, exc):
4602
script = torch.jit.script(graph)
4604
script = torch.jit.script(graph)
4605
c = out_transform(script(x))
4606
self.assertEqual(a, c)
4611
def generate_classification_tests(cls):
4612
for k in torchvision_models.list_models(module=torchvision_models):
4613
test_name = 'test_torchvision_models_' + k
4614
x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
4615
kwargs = dict(num_classes=50)
4616
model_test = cls.generate_test_fn(k, x, kwargs)
4617
setattr(cls, test_name, model_test)
4620
def generate_segmentation_tests(cls):
4621
for k in torchvision_models.list_models(module=torchvision_models.segmentation):
4622
test_name = 'test_torchvision_models_segmentation_' + k
4623
x = torch.rand(1, 3, 32, 32)
4624
kwargs = dict(num_classes=10, pretrained_backbone=False)
4625
model_test = cls.generate_test_fn(k, x, kwargs)
4626
setattr(cls, test_name, model_test)
4629
def generate_detection_tests(cls):
4630
for k in torchvision_models.list_models(module=torchvision_models.detection):
4631
test_name = 'test_torchvision_models_detection_' + k
4632
x = [torch.rand(3, 300, 300)]
4633
kwargs = dict(num_classes=10, pretrained_backbone=False)
4634
model_test = cls.generate_test_fn(k, x, kwargs)
4635
setattr(cls, test_name, model_test)
4638
def generate_video_tests(cls):
4639
for k in torchvision_models.list_models(module=torchvision_models.video):
4640
test_name = 'test_torchvision_models_video_' + k
4642
torch.rand(1, 3, 4, 112, 112)
4643
if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"}
4644
else torch.rand(1, 3, 16, 224, 224)
4646
kwargs = dict(num_classes=50)
4647
model_test = cls.generate_test_fn(k, x, kwargs)
4648
setattr(cls, test_name, model_test)
4651
def generate_tests(cls):
4652
cls.generate_classification_tests()
4653
cls.generate_detection_tests()
4654
cls.generate_segmentation_tests()
4655
cls.generate_video_tests()
4658
TestVisionTracing.generate_tests()
4660
if __name__ == '__main__':