1
# Owner(s): ["oncall: fx"]
6
from unittest.mock import patch
11
from torch.fx._lazy_graph_module import (
14
_use_lazy_graph_module,
16
from torch.fx.experimental.proxy_tensor import make_fx
17
from torch.package import PackageExporter, PackageImporter
18
from torch.testing._internal.common_utils import run_tests, TestCase
21
class TestLazyGraphModule(TestCase):
26
cls.exit_stack = contextlib.ExitStack()
27
cls.exit_stack.enter_context(_use_lazy_graph_module(True))
30
def tearDownClass(cls):
31
cls.exit_stack.close()
34
def replace_sin_with_cos(gm):
35
for n in gm.graph.nodes:
39
def test_replace_sin_with_cos(self):
44
gm = fx.symbolic_trace(f)
45
self.replace_sin_with_cos(gm)
51
self.assertTrue(torch.allclose(expected, actual))
52
code = gm.print_readable(False)
53
self.assertTrue("cos()" in code)
54
self.assertTrue(isinstance(gm, _LazyGraphModule))
56
def test_call_forward_directly(self):
61
gm = fx.symbolic_trace(f)
62
self.assertTrue(isinstance(gm, _LazyGraphModule))
63
self.replace_sin_with_cos(gm)
66
actual = gm.forward(x)
68
self.assertTrue(torch.allclose(expected, actual))
70
def test_needs_recompile(self):
72
Make sure needs_recompile() return the corrent state.
78
gm = fx.symbolic_trace(f)
79
self.assertTrue(isinstance(gm, _LazyGraphModule))
80
self.assertTrue(gm._needs_recompile())
82
self.assertFalse(gm._needs_recompile())
84
def test_multi_recompile(self):
86
Cover the case that multiple recompilation happens.
92
gm = fx.symbolic_trace(f)
93
self.assertTrue(isinstance(gm, _LazyGraphModule))
94
self.assertTrue(gm._needs_recompile())
96
# trigger the first recompilation
97
self.assertTrue(torch.allclose(x.sin(), gm(x)))
98
self.assertFalse(gm._needs_recompile())
100
self.replace_sin_with_cos(gm)
101
self.assertFalse(gm._needs_recompile())
103
self.assertTrue(gm._needs_recompile())
104
# trigger the second recompilation
105
self.assertTrue(torch.allclose(x.cos(), gm(x)))
106
self.assertFalse(gm._needs_recompile())
108
def test_accessing_code_cause_recompiling(self):
110
Make sure we recompile if we have not done that yet when we access the code
111
property of a GraphModule.
117
gm = fx.symbolic_trace(f)
118
self.assertTrue(isinstance(gm, _LazyGraphModule))
119
self.assertTrue(gm._needs_recompile())
120
# should trigger a recompilation
122
self.assertTrue("sin" in code)
123
self.assertFalse(gm._needs_recompile())
125
def test_graph_module_str(self):
129
gm = fx.symbolic_trace(f)
130
self.assertTrue(isinstance(gm, _LazyGraphModule))
131
self.assertTrue("sin" in str(gm))
133
def test_recapture_with_make_fx(self):
137
gm = fx.symbolic_trace(f)
138
self.assertTrue(isinstance(gm, _LazyGraphModule))
139
self.assertTrue(gm._needs_recompile())
140
gm2 = make_fx(gm)(torch.randn(2, 3))
141
self.assertTrue(isinstance(gm2, _LazyGraphModule))
142
self.assertTrue(gm2._needs_recompile())
144
# make_fx will cal foward method of gm. That clears the _needs_recompile()
146
self.assertFalse(gm._needs_recompile())
148
def test_recapture_with_symbolic_trace(self):
152
gm = fx.symbolic_trace(f)
153
self.assertTrue(isinstance(gm, _LazyGraphModule))
154
self.assertTrue(gm._needs_recompile())
155
gm2 = fx.symbolic_trace(gm)
157
# the lazy recompilcation is already realized. We realize the
158
# recompilation in the beginning of symbolic_trace since symbolic_trace can not
159
# handle the tracing of lazy recompilation.
160
self.assertFalse(gm._needs_recompile())
161
self.assertTrue(gm2._needs_recompile())
163
def test_recapture_with_dynamo(self):
167
gm = fx.symbolic_trace(f)
168
self.assertTrue(isinstance(gm, _LazyGraphModule))
169
self.assertTrue(gm._needs_recompile())
170
torch.compile(gm)(torch.rand(2, 3))
172
# dynamo calls gm.forward with eval hook installed. That will trigger
173
# the real recompilation.
174
self.assertFalse(gm._needs_recompile())
176
def test_save_lazy_foward(self):
178
Save the lazy forward method and call it repeatly. Make sure we
179
don't recompile for each such call.
185
orig_gm_recompile = fx.GraphModule.recompile
188
def mock_gm_recompile(self):
189
nonlocal recompile_count
191
return orig_gm_recompile(self)
193
with patch.object(fx.GraphModule, "recompile", mock_gm_recompile):
194
gm = fx.symbolic_trace(f)
195
self.assertTrue(isinstance(gm, _LazyGraphModule))
196
saved_fwd = gm.forward
202
self.assertEqual(recompile_count, 1)
204
def test_pickle(self):
206
Fx graph cache need the ability to pickle GraphModule/_LazyGraphModule.
212
gm = fx.symbolic_trace(f)
213
self.assertTrue(isinstance(gm, _LazyGraphModule))
214
serialized = pickle.dumps(gm)
215
gm2 = pickle.loads(serialized)
216
self.assertTrue(isinstance(gm2, _LazyGraphModule))
217
self.assertTrue("sin" in gm2.code)
219
def test_make_graph_module(self):
220
gm = fx.symbolic_trace(lambda x: x.sin())
221
self.assertTrue(isinstance(gm, _LazyGraphModule))
223
gm1 = _make_graph_module(
224
gm, gm.graph, class_name="MyGraphModule", graph_module_cls=fx.GraphModule
226
self.assertFalse(isinstance(gm1, _LazyGraphModule))
227
self.assertTrue(gm1.__class__.__name__ == "MyGraphModule")
229
gm2 = _make_graph_module(gm, gm.graph)
230
self.assertTrue(isinstance(gm2, _LazyGraphModule))
231
self.assertTrue(gm2.__class__.__name__ == "GraphModule")
233
def test_package_fx_simple(self):
235
Copied from test/package/test_package_fx.py to make sure LazyGraphModule
236
works with torch.package.
239
class SimpleTest(torch.nn.Module):
240
def forward(self, x):
241
return torch.relu(x + 3.0)
244
traced = fx.symbolic_trace(st)
247
with PackageExporter(f) as pe:
248
pe.save_pickle("model", "model.pkl", traced)
251
pi = PackageImporter(f)
252
loaded_traced = pi.load_pickle("model", "model.pkl")
253
input = torch.rand(2, 3)
254
self.assertEqual(loaded_traced(input), traced(input))
256
def test_dynamo_innermost_fn(self):
258
Repro for https://github.com/pytorch/pytorch/issues/121198 .
264
gm = torch.fx.symbolic_trace(f)
265
lazy_gm = torch.fx._lazy_graph_module._LazyGraphModule.from_graphmodule(gm)
267
wrapped_forward = torch._dynamo.disable(gm.forward)
268
got_inner_forward = torch._dynamo.eval_frame.innermost_fn(wrapped_forward)
269
assert hasattr(got_inner_forward, "__self__")
271
wrapped_lazy_forward = torch._dynamo.disable(lazy_gm.forward)
272
got_lazy_inner_forward = torch._dynamo.eval_frame.innermost_fn(
275
assert hasattr(got_lazy_inner_forward, "__self__")
278
if __name__ == "__main__":