pytorch

Форк
0
/
test_functionalization_of_rng_ops.py 
350 строк · 11.3 Кб
1
# Owner(s): ["oncall: pt2"]
2
import sys
3
import unittest
4
import torch
5
from torch.testing._internal.common_utils import (
6
    TestCase,
7
    run_tests,
8
)
9

10
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
11
from functorch.compile import aot_function, nop, min_cut_rematerialization_partition
12
from unittest.mock import patch
13
import functools
14
import torch.utils.checkpoint
15

16

17
from torch.testing._internal.common_utils import (
18
    IS_CI,
19
    IS_WINDOWS,
20
)
21

22
if IS_WINDOWS and IS_CI:
23
    sys.stderr.write(
24
        "torch.compile not supported on windows"
25
    )
26
    if __name__ == "__main__":
27
        sys.exit(0)
28
    raise unittest.SkipTest("torch.compile not supported on windows")
29

30
def count_philox_rand(gm, args, freq):
31
    assert [node.target for node in gm.graph.nodes].count(torch.ops.rngprims.philox_rand.default) == freq
32
    return gm
33

34
class TestFunctionalizationRngOps(TestCase):
35
    @dtypes(torch.float32)
36
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
37
    def test_rand_like(self, dtype, device):
38
        def fn(x):
39
            a = torch.rand_like(x) * x
40
            a = torch.rand_like(x) * a
41
            return a
42

43
        x = torch.rand(10, device=device, dtype=dtype)
44

45
        for seed in range(10):
46
            torch.cuda.manual_seed(seed)
47
            ref = fn(x)
48

49
            torch.cuda.manual_seed(seed)
50
            aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
51
            res = aot_fn(x)
52

53
            self.assertEqual(ref, res)
54

55
    @dtypes(torch.float32)
56
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
57
    def test_rand_like_dynamic(self, dtype, device):
58
        def fn(x):
59
            a = torch.rand_like(x) * x
60
            a = torch.rand_like(x) * a
61
            return a
62

63
        for seed in range(1, 10):
64
            shape = (seed, seed)
65
            x = torch.rand(shape, device=device, dtype=dtype)
66
            torch.cuda.manual_seed(seed)
67
            ref = fn(x)
68

69
            torch.cuda.manual_seed(seed)
70
            opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
71
            res = opt_fn(x)
72

73
            self.assertEqual(ref, res)
74

75

76

77
    @dtypes(torch.float32)
78
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
79
    def test_rand_like_dynamic_bwd(self, dtype, device):
80
        def fn(x):
81
            a = torch.rand_like(x) * x
82
            a = torch.rand_like(x) * a
83
            return a
84

85
        for seed in range(1, 10):
86
            shape = (seed, seed)
87
            x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
88
            torch.cuda.manual_seed(seed)
89
            ref = fn(x)
90
            ref.sum().backward()
91

92
            torch.cuda.manual_seed(seed)
93
            opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
94
            res = opt_fn(x)
95
            res.sum().backward()
96

97
            self.assertEqual(ref, res)
98

99

100
    @dtypes(torch.float32)
101
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
102
    def test_rand(self, dtype, device):
103
        shape = (10,)
104

105
        def fn(x):
106
            a = torch.rand(*shape, device=device, dtype=dtype) * x
107
            a = torch.rand(*shape, device=device, dtype=dtype) * a
108
            return a
109

110
        x = torch.rand(*shape, device=device, dtype=dtype)
111

112
        for seed in range(10):
113
            torch.cuda.manual_seed(seed)
114
            ref = fn(x)
115

116
            torch.cuda.manual_seed(seed)
117
            aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
118
            res = aot_fn(x)
119

120
            self.assertEqual(ref, res)
121

122
    @dtypes(torch.float32)
123
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
124
    def test_autograd_function(self, dtype, device):
125
        shape = (16, 16)
126

127
        class Custom(torch.autograd.Function):
128
            @staticmethod
129
            def forward(ctx, x):
130
                ctx.save_for_backward(x)
131
                a = torch.rand_like(x) * x
132
                a = torch.rand_like(x) * a
133
                return a
134

135
            @staticmethod
136
            def backward(ctx, grad_out):
137
                x, = ctx.saved_tensors
138
                return grad_out * torch.rand_like(grad_out) * torch.cos(x)
139

140
        custom = Custom.apply
141

142
        x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
143

144
        x_clone = x.clone().detach().requires_grad_(True)
145

146
        torch.cuda.manual_seed(123)
147
        ref = custom(x)
148
        ref.sum().backward()
149

150
        torch.cuda.manual_seed(123)
151
        fwd_compiler = functools.partial(count_philox_rand, freq=2)
152
        bwd_compiler = functools.partial(count_philox_rand, freq=1)
153
        aot_custom = aot_function(custom, fwd_compiler, bwd_compiler)
154
        res = aot_custom(x_clone)
155
        res.sum().backward()
156

157
        self.assertEqual(ref, res)
158
        self.assertEqual(x.grad, x_clone.grad)
159

160
    @dtypes(torch.float32)
161
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
162
    def test_multiple_subgraphs(self, dtype, device):
163
        # Checks that rng state is maintained when there are multiple aot traced
164
        # graphs.
165
        shape = (16, 16)
166

167
        class CustomOp1(torch.autograd.Function):
168
            @staticmethod
169
            def forward(ctx, x):
170
                ctx.save_for_backward(x)
171
                a = torch.rand_like(x) * x
172
                a = torch.rand_like(x) * a
173
                return a
174

175
            @staticmethod
176
            def backward(ctx, grad_out):
177
                x, = ctx.saved_tensors
178
                return grad_out * torch.rand_like(grad_out) * torch.cos(x)
179

180
        class CustomOp2(torch.autograd.Function):
181
            @staticmethod
182
            def forward(ctx, x):
183
                ctx.save_for_backward(x)
184
                a = torch.rand_like(x) * x
185
                return a
186

187
            @staticmethod
188
            def backward(ctx, grad_out):
189
                x, = ctx.saved_tensors
190
                return grad_out * torch.rand_like(grad_out) * torch.rand_like(x)
191

192

193
        custom_op1 = CustomOp1.apply
194
        custom_op2 = CustomOp2.apply
195

196
        def fn(x):
197
            a = custom_op1(x)
198
            b = a.sin()
199
            return custom_op2(b)
200

201
        fwd_compiler = functools.partial(count_philox_rand, freq=2)
202
        bwd_compiler = functools.partial(count_philox_rand, freq=1)
203
        aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler)
204
        fwd_compiler = functools.partial(count_philox_rand, freq=1)
205
        bwd_compiler = functools.partial(count_philox_rand, freq=2)
206
        aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler)
207

208
        def aot_fn(x):
209
            a = aot_custom_op1(x)
210
            b = a.sin()
211
            return aot_custom_op2(b)
212

213

214
        for seed in range(10):
215
            torch.cuda.manual_seed(seed)
216
            x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
217
            x_clone = x.clone().detach().requires_grad_(True)
218

219
            torch.cuda.manual_seed(seed)
220
            ref = fn(x)
221
            ref.sum().backward()
222

223
            torch.cuda.manual_seed(seed)
224
            res = aot_fn(x_clone)
225
            res.sum().backward()
226

227
            self.assertEqual(ref, res)
228
            self.assertEqual(x.grad, x_clone.grad)
229

230
    @dtypes(torch.float32)
231
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
232
    def test_set_get_rng_state(self, dtype, device):
233
        def fn(x):
234
            a = torch.rand_like(x) * x
235
            state = torch.cuda.get_rng_state()
236
            a = torch.rand_like(x) * a
237
            torch.cuda.set_rng_state(state)
238
            a = torch.rand_like(x) * a
239
            return a
240

241
        x = torch.rand(10, device=device, dtype=dtype)
242

243
        for seed in range(10):
244
            torch.cuda.manual_seed(seed)
245
            ref = fn(x)
246

247
            torch.cuda.manual_seed(seed)
248
            fwd_compiler = functools.partial(count_philox_rand, freq=3)
249
            aot_fn = aot_function(fn, fwd_compiler)
250
            res = aot_fn(x)
251

252
            self.assertEqual(ref, res)
253

254
    @dtypes(torch.float32)
255
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
256
    def test_min_cut_partitioner(self, dtype, device):
257
        # Checks that the calling convention is maintained
258
        shape = (16, 16)
259

260
        def fn(x):
261
            a = torch.rand_like(x) * x
262
            a = torch.rand_like(x) * a
263
            a = torch.sin(a)
264
            a = torch.sin(a)
265
            a = torch.sin(a)
266
            return a
267

268

269
        x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
270

271
        x_clone = x.clone().detach().requires_grad_(True)
272

273
        torch.cuda.manual_seed(123)
274
        ref = fn(x)
275
        ref.sum().backward()
276

277
        torch.cuda.manual_seed(123)
278
        fwd_compiler = functools.partial(count_philox_rand, freq=2)
279
        bwd_compiler = functools.partial(count_philox_rand, freq=0)
280
        aot_custom = aot_function(fn, fwd_compiler, bwd_compiler, partition_fn=min_cut_rematerialization_partition)
281
        # aot_custom = aot_function(fn, fwd_compiler, bwd_compiler)
282
        res = aot_custom(x_clone)
283
        res.sum().backward()
284

285
        self.assertEqual(ref, res)
286
        self.assertEqual(x.grad, x_clone.grad)
287

288
    # TODO - Dropout needs more work because of offset calculation
289
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
290
    @dtypes(torch.float32)
291
    def test_checkpoint(self, dtype, device):
292
        def g(x, y):
293
            return torch.nn.functional.dropout(x, 0.6)
294

295
        def fn(x, y):
296
            return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False)
297

298
        # x = torch.rand(2, 2, device="cuda", requires_grad=True)
299
        x = torch.ones(2, 2, device="cuda", requires_grad=True)
300
        y = torch.rand(2, 2, device="cuda", requires_grad=True)
301
        torch.cuda.manual_seed(123)
302
        ref = fn(x, y)
303

304
        # With checkpointing we should recompute dropout in bwd, and should see philox_rand
305
        fwd_compiler = functools.partial(count_philox_rand, freq=1)
306
        bwd_compiler = functools.partial(count_philox_rand, freq=1)
307
        aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
308
        # We cant check accuracy here because rand_like generated different rand numbers than dropout
309
        res = aot_fn(x, y)
310
        res.sum().backward()
311

312
    @dtypes(torch.float32)
313
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
314
    def test_dropout_decomp(self, dtype, device):
315
        def fn(x):
316
            return torch.nn.functional.dropout(x, 0.6) * x
317

318
        x = torch.rand(10, device=device, dtype=dtype)
319

320
        # Ensure the decomp is happening
321
        aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
322
        # We cant check accuracy here because rand_like generated different rand numbers than dropout
323
        aot_fn(x)
324

325

326
only_for = ("cuda",)
327
instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for)
328

329

330
class NegativeTest(TestCase):
331
    @dtypes(torch.float32)
332
    @patch.object(torch._functorch.config, "functionalize_rng_ops", True)
333
    def test_on_cpu(self, dtype, device):
334
        def fn(x):
335
            a = torch.rand_like(x) * x
336
            a = torch.rand_like(x) * a
337
            return a
338

339
        x = torch.rand(10, device=device, dtype=dtype)
340

341
        aot_fn = aot_function(fn, nop)
342
        with self.assertRaises(RuntimeError):
343
            aot_fn(x)
344

345

346
only_for = ("cpu",)
347
instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for)
348

349
if __name__ == "__main__":
350
    run_tests()
351

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.