pytorch

Форк
0
/
test_autograd_fallback.py 
459 строк · 16.4 Кб
1
# Owner(s): ["module: autograd"]
2

3
import contextlib
4
import warnings
5

6
import numpy as np
7
import torch
8
from torch.library import _scoped_library, Library
9
from torch.testing._internal.common_utils import (
10
    instantiate_parametrized_tests,
11
    parametrize,
12
    run_tests,
13
    TestCase,
14
)
15

16

17
@contextlib.contextmanager
18
def autograd_fallback_mode(mode):
19
    prev = torch._C._get_autograd_fallback_mode()
20
    try:
21
        torch._C._set_autograd_fallback_mode(mode)
22
        yield
23
    finally:
24
        torch._C._set_autograd_fallback_mode(prev)
25

26

27
class TestAutogradFallback(TestCase):
28
    test_ns = "_test_autograd_fallback"
29

30
    def tearDown(self):
31
        if hasattr(torch.ops, self.test_ns):
32
            delattr(torch.ops, self.test_ns)
33
        if hasattr(self, "lib"):
34
            del self.lib.m
35
            del self.lib
36

37
    def get_op(self, name):
38
        return getattr(getattr(torch.ops, self.test_ns), name).default
39

40
    def get_lib(self):
41
        lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
42
        self.lib = lib
43
        return lib
44

45
    @parametrize("mode", ("nothing", "warn"))
46
    def test_no_grad(self, mode):
47
        with autograd_fallback_mode(mode):
48
            lib = self.get_lib()
49
            lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
50
            lib.impl("foo", lambda a, b, c: a + b + c, "CPU")
51
            op = self.get_op("foo")
52

53
            with warnings.catch_warnings():
54
                warnings.simplefilter("error")
55
                with torch.no_grad():
56
                    a = torch.randn([], requires_grad=True)
57
                    b = torch.randn([], requires_grad=True)
58
                    out = op(a, b, 1)
59
                self.assertFalse(out.requires_grad)
60

61
            with warnings.catch_warnings():
62
                warnings.simplefilter("error")
63
                a = torch.randn([])
64
                b = torch.randn([])
65
                out = op(a, b, 1)
66
                self.assertFalse(out.requires_grad)
67

68
    @parametrize("mode", ("nothing", "warn"))
69
    def test_no_autograd_kernel(self, mode):
70
        with autograd_fallback_mode(mode):
71
            lib = self.get_lib()
72
            lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
73
            op = self.get_op("foo")
74

75
            def foo_impl(a, b, c):
76
                result = a.detach().numpy() + b.detach().numpy() + c
77
                return torch.tensor(result)
78

79
            lib.impl("foo", foo_impl, "CPU")
80

81
            # Some inputs requiring grad
82
            a = torch.randn([], requires_grad=False)
83
            b = torch.randn([], requires_grad=True)
84
            out = op(a, b, 1).sum()
85
            with self._check_ctx(mode, mode_nothing_raises=True):
86
                out.backward()
87
            self.assertIsNone(b.grad)
88

89
    def _check_ctx(self, mode, *, mode_nothing_raises=False):
90
        if mode == "warn":
91
            return self.assertWarnsRegex(
92
                UserWarning, "an autograd kernel was not registered"
93
            )
94
        assert mode == "nothing"
95
        if mode_nothing_raises:
96
            return self.assertRaisesRegex(RuntimeError, "does not require grad")
97
        return contextlib.nullcontext()
98

99
    @parametrize("mode", ("nothing", "warn"))
100
    def test_no_autograd_kernel_inplace(self, mode):
101
        with autograd_fallback_mode(mode):
102
            # input modified in-place gets returned as output
103
            lib = self.get_lib()
104
            lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))")
105
            op = self.get_op("foo")
106

107
            def foo_impl(x, y):
108
                with torch.no_grad():
109
                    x.sin_()
110
                    y.cos_()
111
                return x, y
112

113
            lib.impl("foo", foo_impl, "CPU")
114

115
            x = torch.randn(3, requires_grad=True)
116
            w = x.clone()
117
            v = x.clone()
118
            y0 = w[0]
119
            y1 = v[1]
120
            z0, z1 = op(y0, y1)
121
            for tensor in [w, v, z0, z1, y0, y1]:
122
                with self._check_ctx(mode):
123
                    tensor.sum().backward(retain_graph=True)
124

125
            # no outputs: we don't do anything. Maybe we should in the future.
126
            # This is not a common failure mode.
127
            lib.define("bar(Tensor(a!) self) -> ()")
128
            op = self.get_op("bar")
129

130
            def bar_impl(x):
131
                with torch.no_grad():
132
                    x.sin_()
133

134
            lib.impl("bar", bar_impl, "CPU")
135
            with warnings.catch_warnings():
136
                warnings.simplefilter("error")
137
                x = torch.randn([], requires_grad=True)
138
                y = x.clone()
139
                z = op(y)
140
                y.backward()
141
                self.assertEqual(x.grad, torch.ones_like(x))
142

143
    @parametrize("mode", ("nothing", "warn"))
144
    def test_cpu_return_self(self, mode):
145
        with autograd_fallback_mode(mode):
146
            # To be clear, none of these situations are OK and will lead
147
            # to other problems down the line. We're testing them because
148
            # it is fairly common to actually do these things.
149
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
150
                lib.define("foo(Tensor self) -> Tensor")
151
                lib.impl("foo", lambda x: x, "CPU")
152
                op = self.get_op("foo")
153

154
                x = torch.randn(3, requires_grad=True)
155
                y = op(x).sum()
156
                with self._check_ctx(mode):
157
                    y.backward()
158
                    self.assertEqual(x.grad, torch.ones_like(x))
159

160
                lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
161
                lib.impl("bar", lambda x: x, "CPU")
162
                op = self.get_op("bar")
163

164
                x = torch.randn(3, requires_grad=True)
165
                y = op(x).sum()
166
                with self._check_ctx(mode):
167
                    y.backward()
168
                    self.assertEqual(x.grad, torch.ones_like(x))
169

170
    @parametrize("mode", ("nothing", "warn"))
171
    def test_composite_registered_to_cpu(self, mode):
172
        with autograd_fallback_mode(mode):
173
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
174
                lib.define("foo(Tensor self) -> Tensor")
175
                lib.impl("foo", lambda x: x.sin().sum(), "CPU")
176
                op = self.get_op("foo")
177

178
                x = torch.randn(3, requires_grad=True)
179
                y = op(x)
180
                with self._check_ctx(mode):
181
                    y.backward()
182
                    self.assertEqual(x.grad, x.cos())
183

184
    @parametrize("mode", ("nothing", "warn"))
185
    def test_autograd_function_registered_to_cpu(self, mode):
186
        with autograd_fallback_mode(mode):
187
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
188
                lib.define("foo(Tensor self) -> Tensor")
189

190
                class NumpySin(torch.autograd.Function):
191
                    @staticmethod
192
                    def forward(ctx, x):
193
                        ctx.save_for_backward(x)
194
                        return torch.tensor(np.sin(x.cpu().numpy()))
195

196
                    @staticmethod
197
                    def backward(ctx, gx):
198
                        (x,) = ctx.saved_tensors
199
                        return gx * x.cos()
200

201
                lib.impl("foo", NumpySin.apply, "CPU")
202
                op = self.get_op("foo")
203

204
                x = torch.randn(3, requires_grad=True)
205
                y = op(x).sum()
206
                with self._check_ctx(mode):
207
                    y.backward()
208
                    self.assertEqual(x.grad, x.cos())
209

210
    @parametrize("mode", ("nothing", "warn"))
211
    def test_inplace_autograd_function_registered_to_cpu(self, mode):
212
        with autograd_fallback_mode(mode):
213
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
214
                lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
215

216
                class NumpySin_(torch.autograd.Function):
217
                    @staticmethod
218
                    def forward(ctx, x):
219
                        ctx.save_for_backward(x.clone())
220
                        x_np = x.detach().numpy()
221
                        np.sin(x_np, out=x_np)
222
                        ctx.mark_dirty(x)
223
                        return x
224

225
                    @staticmethod
226
                    def backward(ctx, gx):
227
                        (x,) = ctx.saved_tensors
228
                        return gx * x.cos()
229

230
                lib.impl("foo", NumpySin_.apply, "CPU")
231
                op = self.get_op("foo")
232

233
                x = torch.randn(3, requires_grad=True)
234
                z = x.clone()
235
                w = z[0]
236
                y = op(w)
237

238
                expected = torch.zeros_like(x)
239
                expected[0] = x[0].cos()
240
                with self._check_ctx(mode):
241
                    (gx,) = torch.autograd.grad(
242
                        y, x, torch.ones_like(y), retain_graph=True
243
                    )
244
                    self.assertEqual(gx, expected)
245

246
                expected = torch.ones_like(x)
247
                expected[0] = x[0].cos()
248
                with self._check_ctx(mode):
249
                    (gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
250
                    self.assertEqual(gx, expected)
251

252
    @parametrize("mode", ("nothing", "warn"))
253
    def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
254
        # We don't do anything special (that is, we don't rebase history).
255
        # See NOTE [autograd fallback and in-place operations] for why
256
        with autograd_fallback_mode(mode):
257
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
258
                # Correct usage of (a!)
259
                lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
260

261
                def foo_impl(x, y):
262
                    x_d = x.detach()
263
                    y = y.detach()
264
                    x_d.add_(y)
265
                    return x
266

267
                lib.impl("foo", foo_impl, "CPU")
268
                foo = self.get_op("foo")
269

270
                # Incorrect usage of (a!): user doesn't return tensor as-is
271
                lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
272

273
                def bar_impl(x, y):
274
                    x_d = x.detach()
275
                    y = y.detach()
276
                    x_d.add_(y)
277
                    return x_d.clone()
278

279
                lib.impl("bar", bar_impl, "CPU")
280
                bar = self.get_op("bar")
281

282
                # User mutated input tensor but didn't return it.
283
                lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
284

285
                def baz_impl(x, y):
286
                    x_d = x.detach()
287
                    y = y.detach()
288
                    x_d.add_(y)
289

290
                lib.impl("baz", baz_impl, "CPU")
291
                baz = self.get_op("baz")
292

293
                # Test in-place on non-view
294
                for op in (foo, bar, baz):
295
                    x = torch.randn(3)
296
                    y = torch.randn(3, requires_grad=True)
297
                    with self.assertRaisesRegex(RuntimeError, "does not require grad"):
298
                        z = x.clone()
299
                        op(z, y)
300
                        torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
301

302
                # Test in-place on view
303
                for op in (foo, bar, baz):
304
                    x = torch.randn(3)
305
                    y = torch.randn(3, requires_grad=True)
306
                    with self.assertRaisesRegex(RuntimeError, "does not require grad"):
307
                        z = x[:]
308
                        op(z, y)
309
                        torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
310

311
    @parametrize("mode", ("nothing", "warn"))
312
    def test_post_autograd_returns_leaf(self, mode):
313
        with autograd_fallback_mode(mode):
314
            lib = self.get_lib()
315
            lib.define("foo(Tensor a) -> (Tensor, Tensor)")
316
            op = self.get_op("foo")
317

318
            lib.impl(
319
                "foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU"
320
            )
321
            x = torch.randn(3, requires_grad=True)
322
            y, z = op(x)
323
            with self._check_ctx(mode):
324
                z.sum().backward()
325

326
    @parametrize("mode", ("nothing", "warn"))
327
    def test_undefined_inputs_outputs(self, mode):
328
        with autograd_fallback_mode(mode):
329
            lib = self.get_lib()
330
            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
331
            op = self.get_op("foo")
332

333
            def foo_impl(a, b):
334
                return None, b.clone()
335

336
            lib.impl("foo", foo_impl, "CPU")
337

338
            x = torch.randn(3, requires_grad=True)
339
            # NB: PyTorch dispatcher treats "None" as undefined Tensor.
340
            y, z = op(None, x)
341
            with self._check_ctx(mode):
342
                z.sum().backward()
343

344
    @parametrize("mode", ("nothing", "warn"))
345
    def test_undefined_grads(self, mode):
346
        with autograd_fallback_mode(mode):
347
            lib = self.get_lib()
348
            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
349
            op = self.get_op("foo")
350

351
            def foo_impl(a, b):
352
                return a.sin(), b.cos()
353

354
            lib.impl("foo", foo_impl, "CPU")
355

356
            x = torch.randn(3, requires_grad=True)
357
            y = torch.randn(3)
358
            w, z = op(x, y)
359
            w = torch._C._functions.UndefinedGrad()(w)
360
            z = torch._C._functions.UndefinedGrad()(z)
361
            with self._check_ctx(mode):
362
                (z + w).sum().backward()
363

364
    @parametrize("mode", ("nothing", "warn"))
365
    def test_base_does_not_require_grad(self, mode):
366
        with autograd_fallback_mode(mode):
367
            lib = self.get_lib()
368
            lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
369
            op = self.get_op("foo")
370

371
            def foo_impl(a):
372
                with torch.no_grad():
373
                    return a.zero_()
374

375
            lib.impl("foo", foo_impl, "CPU")
376
            x = torch.randn(3)
377
            y = x[:]
378
            y.requires_grad_()
379
            w = y[:]
380
            self.assertTrue(w._base is x)
381

382
            # Hook should be registered on w, but not w._base
383
            op(w)
384
            with self._check_ctx(mode):
385
                w.sum().backward()
386

387
    @parametrize("mode", ("nothing", "warn"))
388
    def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode):
389
        with autograd_fallback_mode(mode):
390
            lib = self.get_lib()
391
            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)")
392
            op = self.get_op("foo")
393

394
            def foo_impl(a, b):
395
                with torch.no_grad():
396
                    x = a.clone()
397
                    z = b.clone()
398
                y = a * b
399
                return x, y, z
400

401
            lib.impl("foo", foo_impl, "CPU")
402
            a = torch.randn(3, requires_grad=True)
403
            b = torch.randn(3, requires_grad=True)
404
            x, y, z = op(a, b)
405

406
            with self._check_ctx(mode, mode_nothing_raises=True):
407
                torch.autograd.grad(
408
                    x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True
409
                )
410

411
            with self._check_ctx(mode, mode_nothing_raises=False):
412
                torch.autograd.grad(
413
                    y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True
414
                )
415

416
            with self._check_ctx(mode, mode_nothing_raises=True):
417
                torch.autograd.grad(
418
                    z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True
419
                )
420

421
    @parametrize("mode", ("nothing", "warn"))
422
    def test_supports_tensor_lists(self, mode):
423
        with autograd_fallback_mode(mode):
424
            lib = self.get_lib()
425
            lib.define("foo(Tensor[] a) -> Tensor[]")
426
            op = self.get_op("foo")
427

428
            def foo_impl(a):
429
                x, y, z = a
430
                with torch.no_grad():
431
                    return x + y + z, x * y * z
432

433
            lib.impl("foo", foo_impl, "CPU")
434
            x = torch.randn(3, requires_grad=True)
435
            y = torch.randn(1, requires_grad=True)
436
            z = torch.randn(2, 1, requires_grad=True)
437
            a, b = op([x, y, z])
438
            with self._check_ctx(mode, mode_nothing_raises=True):
439
                torch.autograd.grad(
440
                    a,
441
                    (x, y, z),
442
                    torch.ones_like(a),
443
                    allow_unused=True,
444
                    retain_graph=True,
445
                )
446
            with self._check_ctx(mode, mode_nothing_raises=True):
447
                torch.autograd.grad(
448
                    b,
449
                    (x, y, z),
450
                    torch.ones_like(b),
451
                    allow_unused=True,
452
                    retain_graph=True,
453
                )
454

455

456
instantiate_parametrized_tests(TestAutogradFallback)
457

458
if __name__ == "__main__":
459
    run_tests()
460

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.