pytorch

Форк
0
/
test_lazy_graph_module.py 
279 строк · 8.4 Кб
1
# Owner(s): ["oncall: fx"]
2

3
import contextlib
4
import pickle
5
from io import BytesIO
6
from unittest.mock import patch
7

8
import torch
9
import torch._export
10
from torch import fx
11
from torch.fx._lazy_graph_module import (
12
    _LazyGraphModule,
13
    _make_graph_module,
14
    _use_lazy_graph_module,
15
)
16
from torch.fx.experimental.proxy_tensor import make_fx
17
from torch.package import PackageExporter, PackageImporter
18
from torch.testing._internal.common_utils import run_tests, TestCase
19

20

21
class TestLazyGraphModule(TestCase):
22
    exit_stack = None
23

24
    @classmethod
25
    def setUpClass(cls):
26
        cls.exit_stack = contextlib.ExitStack()
27
        cls.exit_stack.enter_context(_use_lazy_graph_module(True))
28

29
    @classmethod
30
    def tearDownClass(cls):
31
        cls.exit_stack.close()
32

33
    @staticmethod
34
    def replace_sin_with_cos(gm):
35
        for n in gm.graph.nodes:
36
            if n.target == "sin":
37
                n.target = "cos"
38

39
    def test_replace_sin_with_cos(self):
40
        def f(x):
41
            return x.sin()
42

43
        x = torch.randn(2, 3)
44
        gm = fx.symbolic_trace(f)
45
        self.replace_sin_with_cos(gm)
46

47
        gm.recompile()
48
        expected = x.cos()
49
        actual = gm(x)
50

51
        self.assertTrue(torch.allclose(expected, actual))
52
        code = gm.print_readable(False)
53
        self.assertTrue("cos()" in code)
54
        self.assertTrue(isinstance(gm, _LazyGraphModule))
55

56
    def test_call_forward_directly(self):
57
        def f(x):
58
            return x.sin()
59

60
        x = torch.randn(2, 3)
61
        gm = fx.symbolic_trace(f)
62
        self.assertTrue(isinstance(gm, _LazyGraphModule))
63
        self.replace_sin_with_cos(gm)
64
        gm.recompile()
65
        expected = x.cos()
66
        actual = gm.forward(x)
67

68
        self.assertTrue(torch.allclose(expected, actual))
69

70
    def test_needs_recompile(self):
71
        """
72
        Make sure needs_recompile() return the corrent state.
73
        """
74

75
        def f(x):
76
            return x.sin()
77

78
        gm = fx.symbolic_trace(f)
79
        self.assertTrue(isinstance(gm, _LazyGraphModule))
80
        self.assertTrue(gm._needs_recompile())
81
        gm(torch.randn(2, 3))
82
        self.assertFalse(gm._needs_recompile())
83

84
    def test_multi_recompile(self):
85
        """
86
        Cover the case that multiple recompilation happens.
87
        """
88

89
        def f(x):
90
            return x.sin()
91

92
        gm = fx.symbolic_trace(f)
93
        self.assertTrue(isinstance(gm, _LazyGraphModule))
94
        self.assertTrue(gm._needs_recompile())
95
        x = torch.randn(2, 3)
96
        # trigger the first recompilation
97
        self.assertTrue(torch.allclose(x.sin(), gm(x)))
98
        self.assertFalse(gm._needs_recompile())
99

100
        self.replace_sin_with_cos(gm)
101
        self.assertFalse(gm._needs_recompile())
102
        gm.recompile()
103
        self.assertTrue(gm._needs_recompile())
104
        # trigger the second recompilation
105
        self.assertTrue(torch.allclose(x.cos(), gm(x)))
106
        self.assertFalse(gm._needs_recompile())
107

108
    def test_accessing_code_cause_recompiling(self):
109
        """
110
        Make sure we recompile if we have not done that yet when we access the code
111
        property of a GraphModule.
112
        """
113

114
        def f(x):
115
            return x.sin()
116

117
        gm = fx.symbolic_trace(f)
118
        self.assertTrue(isinstance(gm, _LazyGraphModule))
119
        self.assertTrue(gm._needs_recompile())
120
        # should trigger a recompilation
121
        code = gm.code
122
        self.assertTrue("sin" in code)
123
        self.assertFalse(gm._needs_recompile())
124

125
    def test_graph_module_str(self):
126
        def f(x):
127
            return x.sin()
128

129
        gm = fx.symbolic_trace(f)
130
        self.assertTrue(isinstance(gm, _LazyGraphModule))
131
        self.assertTrue("sin" in str(gm))
132

133
    def test_recapture_with_make_fx(self):
134
        def f(x):
135
            return x.sin()
136

137
        gm = fx.symbolic_trace(f)
138
        self.assertTrue(isinstance(gm, _LazyGraphModule))
139
        self.assertTrue(gm._needs_recompile())
140
        gm2 = make_fx(gm)(torch.randn(2, 3))
141
        self.assertTrue(isinstance(gm2, _LazyGraphModule))
142
        self.assertTrue(gm2._needs_recompile())
143

144
        # make_fx will cal foward method of gm. That clears the _needs_recompile()
145
        # flag.
146
        self.assertFalse(gm._needs_recompile())
147

148
    def test_recapture_with_symbolic_trace(self):
149
        def f(x):
150
            return x.sin()
151

152
        gm = fx.symbolic_trace(f)
153
        self.assertTrue(isinstance(gm, _LazyGraphModule))
154
        self.assertTrue(gm._needs_recompile())
155
        gm2 = fx.symbolic_trace(gm)
156

157
        # the lazy recompilcation is already realized. We realize the
158
        # recompilation in the beginning of symbolic_trace since symbolic_trace can not
159
        # handle the tracing of lazy recompilation.
160
        self.assertFalse(gm._needs_recompile())
161
        self.assertTrue(gm2._needs_recompile())
162

163
    def test_recapture_with_dynamo(self):
164
        def f(x):
165
            return x.sin()
166

167
        gm = fx.symbolic_trace(f)
168
        self.assertTrue(isinstance(gm, _LazyGraphModule))
169
        self.assertTrue(gm._needs_recompile())
170
        torch.compile(gm)(torch.rand(2, 3))
171

172
        # dynamo calls gm.forward with eval hook installed. That will trigger
173
        # the real recompilation.
174
        self.assertFalse(gm._needs_recompile())
175

176
    def test_save_lazy_foward(self):
177
        """
178
        Save the lazy forward method and call it repeatly. Make sure we
179
        don't recompile for each such call.
180
        """
181

182
        def f(x):
183
            return x.sin()
184

185
        orig_gm_recompile = fx.GraphModule.recompile
186
        recompile_count = 0
187

188
        def mock_gm_recompile(self):
189
            nonlocal recompile_count
190
            recompile_count += 1
191
            return orig_gm_recompile(self)
192

193
        with patch.object(fx.GraphModule, "recompile", mock_gm_recompile):
194
            gm = fx.symbolic_trace(f)
195
            self.assertTrue(isinstance(gm, _LazyGraphModule))
196
            saved_fwd = gm.forward
197

198
            x = torch.rand(2, 3)
199
            for _ in range(10):
200
                saved_fwd(x)
201

202
        self.assertEqual(recompile_count, 1)
203

204
    def test_pickle(self):
205
        """
206
        Fx graph cache need the ability to pickle GraphModule/_LazyGraphModule.
207
        """
208

209
        def f(x):
210
            return x.sin()
211

212
        gm = fx.symbolic_trace(f)
213
        self.assertTrue(isinstance(gm, _LazyGraphModule))
214
        serialized = pickle.dumps(gm)
215
        gm2 = pickle.loads(serialized)
216
        self.assertTrue(isinstance(gm2, _LazyGraphModule))
217
        self.assertTrue("sin" in gm2.code)
218

219
    def test_make_graph_module(self):
220
        gm = fx.symbolic_trace(lambda x: x.sin())
221
        self.assertTrue(isinstance(gm, _LazyGraphModule))
222

223
        gm1 = _make_graph_module(
224
            gm, gm.graph, class_name="MyGraphModule", graph_module_cls=fx.GraphModule
225
        )
226
        self.assertFalse(isinstance(gm1, _LazyGraphModule))
227
        self.assertTrue(gm1.__class__.__name__ == "MyGraphModule")
228

229
        gm2 = _make_graph_module(gm, gm.graph)
230
        self.assertTrue(isinstance(gm2, _LazyGraphModule))
231
        self.assertTrue(gm2.__class__.__name__ == "GraphModule")
232

233
    def test_package_fx_simple(self):
234
        """
235
        Copied from test/package/test_package_fx.py to make sure LazyGraphModule
236
        works with torch.package.
237
        """
238

239
        class SimpleTest(torch.nn.Module):
240
            def forward(self, x):
241
                return torch.relu(x + 3.0)
242

243
        st = SimpleTest()
244
        traced = fx.symbolic_trace(st)
245

246
        f = BytesIO()
247
        with PackageExporter(f) as pe:
248
            pe.save_pickle("model", "model.pkl", traced)
249

250
        f.seek(0)
251
        pi = PackageImporter(f)
252
        loaded_traced = pi.load_pickle("model", "model.pkl")
253
        input = torch.rand(2, 3)
254
        self.assertEqual(loaded_traced(input), traced(input))
255

256
    def test_dynamo_innermost_fn(self):
257
        """
258
        Repro for https://github.com/pytorch/pytorch/issues/121198 .
259
        """
260

261
        def f(x):
262
            return x * 2
263

264
        gm = torch.fx.symbolic_trace(f)
265
        lazy_gm = torch.fx._lazy_graph_module._LazyGraphModule.from_graphmodule(gm)
266

267
        wrapped_forward = torch._dynamo.disable(gm.forward)
268
        got_inner_forward = torch._dynamo.eval_frame.innermost_fn(wrapped_forward)
269
        assert hasattr(got_inner_forward, "__self__")
270

271
        wrapped_lazy_forward = torch._dynamo.disable(lazy_gm.forward)
272
        got_lazy_inner_forward = torch._dynamo.eval_frame.innermost_fn(
273
            wrapped_lazy_forward
274
        )
275
        assert hasattr(got_lazy_inner_forward, "__self__")
276

277

278
if __name__ == "__main__":
279
    run_tests()
280

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.