pytorch

Форк
0
/
test_functionalization_of_rng_ops.py 
347 строк · 11.3 Кб
1
# Owner(s): ["oncall: pt2"]
2
import functools
3
import sys
4
import unittest
5
from unittest.mock import patch
6

7
import torch
8
import torch.utils.checkpoint
9
from functorch.compile import aot_function, min_cut_rematerialization_partition, nop
10

11
from torch.testing._internal.common_device_type import (
12
    dtypes,
13
    instantiate_device_type_tests,
14
)
15

16
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, run_tests, TestCase
17

18
if IS_WINDOWS and IS_CI:
19
    sys.stderr.write("torch.compile not supported on windows")
20
    if __name__ == "__main__":
21
        sys.exit(0)
22
    raise unittest.SkipTest("torch.compile not supported on windows")
23

24

25
def count_philox_rand(gm, args, freq):
26
    assert [node.target for node in gm.graph.nodes].count(
27
        torch.ops.rngprims.philox_rand.default
28
    ) == freq
29
    return gm
30

31

32
class TestFunctionalizationRngOps(TestCase):
33
    @dtypes(torch.float32)
34
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
35
    def test_rand_like(self, dtype, device):
36
        def fn(x):
37
            a = torch.rand_like(x) * x
38
            a = torch.rand_like(x) * a
39
            return a
40

41
        x = torch.rand(10, device=device, dtype=dtype)
42

43
        for seed in range(10):
44
            torch.cuda.manual_seed(seed)
45
            ref = fn(x)
46

47
            torch.cuda.manual_seed(seed)
48
            aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
49
            res = aot_fn(x)
50

51
            self.assertEqual(ref, res)
52

53
    @dtypes(torch.float32)
54
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
55
    def test_rand_like_dynamic(self, dtype, device):
56
        def fn(x):
57
            a = torch.rand_like(x) * x
58
            a = torch.rand_like(x) * a
59
            return a
60

61
        for seed in range(1, 10):
62
            shape = (seed, seed)
63
            x = torch.rand(shape, device=device, dtype=dtype)
64
            torch.cuda.manual_seed(seed)
65
            ref = fn(x)
66

67
            torch.cuda.manual_seed(seed)
68
            opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
69
            res = opt_fn(x)
70

71
            self.assertEqual(ref, res)
72

73
    @dtypes(torch.float32)
74
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
75
    def test_rand_like_dynamic_bwd(self, dtype, device):
76
        def fn(x):
77
            a = torch.rand_like(x) * x
78
            a = torch.rand_like(x) * a
79
            return a
80

81
        for seed in range(1, 10):
82
            shape = (seed, seed)
83
            x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
84
            torch.cuda.manual_seed(seed)
85
            ref = fn(x)
86
            ref.sum().backward()
87

88
            torch.cuda.manual_seed(seed)
89
            opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
90
            res = opt_fn(x)
91
            res.sum().backward()
92

93
            self.assertEqual(ref, res)
94

95
    @dtypes(torch.float32)
96
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
97
    def test_rand(self, dtype, device):
98
        shape = (10,)
99

100
        def fn(x):
101
            a = torch.rand(*shape, device=device, dtype=dtype) * x
102
            a = torch.rand(*shape, device=device, dtype=dtype) * a
103
            return a
104

105
        x = torch.rand(*shape, device=device, dtype=dtype)
106

107
        for seed in range(10):
108
            torch.cuda.manual_seed(seed)
109
            ref = fn(x)
110

111
            torch.cuda.manual_seed(seed)
112
            aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
113
            res = aot_fn(x)
114

115
            self.assertEqual(ref, res)
116

117
    @dtypes(torch.float32)
118
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
119
    def test_autograd_function(self, dtype, device):
120
        shape = (16, 16)
121

122
        class Custom(torch.autograd.Function):
123
            @staticmethod
124
            def forward(ctx, x):
125
                ctx.save_for_backward(x)
126
                a = torch.rand_like(x) * x
127
                a = torch.rand_like(x) * a
128
                return a
129

130
            @staticmethod
131
            def backward(ctx, grad_out):
132
                (x,) = ctx.saved_tensors
133
                return grad_out * torch.rand_like(grad_out) * torch.cos(x)
134

135
        custom = Custom.apply
136

137
        x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
138

139
        x_clone = x.clone().detach().requires_grad_(True)
140

141
        torch.cuda.manual_seed(123)
142
        ref = custom(x)
143
        ref.sum().backward()
144

145
        torch.cuda.manual_seed(123)
146
        fwd_compiler = functools.partial(count_philox_rand, freq=2)
147
        bwd_compiler = functools.partial(count_philox_rand, freq=1)
148
        aot_custom = aot_function(custom, fwd_compiler, bwd_compiler)
149
        res = aot_custom(x_clone)
150
        res.sum().backward()
151

152
        self.assertEqual(ref, res)
153
        self.assertEqual(x.grad, x_clone.grad)
154

155
    @dtypes(torch.float32)
156
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
157
    def test_multiple_subgraphs(self, dtype, device):
158
        # Checks that rng state is maintained when there are multiple aot traced
159
        # graphs.
160
        shape = (16, 16)
161

162
        class CustomOp1(torch.autograd.Function):
163
            @staticmethod
164
            def forward(ctx, x):
165
                ctx.save_for_backward(x)
166
                a = torch.rand_like(x) * x
167
                a = torch.rand_like(x) * a
168
                return a
169

170
            @staticmethod
171
            def backward(ctx, grad_out):
172
                (x,) = ctx.saved_tensors
173
                return grad_out * torch.rand_like(grad_out) * torch.cos(x)
174

175
        class CustomOp2(torch.autograd.Function):
176
            @staticmethod
177
            def forward(ctx, x):
178
                ctx.save_for_backward(x)
179
                a = torch.rand_like(x) * x
180
                return a
181

182
            @staticmethod
183
            def backward(ctx, grad_out):
184
                (x,) = ctx.saved_tensors
185
                return grad_out * torch.rand_like(grad_out) * torch.rand_like(x)
186

187
        custom_op1 = CustomOp1.apply
188
        custom_op2 = CustomOp2.apply
189

190
        def fn(x):
191
            a = custom_op1(x)
192
            b = a.sin()
193
            return custom_op2(b)
194

195
        fwd_compiler = functools.partial(count_philox_rand, freq=2)
196
        bwd_compiler = functools.partial(count_philox_rand, freq=1)
197
        aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler)
198
        fwd_compiler = functools.partial(count_philox_rand, freq=1)
199
        bwd_compiler = functools.partial(count_philox_rand, freq=2)
200
        aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler)
201

202
        def aot_fn(x):
203
            a = aot_custom_op1(x)
204
            b = a.sin()
205
            return aot_custom_op2(b)
206

207
        for seed in range(10):
208
            torch.cuda.manual_seed(seed)
209
            x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
210
            x_clone = x.clone().detach().requires_grad_(True)
211

212
            torch.cuda.manual_seed(seed)
213
            ref = fn(x)
214
            ref.sum().backward()
215

216
            torch.cuda.manual_seed(seed)
217
            res = aot_fn(x_clone)
218
            res.sum().backward()
219

220
            self.assertEqual(ref, res)
221
            self.assertEqual(x.grad, x_clone.grad)
222

223
    @dtypes(torch.float32)
224
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
225
    def test_set_get_rng_state(self, dtype, device):
226
        def fn(x):
227
            a = torch.rand_like(x) * x
228
            state = torch.cuda.get_rng_state()
229
            a = torch.rand_like(x) * a
230
            torch.cuda.set_rng_state(state)
231
            a = torch.rand_like(x) * a
232
            return a
233

234
        x = torch.rand(10, device=device, dtype=dtype)
235

236
        for seed in range(10):
237
            torch.cuda.manual_seed(seed)
238
            ref = fn(x)
239

240
            torch.cuda.manual_seed(seed)
241
            fwd_compiler = functools.partial(count_philox_rand, freq=3)
242
            aot_fn = aot_function(fn, fwd_compiler)
243
            res = aot_fn(x)
244

245
            self.assertEqual(ref, res)
246

247
    @dtypes(torch.float32)
248
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
249
    def test_min_cut_partitioner(self, dtype, device):
250
        # Checks that the calling convention is maintained
251
        shape = (16, 16)
252

253
        def fn(x):
254
            a = torch.rand_like(x) * x
255
            a = torch.rand_like(x) * a
256
            a = torch.sin(a)
257
            a = torch.sin(a)
258
            a = torch.sin(a)
259
            return a
260

261
        x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
262

263
        x_clone = x.clone().detach().requires_grad_(True)
264

265
        torch.cuda.manual_seed(123)
266
        ref = fn(x)
267
        ref.sum().backward()
268

269
        torch.cuda.manual_seed(123)
270
        fwd_compiler = functools.partial(count_philox_rand, freq=2)
271
        bwd_compiler = functools.partial(count_philox_rand, freq=0)
272
        aot_custom = aot_function(
273
            fn,
274
            fwd_compiler,
275
            bwd_compiler,
276
            partition_fn=min_cut_rematerialization_partition,
277
        )
278
        # aot_custom = aot_function(fn, fwd_compiler, bwd_compiler)
279
        res = aot_custom(x_clone)
280
        res.sum().backward()
281

282
        self.assertEqual(ref, res)
283
        self.assertEqual(x.grad, x_clone.grad)
284

285
    # TODO - Dropout needs more work because of offset calculation
286
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
287
    @dtypes(torch.float32)
288
    def test_checkpoint(self, dtype, device):
289
        def g(x, y):
290
            return torch.nn.functional.dropout(x, 0.6)
291

292
        def fn(x, y):
293
            return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False)
294

295
        # x = torch.rand(2, 2, device="cuda", requires_grad=True)
296
        x = torch.ones(2, 2, device="cuda", requires_grad=True)
297
        y = torch.rand(2, 2, device="cuda", requires_grad=True)
298
        torch.cuda.manual_seed(123)
299
        ref = fn(x, y)
300

301
        # With checkpointing we should recompute dropout in bwd, and philox_rand is passed from fwd
302
        fwd_compiler = functools.partial(count_philox_rand, freq=1)
303
        bwd_compiler = functools.partial(count_philox_rand, freq=0)
304
        aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
305
        # We cant check accuracy here because rand_like generated different rand numbers than dropout
306
        res = aot_fn(x, y)
307
        res.sum().backward()
308

309
    @dtypes(torch.float32)
310
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
311
    def test_dropout_decomp(self, dtype, device):
312
        def fn(x):
313
            return torch.nn.functional.dropout(x, 0.6) * x
314

315
        x = torch.rand(10, device=device, dtype=dtype)
316

317
        # Ensure the decomp is happening
318
        aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
319
        # We cant check accuracy here because rand_like generated different rand numbers than dropout
320
        aot_fn(x)
321

322

323
only_for = ("cuda",)
324
instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for)
325

326

327
class NegativeTest(TestCase):
328
    @dtypes(torch.float32)
329
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
330
    def test_on_cpu(self, dtype, device):
331
        def fn(x):
332
            a = torch.rand_like(x) * x
333
            a = torch.rand_like(x) * a
334
            return a
335

336
        x = torch.rand(10, device=device, dtype=dtype)
337

338
        aot_fn = aot_function(fn, nop)
339
        with self.assertRaises(RuntimeError):
340
            aot_fn(x)
341

342

343
only_for = ("cpu",)
344
instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for)
345

346
if __name__ == "__main__":
347
    run_tests()
348

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.