1
# Owner(s): ["module: fx"]
22
from functorch.experimental import control_flow
23
from torch.multiprocessing import Process
24
from torch.testing import FileCheck
25
from torch.testing._internal.common_methods_invocations import op_db
26
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
27
import torch.utils._pytree as pytree
28
import torch.fx._pytree as fx_pytree
29
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
30
from torch.fx.node import Target, Argument, _format_arg
31
from torch.fx.passes import shape_prop
32
from torch.fx.immutable_collections import immutable_dict, immutable_list
33
from torch.fx.experimental.rewriter import RewritingTracer
34
from torch.fx.operator_schemas import get_signature_for_torch_op
35
from copy import deepcopy
36
from collections import namedtuple
38
from torch.fx.proxy import TraceError
39
from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY
40
from torch.fx._symbolic_trace import PHBase, PHWithMeta
41
from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401
42
from fx.test_dce_pass import TestDCE # noqa: F401
43
from fx.test_fx_const_fold import TestConstFold # noqa: F401
44
from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow # noqa: F401
45
from fx.test_pass_infra import TestPassManager # noqa: F401
46
from fx.test_common_passes import TestCommonPass # noqa: F401
47
from fx.test_cse_pass import TestCSEPass # noqa: F401
48
from fx.test_matcher_utils import TestMatcher # noqa: F401
49
from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401
51
from fx.test_gradual_type import AnnotationsTest # noqa: F401
52
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
53
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
54
from torch.testing._internal.common_utils import (
58
find_library_location,
62
from torch.testing._internal.jit_utils import JitTestCase
64
from fx.named_tup import MyNamedTup
67
from torchvision import models as torchvision_models
68
HAS_TORCHVISION = True
70
HAS_TORCHVISION = False
71
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
74
class SimpleTest(torch.nn.Module):
76
return torch.relu(x + 3.0)
78
def a_non_torch_leaf(a, b):
81
# Used for test_autowrap_function. Autowrapped functions need to be global
82
def fx_int(x: float) -> int:
85
def fx_int_x2(x: float) -> int:
88
# used in test_pytree. It's all the way out here because pickling a GraphModule
89
# that uses Point errors out if Point is local to the function
90
Point = namedtuple('Point', ['x', 'y'])
92
# Test wrap() passing both a function name as well as a function
94
def a_lifted_leaf(a, b):
95
return a[0] + a[1] + b
98
# Test wrapping twice doesn't break anything
101
def a_lifted_leaf2(a, b):
102
return a[0] + a[1] + b
110
def wrapped_named_tup(p1, *, p2):
113
wrap(wrapped_named_tup)
116
def wrapped_via_decorator(a):
119
wrap('wrapped_with_submodule')
121
def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
122
return batchnorm1d(x)
126
def wrapper_inside_decorator(*args, **kwargs):
127
return f(*args, **kwargs)
128
return wrapper_inside_decorator
132
def wrapped_decorated_fn(x):
135
real_wrapped_via_decorator = wrapped_via_decorator
136
real_a_lifed_leaf = a_lifted_leaf
137
real_a_lifed_leaf2 = a_lifted_leaf2
145
class Pair(NamedTuple):
149
def _custom_fx_repr_fn(self) -> str:
150
return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
153
class Foo: # noqa: B209
154
def __init__(self, a, b):
158
class Add(torch.nn.Module):
159
def forward(self, x):
162
@torch.fx.has_side_effect
164
def side_effect_func(x: torch.Tensor):
167
class TestFX(JitTestCase):
170
# Checking for mutable operations whil tracing is feature flagged
171
# Enable it in testing but not by default
172
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
173
torch.fx.proxy.TracerBase.check_mutable_operations = True
175
if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
176
lib_file_path = find_library_location('libtorchbind_test.so')
177
torch.ops.load_library(str(lib_file_path))
181
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
183
def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
184
"""Check that an nn.Module's results match the GraphModule version
185
for a given set of args/kwargs.
187
kwargs = kwargs if kwargs else {}
188
ref_outs = m(*args, **kwargs)
189
gm = symbolic_trace(m)
191
test_outs = gm(*args, **kwargs)
192
self.assertEqual(ref_outs, test_outs)
194
def test_graph_module(self):
195
class MySub(torch.nn.Module):
198
self.w = torch.nn.Parameter(torch.rand(4, 3))
200
def forward(self, x):
203
class MyModule(torch.nn.Module):
206
self.lin = torch.nn.Linear(4, 3)
207
self.sub_mod = MySub()
208
self.w = torch.nn.Parameter(torch.rand(3))
210
def forward(self, A, B, c):
211
t = torch.sigmoid(A) + self.lin(c)
212
return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
215
gm = symbolic_trace(m)
217
ms = torch.jit.script(gm)
219
class M2(torch.nn.Module):
220
def forward(self, A):
221
m, idx = torch.max(A, 0)
222
return m + 1, idx + 1
225
gm2 = symbolic_trace(m2)
227
class T(torch.nn.Module):
229
def forward(self, A, b=4, *args, c=5, **kwargs):
230
x = A + 1 + args[0] + kwargs['3']
236
# test for issue described at https://github.com/pytorch/pytorch/issues/63883
237
class M3(torch.nn.Module):
238
def forward(self, x):
242
gm3 = symbolic_trace(m3)
243
new_instance = gm3.__new__(type(gm3))
244
new_instance.__init__(gm3, gm3.graph)
246
x = torch.randn(5, 3)
247
torch.testing.assert_close(new_instance(x), torch.relu(x))
249
def test_informative_co_filename(self):
250
class MyModule(torch.nn.Module):
251
def forward(self, a):
254
gm = symbolic_trace(MyModule())
255
self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename)
257
def test_custom_import(self):
258
graph = torch.fx.Graph()
259
a = graph.placeholder('x')
260
b = graph.placeholder('y')
261
c = graph.call_function(a_non_torch_leaf, (a, b))
262
d = graph.call_function(torch.sin, (c,))
264
gm = GraphModule(torch.nn.Module(), graph)
265
x, y = torch.rand(1), torch.rand(1)
266
self.assertEqual(torch.sin(x + y), gm(x, y))
268
def test_args_kwargs(self):
269
class T(torch.nn.Module):
270
def forward(self, *args, **kwargs):
271
x = args[0] + kwargs['foo']
275
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
277
def test_varargs_concrete(self):
278
class T(torch.nn.Module):
279
def forward(self, *args, **kwargs):
280
x = args[0] + args[1]
283
args = (torch.rand(1), torch.rand(1))
287
gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
289
test_outs = gm(*args)
290
self.assertEqual(ref_outs, test_outs)
292
def test_args_kwargs_no_self(self):
293
class T(torch.nn.Module):
294
def forward(*args, **kwargs): # noqa: B902
296
return torch.relu(args[1])
299
with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
300
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
302
def test_fx_shifts(self):
303
class MyModule(torch.nn.Module):
304
def forward(self, x):
305
return x << 3, x >> 3
307
input = torch.LongTensor(10).random_(0, 1024)
310
self.checkGraphModule(m, (input,))
312
def test_fx_and_or(self):
313
class MyModule(torch.nn.Module):
314
def forward(self, x):
317
input = torch.LongTensor(10).random_(0, 1024)
320
self.checkGraphModule(m, (input,))
323
class MyDictMod(torch.nn.Module):
324
def forward(self, d):
325
return d['3'].relu(), {'4' : d['3'].neg()}
327
input_dict = {'3': torch.rand(3, 4)}
330
self.checkGraphModule(m, (input_dict,))
332
def test_matmul_tracing(self):
333
const = torch.randn(3)
338
mod = symbolic_trace(matmul_f)
340
self.assertEqual(mod(inp), matmul_f(inp))
345
mod = symbolic_trace(rmatmul_f)
347
self.assertEqual(mod(inp), rmatmul_f(inp))
349
@skipIfNoDynamoSupport
350
def test_control_flow_tracing(self):
358
x = control_flow.cond(x[0] == 0, true, false, [x, y])
360
with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"):
361
_ = symbolic_trace(f)
363
def test_disallow_override(self):
364
# Custom delegate to disallow in-place tensor operations
365
class NoMutableCallTracer(Tracer):
366
def create_node(self, kind : str, target : Union[str, Callable],
367
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
368
type_expr : Optional[Any] = None) -> Node:
369
name = target if isinstance(target, str) else torch.typename(target)
371
raise RuntimeError('In-place operations are not supported')
372
return super().create_node(kind, target, args, kwargs, name)
375
class MyInplaceMod(torch.nn.Module):
376
def forward(self, x):
382
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
383
NoMutableCallTracer().trace(m)
386
class MyInplaceMod2(torch.nn.Module):
387
def forward(self, x):
391
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
392
NoMutableCallTracer().trace(m2)
394
# Test symbolic node as an arg
395
class MyInplaceMod3(torch.nn.Module):
396
def forward(self, x):
401
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
402
NoMutableCallTracer().trace(m3)
404
def test_leaf_module(self):
405
# Custom delegate to make it so that there are no leaf modules, everything
406
# should get traced through
407
class NoLeafModulesTracer(Tracer):
408
def is_leaf_module(self, m, qualname):
411
class MyReluMod(torch.nn.Module):
414
self.relu = torch.nn.ReLU()
416
def forward(self, x):
420
sym = NoLeafModulesTracer().trace(mrm)
421
for node in sym.nodes:
422
self.assertNotEqual(node.op, 'call_module')
426
self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
429
return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
431
m = symbolic_trace(to_trace)
432
self.assertIn('a_lifted_leaf', m.code)
433
self.assertEqual(27, m(2))
434
self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
436
def test_wrap_fn_directly(self):
437
self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
440
return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
442
m = symbolic_trace(to_trace)
443
self.assertIn('a_lifted_leaf2', m.code)
444
self.assertEqual(27, m(2))
445
self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
447
def test_wrapped_via_decorator(self):
448
self.assertEqual(wrapped_via_decorator(0), 1)
451
return wrapped_via_decorator(y)
453
m = symbolic_trace(to_trace)
454
self.assertIn('wrapped_via_decorator', m.code)
455
self.assertEqual(m(0), 1)
456
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
457
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
459
def test_wrapped_via_decorator_and_transformed(self):
460
self.assertEqual(wrapped_via_decorator(0), 1)
463
return wrapped_via_decorator(y)
465
m = symbolic_trace(to_trace)
466
self.assertIn('wrapped_via_decorator', m.code)
467
self.assertEqual(m(0), 1)
468
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
469
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
471
transformed = torch.fx.Transformer(m).transform()
472
self.assertIn('wrapped_via_decorator', transformed.code)
473
self.assertEqual(transformed(0), 1)
474
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
475
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
477
def test_wrap_with_submodule(self):
479
class M(torch.nn.Module):
482
self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
484
def forward(self, x: torch.Tensor):
485
return wrapped_with_submodule(x, self.batchnorm1d)
487
m = symbolic_trace(M())
489
self.assertIn("wrapped_with_submodule", m.code)
491
input = torch.rand(3, 2)
492
ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
493
self.assertEqual(ref_batchnorm1d(input), m(input))
495
def test_wrapped_retrace(self):
497
return wrapped_via_decorator(y)
499
m = symbolic_trace(to_trace)
500
self.assertIn('wrapped_via_decorator', m.code)
501
self.assertEqual(m(0), 1)
503
retraced = symbolic_trace(m)
504
self.assertIn('wrapped_via_decorator', retraced.code)
505
self.assertEqual(retraced(0), 1)
507
def test_wrap_decorated_function(self):
509
return wrapped_decorated_fn(y)
511
m = symbolic_trace(to_trace)
512
self.assertIn('wrapped_decorated_fn', m.code)
513
self.assertEqual(m(1), 1)
515
def test_graph_edit_with_proxy(self):
516
class M(torch.nn.Module):
517
def forward(self, a, b):
520
g = symbolic_trace(m).graph
521
new_g = torch.fx.Graph()
522
val_map : Dict[Node, Node] = {}
523
output_val = new_g.graph_copy(g, val_map)
524
t = Proxy(output_val)
525
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
526
new_g.output((t + t).node)
527
gm = GraphModule(m, new_g)
529
self.assertEqual(gm(3, 4), 14)
531
def test_concrete_arg_none_assert(self):
532
class Foo(torch.nn.Module):
533
def forward(self, x, val=None):
534
return x if val is None else x + val
537
traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None})
538
with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'):
539
traced(torch.randn(5), torch.randn(5))
542
torch.testing.assert_close(traced(x), f(x))
544
def test_trace_multiple_funcs(self):
545
class Foo(torch.nn.Module):
546
def forward(self, x, y):
549
def minus_forward(self, x, y):
552
def multiply_forward(self, x, y):
556
x, y = torch.randn(5), torch.randn(5)
558
print(torch.__version__)
561
torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y))
563
tracer.traced_func_name = "minus_forward"
564
torch.testing.assert_close(
565
GraphModule(f, tracer.trace(f))(x, y),
566
f.minus_forward(x, y),
569
tracer.traced_func_name = "multiply_forward"
570
torch.testing.assert_close(
571
GraphModule(f, tracer.trace(f))(x, y),
572
f.multiply_forward(x, y),
575
tracer.traced_func_name = "add_forward"
576
with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
580
def test_graph_unique_names(self):
581
class M(torch.nn.Module):
582
def forward(self, a, b):
585
g = symbolic_trace(m).graph
586
new_g = torch.fx.Graph()
587
val_map : Dict[Node, Node] = {}
588
output_val = new_g.graph_copy(g, val_map)
589
t = Proxy(output_val)
590
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
591
new_g.output((t + t).node)
592
gm = GraphModule(m, new_g)
593
seen_names : Set[str] = set()
594
for node in gm.graph.nodes:
595
assert node.name not in seen_names
596
seen_names.add(node.name)
598
def test_stack_traces(self):
599
class M(torch.nn.Module):
600
def forward(self, a, b):
603
tracer = torch.fx.Tracer()
604
tracer.record_stack_traces = True
606
graph = tracer.trace(M())
607
# saving the original list because we will insert new nodes as a part of a test
608
orig_graph_nodes = list(graph.nodes)
609
for node in orig_graph_nodes:
610
if node.op == 'output':
612
self.assertTrue(node.stack_trace is not None)
613
assert 'test_fx.py' in node.stack_trace
615
# verify that copying the node does not lose the stack trace
616
new_node = graph.node_copy(node)
617
self.assertTrue(new_node.stack_trace is not None)
618
assert 'test_fx.py' in new_node.stack_trace
620
def test_stack_traces_with_transformer(self):
621
class M(torch.nn.Module):
622
def forward(self, a, b):
625
tracer = torch.fx.Tracer()
626
tracer.record_stack_traces = True
628
graph = tracer.trace(M())
629
gm = GraphModule(tracer.root, graph)
630
new_gm = Transformer(gm).transform()
632
# nodes after Transformer should still preserve the original node's stack trace
633
for node in new_gm.graph.nodes:
634
if node.op in {'placeholder', 'output'}:
636
self.assertTrue(node.stack_trace is not None)
637
assert 'test_fx.py' in node.stack_trace
639
def test_lineno_map(self):
640
class M(torch.nn.Module):
641
def forward(self, a, b):
646
tracer = torch.fx.Tracer()
647
graph = tracer.trace(M())
648
gm = GraphModule(tracer.root, graph)
649
expected = {1: 2, 2: 3, 3: 4, 4: 5}
650
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
652
# test custom codegen
653
def transform_code(code):
654
return ["print('hello!')\n", *code]
655
gm.graph.on_generate_code(lambda _: transform_code)
657
expected = {2: 2, 3: 3, 4: 4, 5: 5}
658
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
660
def test_graph_unique_names_manual(self):
661
graph : torch.fx.Graph = torch.fx.Graph()
662
a : torch.fx.Node = graph.create_node('placeholder', 'x')
663
b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
664
c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
665
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
667
graph2 = torch.fx.Graph()
668
val_map : Dict[Node, Node] = {}
669
graph2.graph_copy(graph, val_map)
670
seen_names : Set[str] = set()
671
for node in graph2.nodes:
672
assert node.name not in seen_names
673
seen_names.add(node.name)
675
def test_unpack(self):
676
class M(torch.nn.Module):
677
def forward(self, a, b):
681
a = (torch.rand(1), torch.rand(1))
684
self.checkGraphModule(m, (a, b))
686
def test_native_callable(self):
687
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
688
raise unittest.SkipTest("non-portable load_library call used in test")
689
# This test exercises the case where we use FX to translate from Python
690
# code to some native callable object
692
# For the purposes of testing, we use ElementwiseInterpreter defined
693
# in test_custom_class.cpp.
695
# We test that we can
696
# 1) Construct a native callable from FX IR
697
# 2) Construct a drop-in replacement module that delegates to the
698
# native callable rather than the original code
699
# 3) Run both the original code and native callable wrapper with
701
# 4) TorchScript compile the native callable wrapper and confirm
702
# equivalent results with the reference
703
# 5) TorchScript serialize and deserialize the native callable
704
# and confirm equivalent results with the reference
706
# We use this simple Module as a reference computation
707
class MySimpleMod(torch.nn.Module):
708
def forward(self, x):
713
# This is what a lowering pass might look like: a function that takes
714
# a valid nn.Module, symbolically traces it, lowers the Module to some
715
# representation, and wraps that representation up into another
716
# nn.Module instance that handles dispatch to the compiled/lowered code.
717
def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
718
# ===== Stage 1: Symbolic trace the module =====
719
mod = symbolic_trace(orig_mod)
721
# ===== Stage 2: Lower GraphModule representation to the C++
722
# interpreter's instruction format ======
729
operator.add : "add",
733
output_node : Optional[Node] = None
734
# For each instruction, create a triple
735
# (instruction_name : str, inputs : List[str], output : str)
736
# to feed into the C++ interpreter
737
for n in mod.graph.nodes:
738
target, args, out_name = n.target, n.args, n.name
739
assert len(n.kwargs) == 0, "kwargs currently not supported"
741
if n.op == 'placeholder':
742
# Placeholders specify function argument names. Save these
743
# for later when we generate the wrapper GraphModule
744
fn_input_names.append(target)
745
elif n.op == 'call_function':
746
assert target in target_to_name, "Unsupported call target " + target
749
if not isinstance(arg, Node):
750
# Pull out constants. These constants will later be
751
# fed to the interpreter C++ object via add_constant()
752
arg_name = f'constant_{constant_idx}'
753
constants[arg_name] = torch.tensor(
754
[arg] if isinstance(arg, numbers.Number) else arg)
755
arg_names.append(arg_name)
758
arg_names.append(arg.name)
759
instructions.append((target_to_name[target], arg_names, out_name))
760
elif n.op == 'output':
761
if output_node is not None:
762
raise RuntimeError('Multiple output nodes!')
765
raise RuntimeError('Unsupported opcode ' + n.op)
767
interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
769
for k, v in constants.items():
770
interpreter.add_constant(k, v)
771
# Specify names for positional input arguments
772
interpreter.set_input_names(fn_input_names)
774
interpreter.set_instructions(instructions)
775
# Specify name for single output
776
assert isinstance(output_node.args[0], torch.fx.Node)
777
interpreter.set_output_name(output_node.args[0].name)
779
# ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
780
class WrapperModule(torch.nn.Module):
781
def __init__(self, interpreter):
783
self.interpreter = interpreter
785
wrapper = WrapperModule(interpreter)
787
# Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
788
# 3) Returns the speficied return value
790
# FIXME: The following code could be greatly simplified by symbolic_trace'ing
791
# the wrapper with a Tracer that considers the Wrapper instance a root
792
# module, however, I can't get `__call__` exposed on TorchBind classes
793
# without it messing up Python `hasattr` for some reason. More digging
794
# into CPython's implementation of hasattr is probably in order...
796
graph = torch.fx.Graph()
797
# Add placeholders for fn inputs
798
placeholder_nodes = []
799
for name in fn_input_names:
800
placeholder_nodes.append(graph.create_node('placeholder', name))
802
# Get the interpreter object
803
interpreter_node = graph.create_node('get_attr', 'interpreter')
805
# Add a node to call the interpreter instance
806
output_node = graph.create_node(
807
op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
810
graph.output(output_node)
814
# Return final GraphModule!!!
815
return GraphModule(wrapper, graph)
818
# Lower GraphModule to C++ interpreter
819
lowered = lower_to_elementwise_interpreter(msm)
821
# Compare correctness with original module
824
test_out = lowered(x)
825
torch.testing.assert_close(test_out, ref_out)
827
# Test TorchScript compilation
828
scripted_lowered = torch.jit.script(lowered)
829
script_out = scripted_lowered(x)
830
torch.testing.assert_close(script_out, ref_out)
832
# Test TorchScript ser/de
833
import_copy = self.getExportImportCopy(scripted_lowered)
834
imported_out = import_copy(x)
835
torch.testing.assert_close(imported_out, ref_out)
837
def test_reserved_getattr(self):
838
"""Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
839
class M(torch.nn.Module):
840
def forward(self, a):
844
m_g = symbolic_trace(m)
846
for node in m_g.graph.nodes:
847
self.assertTrue(node.name != "getattr")
849
@unittest.skip("Hotfix for SEV remediation")
850
def test_trace_buffer_slice(self):
853
class ExampleCode(torch.nn.Module):
856
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
857
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
858
self.lin = torch.nn.Linear(d_hid, d_hid)
859
self.register_buffer('buffer', torch.randn(bs + 100, d_hid))
861
def forward(self, x):
862
x = torch.mm(x, self.mm_param)
865
x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]]
868
x = x + skip_connection
869
x = torch.mm(x, self.mm_param2)
876
traced = torch.fx.symbolic_trace(ec)
878
x = torch.randn(bs, d_hid)
879
torch.testing.assert_close(ec(x), traced(x))
882
def test_node_tagging(self):
883
class TaggingTracer(Tracer):
884
def create_node(self, kind : str, target : Union[str, Callable],
885
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
886
type_expr : Optional[Any] = None) -> Node:
887
n = super().create_node(kind, target, args, kwargs, name)
891
class M(torch.nn.Module):
892
def forward(self, a, b):
896
g = TaggingTracer().trace(m)
899
self.assertTrue(hasattr(n, 'tag'))
900
self.assertEqual(n.tag, 'foo')
902
def test_tensor_attribute(self):
903
class TensorAttribute(torch.nn.Module):
906
self.tensor = torch.rand(3, 4)
908
def forward(self, x):
909
return torch.nn.functional.linear(x, self.tensor)
911
ta = TensorAttribute()
912
traced = symbolic_trace(ta)
913
traced(torch.rand(4, 4))
915
class WrapperForQualname(torch.nn.Module):
918
self.ta = TensorAttribute()
920
def forward(self, x):
921
return torch.nn.functional.linear(x, self.ta.tensor)
923
wfq = WrapperForQualname()
924
traced2 = symbolic_trace(wfq)
926
traced2(torch.rand(4, 4))
928
def test_tensor_attribute_coalseced(self):
930
def count_attrs(fx_module):
932
for node in traced.graph.nodes:
933
if node.op == 'get_attr':
934
targets.add(node.target)
937
val = torch.tensor(5)
941
traced = symbolic_trace(f)
943
self.assertEqual(count_attrs(traced), 1)
945
val2 = torch.tensor(5)
948
val = torch.tensor(5)
949
return x + val + val2
951
traced = symbolic_trace(f)
953
self.assertEqual(count_attrs(traced), 2)
956
def test_symbolic_trace_sequential(self):
957
class Simple(torch.nn.Module):
958
def forward(self, x):
961
seq = torch.nn.Sequential(
966
traced = symbolic_trace(seq)
969
self.assertEqual(traced(x), seq(x))
971
def test_tensor_constant(self):
972
class ConstTensor(torch.nn.Module):
973
def forward(self, x):
974
return torch.nn.functional.linear(x, torch.zeros(3, 4))
977
traced = symbolic_trace(ct)
979
traced(torch.rand(4, 4))
981
def test_pickle_graphmodule(self):
982
class Nested(torch.nn.Module):
985
self.st = torch.nn.Linear(4, 4)
987
def forward(self, x):
991
traced = symbolic_trace(n)
993
pickled = pickle.dumps(traced)
994
loaded = pickle.loads(pickled)
997
self.assertEqual(loaded(x), traced(x))
999
def test_pickle_custom_import(self):
1000
graph = torch.fx.Graph()
1001
a = graph.placeholder('x')
1002
b = graph.placeholder('y')
1003
c = graph.call_function(a_non_torch_leaf, (a, b))
1004
d = graph.call_function(torch.sin, (c,))
1006
gm = GraphModule(torch.nn.Module(), graph)
1007
pickled = pickle.dumps(gm)
1008
loaded = pickle.loads(pickled)
1010
x, y = torch.rand(1), torch.rand(1)
1011
self.assertEqual(loaded(x, y), gm(x, y))
1013
def test_all_input_nodes(self):
1014
graph : torch.fx.Graph = torch.fx.Graph()
1015
a : torch.fx.Node = graph.placeholder('x')
1016
b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
1017
c : torch.fx.Node = graph.get_attr('y_attr')
1018
d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
1019
e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
1023
self.assertEqual(b.all_input_nodes, [a])
1024
self.assertEqual(c.all_input_nodes, [])
1025
self.assertEqual(d.all_input_nodes, [b, c])
1026
self.assertEqual(e.all_input_nodes, [d])
1028
def test_deepcopy_graphmodule_with_transform(self):
1030
traced = symbolic_trace(st)
1033
def transform(traced):
1034
new_graph = torch.fx.Graph()
1035
val_map : Dict[Node, Node] = {}
1036
output_value = new_graph.graph_copy(traced.graph, val_map)
1037
relu_out = new_graph.create_node(
1038
op='call_method', target='neg', args=(output_value,), kwargs={})
1039
new_graph.output(relu_out)
1040
return GraphModule(traced, new_graph)
1041
transformed = transform(traced)
1042
transformed.graph.lint()
1043
copied = copy.deepcopy(transformed)
1044
self.assertNotEqual(id(type(transformed)), id(type(copied)))
1045
x = torch.randn(3, 4)
1046
self.assertEqual(copied(x), transformed(x))
1048
def test_deepcopy_with_submods_params(self):
1049
class Bar(torch.nn.Module):
1052
self.param = torch.nn.Parameter(torch.rand(3, 4))
1054
def forward(self, x):
1055
return torch.relu(x) + self.param
1057
class Baz(torch.nn.Module):
1060
self.param = torch.nn.Parameter(torch.rand(3, 4))
1063
def forward(self, x):
1064
return self.bar(x) - self.param
1067
traced = symbolic_trace(baz)
1069
copied = copy.deepcopy(traced)
1072
def test_deepcopy_graph_with_tracer_cls(self):
1073
class TestTracer(Tracer):
1074
def is_leaf_module(self, module, name):
1077
g = Graph(tracer_cls=TestTracer)
1078
x = g.placeholder("x")
1081
h = copy.deepcopy(g)
1082
self.assertIsNotNone(h._tracer_cls)
1083
self.assertTrue(g._tracer_cls == h._tracer_cls)
1085
def test_unpack_list_better_error(self):
1086
class SomeArgs(torch.nn.Module):
1087
def forward(self, a, b):
1088
return torch.rand(3, 4)
1090
class UnpacksList(torch.nn.Module):
1093
self.sa = SomeArgs()
1095
def forward(self, x : list):
1099
with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1102
def test_unpack_dict_better_error(self):
1103
class SomeKwargs(torch.nn.Module):
1104
def forward(self, x=3, y=4):
1105
return torch.rand(3, 4)
1107
class UnpacksDict(torch.nn.Module):
1110
self.sk = SomeKwargs()
1112
def forward(self, x : dict):
1116
with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
1119
def test_pretty_print_targets(self):
1120
# Test that Graph pretty-print prints friendly name for targets
1121
# in `operator` and `builtins`
1123
class SomeMod(torch.nn.Module):
1124
def forward(self, x):
1125
return torch.add(x.foo + x.bar, 3.0)
1127
traced = symbolic_trace(SomeMod())
1128
graph_str = str(traced.graph)
1129
self.assertIn('builtins.getattr', graph_str)
1130
self.assertIn('operator.add', graph_str)
1131
self.assertIn('torch.add', graph_str)
1133
def test_pretty_print_node(self):
1134
class M(torch.nn.Module):
1137
self.param: torch.nn.Parameter = torch.nn.Parameter(
1139
self.linear = torch.nn.Linear(4, 5)
1141
def forward(self, x: torch.Tensor, y: int = 2):
1142
return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
1144
traced = symbolic_trace(M())
1146
all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
1148
FileCheck().check("x").check("placeholder") \
1149
.check("y").check("placeholder") \
1150
.check("getitem").check("call_function") \
1151
.check("param").check("get_attr") \
1152
.check("add").check("call_function") \
1153
.check("linear").check("call_module") \
1154
.check("clamp").check("call_method") \
1157
def test_script_tensor_constant(self):
1158
# TorchScript seems to ignore attributes that start with `__`.
1159
# We used to call anonymous Tensor values `__tensor_constant*`, but
1160
# they were getting ignored by script. Now they're called
1161
# `_tensor_constant*`
1162
class IHaveATensorConstant(torch.nn.Module):
1163
def forward(self, x):
1164
return x + torch.rand(3, 4)
1166
traced = torch.fx.symbolic_trace(IHaveATensorConstant())
1167
torch.jit.script(traced)
1169
def test_autowrap_functions(self):
1170
class AutowrapFnTest(torch.nn.Module):
1171
def forward(self, x):
1172
return fx_int(x.shape[0] / 2)
1174
class AutowrapFnTest2(torch.nn.Module):
1175
def forward(self, x):
1176
return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
1178
# Check function(s) are wrapped
1179
# `int` would normally throw a TypeError as argument can't be `Proxy`
1180
tracer = Tracer(autowrap_functions=(fx_int,))
1181
graph = tracer.trace(AutowrapFnTest())
1182
traced = GraphModule(tracer.root, graph, 'test')
1183
tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
1184
tracer_2.trace(AutowrapFnTest2())
1186
# Test scriptability
1187
traced_scripted = torch.jit.script(traced)
1188
self.assertEqual(traced_scripted(torch.rand(4)), 2)
1190
def test_tuple_no_subscript(self):
1194
traced = torch.fx.symbolic_trace(foo)
1195
x = (torch.randn(5, 3),)
1196
torch.testing.assert_close(traced(x), x[0])
1200
torch.save(traced, bio)
1204
loaded = torch.load(bio)
1206
torch.testing.assert_close(loaded(x), x[0])
1208
def test_torch_fx_len(self):
1209
class FXLenTest(torch.nn.Module):
1210
def forward(self, x):
1213
traced = symbolic_trace(FXLenTest())
1214
self.assertEqual(traced(torch.rand(3, 4)), 3)
1216
# Test scriptability
1217
scripted = torch.jit.script(FXLenTest())
1218
self.assertEqual(scripted(torch.rand(3)), 3)
1220
traced_scripted = torch.jit.script(traced)
1221
self.assertEqual(traced_scripted(torch.rand(3)), 3)
1223
# Test non-proxy len
1224
class FXLenTest2(torch.nn.Module):
1229
def forward(self, x):
1230
return x + len(self.l)
1232
traced2 = symbolic_trace(FXLenTest2())
1233
inp = torch.rand(3, 4)
1234
self.assertEqual(traced2(inp), inp + 3.0)
1235
self.assertIs(len, builtins.len)
1237
def test_torch_fx_getattr(self):
1238
class FXGetattrTest(torch.nn.Module):
1239
def forward(self, x):
1240
return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
1242
traced = symbolic_trace(FXGetattrTest())
1243
self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
1245
def test_sqrt(self):
1246
class Sqrt1(torch.nn.Module):
1247
def forward(self, x):
1248
return sqrt(x.size(0))
1250
class Sqrt2(torch.nn.Module):
1251
def forward(self, x):
1252
return math.sqrt(x.size(0))
1254
class Sqrt3(torch.nn.Module):
1255
def forward(self, x):
1256
return x + math.sqrt(2) + sqrt(2)
1258
self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
1259
self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
1260
self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
1261
self.assertIs(sqrt, _sqrt)
1262
self.assertIs(math.sqrt, _sqrt)
1264
def test_torch_custom_ops(self):
1265
class M(torch.nn.Module):
1266
def forward(self, a):
1267
b = torch.ops.aten.sigmoid(a)
1268
c = torch.ops.aten.cat([a, b])
1269
return torch.ops.aten.cat((c, c))
1271
input = torch.randn(3)
1273
gm = symbolic_trace(m)
1276
self.assertEqual(out, ref_out)
1278
def test_torch_op_overloads(self):
1279
class M(torch.nn.Module):
1280
def forward(self, a):
1281
b = torch.ops.aten.add.Tensor(a, a)
1284
input = torch.randn(3)
1286
gm = symbolic_trace(m)
1289
self.assertEqual(out, ref_out)
1291
for node in gm.graph.nodes:
1292
if node.op == 'call_function':
1293
assert isinstance(node.target, torch._ops.OpOverload)
1294
assert node.target.__name__ == 'add.Tensor'
1296
def test_pickle_torch_custom_ops(self):
1297
class M(torch.nn.Module):
1298
def forward(self, a):
1299
b = torch.ops.aten.sigmoid(a)
1300
c = torch.ops.aten.cat([a, b])
1301
return torch.ops.aten.cat((c, c))
1303
input = torch.randn(3)
1305
gm = symbolic_trace(m)
1307
pickled = pickle.dumps(gm)
1308
loaded = pickle.loads(pickled)
1309
self.assertEqual(loaded(input), gm(input))
1311
def test_pretty_print(self):
1313
traced = symbolic_trace(st)
1315
printed = str(traced)
1316
assert 'SimpleTest()' in printed
1317
assert 'torch.relu' in printed
1319
def test_pretty_print_graph(self):
1320
class KwargPrintTest(torch.nn.Module):
1321
def forward(self, x):
1322
return torch.squeeze(x + 3.0, dim=2)
1323
st = KwargPrintTest()
1324
traced = symbolic_trace(st)
1326
stringed = str(traced.graph)
1327
for s in ['args', 'kwargs', 'num_users']:
1328
assert s in stringed
1330
def test_custom_proxy_type(self):
1332
def __init__(self, left, right):
1333
self.left, self.right = left, right
1335
def add(self, other):
1336
l = self.left + other.left
1337
r = self.right + other.right
1338
return TensorPair(l, r)
1340
def mul(self, other):
1341
l = self.left * other.left
1342
r = self.right * other.right
1343
return TensorPair(l, r)
1345
def use_tensor_pair(x : TensorPair, y : TensorPair):
1349
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1350
y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1352
ref_out = use_tensor_pair(x, y)
1354
traced = symbolic_trace(use_tensor_pair)
1356
traced_out = traced(x, y)
1357
self.assertEqual(traced_out.left, ref_out.left)
1358
self.assertEqual(traced_out.right, ref_out.right)
1360
def test_custom_proxy_type_literal(self):
1361
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1362
def __init__(self, left, right):
1363
self.left, self.right = left, right
1365
def add(self, other):
1366
l = self.left + other.left
1367
r = self.right + other.right
1368
return TensorPair(l, r)
1370
def mul(self, other):
1371
l = self.left * other.left
1372
r = self.right * other.right
1373
return TensorPair(l, r)
1375
def use_tensor_pair_literal(x : TensorPair):
1376
s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
1379
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1381
ref_out = use_tensor_pair_literal(x)
1383
traced = symbolic_trace(use_tensor_pair_literal)
1385
traced_out = traced(x)
1386
self.assertEqual(traced_out.left, ref_out.left)
1387
self.assertEqual(traced_out.right, ref_out.right)
1389
def test_custom_proxy_dynamic_value(self):
1390
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1391
def __init__(self, left, right):
1392
self.left, self.right = left, right
1394
def add(self, other):
1395
l = self.left + other.left
1396
r = self.right + other.right
1397
return TensorPair(l, r)
1399
def mul(self, other):
1400
l = self.left * other.left
1401
r = self.right * other.right
1402
return TensorPair(l, r)
1404
def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
1405
s = x.add(TensorPair(y, y))
1408
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1409
y = torch.randn(5, 3)
1410
ref_out = use_tensor_pair_ctor(x, y)
1412
traced = symbolic_trace(use_tensor_pair_ctor)
1414
traced_out = traced(x, y)
1415
self.assertEqual(traced_out.left, ref_out.left)
1416
self.assertEqual(traced_out.right, ref_out.right)
1418
def test_custom_proxy_input_dependent_control_flow(self):
1419
class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
1420
def __init__(self, inp):
1423
self.tensor = torch.tensor([])
1425
self.is_zero = False
1428
def add(self, other):
1430
return ZeroTensor(other.tensor)
1434
def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
1435
return ZeroTensor(x + y)
1437
x, y = torch.randn(5, 3), torch.randn(5, 3)
1439
ref_out = use_zero_tensor(x, y)
1441
traced = symbolic_trace(use_zero_tensor)
1443
traced_out = traced(x, y)
1445
self.assertEqual(traced_out.is_zero, ref_out.is_zero)
1446
self.assertEqual(traced_out.tensor, ref_out.tensor)
1448
def test_graph_fns(self):
1450
a = g.placeholder('a')
1451
b = g.call_module('linear', (a,))
1452
c = g.get_attr('bias')
1453
d = g.call_method('add', (b, c))
1454
e = g.call_function(torch.sin, (d,))
1456
mod = torch.nn.Module()
1457
mod.linear = torch.nn.Linear(3, 4)
1458
mod.bias = torch.rand(4)
1459
gm = GraphModule(mod, g)
1461
input = torch.rand(3)
1463
ref = torch.sin(mod.linear(input) + mod.bias)
1464
self.assertEqual(r, ref)
1466
def test_remove_uses(self):
1467
g : torch.fx.Graph = Graph()
1468
x : torch.fx.Node = g.placeholder('x')
1469
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1470
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1473
neg.replace_all_uses_with(relu)
1476
self.assertTrue(neg not in relu.users)
1478
def test_remove_uses_with_custom_filter(self):
1479
g : torch.fx.Graph = Graph()
1480
x : torch.fx.Node = g.placeholder('x')
1481
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1482
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1485
neg.replace_all_uses_with(relu, lambda x: x != neg)
1487
self.assertTrue(neg in relu.users)
1490
def test_nonetype_annotation(self):
1491
eb = torch.nn.EmbeddingBag(3, 4)
1494
def test_pickle_nonetype_annotation(self):
1495
eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
1496
traced = symbolic_trace(eb)
1497
pickled = pickle.dumps(traced)
1498
loaded = pickle.loads(pickled)
1500
input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1501
offsets = torch.LongTensor([0, 4])
1502
self.assertEqual(loaded(input, offsets), traced(input, offsets))
1504
def test_return_tuple(self):
1505
class M(torch.nn.Module):
1506
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1511
traced = symbolic_trace(original)
1512
self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
1514
def test_construct_root_dict(self):
1515
graph : torch.fx.Graph = torch.fx.Graph()
1516
a : torch.fx.Node = graph.create_node('placeholder', 'x')
1517
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1518
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1519
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1522
linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
1523
add_param : torch.Tensor = torch.rand(3, 4)
1524
gm : torch.fx.GraphModule = torch.fx.GraphModule(
1525
{'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
1528
assert 'self.foo.bar.baz' in gm.code
1530
x : torch.Tensor = torch.rand(3, 3)
1531
out : torch.Tensor = gm(x)
1532
ref_out : torch.Tensor = linear_mod(x) + add_param
1533
self.assertEqual(out, ref_out)
1535
def test_symbolic_trace_assert(self):
1537
class AssertsTensorShape(torch.nn.Module):
1538
def forward(self, x):
1539
torch._assert(x.shape[1] > 4, "assert_foobar")
1542
m = AssertsTensorShape()
1543
# verify traceability
1544
traced = symbolic_trace(m)
1545
# verify assertion on traced model works correctly at runtime
1546
traced(torch.rand(4, 5))
1547
with self.assertRaisesRegex(AssertionError, "assert_foobar"):
1548
traced(torch.rand(4, 3))
1549
# verify the symbolically traced module is scriptable
1550
ms = torch.jit.script(m)
1551
with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
1552
ms(torch.rand(4, 3))
1554
def test_fx_create_arg(self):
1555
class CustomArgObject:
1556
def __init__(self, x, y):
1560
def __fx_create_arg__(self, tracer: torch.fx.Tracer):
1561
return tracer.create_node(
1565
tracer.create_arg(self.x),
1566
tracer.create_arg(self.y),
1571
class HasCustomArgObjectWhenLeaf(torch.nn.Module):
1572
def forward(self, o: CustomArgObject):
1573
# Not normally traceable; good reason to make
1574
# this module a leaf.
1579
class Root(torch.nn.Module):
1582
self.inner = HasCustomArgObjectWhenLeaf()
1584
def forward(self, x, y):
1585
o = CustomArgObject(x, y)
1586
return self.inner(o)
1588
class CreateArgTracer(torch.fx.Tracer):
1589
def is_leaf_module(self, m, module_qualified_name):
1590
return type(m) is HasCustomArgObjectWhenLeaf
1593
graph = CreateArgTracer().trace(m)
1594
gm = torch.fx.GraphModule(m, graph)
1595
assert "CustomArgObject(" in gm.code
1597
def test_trace_fn_constant(self):
1598
some_constant = torch.rand(3, 4)
1601
return some_constant + x
1603
traced = symbolic_trace(add_const)
1605
input = torch.rand(3, 4)
1606
self.assertEqual(traced(input), add_const(input))
1608
def test_copy_no_remap(self):
1609
traced = symbolic_trace(SimpleTest())
1611
copied = torch.fx.Graph()
1612
for node in g.nodes:
1613
copied.node_copy(node)
1614
with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
1617
def test_wrong_topo(self):
1618
graph : torch.fx.Graph = torch.fx.Graph()
1619
a : torch.fx.Node = graph.create_node('placeholder', 'x')
1620
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1621
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1622
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
1624
nodes = list(graph.nodes)
1625
nodes[3].append(nodes[2])
1626
with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
1629
def test_wrong_target_type(self):
1630
graph : torch.fx.Graph = torch.fx.Graph()
1631
with self.assertRaises(ValueError):
1632
n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
1635
def test_example_shape_prop(self):
1636
class TestCase(torch.nn.Module):
1639
self.attr = torch.randn(3, 4)
1640
self.submod = torch.nn.Linear(4, 4)
1642
def forward(self, x):
1643
return torch.neg(self.submod(x.relu() + self.attr))
1645
tc_traced = symbolic_trace(tc)
1646
ref_out = tc_traced(torch.rand(3, 4))
1647
shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
1649
# Make sure we're testing all opcodes
1651
output_shape : Optional[torch.Shape] = None
1652
output_stride : Optional[Tuple[int]] = None
1653
for node in tc_traced.graph.nodes:
1654
opcodes.add(node.op)
1655
if node.op == 'output':
1656
output_shape = node.args[0].meta['tensor_meta'].shape
1657
output_stride = node.args[0].meta['tensor_meta'].stride
1658
self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
1659
'call_module', 'output'})
1661
# Test shape propagation and make sure results match actual
1662
self.assertEqual(output_shape, ref_out.shape)
1663
self.assertEqual(output_stride, ref_out.stride())
1665
def test_shape_prop_layout(self):
1666
class ConvTest(torch.nn.Module):
1669
self.conv_mod = torch.nn.Conv2d(5, 5, 3)
1671
def forward(self, x):
1672
return self.conv_mod(x)
1675
test_mod = ConvTest()
1676
traced = symbolic_trace(test_mod)
1677
x = torch.randn(5, 5, 224, 224)
1678
shape_prop.ShapeProp(traced).propagate(x)
1680
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1681
for node in traced.graph.nodes)
1683
x_channels_last = x.contiguous(memory_format=torch.channels_last)
1684
traced.to(memory_format=torch.channels_last)
1685
shape_prop.ShapeProp(traced).propagate(x_channels_last)
1686
for node in traced.graph.nodes:
1687
# NB: the implementation of conv may not preserve the memory format,
1688
# unfortunately. The best we can do is just check that the placeholder
1689
# node is channels-last
1690
if node.op in {'placeholder'}:
1691
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
1693
def test_shape_prop_aggregate(self):
1694
class ReturnTwo(torch.nn.Module):
1695
def forward(self, x):
1696
return (3, torch.sum(x))
1698
class UnderTest(torch.nn.Module):
1701
self.rt = ReturnTwo()
1703
def forward(self, x):
1708
class RTTracer(torch.fx.Tracer):
1709
def is_leaf_module(self, m, module_qualified_name):
1710
return type(m) is ReturnTwo
1712
graph = RTTracer().trace(ut)
1713
mod = torch.fx.GraphModule(ut, graph)
1715
shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
1717
for node in mod.graph.nodes:
1718
if node.op == 'call_module':
1719
assert 'tensor_meta' in node.meta
1720
tensor_meta = node.meta['tensor_meta']
1721
assert tensor_meta[0] == 3
1722
assert tensor_meta[1].shape == torch.Size([])
1724
def test_shape_prop_layout_3d(self):
1725
class ConvTest3d(torch.nn.Module):
1728
self.conv_mod = torch.nn.Conv3d(5, 5, 3)
1730
def forward(self, x):
1731
return self.conv_mod(x)
1733
test_mod_3d = ConvTest3d()
1734
traced_3d = symbolic_trace(test_mod_3d)
1735
x_3d = torch.randn(5, 5, 224, 224, 15)
1736
shape_prop.ShapeProp(traced_3d).propagate(x_3d)
1737
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1738
for node in traced_3d.graph.nodes)
1740
x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
1741
traced_3d.to(memory_format=torch.channels_last_3d)
1742
shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
1743
for node in traced_3d.graph.nodes:
1744
# NB: the implementation of conv may not preserve the memory format,
1745
# unfortunately. The best we can do is just check that the placeholder
1746
# node is channels-last
1747
if node.op in {'placeholder'}:
1748
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
1750
def test_nn_module_stack(self):
1751
class SubModule(torch.nn.Module):
1754
self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
1756
def forward(self, x):
1757
return self.conv_mod(x)
1759
class MyModule(torch.nn.Module):
1762
self.sub_mod = SubModule()
1764
def forward(self, x):
1765
return self.sub_mod(x)
1768
gm = torch.fx.symbolic_trace(m)
1771
expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))),
1772
('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))]
1773
for node in gm.graph.nodes:
1774
mod_stack = node.meta.get('nn_module_stack', {})
1777
stack_list = list(mod_stack.items())
1778
self.assertEqual(stack_list, expected_stack)
1780
def test_transformer_preserves_nn_module_stack_for_get_attr(self):
1781
class M(torch.nn.Module):
1784
self.weight = torch.nn.Parameter(torch.ones(1, 1))
1786
def forward(self, x):
1787
return self.weight + x
1789
tracer = torch.fx.Tracer()
1790
graph = tracer.trace(M())
1791
gm = GraphModule(tracer.root, graph)
1792
for node in gm.graph.nodes:
1793
if node.op == 'get_attr':
1794
node.meta["nn_module_stack"] = "self"
1795
node.meta["stack_trace"] = "stack_trace"
1796
node.meta["source_fn_stack"] = "source_fn_stack"
1797
new_gm = Transformer(gm).transform()
1798
for node in new_gm.graph.nodes:
1799
if node.op == 'get_attr':
1800
self.assertEqual(node.meta["nn_module_stack"], "self")
1801
self.assertEqual(node.meta["stack_trace"], "stack_trace")
1802
self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
1805
def test_interpreter(self):
1806
class MyModule(torch.nn.Module):
1809
self.param = torch.nn.Parameter(torch.rand(3, 4))
1810
self.linear = torch.nn.Linear(4, 5)
1812
def forward(self, x):
1813
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1816
gm = torch.fx.symbolic_trace(m)
1818
interpreter = Interpreter(gm)
1819
input = torch.randn(3, 4)
1820
self.assertEqual(interpreter.run(input), gm(input))
1821
self.assertEqual(interpreter.run(input), m(input))
1823
def test_interpreter_other_graph(self):
1824
class MyModule(torch.nn.Module):
1827
self.param = torch.nn.Parameter(torch.rand(3, 4))
1828
self.linear = torch.nn.Linear(4, 5)
1830
def forward(self, x):
1831
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1834
gm = torch.fx.symbolic_trace(m)
1836
interpreter = Interpreter(gm, graph=gm.graph)
1837
input = torch.randn(3, 4)
1838
self.assertEqual(interpreter.run(input), gm(input))
1839
self.assertEqual(interpreter.run(input), m(input))
1841
def test_interpreter_run_node_override(self):
1842
class MyModule(torch.nn.Module):
1845
self.param = torch.nn.Parameter(torch.rand(3, 4))
1846
self.linear = torch.nn.Linear(4, 5)
1848
def forward(self, x):
1849
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1852
gm = torch.fx.symbolic_trace(m)
1854
class RunNodeInterpreter(Interpreter):
1855
def __init__(self, module):
1856
super().__init__(module)
1858
def run_node(self, n : Node) -> Any:
1859
result = super().run_node(n)
1860
n.cached_value = result
1863
input = torch.randn(3, 4)
1864
RunNodeInterpreter(gm).run(input)
1865
for node in gm.graph.nodes:
1866
assert hasattr(node, 'cached_value')
1868
def test_interpreter_onthefly_swap(self):
1871
return torch.sigmoid(x).neg()
1873
gm = torch.fx.symbolic_trace(fn)
1875
class NegSigmSwapInterpreter(Interpreter):
1876
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1877
if target == torch.sigmoid:
1878
return torch.neg(*args, **kwargs)
1879
return super().call_function(n) # noqa: F821
1881
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1883
call_self, *args_tail = args
1884
return call_self.sigmoid(*args_tail, **kwargs)
1885
return super().call_method(n) # noqa: F821
1887
input = torch.randn(3, 4)
1888
result = NegSigmSwapInterpreter(gm).run(input)
1889
self.assertEqual(result, torch.neg(input).sigmoid())
1891
def test_interpreter_partial_eval(self):
1892
class MyModule(torch.nn.Module):
1895
self.param = torch.nn.Parameter(torch.rand(3, 4))
1896
self.linear = torch.nn.Linear(4, 5)
1898
def forward(self, x):
1899
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1901
gm = torch.fx.symbolic_trace(MyModule())
1902
interp = Interpreter(gm)
1904
for node in gm.graph.nodes:
1905
if node.op == 'call_module' and node.target == 'linear':
1906
env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
1908
assert len(env) == 1
1909
x = torch.randn(3, 4)
1910
result = interp.run(x, initial_env=env)
1911
self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
1913
def test_interpreter_star_args(self):
1914
def with_star_args(x, *args):
1917
gm = torch.fx.symbolic_trace(with_star_args)
1918
interp = Interpreter(gm)
1919
result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
1920
self.assertEqual(result, torch.ones(3, 4) * 2.0)
1922
@skipIfNoTorchVision
1923
def test_interpreter_noop_resnet18(self):
1924
rn18 = torchvision_models.resnet18()
1925
transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
1926
inp = torch.randn(5, 3, 224, 224)
1927
self.assertEqual(transformed(inp), rn18(inp))
1929
@skipIfNoTorchVision
1930
def test_interpreter_gc_values(self):
1931
rn18 = torchvision_models.resnet18()
1932
interp = Interpreter(symbolic_trace(rn18))
1933
inp = torch.rand(5, 3, 224, 224)
1934
out = interp.run(inp)
1935
env_key_names = {n.name for n in interp.env.keys()}
1936
self.assertEqual(env_key_names, {'output'})
1938
def test_interpreter_default_args(self):
1939
class Model(torch.nn.Module):
1940
def forward(self, x, y=3.14159):
1944
gm = torch.fx.symbolic_trace(model)
1946
interp = Interpreter(gm)
1947
x = torch.randn(5, 3)
1949
torch.testing.assert_close(out, x + 3.14159)
1951
def test_interpreter_not_enough_args(self):
1952
class Model(torch.nn.Module):
1953
def forward(self, x, y):
1957
gm = torch.fx.symbolic_trace(model)
1959
interp = Interpreter(gm)
1960
x = torch.randn(5, 3)
1961
with self.assertRaisesRegex(RuntimeError,
1962
'Expected positional argument for parameter y, but one was not passed in'):
1965
def test_transformer_noop(self):
1966
class MyModule(torch.nn.Module):
1969
self.param = torch.nn.Parameter(torch.rand(3, 4))
1970
self.linear = torch.nn.Linear(4, 5)
1972
def forward(self, x):
1973
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1976
gm = torch.fx.symbolic_trace(m)
1978
new_gm = Transformer(gm).transform()
1980
input = torch.randn(3, 4)
1981
self.assertEqual(new_gm(input), gm(input))
1983
def test_transformer_op_swap(self):
1986
return torch.sigmoid(x).neg()
1988
gm = torch.fx.symbolic_trace(fn)
1990
class NegSigmSwapXformer(Transformer):
1991
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1992
if target == torch.sigmoid:
1993
return torch.neg(*args, **kwargs)
1994
return super().call_function(n) # noqa: F821
1996
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1998
call_self, *args_tail = args
1999
return call_self.sigmoid(*args_tail, **kwargs)
2000
return super().call_method(n) # noqa: F821
2002
transformed = NegSigmSwapXformer(gm).transform()
2003
input = torch.randn(3, 4)
2004
self.assertEqual(transformed(input), torch.neg(input).sigmoid())
2006
def test_transformer_multi_outputs(self):
2007
class MyModule(torch.nn.Module):
2010
self.param = torch.nn.Parameter(torch.rand(3, 4))
2011
self.linear = torch.nn.Linear(4, 5)
2013
def forward(self, x):
2015
out = self.linear(x)
2019
gm = torch.fx.symbolic_trace(m)
2021
new_gm = Transformer(gm).transform()
2023
input = torch.randn(3, 4)
2024
self.assertEqual(new_gm(input), gm(input))
2026
def test_fn_type_annotations(self):
2027
class Foo(torch.nn.Module):
2028
def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
2029
return {'a': p.x + p.y + z + i}
2031
foo_scripted = torch.jit.script(Foo())
2032
foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2034
fxed = symbolic_trace(Foo())
2035
fxed_scripted = torch.jit.script(fxed)
2036
fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
2038
def test_fn_type_annotation_empty(self):
2039
def forward(a : List[torch.Tensor]):
2041
torch.jit.script(symbolic_trace(forward))
2043
def test_wrapped_method(self):
2044
def wrap_with_relu(fn):
2045
@functools.wraps(fn)
2046
def wrapper(*args, **kwargs):
2047
return torch.relu(fn(*args, **kwargs))
2050
class Foo(torch.nn.Module):
2052
def forward(self, x, w):
2053
return torch.matmul(x, w)
2056
traced = symbolic_trace(f)
2057
x, w = torch.rand(3, 4), torch.rand(4, 4)
2058
self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
2060
def test_empty_graph_codegen(self):
2061
graph = torch.fx.Graph()
2062
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2063
self.assertEqual(gm(), None)
2065
def test_sequential(self):
2066
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
2067
gm = torch.fx.symbolic_trace(m)
2068
gm_copy = copy.deepcopy(gm)
2070
def test_ctx_mgr(self):
2071
@contextlib.contextmanager
2075
class M(torch.nn.Module):
2077
def forward(self, x):
2078
return torch.relu(x)
2081
self.checkGraphModule(m, (torch.rand(3, 4),))
2083
def test_typename_print(self):
2084
graph : torch.fx.Graph = torch.fx.Graph()
2085
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2086
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
2087
type_expr=List[float])
2088
output : torch.fx.Node = graph.output(b)
2090
self.assertTrue('typing.List[float]' in str(graph))
2092
def test_layout(self):
2093
class M(torch.nn.Module):
2094
def forward(self, x):
2095
return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
2097
traced = symbolic_trace(M())
2098
x = torch.rand(5, 9, 3, 4)
2099
self.assertEqual(traced(x), torch.zeros_like(x))
2101
def test_ellipsis(self):
2102
class M(torch.nn.Module):
2103
def forward(self, x, y):
2104
return x + y[:, 1:10, ...]
2106
traced = symbolic_trace(M())
2107
x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
2108
self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
2110
def test_inf_nan(self):
2111
class FooMod(torch.nn.Module):
2112
def forward(self, x):
2113
return x + float('inf'), x + float('-inf'), x + float('nan')
2116
self.checkGraphModule(fm, (torch.rand(3, 4),))
2118
def test_inf_nan_kwds(self):
2119
graph : torch.fx.Graph = torch.fx.Graph()
2120
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2121
b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
2122
c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
2123
graph.output((b, c))
2125
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2126
x = torch.rand(3, 4)
2127
self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
2129
def test_deepcopy_recursion_depth(self):
2130
depth = sys.getrecursionlimit() + 20
2132
g = torch.fx.Graph()
2133
x = g.placeholder('x')
2134
for i in range(depth):
2135
x = g.call_function(torch.relu, (x,))
2138
copied_graph = copy.deepcopy(g)
2141
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2142
val_map[orig_node] = new_node
2144
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2145
orig_users = set(orig_node.users.keys())
2146
orig_users_equiv = {val_map[u] for u in orig_users}
2147
new_users = set(new_node.users.keys())
2148
self.assertEqual(orig_users_equiv, new_users)
2150
@skipIfNoTorchVision
2151
def test_replace_uses(self):
2152
rn18 = torchvision_models.resnet18()
2154
class LowerReluTracer(torch.fx.Tracer):
2155
def is_leaf_module(self, m : torch.nn.Module, qualname : str):
2156
if isinstance(m, torch.nn.ReLU):
2158
return super().is_leaf_module(m, qualname)
2160
rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
2163
for node in rn18_traced.graph.nodes:
2164
if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
2165
kwargs = node.kwargs.copy()
2166
# Neg doesn't have in-place
2167
kwargs.pop('inplace')
2168
with rn18_traced.graph.inserting_before(node):
2169
new_node = rn18_traced.graph.call_function(
2170
the_function=torch.neg, args=node.args, kwargs=node.kwargs)
2171
node.replace_all_uses_with(replace_with=new_node)
2172
to_erase.append(node)
2174
for node in to_erase:
2175
rn18_traced.graph.erase_node(node)
2178
def test_replace_input(self):
2179
graph : torch.fx.Graph = torch.fx.Graph()
2180
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2181
y : torch.fx.Node = graph.create_node('placeholder', 'y')
2182
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2183
output : torch.fx.Node = graph.output(b)
2185
b.replace_input_with(x, y)
2187
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2189
input_x = torch.randn(33, 44)
2190
input_y = torch.randn(11, 22)
2191
self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
2193
def test_insertion_point(self):
2194
graph : torch.fx.Graph = torch.fx.Graph()
2195
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2196
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2197
output : torch.fx.Node = graph.output(b)
2199
with graph.inserting_before(b):
2200
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2201
_, *relu_args = b.args
2202
b.args = (neg, *relu_args)
2204
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2206
input = torch.randn(33, 44)
2207
self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2209
def test_update_args_api(self):
2210
graph : torch.fx.Graph = torch.fx.Graph()
2211
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2212
y : torch.fx.Node = graph.create_node('placeholder', 'y')
2213
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2214
output : torch.fx.Node = graph.output(b)
2216
orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2217
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2218
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2222
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2223
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2225
def test_update_kwargs_api(self):
2226
graph : torch.fx.Graph = torch.fx.Graph()
2227
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2228
y : torch.fx.Node = graph.create_node('placeholder', 'y')
2229
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
2230
output : torch.fx.Node = graph.output(b)
2232
orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2233
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2234
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2237
b.update_kwarg('input', y)
2238
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2239
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2241
def test_immutable_list_pytree_ops(self):
2242
rand_tensor = torch.randn(5, 3)
2243
l = immutable_list([3, [rand_tensor, 42]])
2245
flattened, spec = pytree.tree_flatten(l)
2246
assert flattened == [3, rand_tensor, 42]
2248
unflattened = pytree.tree_unflatten(flattened, spec)
2249
assert unflattened == l
2250
assert isinstance(unflattened, immutable_list)
2252
def test_immutable_dict_pytree_ops(self):
2253
rand_tensor = torch.randn(5, 3)
2254
d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]})
2256
flattened, spec = pytree.tree_flatten(d)
2257
assert flattened == [3, rand_tensor, 42]
2259
unflattened = pytree.tree_unflatten(flattened, spec)
2260
assert unflattened == d
2261
assert isinstance(unflattened, immutable_dict)
2263
def test_move_before(self):
2264
graph : torch.fx.Graph = torch.fx.Graph()
2265
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2266
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2267
output : torch.fx.Node = graph.output(b)
2269
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
2270
_, *relu_args = b.args
2271
b.args = (neg, *relu_args)
2274
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2276
input = torch.randn(33, 44)
2277
self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2279
def test_prepend_self(self):
2280
graph : torch.fx.Graph = torch.fx.Graph()
2281
x : torch.fx.Node = graph.create_node('placeholder', 'x')
2282
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2283
output : torch.fx.Node = graph.output(b)
2287
self.assertEqual(len(graph.nodes), 3)
2289
def test_erase_node_error(self):
2291
traced = symbolic_trace(st)
2293
for node in traced.graph.nodes:
2294
# Test deleting with uses both in another Node and at the output
2295
if node.target in [operator.add, torch.relu]:
2296
with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
2297
traced.graph.erase_node(node)
2299
def test_copy_it(self):
2300
d = immutable_dict([(3, 4), (5, 6)])
2301
l = immutable_list([(3, 4), (5, 6)])
2303
self.assertEqual(d, deepcopy(d))
2304
self.assertEqual(l, deepcopy(l))
2306
def test_get_torch_func_signature(self):
2307
for key in dir(torch):
2308
obj = getattr(torch, key)
2310
schemas = get_signature_for_torch_op(obj)
2312
def test_find_uses(self):
2313
graph = torch.fx.Graph()
2314
x = torch.fx.Proxy(graph.placeholder('x'))
2319
graph.output((y + z + u).node)
2322
users_of_x = x.node.users
2323
self.assertEqual(len(users_of_x), 3)
2324
expected_ops = {'relu', 'add', 'neg'}
2325
for use in users_of_x:
2326
assert any(use.name.startswith(prefix) for prefix in expected_ops)
2328
def test_inline_graph(self):
2329
class InlineInto(torch.nn.Module):
2330
def forward(self, x):
2331
return torch.relu(x)
2333
class ToInline(torch.nn.Module):
2334
def forward(self, x):
2337
inline_into = symbolic_trace(InlineInto())
2338
to_inline = symbolic_trace(ToInline())
2340
combined_graph = torch.fx.Graph()
2341
output_node = combined_graph.graph_copy(inline_into.graph, {})
2343
input_node = next(iter(to_inline.graph.nodes))
2344
assert input_node and input_node.op == 'placeholder'
2346
val_map = {input_node : output_node}
2347
output = combined_graph.graph_copy(to_inline.graph, val_map)
2348
combined_graph.output(output)
2350
combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
2352
input = torch.rand(3, 4)
2353
self.assertEqual(combined_module(input), input.relu().neg())
2355
def test_multi_insert_point(self):
2356
graph = torch.fx.Graph()
2357
x = torch.fx.Proxy(graph.placeholder('x'))
2358
relu = torch.relu(x)
2360
with graph.inserting_before(relu.node):
2364
graph.output((relu.node, z.node))
2367
expected_ops = ['x', 'neg', 'tanh', 'relu']
2368
for node, expected in zip(graph.nodes, expected_ops):
2369
assert expected in node.name
2371
def test_reassign_args_kwargs_uses(self):
2372
graph = torch.fx.Graph()
2373
x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
2376
graph.output(zed.node)
2379
# zed = z + z + z -> zed = z + z + x
2380
zed.node.args = (zed.node.args[0], x.node)
2381
self.assertEqual(list(x.node.users.keys()), [z.node, zed.node])
2383
# z = x + y -> z = y + y
2384
z.node.args = (y.node, y.node)
2385
self.assertEqual(list(x.node.users.keys()), [zed.node])
2387
def test_trace_function(self):
2389
return torch.relu(x) + y
2391
x, y = torch.randn(3, 4), torch.randn(3, 4)
2392
self.checkGraphModule(foo, (x, y))
2395
def test_trace_return_dataclass(self):
2397
Test case for Module that return dataclass
2399
from dataclasses import dataclass
2406
class ModuleReturnDataclass(torch.nn.Module):
2407
def forward(self, d : torch.Tensor):
2408
return MyOutput(foo=d + d, bar=d * 3)
2410
module = ModuleReturnDataclass()
2411
traced_graph = symbolic_trace(module).graph
2414
gm = GraphModule(module, traced_graph)
2417
self.assertEqual(module(x), gm(x))
2419
def test_trace_return_dataclass_nested(self):
2421
Test case for Module that return dataclass
2423
from dataclasses import dataclass
2430
class ModuleReturnDataclass(torch.nn.Module):
2431
def forward(self, d : torch.Tensor):
2432
return MyOutput(foo=d + d, bar=d * 3)
2434
class CallsModule(torch.nn.Module):
2437
self.m = ModuleReturnDataclass()
2439
def forward(self, x):
2441
return MyOutput(foo=tmp.foo, bar=tmp.bar)
2443
module = CallsModule()
2444
traced_graph = symbolic_trace(module).graph
2447
gm = GraphModule(module, traced_graph)
2450
self.assertEqual(module(x), gm(x))
2453
def test_trace_return_namedtuple(self):
2455
Test case for Module that return namedtuple
2457
class MyOutput(NamedTuple):
2461
class ModuleReturnNamedTuple(torch.nn.Module):
2462
def forward(self, d : torch.Tensor):
2463
return MyOutput(foo=d, bar=d)
2466
module = ModuleReturnNamedTuple()
2468
traced_graph = symbolic_trace(module).graph
2471
gm = GraphModule(module, traced_graph)
2474
self.assertEqual(module(x), gm(x))
2476
def test_trace_dict_int_keys(self):
2477
class ModWithDictArg(torch.nn.Module):
2478
def forward(self, d : Dict[int, torch.Tensor]):
2481
class CallsModWithDict(torch.nn.Module):
2484
self.m = ModWithDictArg()
2486
def forward(self, x):
2487
return self.m({42: x})
2489
class MyTracer(torch.fx.Tracer):
2490
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2491
return isinstance(m, ModWithDictArg)
2493
traced_graph = MyTracer().trace(CallsModWithDict())
2495
def test_trace_dict_proxy_keys(self):
2496
class ModWithDictArg(torch.nn.Module):
2497
def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
2500
class CallsModWithDict(torch.nn.Module):
2503
self.m = ModWithDictArg()
2505
def forward(self, x):
2506
return self.m({x: x})
2508
class MyTracer(torch.fx.Tracer):
2509
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
2510
return isinstance(m, ModWithDictArg)
2512
with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
2513
traced_graph = MyTracer().trace(CallsModWithDict())
2515
def test_module_deepcopy_edit_nodes(self):
2516
class Foo(torch.nn.Module):
2517
def forward(self, x):
2518
return torch.relu(x)
2520
traced1 = symbolic_trace(Foo())
2521
copied = copy.deepcopy(traced1)
2523
for node in copied.graph.nodes:
2524
if node.target == torch.relu:
2525
node.target = torch.neg
2530
x = torch.randn(15, 15)
2531
torch.testing.assert_close(traced1(x), torch.relu(x))
2532
torch.testing.assert_close(copied(x), torch.neg(x))
2534
def test_direct_param_use(self):
2535
class TransposeTest(torch.nn.Module):
2538
self.b = torch.nn.Parameter(torch.rand(4, 3))
2540
def forward(self, x):
2543
class Foo(torch.nn.Module):
2546
self.a = TransposeTest()
2548
def forward(self, x):
2549
return self.a.b, self.a.b.t(), self.a.b.view(12)
2551
traced = torch.fx.symbolic_trace(Foo())
2552
assert all('constant' not in node.target for node in traced.graph.nodes)
2554
def test_single_default_arg(self):
2555
class M(torch.nn.Module):
2556
def forward(self, y=1):
2560
self.checkGraphModule(m, ())
2561
self.checkGraphModule(m, (3,))
2563
def test_multiple_default_args(self):
2564
class M(torch.nn.Module):
2565
def forward(self, y=1, z=2):
2569
self.checkGraphModule(m, ())
2570
self.checkGraphModule(m, (3,))
2571
self.checkGraphModule(m, (3, 4))
2573
def test_regular_and_default_args(self):
2574
class M(torch.nn.Module):
2575
def forward(self, x, y=1):
2579
self.checkGraphModule(m, (2,))
2580
self.checkGraphModule(m, (2, 3))
2582
def test_string_literal_return(self):
2583
class M(torch.nn.Module):
2588
self.checkGraphModule(m, ())
2590
def test_namedtuple_return_qualname(self):
2591
class NamedTupReturn(torch.nn.Module):
2592
def forward(self, x):
2593
return MyNamedTup(x, x)
2595
traced = symbolic_trace(NamedTupReturn())
2596
input = torch.rand(3, 4)
2597
self.assertEqual(traced(input), MyNamedTup(input, input))
2599
def test_update_args_kwargs_yells_at_you(self):
2600
symtraced = symbolic_trace(SimpleTest())
2601
node = next(iter(symtraced.graph.nodes))
2602
with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
2603
node.__update_args_kwargs((), {})
2605
def test_torchbind_class_attribute_in_fx(self):
2606
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2607
self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
2609
class FooBar1234(torch.nn.Module):
2612
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
2618
self.checkGraphModule(m, ())
2620
def test_torchbind_class_attribute_in_fx_tensor_arg(self):
2621
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
2622
self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
2624
class FooBar2341(torch.nn.Module):
2627
self.f = torch.classes._TorchScriptTesting._ReLUClass()
2629
def forward(self, x):
2630
return self.f.run(x)
2634
traced = symbolic_trace(m)
2635
input = torch.randn(3, 4)
2636
self.assertEqual(traced(input), m(input))
2638
self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2640
def test_script_method_trace(self):
2641
class Scripted(torch.nn.Module):
2642
def forward(self, x):
2643
return torch.relu(x)
2645
class Holder(torch.nn.Module):
2648
self.s = torch.jit.script(Scripted())
2650
def forward(self, x):
2654
traced = symbolic_trace(h)
2655
input = torch.randn(3, 4)
2656
self.assertEqual(traced(input), h(input))
2658
self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
2660
def test_namedtuple_return_trace(self):
2661
class NamedTupReturn(torch.nn.Module):
2662
def forward(self, x):
2665
traced = symbolic_trace(NamedTupReturn())
2666
input = torch.rand(3, 4)
2667
self.assertEqual(traced(input), Pair(input, input))
2669
def test_named_tuple_inlined(self):
2670
class NamedTupMod(torch.nn.Module):
2671
def forward(self, inp):
2672
return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
2675
input = torch.rand(3, 4)
2677
traced = symbolic_trace(m)
2680
self.assertEqual(ref, res)
2682
# Check Pair NamedTuple works when inlined into the function call.
2683
ph = call_func = None
2684
for node in traced.graph.nodes:
2685
if node.op == "placeholder":
2687
elif node.op == "call_function" and node.target == wrapped_named_tup:
2688
node.update_arg(0, Pair(ph, 1.2))
2689
node.update_kwarg("p2", Pair(3.4, ph))
2692
self.assertTrue(call_func is not None)
2693
self.assertTrue(isinstance(call_func.args[0], Pair))
2694
self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
2695
self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
2696
self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
2698
traced.graph.eliminate_dead_code()
2701
self.assertEqual(ref, res)
2703
def test_return_type_exists(self):
2704
class ReturnTypeModule(torch.nn.Module):
2705
def other(self, x: List[str]) -> List[str]:
2708
def forward(self, x: List[str]) -> List[str]:
2709
return self.other(x)
2711
traced = symbolic_trace(ReturnTypeModule())
2712
self.assertIn("-> typing_List[str]", traced._code)
2713
scripted = torch.jit.script(traced)
2714
self.assertIn("-> List[str]", scripted.code)
2716
def getitem_inner(self):
2717
class GetItemBase(torch.nn.Module):
2720
self.register_buffer('pe', torch.randn(8, 8))
2722
class GetItem1(GetItemBase):
2723
def forward(self, x):
2724
return self.pe[:, :x.size(0)]
2726
class GetItem2(GetItemBase):
2727
def forward(self, x):
2728
return self.pe[x.size(0)]
2730
class GetItem3(GetItemBase):
2731
def forward(self, x):
2732
return self.pe[4] # fx creates `self._tensor_constant0` here
2734
self.checkGraphModule(GetItem1(), [torch.zeros(4)])
2735
self.checkGraphModule(GetItem2(), [torch.zeros(4)])
2736
self.checkGraphModule(GetItem3(), [torch.zeros(4)])
2738
@unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
2739
"Will be checked in test_getitem_subproc")
2740
def test_getitem(self):
2741
self.getitem_inner()
2743
def test_getitem_subproc(self):
2744
# need to run this test in a subproc to work around:
2745
# https://github.com/pytorch/pytorch/issues/50710
2746
proc = Process(target=run_getitem_target)
2749
self.assertEqual(proc.exitcode, 0)
2752
def test_user_friendly_call_provenance_with_function(self):
2754
return wrapper_fn(x)
2756
traced = torch.fx.symbolic_trace(fn)
2758
with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2759
"being compiled since it was called"
2760
" from 'fn.forward'"):
2761
scripted = torch.jit.script(traced)
2763
def test_user_friendly_call_provenance_with_module(self):
2764
class M(torch.nn.Module):
2765
def forward(self, x):
2766
return wrapper_fn(x)
2768
traced = torch.fx.symbolic_trace(M())
2770
with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2771
"being compiled since it was called"
2772
" from 'M.forward'"):
2773
scripted = torch.jit.script(traced)
2775
def test_snake_case(self):
2776
class M(torch.nn.Module):
2779
self.activations = torch.nn.ModuleDict([
2780
["snake_case", torch.nn.ReLU()],
2781
["PascalCase", torch.nn.LeakyReLU()],
2782
["ALL_CAPS", torch.nn.PReLU()]
2785
def forward(self, x):
2786
a = self.activations["snake_case"](x)
2787
b = self.activations["PascalCase"](x)
2788
c = self.activations["ALL_CAPS"](x)
2791
traced = symbolic_trace(M())
2794
("activations_snake_case", "activations.snake_case"),
2795
("activations_pascal_case", "activations.PascalCase"),
2796
("activations_all_caps", "activations.ALL_CAPS")
2800
for node in traced.graph.nodes:
2801
if node.op == "placeholder" or node.op == "output":
2804
target = check[i][1]
2805
self.assertEqual(name, node.name)
2806
self.assertEqual(target, node.target)
2808
self.assertEqual(i, 3)
2810
def test_no_mutation(self):
2811
from torch.fx.immutable_collections import immutable_list
2812
x = immutable_list([3, 4])
2813
with self.assertRaisesRegex(NotImplementedError, "new_args"):
2816
def test_partial_trace(self):
2817
class Foo(torch.nn.Module):
2818
def forward(self, x, y):
2824
mod_true = symbolic_trace(mod, concrete_args={'y': True})
2825
mod_false = symbolic_trace(mod, concrete_args={'y': False})
2826
self.assertEqual(mod_true(3, True), 6)
2827
print(mod_true.code)
2828
assert any(i.target == torch._assert for i in mod_true.graph.nodes)
2829
with self.assertRaises(AssertionError):
2831
self.assertEqual(mod_false(3, False), 3)
2832
with self.assertRaises(AssertionError):
2838
nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
2839
self.assertEqual(nf(3, lambda x: x * 2), 6)
2841
def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
2842
class M(torch.nn.Module):
2845
self.W = torch.nn.Parameter(torch.randn(5))
2847
def forward(self, x):
2848
return torch.dot(self.W, x)
2850
traced = torch.fx.symbolic_trace(M())
2852
out = [n for n in traced.graph.nodes if n.op == "output"][-1]
2853
with traced.graph.inserting_before(out):
2854
relu_out = traced.graph.call_method(method_name='relu',
2855
args=(out.args[0],))
2856
out.args = (relu_out,)
2860
with self.capture_stderr() as captured:
2861
with self.assertRaises(TypeError):
2864
self.assertRegex(captured[0],
2865
r"Call using an FX-traced Module, line .* of the "
2866
r"traced Module's generated forward function:")
2868
def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
2869
class M(torch.nn.Module):
2872
self.linear = torch.nn.Linear(3, 4)
2874
def forward(self, x):
2875
return self.linear(x)
2877
traced = torch.fx.symbolic_trace(M())
2879
# Do not change this to `capture_stderr` or another context
2880
# manager without ensuring that the output is as expected
2882
traced(torch.rand(5, 5))
2883
except RuntimeError:
2884
captured = traceback.format_exc()
2886
self.assertNotRegex(captured,
2887
r"Call using an FX-traced Module, line .* of the "
2888
r"traced Module's generated forward function:")
2890
def test_graph_module_replicate_for_dp(self):
2891
class Foo(torch.nn.Module):
2892
def forward(self, x):
2893
return torch.relu(x)
2895
gm = torch.fx.symbolic_trace(Foo())
2897
x = torch.randn(5, 3)
2900
replica = gm._replicate_for_data_parallel()
2901
out_replica = replica(x)
2903
torch.testing.assert_close(out_replica, out)
2905
def test_ast_rewriter_rewrites_assert(self):
2906
class M(torch.nn.Module):
2907
def forward(self, x: torch.Tensor, y: int, z: int):
2909
return torch.add(x, x)
2911
ast_rewriter = RewritingTracer()
2912
graph = ast_rewriter.trace(M())
2913
traced = GraphModule(ast_rewriter.root, graph, "gm")
2917
def test_ast_rewriter_rewrites_assert_with_message(self):
2918
class M(torch.nn.Module):
2919
def forward(self, x: torch.Tensor, y: int, z: int):
2920
assert y == z, "msg"
2921
return torch.add(x, x)
2923
ast_rewriter = RewritingTracer()
2924
graph = ast_rewriter.trace(M())
2925
traced = GraphModule(ast_rewriter.root, graph, "gm")
2929
def test_throw_out_variant(self):
2931
y = torch.rand_like(x)
2932
torch.sigmoid(x, out=y)
2935
class MyTracer(torch.fx.Tracer):
2936
check_mutable_operations = True
2939
with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
2940
traced_graph = tracer.trace(foo)
2942
def test_ast_rewriter_reassigns_submodules(self):
2943
class M(torch.nn.Module):
2946
self.bn = torch.nn.BatchNorm2d(100)
2948
def forward(self, x: torch.Tensor):
2949
return torch.add(x, x)
2951
ast_rewriter = RewritingTracer()
2952
graph = ast_rewriter.trace(M())
2953
traced = GraphModule(ast_rewriter.root, graph, "gm")
2957
def test_ast_rewriter_wrap(self):
2958
self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
2962
a_lifted_leaf((4, y), 3)
2963
+ a_lifted_leaf((3, 4), 5)
2964
+ a_lifted_leaf((y, y), y)
2967
ast_rewriter = RewritingTracer()
2968
graph = ast_rewriter.trace(to_trace)
2969
traced = GraphModule(ast_rewriter.root, graph, "gm")
2971
self.assertIn("a_lifted_leaf", traced.code)
2972
self.assertEqual(27, traced(2))
2973
self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
2975
def test_ast_rewriter_wrap_fn_directly(self):
2976
self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
2980
a_lifted_leaf2((4, y), 3)
2981
+ a_lifted_leaf2((3, 4), 5)
2982
+ a_lifted_leaf2((y, y), y)
2985
ast_rewriter = RewritingTracer()
2986
graph = ast_rewriter.trace(to_trace)
2987
traced = GraphModule(ast_rewriter.root, graph, "gm")
2989
self.assertIn("a_lifted_leaf2", traced.code)
2990
self.assertEqual(27, traced(2))
2991
self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
2993
def test_profiler_ranges_side_effect(self):
2994
g = torch.fx.Graph()
2995
handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',))
2996
g.call_function(torch.ops.profiler._record_function_exit, (handle,))
3000
for node in g.nodes:
3001
if node.op == 'call_function':
3002
found_targets.setdefault(node.target)
3004
list(found_targets.keys()),
3005
[torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3008
g.eliminate_dead_code()
3010
for node in g.nodes:
3011
if node.op == 'call_function':
3012
found_targets.setdefault(node.target)
3014
list(found_targets.keys()),
3015
[torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
3018
def test_ast_rewriter_wrapped_via_decorator(self):
3019
class F(torch.nn.Module):
3020
def forward(self, x):
3021
return wrapped_via_decorator(x)
3023
ast_rewriter = RewritingTracer()
3024
graph = ast_rewriter.trace(F())
3025
traced = GraphModule(ast_rewriter.root, graph, "gm")
3027
self.assertIn("wrapped_via_decorator", traced.code)
3028
self.assertEqual(traced(0), 1)
3029
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3030
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3032
def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
3033
self.assertEqual(wrapped_via_decorator(0), 1)
3036
return wrapped_via_decorator(y)
3038
ast_rewriter = RewritingTracer()
3039
graph = ast_rewriter.trace(to_trace)
3040
traced = GraphModule(ast_rewriter.root, graph, "gm")
3042
self.assertIn("wrapped_via_decorator", traced.code)
3043
self.assertEqual(traced(0), 1)
3044
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3045
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3047
transformed = torch.fx.Transformer(traced).transform()
3048
self.assertIn("wrapped_via_decorator", transformed.code)
3049
self.assertEqual(transformed(0), 1)
3050
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
3051
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
3053
def test_ast_rewriter_wrap_with_submodule(self):
3054
class M(torch.nn.Module):
3057
self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3059
def forward(self, x: torch.Tensor):
3060
return wrapped_with_submodule(x, self.batchnorm1d)
3062
ast_rewriter = RewritingTracer()
3063
graph = ast_rewriter.trace(M())
3064
traced = GraphModule(ast_rewriter.root, graph, "gm")
3066
self.assertIn("wrapped_with_submodule", traced.code)
3068
input = torch.rand(3, 2)
3069
ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
3070
self.assertEqual(ref_batchnorm1d(input), traced(input))
3072
def test_submodule_manipulation_API(self):
3073
class C(torch.nn.Module):
3076
self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
3077
self.param = torch.nn.Parameter(torch.rand(2, 3))
3079
def forward(self, x):
3080
return self.conv(torch.cat([self.param, x]))
3082
class B(torch.nn.Module):
3085
self.linear = torch.nn.Linear(100, 200)
3086
self.register_buffer("buf", torch.randn(2, 3))
3089
def forward(self, x):
3090
return self.linear(torch.cat([self.buf, self.net_c(x)]))
3092
class A(torch.nn.Module):
3096
self.param = torch.nn.Parameter(torch.rand(2, 3))
3098
def forward(self, x):
3099
return self.net_b(x) + self.param
3101
a = symbolic_trace(A())
3103
a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
3105
conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
3106
with a.graph.inserting_before(conv):
3107
with warnings.catch_warnings(record=True) as w:
3108
dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
3110
self.assertEqual(len(w), 0)
3112
conv.replace_all_uses_with(dropout)
3113
a.graph.erase_node(conv)
3116
def module_exists(gm: GraphModule, path: str) -> bool:
3117
return any(path == name for name, _ in gm.named_modules())
3119
def parameter_exists(gm: GraphModule, path: str) -> bool:
3120
return (any(path == name for name, _ in gm.named_parameters())
3121
and any(path == name for name in gm.state_dict().keys()))
3123
def buffer_exists(gm: GraphModule, path: str) -> bool:
3124
return (any(path == name for name, _ in gm.named_buffers())
3125
and any(path == name for name in gm.state_dict().keys()))
3127
# Test that we added the "dropout" submodule
3128
self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
3130
# Test `get_submodule` with an added submodule
3131
self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
3133
# Test that the "conv" submodule is still there
3134
self.assertTrue(module_exists(a, "net_b.net_c.conv"))
3136
# Test `get_submodule` with an original module
3137
self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
3139
# Test that the "conv" node is NOT still there
3140
conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
3141
self.assertEqual(conv, [])
3143
a.delete_submodule("net_b.net_c.conv")
3145
# Test that the "conv" submodule is now gone
3146
self.assertFalse(module_exists(a, "net_b.net_c.conv"))
3148
# Test `get_submodule` with a deleted submodule
3149
with self.assertRaisesRegex(AttributeError, "has no attribute "
3151
self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
3153
# Test `get_attr` warnings
3154
cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
3156
with a.graph.inserting_before(cat):
3158
with warnings.catch_warnings(record=True) as w:
3159
param = a.graph.get_attr(qualified_name="net_b.net_c.param")
3160
self.assertEqual(len(w), 0)
3162
with self.assertWarnsRegex(UserWarning, "Attempted to "
3163
"insert a get_attr Node with no "
3164
"underlying reference in the "
3165
"owning GraphModule"):
3166
bad_param = a.graph.get_attr(qualified_name="net_b.param")
3167
a.graph.erase_node(bad_param)
3169
cat.args = (*cat.args, param)
3175
# Test `get_parameter`
3176
a.get_parameter("net_b.net_c.param")
3177
with self.assertRaisesRegex(AttributeError, "is not an "
3179
a.get_parameter("net_b.buf")
3180
with self.assertRaisesRegex(AttributeError, "has no attribute "
3182
a.get_parameter("net_b.param")
3185
a.get_buffer("net_b.buf")
3186
with self.assertRaisesRegex(AttributeError, "is not a "
3188
a.get_buffer("net_b.net_c.param")
3189
with self.assertRaisesRegex(AttributeError, "has no attribute "
3191
a.get_buffer("net_b.net_c.buf")
3193
# Test non-nested attributes
3195
a.get_parameter("param")
3197
# Insert some unused submodules
3198
a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
3199
a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
3200
a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
3201
a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
3203
# Garbage collection
3204
a.delete_all_unused_submodules()
3206
# Test that all the unused submodules are gone
3207
self.assertFalse(module_exists(a, "net_b.embedding"))
3208
self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
3209
self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
3210
self.assertFalse(module_exists(a, "batch_norm_2d"))
3212
# Test that we didn't delete any unused Parameters or buffers
3213
self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
3214
self.assertTrue(buffer_exists(a, "net_b.buf"))
3218
def test_delete_unused_submodules_leaf(self):
3219
class SubModule(torch.nn.Module):
3222
self.linear = torch.nn.Linear(10, 10)
3223
self.relu = torch.nn.ReLU()
3225
def forward(self, x):
3230
class Model(torch.nn.Module):
3233
self.submod = SubModule()
3235
def forward(self, x):
3241
class MyCustomTracer(torch.fx.Tracer):
3242
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
3243
return module_qualified_name == "submod"
3245
inputs = torch.randn(1, 10)
3246
traced_graph = MyCustomTracer().trace(model)
3247
gm2 = torch.fx.GraphModule(model, traced_graph)
3248
gm2.delete_all_unused_submodules()
3249
torch.testing.assert_close(gm2(inputs), model(inputs))
3251
def test_fx_stateless(self):
3252
class MockModule(torch.nn.Module):
3255
self.l1 = torch.nn.Linear(1, 1)
3256
self.register_buffer('buffer', torch.ones(1))
3258
def forward(self, x):
3259
return self.l1(x) + self.buffer
3261
module = MockModule()
3262
x = torch.rand((1, 1))
3263
weight = torch.tensor([[1.0]], requires_grad=True)
3264
bias = torch.tensor([0.0], requires_grad=True)
3265
buffer = torch.tensor([0.0])
3266
parameters = {'l1.weight': weight,
3269
fx_module = torch.fx.symbolic_trace(module)
3270
res = torch.func.functional_call(fx_module, parameters, x)
3272
self.assertIsNotNone(weight.grad)
3273
self.assertIsNotNone(bias.grad)
3274
self.assertIsNone(buffer.grad)
3275
# Gradient was not calculated for the module stated and buffers
3276
self.assertIsNone(module.l1.weight.grad)
3277
self.assertIsNone(module.l1.bias.grad)
3278
self.assertIsNone(module.buffer.grad)
3280
def test_tracing_graphmodules_as_leaf_submodules(self):
3281
class A(torch.nn.Module):
3282
def forward(self, t):
3285
class B(torch.nn.Module):
3287
super(type(self), self).__init__()
3288
self.calling = False
3291
def forward(self, t):
3297
def __call__(self, *args):
3300
return super(type(self), self).__call__(*args)
3301
self.calling = False
3303
class M(torch.nn.Module):
3304
def __init__(self, a, b):
3309
def forward(self, t):
3314
class LeafTracer(Tracer):
3315
def is_leaf_module(self, module, name):
3318
class LeafTracerNotB(Tracer):
3319
def is_leaf_module(self, module, name):
3320
return False if "b" in name else True
3322
# Recompile calls added "for fun", since they
3323
# chain __call__ wrappers.
3326
# Test: B as a regular, non-leaf module
3328
a = symbolic_trace(A())
3331
graph = LeafTracerNotB().trace(m)
3332
gm = GraphModule(m, graph)
3335
# Test graphmodule/submodule a is not inlined.
3336
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3337
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3338
self.assertTrue(len(match) == 1)
3340
# Test submodule b is not treated as leaf.
3341
self.assertFalse(hasattr(gm, "b"))
3343
# Test assert custom __call__ on submodule b was honored.
3346
for n in gm.graph.nodes
3347
if n.op == "call_function" and n.target == operator.sub
3349
self.assertTrue(len(match) == 1)
3352
# Test: B as a regular, leaf module
3353
# symbolic_trace should only patch torch.nn.Module.__call__,
3354
# which means B.__call__ should still execute
3356
a = symbolic_trace(A())
3360
graph = LeafTracer().trace(m)
3361
gm = GraphModule(m, graph)
3364
# Test graphmodule/submodule a is not inlined.
3365
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3366
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3367
self.assertTrue(len(match) == 1)
3369
# Test submodule b is leaf:
3370
self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3371
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3372
self.assertTrue(len(match) == 1)
3374
# Test b.__call__ was run
3375
self.assertTrue(b.called)
3376
self.assertTrue(gm.get_submodule("b").called)
3379
# Test: B as GraphModule leaf
3380
# __call__ not honored since symbolic_trace directly invokes forward()
3382
a = symbolic_trace(A())
3384
b = symbolic_trace(B())
3387
graph = LeafTracer().trace(m)
3388
gm = GraphModule(m, graph)
3391
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3392
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3393
self.assertTrue(len(match) == 1)
3395
self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3396
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3397
self.assertTrue(len(match) == 1)
3399
def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
3400
class MyModule(torch.nn.Module):
3403
self.register_buffer("my_buff", torch.rand(3, 4))
3404
self.register_parameter(
3405
"my_param", torch.nn.Parameter(torch.rand(3, 4))
3408
def forward(self, x):
3409
return x + self.my_buff + self.my_param
3412
mod_traced = symbolic_trace(mod)
3414
# Create new GraphModule based on original, either w/ dict or root module.
3415
orig_buff = mod_traced.get_buffer("my_buff")
3416
orig_param = mod_traced.get_parameter("my_param")
3417
mod_traced_new = GraphModule(
3418
{"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
3422
# Check that both my_buff and my_param are found and the same.
3424
new_buff = mod_traced_new.get_buffer("my_buff")
3426
self.fail("Did not find my_buff")
3427
self.assertEqual(orig_buff, new_buff)
3430
new_param = mod_traced_new.get_parameter("my_param")
3432
self.fail("Did not find my_param")
3433
self.assertEqual(orig_param, new_param)
3435
x = torch.rand(3, 4)
3436
orig_out = mod_traced(x)
3437
submodules_out = mod_traced_new(x)
3439
self.assertEqual(orig_out, submodules_out)
3441
def test_graph_module_init_buffer_param_copied_dict_init(self):
3442
self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
3444
def test_graph_module_init_buffer_param_copied_mod_init(self):
3445
self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
3447
def test_annotations_with_no_forward_references(self):
3449
def __call__(self, x: torch.Tensor):
3450
return torch.add(x, x)
3452
class M(torch.nn.Module):
3453
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
3456
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3458
def test_annotations_with_forward_references(self):
3460
def __call__(self, x: torch.Tensor):
3461
return torch.add(x, x)
3463
class M(torch.nn.Module):
3464
def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
3467
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3469
def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
3471
def __call__(self, x: torch.Tensor):
3472
return torch.add(x, x)
3474
class M(torch.nn.Module):
3475
def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
3478
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3480
def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
3482
def __call__(self, x: torch.Tensor):
3483
return torch.add(x, x)
3485
class M(torch.nn.Module):
3486
def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
3489
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3491
@unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
3492
"`annotations` is not defined in Python <3.7")
3493
def test_annotation_with_future(self):
3495
import fx.test_future # noqa: F401
3497
del sys.modules["__future__"]
3499
@unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
3500
def test_annotations_empty_tuple(self):
3501
class Foo(torch.nn.Module):
3502
def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
3505
traced = torch.fx.symbolic_trace(Foo())
3512
FileCheck().check("_Tuple[()]") \
3513
.check("typing_Tuple[str,typing_Tuple[()]]") \
3516
scripted = torch.jit.script(traced)
3520
FileCheck().check("Tuple[()]") \
3521
.check("Tuple[str, Tuple[()]]") \
3524
@unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108")
3525
@unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10")
3526
def test_assert(self):
3531
torch.fx.proxy.TracerBase.trace_asserts = True
3532
traced = symbolic_trace(f)
3534
torch.fx.proxy.TracerBase.trace_asserts = False
3536
self.assertEqual(f(2), traced(2))
3537
with self.assertRaises(AssertionError):
3540
def test_pytree(self):
3541
# Used to test that you can use your own placeholder class
3542
class PHTest(PHBase):
3550
for v in x.values():
3554
def f_dict_list_map(x):
3556
for k, v in x.items():
3557
new_dict[k] = [i + 1 for i in v]
3561
return x['a'] + sum(x['z'])
3563
def f_namedtuple_add(x):
3566
pytree.register_pytree_node(
3568
lambda x: ([x.a, x.b], None),
3569
lambda x, _: Foo(x[0], x[1]),
3571
fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
3576
def f_custom_dict(x):
3577
return f_sum_dict(x.a) + x.b
3579
def f_return_custom(x):
3580
return Foo(x.b, x.a)
3583
(f_sum, [PH, PH, PH]),
3585
(f_sum, [PHTest(), PHTest(), PHTest()]),
3586
(f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
3587
(f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
3588
(f_dict_list_map, {5: (PH, PH, PH)}),
3589
(f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
3590
(f_dict_add, {'a': PH, 'z': []}),
3591
(f_custom, Foo(PH, PH)),
3592
(f_custom, Foo(PH, 3)),
3593
(f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
3594
# (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
3595
(f_namedtuple_add, Point(PH, PH)),
3598
def verify_pytree(f, inp):
3599
val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
3600
num_flat_args = len([i == PH for i in pytree.tree_leaves(inp)])
3602
nf = symbolic_trace(f, concrete_args={'x': inp})
3603
self.assertEqual(nf(val), orig_out)
3605
bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3606
bare_fx.graph.set_codegen(CodeGen())
3608
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
3610
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3611
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
3613
nf = symbolic_trace(nf)
3614
self.assertEqual(nf(val), orig_out)
3615
assert "tree_flatten_spec" not in nf.code
3616
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1
3618
nf = symbolic_trace(nf, concrete_args={'x': inp})
3619
self.assertEqual(nf(val), orig_out)
3620
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
3621
assert sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args
3623
pickled = pickle.dumps(nf)
3624
nf = pickle.loads(pickled)
3625
self.assertEqual(nf(val), orig_out)
3627
for f, inp in tests:
3628
verify_pytree(f, inp)
3630
def test_pytree_concrete(self):
3637
inp = {'a': {'a': PH, 'z': PH}, 'b': True}
3638
nf = symbolic_trace(f, concrete_args=inp)
3639
val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
3640
self.assertEqual(nf(**val), f(**val))
3642
nf = symbolic_trace(nf)
3643
self.assertEqual(nf(**val), f(**val))
3645
def test_metadata_on_ph(self):
3646
def f_sum(a: int, b: int) -> int:
3649
# Due to unflattening of dict, the batch argument
3650
# will be split into two separate nodes with the names
3651
# "batch_1" and "batch_2", referring to the keys
3652
# "f1" and "f2" respectively in the dict.
3653
def f_dict(a: Dict[str, str]) -> bool:
3654
return a["f1"] == a["f2"]
3656
def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]):
3657
for node in gm.graph.nodes:
3658
if node.op == "placeholder":
3659
self.assertTrue(node.name in arg_names)
3660
self.assertTrue(node.ph_key in metadata)
3665
concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")}
3667
arg_names=["a_1", "b_1"],
3673
concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}}
3675
arg_names=["a_1", "a_2"],
3676
metadata=["f1", "f2"]
3679
# Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag)
3680
class TaggingTracer(Tracer):
3681
def create_node(self, kind : str, target : Union[str, Callable],
3682
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
3683
type_expr : Optional[Any] = None) -> Node:
3684
n = super().create_node(kind, target, args, kwargs, name)
3688
class PHWithTag(PHBase):
3689
def __init__(self, tag: str):
3694
g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")})
3696
self.assertTrue(hasattr(n, "tag"))
3697
# Ensure that tag is still "foo" and not "bar" (from PHWithTag)
3698
self.assertEqual(n.tag, "foo")
3700
def test_custom_codegen(self):
3701
class ListCodeGen(CodeGen):
3702
def gen_fn_def(self, free_vars, maybe_return_annotation):
3704
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3705
{', '.join(free_vars)} = args_list"""
3708
def additional_globals(self):
3709
return [('List', typing.List)]
3711
def process_inputs(self, *inputs):
3712
assert len(inputs) == 1
3718
nf = symbolic_trace(f)
3719
vals = [torch.randn(3), torch.randn(3)]
3720
self.assertEqual(nf(*vals), f(*vals))
3722
nf.graph.set_codegen(ListCodeGen())
3725
bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3726
bare_fx.graph.set_codegen(CodeGen())
3729
self.assertEqual(nf(vals), f(*vals))
3730
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals))
3732
ts_f = torch.jit.script(nf)
3733
self.assertEqual(nf(vals), ts_f(vals))
3735
def test_custom_codegen_with_transformer(self):
3736
class ListCodeGen(CodeGen):
3737
def gen_fn_def(self, free_vars, maybe_return_annotation):
3739
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3740
{', '.join(free_vars)} = args_list"""
3743
def additional_globals(self):
3744
return [('List', typing.List)]
3746
def process_inputs(self, *inputs):
3747
assert len(inputs) == 1
3753
nf = symbolic_trace(f)
3754
vals = [torch.randn(3), torch.randn(3)]
3755
self.assertEqual(nf(*vals), f(*vals))
3757
nf.graph.set_codegen(ListCodeGen())
3759
self.assertEqual(nf(vals), f(*vals))
3761
transformed_gm = Transformer(nf).transform()
3762
self.assertEqual(nf(vals), transformed_gm(vals))
3764
def test_interpreter_with_codegen(self):
3765
class ListCodeGen(CodeGen):
3766
def gen_fn_def(self, free_vars, maybe_return_annotation):
3768
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
3769
{', '.join(free_vars)} = args_list"""
3772
def additional_globals(self):
3773
return [('List', typing.List)]
3775
def process_inputs(self, *inputs):
3776
assert len(inputs) == 1
3779
def generate_output(self, output_args):
3780
return f'return list({repr(output_args)})'
3782
def process_outputs(self, outputs):
3783
return list(outputs)
3790
nf = symbolic_trace(f)
3791
vals = [torch.randn(3), torch.randn(3)]
3792
nf.graph.set_codegen(ListCodeGen())
3794
self.assertEqual(Interpreter(nf).run(vals), nf(vals))
3796
def test_imul_code_print(self):
3797
graph = torch.fx.Graph()
3798
a = graph.placeholder("a")
3799
b = graph.placeholder("b")
3800
graph.call_function(operator.imul, (a, b), {})
3802
gm = torch.fx.GraphModule({}, graph)
3804
self.assertEqual(gm(2, 3), 6)
3805
self.assertIn("a *= b", gm.code)
3807
def test_deepcopy_tracer(self):
3809
return (x + y).relu().sin()
3812
tracer_before = copy.deepcopy(tracer)
3814
tracer_after = copy.deepcopy(tracer)
3816
self.assertEqual(str(tracer.graph), str(tracer_after.graph))
3817
self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph))
3819
def test_deepcopy_graphmodule(self):
3820
m = symbolic_trace(SimpleTest())
3821
m.meta['hello'] = 'world'
3822
copy_m = copy.deepcopy(m)
3823
self.assertEqual(copy_m.meta['hello'], 'world')
3825
def test_deepcopy_no_recursion(self):
3826
m = symbolic_trace(SimpleTest())
3827
m.meta['hello'] = m # circular reference
3828
copy_m = copy.deepcopy(m) # finishes
3829
self.assertEqual(id(copy_m), id(copy_m.meta['hello']))
3831
def test_enum(self):
3832
from enum import Enum
3838
def leaf_fn(arr, enum_val):
3840
arr.append(enum_val)
3841
return arr[-1].value
3844
# Pass the enum as argument.
3845
return leaf_fn(x, Foo.A)
3847
traced = torch.fx.symbolic_trace(foo)
3848
self.assertEqual(foo([]), traced([]))
3850
def test_insert_arg(self):
3851
m = symbolic_trace(SimpleTest())
3852
m.register_buffer("buf", torch.tensor(0))
3853
output_node = next(iter(reversed(m.graph.nodes)))
3854
with m.graph.inserting_before(output_node):
3855
a = m.graph.get_attr("buf")
3856
r = len(output_node.args)
3857
output_node.insert_arg(0, a)
3858
self.assertEqual(len(output_node.args), r + 1)
3859
self.assertEqual(len(a.users), 1)
3860
self.assertIs(output_node.args[0], a)
3861
self.assertIs(next(iter(a.users.keys())), output_node)
3862
output_node.insert_arg(2, a)
3863
self.assertEqual(len(output_node.args), r + 2)
3864
self.assertEqual(len(a.users), 1)
3865
self.assertIs(output_node.args[2], a)
3866
self.assertIs(next(iter(a.users.keys())), output_node)
3871
def run_getitem_target():
3872
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
3873
_wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
3875
TestFX().getitem_inner()
3877
_wrapped_methods_to_patch.pop()
3880
class TestOperatorSignatures(JitTestCase):
3882
# Checking for mutable operations whil tracing is feature flagged
3883
# Enable it in testing but not by default
3884
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3885
torch.fx.proxy.TracerBase.check_mutable_operations = True
3888
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3891
@ops(op_db, allowed_dtypes=(torch.float,))
3892
def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
3893
if not isinstance(op.op, types.BuiltinFunctionType):
3894
raise unittest.SkipTest("This path doesn't work on Python functions")
3895
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
3896
schemas = get_signature_for_torch_op(op.op)
3898
raise RuntimeError('No Schemas Returned')
3899
for sample_input in sample_inputs_itr:
3900
# Iterate through overloads until we hit a match. If we exit this
3901
# loop via `else`, we haven't found a match
3902
for schema in schemas:
3904
bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
3905
bound_args.apply_defaults()
3906
op(*bound_args.args, **bound_args.kwargs)
3908
except TypeError as e:
3911
raise RuntimeError(f'Did not match any schemas for op {op.name}!')
3914
class TestFXAPIBackwardCompatibility(JitTestCase):
3919
# Checking for mutable operations whil tracing is feature flagged
3920
# Enable it in testing but not by default
3921
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3922
torch.fx.proxy.TracerBase.check_mutable_operations = True
3926
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3929
def _fn_to_stable_annotation_str(self, obj):
3931
Unfortunately we have to serialize function signatures manually since
3932
serialization for `inspect.Signature` objects is not stable across
3935
fn_name = torch.typename(obj)
3937
signature = inspect.signature(obj)
3939
sig_str = f'{fn_name}{signature}'
3942
for k, v in signature.parameters.items():
3943
maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
3944
if v.annotation is not inspect.Signature.empty else ''
3946
def default_val_str(val):
3947
if isinstance(val, (tuple, list)):
3948
str_pieces = ['(' if isinstance(val, tuple) else '[']
3949
str_pieces.append(', '.join(default_val_str(v) for v in val))
3950
if isinstance(val, tuple) and len(str_pieces) == 2:
3951
str_pieces.append(',')
3952
str_pieces.append(')' if isinstance(val, tuple) else ']')
3953
return ''.join(str_pieces)
3955
# Need to fix up some default value strings.
3956
# First case: modules. Default module `repr` contains the FS path of the module.
3958
if isinstance(val, types.ModuleType):
3959
return f'<module {val.__name__}>'
3961
# Second case: callables. Callables (such as lambdas) encode their address in
3962
# their string repr. Don't do that
3964
return f'<function {val.__name__}>'
3968
if v.default is not inspect.Signature.empty:
3969
default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
3970
maybe_default = f' = {default_val_str}'
3974
if v.kind == inspect.Parameter.VAR_POSITIONAL:
3976
elif v.kind == inspect.Parameter.VAR_KEYWORD:
3978
arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
3980
return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
3981
if signature.return_annotation is not inspect.Signature.empty else ''
3983
return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
3985
def _annotation_type_to_stable_str(self, t, sig_str):
3986
if t is inspect.Signature.empty:
3990
if isinstance(t, str):
3992
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
3993
return t.__forward_arg__
3994
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
3995
return t.__forward_arg__
3997
trivial_mappings = {
4002
torch.dtype: 'torch.dtype',
4003
torch.Tensor: 'torch.Tensor',
4004
torch.device: 'torch.device',
4005
torch.memory_format: 'torch.memory_format',
4007
torch.nn.Module: 'torch.nn.modules.module.Module',
4008
torch.fx.Graph : 'torch.fx.graph.Graph',
4009
torch.fx.Node : 'torch.fx.node.Node',
4010
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
4011
torch.fx.node.Target : 'torch.fx.node.Target',
4012
torch.fx.node.Argument : 'torch.fx.node.Argument',
4013
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
4014
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
4015
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
4018
type(None): 'NoneType',
4020
typing.Iterator: 'Iterator',
4023
mapping = trivial_mappings.get(t, None)
4027
# Handle types with contained types
4028
contained = getattr(t, '__args__', None) or []
4030
# Callables contain a bare List for arguments
4031
contained = t if isinstance(t, list) else contained
4033
# Python 3.8 puts type vars into __args__ for unbound types such as Dict
4034
if all(isinstance(ct, typing.TypeVar) for ct in contained):
4037
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
4038
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
4041
origin = getattr(t, '__origin__', None)
4043
# Unbound types don't have `__origin__` in some Python versions, so fix that up here.
4044
origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
4046
if origin in {tuple, typing.Tuple}:
4047
return f'Tuple{contained_type_str}'
4048
if origin in {typing.Union}:
4049
# Annoying hack to detect Optional
4050
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
4051
not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
4052
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
4053
return f'Union{contained_type_str}'
4054
if origin in {dict, typing.Dict}:
4055
return f'Dict{contained_type_str}'
4056
if origin in {list, typing.List}:
4057
return f'List{contained_type_str}'
4058
if origin in {type, typing.Type}:
4059
return f'Type{contained_type_str}'
4060
if isinstance(t, typing.Callable):
4061
if len(contained) > 0 and contained[0] is not Ellipsis:
4062
return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
4064
return f'Callable{contained_type_str}'
4066
raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
4067
f'Please add support for this type and confirm with the '
4068
f'FX team that your signature change is valid.')
4071
def test_function_back_compat(self):
4073
Test backward compatibility for function signatures with
4074
@compatibility(is_backward_compatible=True). Currently this checks for
4075
exact signature matches, which may lead to false positives. If this
4076
becomes too annoying, we can refine this check to actually parse out
4077
the saved schema strings and check if the change is truly backward-
4082
for obj in _BACK_COMPAT_OBJECTS:
4083
if not isinstance(obj, type):
4084
signature_strs.append(self._fn_to_stable_annotation_str(obj))
4086
signature_strs.sort()
4089
self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures')
4090
except AssertionError as e:
4091
msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
4092
f"as backwards-compatible has experienced a signature change. See the " \
4093
f"above exception context for more information. If this change was " \
4094
f"unintended, please revert it. If it was intended, check with the FX " \
4095
f"team to ensure that the proper deprecation protocols have been followed " \
4096
f"and subsequently --accept the change."
4097
raise AssertionError(msg) # noqa: TRY200
4099
def test_class_member_back_compat(self):
4101
Test backward compatibility for members of classes with
4102
@compatibility(is_backward_compatible=True). Currently this checks for
4103
exact matches on the publicly visible members of the class.
4105
class_method_strs = []
4107
for obj in _BACK_COMPAT_OBJECTS:
4108
if isinstance(obj, type):
4109
public_members = [name for name in obj.__dict__ if not name.startswith('_')]
4110
class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
4112
class_method_strs.sort()
4115
self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
4116
except AssertionError as e:
4117
msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
4118
f"as backwards-compatible has experienced change in its public members. See the " \
4119
f"above exception context for more information. If this change was " \
4120
f"unintended, please revert it. If it was intended, check with the FX " \
4121
f"team to ensure that the proper deprecation protocols have been followed " \
4122
f"and subsequently --accept the change."
4123
raise AssertionError(msg) from e
4125
def test_public_api_surface(self):
4126
non_back_compat_objects = {}
4128
def check_symbols_have_bc_designation(m, prefix):
4129
if not m.__name__.startswith('torch.fx'):
4131
if m.__name__.startswith('torch.fx.experimental'):
4133
for k, v in m.__dict__.items():
4136
if k.startswith('_'):
4138
if isinstance(v, types.ModuleType):
4139
check_symbols_have_bc_designation(v, prefix + [k])
4140
elif isinstance(v, (type, types.FunctionType)):
4141
if v not in _MARKED_WITH_COMPATIBILITY:
4142
non_back_compat_objects.setdefault(v)
4144
check_symbols_have_bc_designation(torch.fx, ['torch', 'fx'])
4145
check_symbols_have_bc_designation(torch.fx.passes, ['torch', 'fx', 'passes'])
4147
non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
4148
# Only want objects in torch.fx
4149
non_back_compat_strs = [
4150
s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
4151
# Only want objects in public namespaces
4152
non_back_compat_strs = [
4153
s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
4154
non_back_compat_strs.sort()
4156
if len(non_back_compat_strs) != 0:
4157
raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
4158
f"backwards-compatibility classification! Please decorate these "
4159
f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
4162
def test_adding_side_effect_function(self):
4163
class TestModule(torch.nn.Module):
4164
def forward(self, x):
4168
gm = torch.fx.symbolic_trace(TestModule())
4169
self.assertEqual(len(gm.graph.nodes), 3)
4170
gm.graph.eliminate_dead_code()
4172
self.assertEqual(len(gm.graph.nodes), 3)
4174
for node in gm.graph.nodes:
4175
if node.op == 'call_function' and node.target == side_effect_func:
4177
self.assertTrue(found)
4179
def test_preserve_unused_attr_after_unpickle(self):
4180
gm = torch.fx.symbolic_trace(Add())
4181
gm.add_submodule("foo", Add())
4182
gm.register_buffer("dummy_buffer", torch.empty(1))
4183
gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1)))
4187
reload_gm = torch.load(b)
4188
self.assertTrue(hasattr(reload_gm, "foo"))
4189
self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
4190
self.assertTrue(hasattr(reload_gm, "dummy_parameter"))
4192
# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454
4194
sys.version_info >= (3, 12), "Failing on python 3.12+"
4196
class TestFunctionalTracing(JitTestCase):
4199
# Checking for mutable operations whil tracing is feature flagged
4200
# Enable it in testing but not by default
4201
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4202
torch.fx.proxy.TracerBase.check_mutable_operations = True
4206
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4208
IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
4209
"has_torch_function_variadic", "handle_torch_function",
4211
TO_PATCH = {"has_torch_function": None,
4212
"has_torch_function_unary": None,
4213
"has_torch_function_variadic": None}
4215
BUILT_IN_FUNC = (AssertionError, "")
4216
PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
4217
PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4218
LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
4219
ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
4220
CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
4221
INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
4222
MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
4224
UNTRACEABLE_FUNCTIONALS = {
4225
"adaptive_avg_pool1d": BUILT_IN_FUNC,
4226
"avg_pool1d": BUILT_IN_FUNC,
4227
"avg_pool2d": BUILT_IN_FUNC,
4228
"avg_pool3d": BUILT_IN_FUNC,
4229
"bilinear": BUILT_IN_FUNC,
4230
"celu_": BUILT_IN_FUNC,
4231
"channel_shuffle": BUILT_IN_FUNC,
4232
"native_channel_shuffle": BUILT_IN_FUNC,
4233
"conv1d": BUILT_IN_FUNC,
4234
"conv2d": BUILT_IN_FUNC,
4235
"conv3d": BUILT_IN_FUNC,
4236
"conv_tbc": BUILT_IN_FUNC,
4237
"conv_transpose1d": BUILT_IN_FUNC,
4238
"conv_transpose2d": BUILT_IN_FUNC,
4239
"conv_transpose3d": BUILT_IN_FUNC,
4240
"cosine_similarity": BUILT_IN_FUNC,
4241
"elu_": BUILT_IN_FUNC,
4242
"gelu": BUILT_IN_FUNC,
4243
"hardshrink": BUILT_IN_FUNC,
4244
"hardtanh_": BUILT_IN_FUNC,
4245
"leaky_relu_": BUILT_IN_FUNC,
4246
"linear": BUILT_IN_FUNC,
4247
"logsigmoid": BUILT_IN_FUNC,
4248
"one_hot": BUILT_IN_FUNC,
4249
"pad": ARG_TYPE_MISMATCH,
4250
"pairwise_distance": BUILT_IN_FUNC,
4251
"pdist": BUILT_IN_FUNC,
4252
"pixel_shuffle": BUILT_IN_FUNC,
4253
"pixel_unshuffle": BUILT_IN_FUNC,
4254
"prelu": BUILT_IN_FUNC,
4255
"relu_": BUILT_IN_FUNC,
4256
"rrelu_": BUILT_IN_FUNC,
4257
"selu_": BUILT_IN_FUNC,
4258
"scaled_dot_product_attention": BUILT_IN_FUNC,
4259
"softplus": BUILT_IN_FUNC,
4260
"softshrink": BUILT_IN_FUNC,
4261
"threshold_": BUILT_IN_FUNC,
4263
"adaptive_avg_pool2d": LEN_ERROR,
4264
"adaptive_avg_pool3d": LEN_ERROR,
4265
"adaptive_max_pool2d_with_indices": LEN_ERROR,
4266
"adaptive_max_pool3d_with_indices": LEN_ERROR,
4267
"instance_norm": CONTROL_FLOW,
4269
"adaptive_max_pool1d": PROXY_ITERABLE,
4270
"adaptive_max_pool2d": PROXY_ITERABLE,
4271
"adaptive_max_pool3d": PROXY_ITERABLE,
4272
"fractional_max_pool2d": PROXY_ITERABLE,
4273
"fractional_max_pool3d": PROXY_ITERABLE,
4274
"max_pool1d": PROXY_ITERABLE,
4275
"max_pool2d": PROXY_ITERABLE,
4276
"max_pool3d": PROXY_ITERABLE,
4278
"lp_pool2d": PROXY_ITERATED,
4279
"lp_pool3d": PROXY_ITERATED,
4280
"max_unpool1d": PROXY_ITERATED,
4281
"max_unpool2d": PROXY_ITERATED,
4282
"max_unpool3d": PROXY_ITERATED,
4283
"fold": PROXY_ITERATED,
4284
"unfold": PROXY_ITERATED,
4286
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4287
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4288
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4289
"layer_norm": ARG_TYPE_MISMATCH,
4290
"lp_pool1d": ARG_TYPE_MISMATCH,
4292
"affine_grid": CONTROL_FLOW,
4293
"alpha_dropout": CONTROL_FLOW,
4294
"batch_norm": CONTROL_FLOW,
4295
"binary_cross_entropy": CONTROL_FLOW,
4296
"binary_cross_entropy_with_logits": CONTROL_FLOW,
4297
"celu": CONTROL_FLOW,
4298
"cosine_embedding_loss": CONTROL_FLOW,
4299
"cross_entropy": CONTROL_FLOW,
4300
"ctc_loss": CONTROL_FLOW,
4301
"dropout": CONTROL_FLOW,
4302
"dropout1d": CONTROL_FLOW,
4303
"dropout2d": CONTROL_FLOW,
4304
"dropout3d": CONTROL_FLOW,
4305
"elu": CONTROL_FLOW,
4306
"embedding": CONTROL_FLOW,
4307
"embedding_bag": CONTROL_FLOW,
4308
"feature_alpha_dropout": CONTROL_FLOW,
4309
"gaussian_nll_loss": CONTROL_FLOW,
4310
"glu": CONTROL_FLOW,
4311
"grid_sample": CONTROL_FLOW,
4312
"group_norm": CONTROL_FLOW,
4313
"gumbel_softmax": CONTROL_FLOW,
4314
"hardsigmoid": CONTROL_FLOW,
4315
"hardswish": CONTROL_FLOW,
4316
"hardtanh": CONTROL_FLOW,
4317
"hinge_embedding_loss": CONTROL_FLOW,
4318
"huber_loss": CONTROL_FLOW,
4319
"interpolate": CONTROL_FLOW,
4320
"kl_div": CONTROL_FLOW,
4321
"l1_loss": CONTROL_FLOW,
4322
"leaky_relu": CONTROL_FLOW,
4323
"local_response_norm": CONTROL_FLOW,
4324
"margin_ranking_loss": CONTROL_FLOW,
4325
"max_pool1d_with_indices": ARG_TYPE_MISMATCH,
4326
"max_pool2d_with_indices": ARG_TYPE_MISMATCH,
4327
"max_pool3d_with_indices": ARG_TYPE_MISMATCH,
4328
"mse_loss": CONTROL_FLOW,
4329
"multi_head_attention_forward": CONTROL_FLOW,
4330
"multi_margin_loss": CONTROL_FLOW,
4331
"multilabel_margin_loss": CONTROL_FLOW,
4332
"multilabel_soft_margin_loss": CONTROL_FLOW,
4333
"nll_loss": CONTROL_FLOW,
4334
"poisson_nll_loss": CONTROL_FLOW,
4335
"relu": CONTROL_FLOW,
4336
"relu6": CONTROL_FLOW,
4337
"rrelu": CONTROL_FLOW,
4338
"selu": CONTROL_FLOW,
4339
"silu": CONTROL_FLOW,
4340
"mish": CONTROL_FLOW,
4341
"smooth_l1_loss": CONTROL_FLOW,
4342
"soft_margin_loss": CONTROL_FLOW,
4343
"threshold": CONTROL_FLOW,
4344
"triplet_margin_loss": CONTROL_FLOW,
4345
"triplet_margin_with_distance_loss": CONTROL_FLOW,
4346
"upsample": CONTROL_FLOW,
4348
"upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
4349
"upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
4352
# List of nn.functionals with Tensor inputs but not with type annotation
4353
FUNCTIONALS_WITHOUT_ANNOTATION = (
4354
"adaptive_max_pool1d",
4355
"adaptive_max_pool2d",
4356
"adaptive_max_pool3d",
4357
"fractional_max_pool2d",
4358
"fractional_max_pool3d",
4362
"gaussian_nll_loss",
4364
"upsample_bilinear",
4368
# Inconsistent behavior between Python 3.8 and other Python versions:
4369
# - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
4370
# - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
4371
# internal exception above
4372
# Use the following map to override the expected exception for Python 3.8
4373
UNTRACEABLE_FUNCTIONALS_PY38 = {
4374
"adaptive_max_pool1d": PROXY_ITERATED,
4375
"adaptive_max_pool2d": PROXY_ITERATED,
4376
"adaptive_max_pool3d": PROXY_ITERATED,
4377
"fractional_max_pool2d": PROXY_ITERATED,
4378
"fractional_max_pool3d": PROXY_ITERATED,
4379
"max_pool1d": PROXY_ITERATED,
4380
"max_pool2d": PROXY_ITERATED,
4381
"max_pool3d": PROXY_ITERATED,
4383
"group_norm": CONTROL_FLOW
4387
def _get_functional(cls):
4388
functional_list = []
4389
for f in dir(torch.nn.functional):
4392
# Ignore internal functions
4393
if f.startswith('_'):
4395
# Ignore supporting functions
4396
if f in cls.IGNORE_FUNCS:
4398
fn = getattr(torch.nn.functional, f)
4399
# Ignore non-callable object like modules
4400
if not isinstance(fn, Callable):
4402
if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
4404
sig = inspect.signature(fn)
4405
has_tensor_arg = False
4406
for param in sig.parameters.values():
4407
if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
4408
has_tensor_arg = True
4409
if not has_tensor_arg:
4411
# No signature or Object is not supported
4414
functional_list.append((f, fn))
4415
return functional_list
4418
def generate_test_func(cls, func_name, fn):
4420
def functional_test(self):
4421
if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
4422
sys.version_info >= (3, 8) and sys.version_info < (3, 12):
4423
exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
4424
with self.assertRaisesRegex(exc, err):
4426
elif func_name in self.UNTRACEABLE_FUNCTIONALS:
4427
exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
4428
with self.assertRaisesRegex(exc, err):
4432
return functional_test
4435
def generate_tests(cls):
4436
functional_list = cls._get_functional()
4437
for func_name, fn in functional_list:
4438
test_name = "test_nn_functional_" + func_name
4439
functional_test = cls.generate_test_func(func_name, fn)
4440
setattr(cls, test_name, functional_test)
4443
def setUpClass(cls):
4445
def no(*args, **kwargs):
4448
for name in cls.TO_PATCH.keys():
4449
cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
4450
setattr(torch.nn.functional, name, no)
4453
def tearDownClass(cls):
4454
for name in cls.TO_PATCH.keys():
4455
setattr(torch.nn.functional, name, cls.TO_PATCH[name])
4457
TestFunctionalTracing.generate_tests()
4460
instantiate_device_type_tests(TestOperatorSignatures, globals())
4462
@skipIfTorchDynamo("too slow")
4464
class TestVisionTracing(JitTestCase):
4466
# Checking for mutable operations while tracing is feature flagged
4467
# Enable it in testing but not by default
4468
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4469
torch.fx.proxy.TracerBase.check_mutable_operations = True
4472
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4474
PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4475
INCONSISTENT_TYPE = (
4477
r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
4480
UNTRACEABLE_MODELS = {
4481
"fasterrcnn_resnet50_fpn": PROXY_ITERATED,
4482
"fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4483
"fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
4484
"fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
4485
"maskrcnn_resnet50_fpn": PROXY_ITERATED,
4486
"maskrcnn_resnet50_fpn_v2": PROXY_ITERATED,
4487
"keypointrcnn_resnet50_fpn": PROXY_ITERATED,
4488
"retinanet_resnet50_fpn": PROXY_ITERATED,
4489
"retinanet_resnet50_fpn_v2": PROXY_ITERATED,
4490
"ssd300_vgg16": PROXY_ITERATED,
4491
"fcos_resnet50_fpn": PROXY_ITERATED,
4492
"ssdlite320_mobilenet_v3_large": PROXY_ITERATED,
4494
UNSCRIPTABLE_MODELS = {
4495
"googlenet": INCONSISTENT_TYPE,
4496
"inception_v3": INCONSISTENT_TYPE,
4499
output_transform = {
4500
"fcn_resnet50": lambda x: x["out"],
4501
"fcn_resnet101": lambda x: x["out"],
4502
"deeplabv3_resnet50": lambda x: x["out"],
4503
"deeplabv3_resnet101": lambda x: x["out"],
4504
"deeplabv3_mobilenet_v3_large": lambda x: x["out"],
4505
"lraspp_mobilenet_v3_large": lambda x: x["out"],
4506
"fasterrcnn_resnet50_fpn": lambda x: x[1],
4507
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4508
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
4509
"maskrcnn_resnet50_fpn": lambda x: x[1],
4510
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4511
"retinanet_resnet50_fpn": lambda x: x[1],
4515
def generate_test_fn(cls, name, x, kwargs):
4517
model = torchvision_models.get_model(name, **kwargs)
4518
model = model.eval()
4519
if name in self.UNTRACEABLE_MODELS:
4520
err, exc = self.UNTRACEABLE_MODELS[name]
4521
with self.assertRaisesRegex(err, exc):
4522
graph = symbolic_trace(model)
4524
out_transform = self.output_transform.get(name, lambda x: x)
4525
graph : torch.fx.GraphModule = symbolic_trace(model)
4526
a = out_transform(model(x))
4527
b = out_transform(graph(x))
4528
self.assertEqual(a, b)
4530
if name in self.UNSCRIPTABLE_MODELS:
4531
err, exc = self.UNSCRIPTABLE_MODELS[name]
4532
with self.assertRaisesRegex(err, exc):
4533
script = torch.jit.script(graph)
4535
script = torch.jit.script(graph)
4536
c = out_transform(script(x))
4537
self.assertEqual(a, c)
4542
def generate_classification_tests(cls):
4543
for k in torchvision_models.list_models(module=torchvision_models):
4544
test_name = 'test_torchvision_models_' + k
4545
x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
4546
kwargs = dict(num_classes=50)
4547
model_test = cls.generate_test_fn(k, x, kwargs)
4548
setattr(cls, test_name, model_test)
4551
def generate_segmentation_tests(cls):
4552
for k in torchvision_models.list_models(module=torchvision_models.segmentation):
4553
test_name = 'test_torchvision_models_segmentation_' + k
4554
x = torch.rand(1, 3, 32, 32)
4555
kwargs = dict(num_classes=10, pretrained_backbone=False)
4556
model_test = cls.generate_test_fn(k, x, kwargs)
4557
setattr(cls, test_name, model_test)
4560
def generate_detection_tests(cls):
4561
for k in torchvision_models.list_models(module=torchvision_models.detection):
4562
test_name = 'test_torchvision_models_detection_' + k
4563
x = [torch.rand(3, 300, 300)]
4564
kwargs = dict(num_classes=10, pretrained_backbone=False)
4565
model_test = cls.generate_test_fn(k, x, kwargs)
4566
setattr(cls, test_name, model_test)
4569
def generate_video_tests(cls):
4570
for k in torchvision_models.list_models(module=torchvision_models.video):
4571
test_name = 'test_torchvision_models_video_' + k
4573
torch.rand(1, 3, 4, 112, 112)
4574
if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"}
4575
else torch.rand(1, 3, 16, 224, 224)
4577
kwargs = dict(num_classes=50)
4578
model_test = cls.generate_test_fn(k, x, kwargs)
4579
setattr(cls, test_name, model_test)
4582
def generate_tests(cls):
4583
cls.generate_classification_tests()
4584
cls.generate_detection_tests()
4585
cls.generate_segmentation_tests()
4586
cls.generate_video_tests()
4589
TestVisionTracing.generate_tests()
4591
if __name__ == '__main__':