8
from torch.library import _scoped_library, Library
9
from torch.testing._internal.common_utils import (
10
instantiate_parametrized_tests,
17
@contextlib.contextmanager
18
def autograd_fallback_mode(mode):
19
prev = torch._C._get_autograd_fallback_mode()
21
torch._C._set_autograd_fallback_mode(mode)
24
torch._C._set_autograd_fallback_mode(prev)
27
class TestAutogradFallback(TestCase):
28
test_ns = "_test_autograd_fallback"
31
if hasattr(torch.ops, self.test_ns):
32
delattr(torch.ops, self.test_ns)
33
if hasattr(self, "lib"):
37
def get_op(self, name):
38
return getattr(getattr(torch.ops, self.test_ns), name).default
41
lib = Library(self.test_ns, "FRAGMENT")
45
@parametrize("mode", ("nothing", "warn"))
46
def test_no_grad(self, mode):
47
with autograd_fallback_mode(mode):
49
lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
50
lib.impl("foo", lambda a, b, c: a + b + c, "CPU")
51
op = self.get_op("foo")
53
with warnings.catch_warnings():
54
warnings.simplefilter("error")
56
a = torch.randn([], requires_grad=True)
57
b = torch.randn([], requires_grad=True)
59
self.assertFalse(out.requires_grad)
61
with warnings.catch_warnings():
62
warnings.simplefilter("error")
66
self.assertFalse(out.requires_grad)
68
@parametrize("mode", ("nothing", "warn"))
69
def test_no_autograd_kernel(self, mode):
70
with autograd_fallback_mode(mode):
72
lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
73
op = self.get_op("foo")
75
def foo_impl(a, b, c):
76
result = a.detach().numpy() + b.detach().numpy() + c
77
return torch.tensor(result)
79
lib.impl("foo", foo_impl, "CPU")
82
a = torch.randn([], requires_grad=False)
83
b = torch.randn([], requires_grad=True)
84
out = op(a, b, 1).sum()
85
with self._check_ctx(mode, mode_nothing_raises=True):
87
self.assertIsNone(b.grad)
89
def _check_ctx(self, mode, *, mode_nothing_raises=False):
91
return self.assertWarnsRegex(
92
UserWarning, "an autograd kernel was not registered"
94
assert mode == "nothing"
95
if mode_nothing_raises:
96
return self.assertRaisesRegex(RuntimeError, "does not require grad")
97
return contextlib.nullcontext()
99
@parametrize("mode", ("nothing", "warn"))
100
def test_no_autograd_kernel_inplace(self, mode):
101
with autograd_fallback_mode(mode):
104
lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))")
105
op = self.get_op("foo")
108
with torch.no_grad():
113
lib.impl("foo", foo_impl, "CPU")
115
x = torch.randn(3, requires_grad=True)
121
for tensor in [w, v, z0, z1, y0, y1]:
122
with self._check_ctx(mode):
123
tensor.sum().backward(retain_graph=True)
127
lib.define("bar(Tensor(a!) self) -> ()")
128
op = self.get_op("bar")
131
with torch.no_grad():
134
lib.impl("bar", bar_impl, "CPU")
135
with warnings.catch_warnings():
136
warnings.simplefilter("error")
137
x = torch.randn([], requires_grad=True)
141
self.assertEqual(x.grad, torch.ones_like(x))
143
@parametrize("mode", ("nothing", "warn"))
144
def test_cpu_return_self(self, mode):
145
with autograd_fallback_mode(mode):
149
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
150
lib.define("foo(Tensor self) -> Tensor")
151
lib.impl("foo", lambda x: x, "CPU")
152
op = self.get_op("foo")
154
x = torch.randn(3, requires_grad=True)
156
with self._check_ctx(mode):
158
self.assertEqual(x.grad, torch.ones_like(x))
160
lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
161
lib.impl("bar", lambda x: x, "CPU")
162
op = self.get_op("bar")
164
x = torch.randn(3, requires_grad=True)
166
with self._check_ctx(mode):
168
self.assertEqual(x.grad, torch.ones_like(x))
170
@parametrize("mode", ("nothing", "warn"))
171
def test_composite_registered_to_cpu(self, mode):
172
with autograd_fallback_mode(mode):
173
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
174
lib.define("foo(Tensor self) -> Tensor")
175
lib.impl("foo", lambda x: x.sin().sum(), "CPU")
176
op = self.get_op("foo")
178
x = torch.randn(3, requires_grad=True)
180
with self._check_ctx(mode):
182
self.assertEqual(x.grad, x.cos())
184
@parametrize("mode", ("nothing", "warn"))
185
def test_autograd_function_registered_to_cpu(self, mode):
186
with autograd_fallback_mode(mode):
187
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
188
lib.define("foo(Tensor self) -> Tensor")
190
class NumpySin(torch.autograd.Function):
193
ctx.save_for_backward(x)
194
return torch.tensor(np.sin(x.cpu().numpy()))
197
def backward(ctx, gx):
198
(x,) = ctx.saved_tensors
201
lib.impl("foo", NumpySin.apply, "CPU")
202
op = self.get_op("foo")
204
x = torch.randn(3, requires_grad=True)
206
with self._check_ctx(mode):
208
self.assertEqual(x.grad, x.cos())
210
@parametrize("mode", ("nothing", "warn"))
211
def test_inplace_autograd_function_registered_to_cpu(self, mode):
212
with autograd_fallback_mode(mode):
213
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
214
lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
216
class NumpySin_(torch.autograd.Function):
219
ctx.save_for_backward(x.clone())
220
x_np = x.detach().numpy()
221
np.sin(x_np, out=x_np)
226
def backward(ctx, gx):
227
(x,) = ctx.saved_tensors
230
lib.impl("foo", NumpySin_.apply, "CPU")
231
op = self.get_op("foo")
233
x = torch.randn(3, requires_grad=True)
238
expected = torch.zeros_like(x)
239
expected[0] = x[0].cos()
240
with self._check_ctx(mode):
241
(gx,) = torch.autograd.grad(
242
y, x, torch.ones_like(y), retain_graph=True
244
self.assertEqual(gx, expected)
246
expected = torch.ones_like(x)
247
expected[0] = x[0].cos()
248
with self._check_ctx(mode):
249
(gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
250
self.assertEqual(gx, expected)
252
@parametrize("mode", ("nothing", "warn"))
253
def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
256
with autograd_fallback_mode(mode):
257
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
259
lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
267
lib.impl("foo", foo_impl, "CPU")
268
foo = self.get_op("foo")
271
lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
279
lib.impl("bar", bar_impl, "CPU")
280
bar = self.get_op("bar")
283
lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
290
lib.impl("baz", baz_impl, "CPU")
291
baz = self.get_op("baz")
294
for op in (foo, bar, baz):
296
y = torch.randn(3, requires_grad=True)
297
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
300
torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
303
for op in (foo, bar, baz):
305
y = torch.randn(3, requires_grad=True)
306
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
309
torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
311
@parametrize("mode", ("nothing", "warn"))
312
def test_post_autograd_returns_leaf(self, mode):
313
with autograd_fallback_mode(mode):
315
lib.define("foo(Tensor a) -> (Tensor, Tensor)")
316
op = self.get_op("foo")
319
"foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU"
321
x = torch.randn(3, requires_grad=True)
323
with self._check_ctx(mode):
326
@parametrize("mode", ("nothing", "warn"))
327
def test_undefined_inputs_outputs(self, mode):
328
with autograd_fallback_mode(mode):
330
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
331
op = self.get_op("foo")
334
return None, b.clone()
336
lib.impl("foo", foo_impl, "CPU")
338
x = torch.randn(3, requires_grad=True)
341
with self._check_ctx(mode):
344
@parametrize("mode", ("nothing", "warn"))
345
def test_undefined_grads(self, mode):
346
with autograd_fallback_mode(mode):
348
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
349
op = self.get_op("foo")
352
return a.sin(), b.cos()
354
lib.impl("foo", foo_impl, "CPU")
356
x = torch.randn(3, requires_grad=True)
359
w = torch._C._functions.UndefinedGrad()(w)
360
z = torch._C._functions.UndefinedGrad()(z)
361
with self._check_ctx(mode):
362
(z + w).sum().backward()
364
@parametrize("mode", ("nothing", "warn"))
365
def test_base_does_not_require_grad(self, mode):
366
with autograd_fallback_mode(mode):
368
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
369
op = self.get_op("foo")
372
with torch.no_grad():
375
lib.impl("foo", foo_impl, "CPU")
380
self.assertTrue(w._base is x)
384
with self._check_ctx(mode):
387
@parametrize("mode", ("nothing", "warn"))
388
def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode):
389
with autograd_fallback_mode(mode):
391
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)")
392
op = self.get_op("foo")
395
with torch.no_grad():
401
lib.impl("foo", foo_impl, "CPU")
402
a = torch.randn(3, requires_grad=True)
403
b = torch.randn(3, requires_grad=True)
406
with self._check_ctx(mode, mode_nothing_raises=True):
408
x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True
411
with self._check_ctx(mode, mode_nothing_raises=False):
413
y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True
416
with self._check_ctx(mode, mode_nothing_raises=True):
418
z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True
421
@parametrize("mode", ("nothing", "warn"))
422
def test_supports_tensor_lists(self, mode):
423
with autograd_fallback_mode(mode):
425
lib.define("foo(Tensor[] a) -> Tensor[]")
426
op = self.get_op("foo")
430
with torch.no_grad():
431
return x + y + z, x * y * z
433
lib.impl("foo", foo_impl, "CPU")
434
x = torch.randn(3, requires_grad=True)
435
y = torch.randn(1, requires_grad=True)
436
z = torch.randn(2, 1, requires_grad=True)
438
with self._check_ctx(mode, mode_nothing_raises=True):
446
with self._check_ctx(mode, mode_nothing_raises=True):
456
instantiate_parametrized_tests(TestAutogradFallback)
458
if __name__ == "__main__":