1
from typing import List, Tuple
4
from torch._vmap_internals import _vmap
5
from . import forward_ad as fwAD
7
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
12
def _as_tuple_nocheck(x):
13
if isinstance(x, tuple):
15
elif isinstance(x, list):
21
def _as_tuple(inp, arg_name=None, fn_name=None):
22
# Ensures that inp is a tuple of Tensors
23
# Returns whether or not the original inp was a tuple and the tupled version of the input
24
if arg_name is None and fn_name is None:
25
return _as_tuple_nocheck(inp)
28
if not isinstance(inp, tuple):
32
for i, el in enumerate(inp):
33
if not isinstance(el, torch.Tensor):
36
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
37
f" value at index {i} has type {type(el)}."
41
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
42
f" given {arg_name} has type {type(el)}."
45
return is_inp_tuple, inp
48
def _tuple_postprocess(res, to_unpack):
49
# Unpacks a potentially nested tuple of Tensors
50
# to_unpack should be a single boolean or a tuple of two booleans.
52
# - invert _as_tuple when res should match the inp given to _as_tuple
53
# - optionally remove nesting of two tuples created by multiple calls to _as_tuple
54
if isinstance(to_unpack, tuple):
55
assert len(to_unpack) == 2
57
res = tuple(el[0] for el in res)
66
def _grad_preprocess(inputs, create_graph, need_graph):
67
# Preprocess the inputs to make sure they require gradient
68
# inputs is a tuple of Tensors to preprocess
69
# create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
70
# need_graph specifies if we internally want gradients to flow back to the Tensors in res
71
# Note that we *always* create a new Tensor object to be able to see the difference between
72
# inputs given as arguments and the same Tensors automatically captured by the user function.
73
# Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
76
if create_graph and inp.requires_grad:
77
# Create at least a new Tensor object in a differentiable way
79
# Use .view_as() to get a shallow copy
80
res.append(inp.view_as(inp))
82
# We cannot use view for sparse Tensors so we clone
83
res.append(inp.clone())
85
res.append(inp.detach().requires_grad_(need_graph))
89
def _grad_postprocess(inputs, create_graph):
90
# Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
92
if isinstance(inputs[0], torch.Tensor):
94
return tuple(inp.detach() for inp in inputs)
98
return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
101
def _validate_v(v, other, is_other_tuple):
102
# This assumes that other is the correct shape, and v should match
103
# Both are assumed to be tuples of Tensors
104
if len(other) != len(v):
107
f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
110
raise RuntimeError("The given v should contain a single Tensor.")
112
for idx, (el_v, el_other) in enumerate(zip(v, other)):
113
if el_v.size() != el_other.size():
116
prepend = f"Entry {idx} in "
118
f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}."
122
def _check_requires_grad(inputs, input_type, strict):
123
# Used to make all the necessary checks to raise nice errors in strict mode.
127
if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
128
raise RuntimeError("Invalid input_type to _check_requires_grad")
129
for i, inp in enumerate(inputs):
131
# This can only be reached for grad_inputs.
133
f"The output of the user-provided function is independent of input {i}."
134
" This is not allowed in strict mode."
136
if not inp.requires_grad:
137
if input_type == "hessian":
139
f"The hessian of the user-provided function with respect to input {i}"
140
" is independent of the input. This is not allowed in strict mode."
141
" You should ensure that your function is thrice differentiable and that"
142
" the hessian depends on the inputs."
144
elif input_type == "jacobian":
146
"While computing the hessian, found that the jacobian of the user-provided"
147
f" function with respect to input {i} is independent of the input. This is not"
148
" allowed in strict mode. You should ensure that your function is twice"
149
" differentiable and that the jacobian depends on the inputs (this would be"
150
" violated by a linear function for example)."
152
elif input_type == "grad_inputs":
154
f"The gradient with respect to input {i} is independent of the inputs of the"
155
" user-provided function. This is not allowed in strict mode."
159
f"Output {i} of the user-provided function does not require gradients."
160
" The outputs must be computed in a differentiable manner from the input"
161
" when running in strict mode."
171
is_grads_batched=False,
173
# Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
174
# This has the extra constraint that inputs has to be a tuple
175
assert isinstance(outputs, tuple)
176
if grad_outputs is None:
177
grad_outputs = (None,) * len(outputs)
178
assert isinstance(grad_outputs, tuple)
179
assert len(outputs) == len(grad_outputs)
181
new_outputs: Tuple[torch.Tensor, ...] = tuple()
182
new_grad_outputs: Tuple[torch.Tensor, ...] = tuple()
183
for out, grad_out in zip(outputs, grad_outputs):
184
if out is not None and out.requires_grad:
185
new_outputs += (out,)
186
new_grad_outputs += (grad_out,)
188
if len(new_outputs) == 0:
189
# No differentiable output, we don't need to call the autograd engine
190
return (None,) * len(inputs)
192
return torch.autograd.grad(
197
create_graph=create_graph,
198
retain_graph=retain_graph,
199
is_grads_batched=is_grads_batched,
203
def _fill_in_zeros(grads, refs, strict, create_graph, stage):
204
# Used to detect None in the grads and depending on the flags, either replace them
205
# with Tensors full of 0s of the appropriate size based on the refs or raise an error.
206
# strict and create graph allow us to detect when it is appropriate to raise an error
207
# stage gives us information of which backward call we consider to give good error message
208
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
209
raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
211
res: Tuple[torch.Tensor, ...] = tuple()
212
for i, grads_i in enumerate(grads):
217
"The output of the user-provided function is independent of "
218
f"input {i}. This is not allowed in strict mode."
220
elif stage == "back_trick":
222
f"The gradient with respect to the input is independent of entry {i}"
223
" in the grad_outputs when using the double backward trick to compute"
224
" forward mode gradients. This is not allowed in strict mode."
226
elif stage == "double_back":
228
"The jacobian of the user-provided function is independent of "
229
f"input {i}. This is not allowed in strict mode."
233
"The hessian of the user-provided function is independent of "
234
f"entry {i} in the grad_jacobian. This is not allowed in strict "
235
"mode as it prevents from using the double backward trick to "
236
"replace forward mode AD."
239
grads_i = torch.zeros_like(refs[i])
241
if strict and create_graph and not grads_i.requires_grad:
242
if "double" not in stage:
244
"The jacobian of the user-provided function is independent of "
245
f"input {i}. This is not allowed in strict mode when create_graph=True."
249
"The hessian of the user-provided function is independent of "
250
f"input {i}. This is not allowed in strict mode when create_graph=True."
261
def vjp(func, inputs, v=None, create_graph=False, strict=False):
262
r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
265
func (function): a Python function that takes Tensor inputs and returns
266
a tuple of Tensors or a Tensor.
267
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
268
v (tuple of Tensors or Tensor): The vector for which the vector
269
Jacobian product is computed. Must be the same size as the output
270
of ``func``. This argument is optional when the output of ``func``
271
contains a single element and (if it is not provided) will be set
272
as a Tensor containing a single ``1``.
273
create_graph (bool, optional): If ``True``, both the output and result
274
will be computed in a differentiable way. Note that when ``strict``
275
is ``False``, the result can not require gradients or be
276
disconnected from the inputs. Defaults to ``False``.
277
strict (bool, optional): If ``True``, an error will be raised when we
278
detect that there exists an input such that all the outputs are
279
independent of it. If ``False``, we return a Tensor of zeros as the
280
vjp for said inputs, which is the expected mathematical value.
281
Defaults to ``False``.
284
output (tuple): tuple with:
285
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
287
vjp (tuple of Tensors or Tensor): result of the dot product with
288
the same shape as the inputs.
292
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
293
>>> def exp_reducer(x):
294
... return x.exp().sum(dim=1)
295
>>> inputs = torch.rand(4, 4)
296
>>> v = torch.ones(4)
297
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
298
>>> vjp(exp_reducer, inputs, v)
299
(tensor([5.7817, 7.2458, 5.7830, 6.7782]),
300
tensor([[1.4458, 1.3962, 1.3042, 1.6354],
301
[2.1288, 1.0652, 1.5483, 2.5035],
302
[2.2046, 1.1292, 1.1432, 1.3059],
303
[1.3225, 1.6652, 1.7753, 2.0152]]))
305
>>> vjp(exp_reducer, inputs, v, create_graph=True)
306
(tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),
307
tensor([[1.4458, 1.3962, 1.3042, 1.6354],
308
[2.1288, 1.0652, 1.5483, 2.5035],
309
[2.2046, 1.1292, 1.1432, 1.3059],
310
[1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
313
... return 2 * x + 3 * y
314
>>> inputs = (torch.rand(2), torch.rand(2))
315
>>> v = torch.ones(2)
316
>>> vjp(adder, inputs, v)
317
(tensor([2.4225, 2.3340]),
318
(tensor([2., 2.]), tensor([3., 3.])))
320
with torch.enable_grad():
321
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
322
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
324
outputs = func(*inputs)
325
is_outputs_tuple, outputs = _as_tuple(
326
outputs, "outputs of the user-provided function", "vjp"
328
_check_requires_grad(outputs, "outputs", strict=strict)
331
_, v = _as_tuple(v, "v", "vjp")
332
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
333
_validate_v(v, outputs, is_outputs_tuple)
335
if len(outputs) != 1 or outputs[0].nelement() != 1:
337
"The vector v can only be None if the "
338
"user-provided function returns "
339
"a single Tensor with a single element."
342
enable_grad = True if create_graph else torch.is_grad_enabled()
343
with torch.set_grad_enabled(enable_grad):
344
grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
345
vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
347
# Cleanup objects and return them to the user
348
outputs = _grad_postprocess(outputs, create_graph)
349
vjp = _grad_postprocess(vjp, create_graph)
351
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
356
def jvp(func, inputs, v=None, create_graph=False, strict=False):
357
r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
360
func (function): a Python function that takes Tensor inputs and returns
361
a tuple of Tensors or a Tensor.
362
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
363
v (tuple of Tensors or Tensor): The vector for which the Jacobian
364
vector product is computed. Must be the same size as the input of
365
``func``. This argument is optional when the input to ``func``
366
contains a single element and (if it is not provided) will be set
367
as a Tensor containing a single ``1``.
368
create_graph (bool, optional): If ``True``, both the output and result
369
will be computed in a differentiable way. Note that when ``strict``
370
is ``False``, the result can not require gradients or be
371
disconnected from the inputs. Defaults to ``False``.
372
strict (bool, optional): If ``True``, an error will be raised when we
373
detect that there exists an input such that all the outputs are
374
independent of it. If ``False``, we return a Tensor of zeros as the
375
jvp for said inputs, which is the expected mathematical value.
376
Defaults to ``False``.
379
output (tuple): tuple with:
380
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
382
jvp (tuple of Tensors or Tensor): result of the dot product with
383
the same shape as the output.
386
``autograd.functional.jvp`` computes the jvp by using the backward of
387
the backward (sometimes called the double backwards trick). This is not
388
the most performant way of computing the jvp. Please consider using
389
:func:`torch.func.jvp` or the
390
:ref:`low-level forward-mode AD API <forward-mode-ad>` instead.
394
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
395
>>> def exp_reducer(x):
396
... return x.exp().sum(dim=1)
397
>>> inputs = torch.rand(4, 4)
398
>>> v = torch.ones(4, 4)
399
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
400
>>> jvp(exp_reducer, inputs, v)
401
(tensor([6.3090, 4.6742, 7.9114, 8.2106]),
402
tensor([6.3090, 4.6742, 7.9114, 8.2106]))
404
>>> jvp(exp_reducer, inputs, v, create_graph=True)
405
(tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>),
406
tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))
409
... return 2 * x + 3 * y
410
>>> inputs = (torch.rand(2), torch.rand(2))
411
>>> v = (torch.ones(2), torch.ones(2))
412
>>> jvp(adder, inputs, v)
413
(tensor([2.2399, 2.5005]),
417
with torch.enable_grad():
418
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
419
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
422
_, v = _as_tuple(v, "v", "jvp")
423
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
424
_validate_v(v, inputs, is_inputs_tuple)
426
if len(inputs) != 1 or inputs[0].nelement() != 1:
428
"The vector v can only be None if the input to "
429
"the user-provided function is a single Tensor "
430
"with a single element."
433
outputs = func(*inputs)
434
is_outputs_tuple, outputs = _as_tuple(
435
outputs, "outputs of the user-provided function", "jvp"
437
_check_requires_grad(outputs, "outputs", strict=strict)
438
# The backward is linear so the value of grad_outputs is not important as
439
# it won't appear in the double backward graph. We only need to ensure that
440
# it does not contain inf or nan.
441
grad_outputs = tuple(
442
torch.zeros_like(out, requires_grad=True) for out in outputs
445
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
446
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
449
with torch.enable_grad():
450
grad_res = _autograd_grad(
451
grad_inputs, grad_outputs, v, create_graph=create_graph
453
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
455
grad_res = _autograd_grad(
456
grad_inputs, grad_outputs, v, create_graph=create_graph
458
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
460
# Cleanup objects and return them to the user
461
outputs = _grad_postprocess(outputs, create_graph)
462
jvp = _grad_postprocess(jvp, create_graph)
464
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
465
jvp, is_outputs_tuple
469
def _construct_standard_basis_for(
470
tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]
471
) -> Tuple[torch.Tensor, ...]:
473
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
474
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
475
# - Each chunk corresponds to one tensor. The chunk has the same dtype and
476
# device as the tensor
478
# For example, with tensor_numels = [1, 2, 1], this function returns:
479
# ( tensor([[1], tensor([[0, 0], tensor([[0],
482
# [0]]) , [0, 0]]) , [1]]) )
484
# Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
485
# Precondition: tensors always has at least one element.
487
# See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
488
# for context behind this function. All the pre-conditions are guarded for
489
# in torch.autograd.functional.jacobian.
490
assert len(tensors) == len(tensor_numels)
491
assert len(tensors) > 0
492
total_numel = sum(tensor_numels)
494
tensor.new_zeros(total_numel, tensor_numel)
495
for tensor, tensor_numel in zip(tensors, tensor_numels)
498
for chunk, numel in zip(chunks, tensor_numels):
499
chunk.diagonal(diag_start_idx).fill_(1)
500
diag_start_idx -= numel
504
def _jacfwd(func, inputs, strict=False, vectorize=False):
507
"torch.autograd.functional.jacobian: `strict=True` "
508
'and `strategy="forward-mode"` are not supported together (yet). '
509
"Please either set `strict=False` or "
510
'`strategy="reverse-mode"`.'
512
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
516
# See NOTE: [Computing jacobian with vmap and grad for multiple outputs]
517
input_numels = tuple(input.numel() for input in inputs)
519
# Step 1: Prepare tangents
520
tangents = _construct_standard_basis_for(inputs, input_numels)
522
# Step 2: Compute vmap over computation with dual tensors
524
with fwAD.dual_level():
526
fwAD.make_dual(input, tangent.view_as(input))
527
for input, tangent in zip(inputs, tangents)
529
_is_outputs_tuple, dual_outputs = _as_tuple(
530
func(*dual_inputs), "outputs"
532
output_info.append(_is_outputs_tuple)
535
for dual_out in dual_outputs:
536
primal, tangent = fwAD.unpack_dual(dual_out)
537
primal_outs.append(primal)
538
if tangent is not None:
541
jv.append(torch.zeros_like(primal))
542
output_info.append(primal_outs)
545
outputs_before_split = _vmap(jvp)(tangents)
546
is_outputs_tuple, outputs = output_info
547
# Step 3: for each of the output tangents, split along dim 0
548
jacobian_input_output = []
549
for jac_output_i, output_i in zip(outputs_before_split, outputs):
550
jacobian_output_i_output = []
551
for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs):
552
# We need to transpose the Jacobian because in forward AD, the
553
# batch dimension represents that of the inputs
554
jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape(
555
(*output_i.shape, *input_j.shape)
558
jacobian_output_i_output.append(jacobian_input_i_output_j)
559
jacobian_input_output.append(jacobian_output_i_output)
561
# Omit [Step 4] because everything is already transposed w/ forward AD
562
return _tuple_postprocess(
563
jacobian_input_output, (is_outputs_tuple, is_inputs_tuple)
566
raise NotImplementedError(
567
"Computing Jacobian using forward-AD or forward-over-reverse Hessian is"
568
"only implemented for `vectorize=True`."
578
strategy="reverse-mode",
580
r"""Compute the Jacobian of a given function.
583
func (function): a Python function that takes Tensor inputs and returns
584
a tuple of Tensors or a Tensor.
585
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
586
create_graph (bool, optional): If ``True``, the Jacobian will be
587
computed in a differentiable manner. Note that when ``strict`` is
588
``False``, the result can not require gradients or be disconnected
589
from the inputs. Defaults to ``False``.
590
strict (bool, optional): If ``True``, an error will be raised when we
591
detect that there exists an input such that all the outputs are
592
independent of it. If ``False``, we return a Tensor of zeros as the
593
jacobian for said inputs, which is the expected mathematical value.
594
Defaults to ``False``.
595
vectorize (bool, optional): This feature is experimental.
596
Please consider using :func:`torch.func.jacrev` or
597
:func:`torch.func.jacfwd` instead if you are looking for something
598
less experimental and more performant.
599
When computing the jacobian, usually we invoke
600
``autograd.grad`` once per row of the jacobian. If this flag is
601
``True``, we perform only a single ``autograd.grad`` call with
602
``batched_grad=True`` which uses the vmap prototype feature.
603
Though this should lead to performance improvements in many cases,
604
because this feature is still experimental, there may be performance
605
cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for
607
strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to
608
determine whether the Jacobian will be computed with forward or reverse
609
mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``.
610
Defaults to ``"reverse-mode"``. If ``func`` has more outputs than
611
inputs, ``"forward-mode"`` tends to be more performant. Otherwise,
612
prefer to use ``"reverse-mode"``.
615
Jacobian (Tensor or nested tuple of Tensors): if there is a single
616
input and output, this will be a single Tensor containing the
617
Jacobian for the linearized inputs and output. If one of the two is
618
a tuple, then the Jacobian will be a tuple of Tensors. If both of
619
them are tuples, then the Jacobian will be a tuple of tuple of
620
Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the
621
``i``\th output and ``j``\th input and will have as size the
622
concatenation of the sizes of the corresponding output and the
623
corresponding input and will have same dtype and device as the
624
corresponding input. If strategy is ``forward-mode``, the dtype will be
625
that of the output; otherwise, the input.
629
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
630
>>> def exp_reducer(x):
631
... return x.exp().sum(dim=1)
632
>>> inputs = torch.rand(2, 2)
633
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
634
>>> jacobian(exp_reducer, inputs)
635
tensor([[[1.4917, 2.4352],
640
>>> jacobian(exp_reducer, inputs, create_graph=True)
641
tensor([[[1.4917, 2.4352],
644
[2.4369, 2.3799]]], grad_fn=<ViewBackward>)
646
>>> def exp_adder(x, y):
647
... return 2 * x.exp() + 3 * y
648
>>> inputs = (torch.rand(2), torch.rand(2))
649
>>> jacobian(exp_adder, inputs)
650
(tensor([[2.8052, 0.0000],
655
assert strategy in ("forward-mode", "reverse-mode"), (
656
'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your '
657
'function has more outputs than inputs, "forward-mode" tends to be more performant. '
658
'Otherwise, prefer to use "reverse-mode".'
660
if strategy == "forward-mode":
662
raise NotImplementedError(
663
"torch.autograd.functional.jacobian: `create_graph=True` "
664
'and `strategy="forward-mode"` are not supported together (yet). '
665
"Please either set `create_graph=False` or "
666
'`strategy="reverse-mode"`.'
668
return _jacfwd(func, inputs, strict, vectorize)
670
with torch.enable_grad():
671
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
672
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
674
outputs = func(*inputs)
675
is_outputs_tuple, outputs = _as_tuple(
676
outputs, "outputs of the user-provided function", "jacobian"
678
_check_requires_grad(outputs, "outputs", strict=strict)
683
"torch.autograd.functional.jacobian: `strict=True` "
684
"and `vectorized=True` are not supported together. "
685
"Please either set `strict=False` or "
688
# NOTE: [Computing jacobian with vmap and grad for multiple outputs]
690
# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
691
# It turns out we can compute the jacobian of this function with a single
692
# call to autograd.grad by using vmap over the correct grad_outputs.
694
# Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
695
# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
697
# To get the first row of the jacobian, we call
698
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
699
# To get the 2nd row of the jacobian, we call
700
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
703
# Using vmap, we can vectorize all 4 of these computations into one by
704
# passing the standard basis for R^4 as the grad_output.
705
# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
707
# Now, how do we compute the jacobian *without stacking the output*?
708
# We can just split the standard basis across the outputs. So to
709
# compute the jacobian of f(x), we'd use
710
# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
711
# The grad_outputs looks like the following:
712
# ( torch.tensor([[1, 0, 0],
721
# But we're not done yet!
722
# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
723
# returns a Tensor of shape [4, 3]. We have to remember to split the
724
# jacobian of shape [4, 3] into two:
725
# - one of shape [3, 3] for the first output
726
# - one of shape [ 3] for the second output
728
# Step 1: Construct grad_outputs by splitting the standard basis
729
output_numels = tuple(output.numel() for output in outputs)
730
grad_outputs = _construct_standard_basis_for(outputs, output_numels)
731
flat_outputs = tuple(output.reshape(-1) for output in outputs)
733
# Step 2: Call vmap + autograd.grad
734
def vjp(grad_output):
740
create_graph=create_graph,
741
is_grads_batched=True,
744
for el_idx, vj_el in enumerate(vj):
745
if vj_el is not None:
747
vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand(
748
(sum(output_numels),) + inputs[el_idx].shape
752
jacobians_of_flat_output = vjp(grad_outputs)
754
# Step 3: The returned jacobian is one big tensor per input. In this step,
755
# we split each Tensor by output.
756
jacobian_input_output = []
757
for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs):
758
jacobian_input_i_output = []
759
for jac, output_j in zip(
760
jac_input_i.split(output_numels, dim=0), outputs
762
jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)
763
jacobian_input_i_output.append(jacobian_input_i_output_j)
764
jacobian_input_output.append(jacobian_input_i_output)
766
# Step 4: Right now, `jacobian` is a List[List[Tensor]].
767
# The outer List corresponds to the number of inputs,
768
# the inner List corresponds to the number of outputs.
769
# We need to exchange the order of these and convert to tuples
771
jacobian_output_input = tuple(zip(*jacobian_input_output))
773
jacobian_output_input = _grad_postprocess(
774
jacobian_output_input, create_graph
776
return _tuple_postprocess(
777
jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)
780
jacobian: Tuple[torch.Tensor, ...] = tuple()
782
for i, out in enumerate(outputs):
783
# mypy complains that expression and variable have different types due to the empty list
784
jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment]
785
for j in range(out.nelement()):
787
(out.reshape(-1)[j],),
790
create_graph=create_graph,
793
for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(
794
zip(jac_i, vj, inputs)
796
if vj_el is not None:
797
if strict and create_graph and not vj_el.requires_grad:
799
"The jacobian of the user-provided function is "
800
f"independent of input {i}. This is not allowed in "
801
"strict mode when create_graph=True."
803
raise RuntimeError(msg)
804
jac_i_el.append(vj_el)
808
f"Output {i} of the user-provided function is "
809
f"independent of input {el_idx}. This is not allowed in "
812
raise RuntimeError(msg)
813
jac_i_el.append(torch.zeros_like(inp_el))
817
torch.stack(jac_i_el, dim=0).view(
818
out.size() + inputs[el_idx].size() # type: ignore[operator]
820
for (el_idx, jac_i_el) in enumerate(jac_i)
824
jacobian = _grad_postprocess(jacobian, create_graph)
826
return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
835
outer_jacobian_strategy="reverse-mode",
837
r"""Compute the Hessian of a given scalar function.
840
func (function): a Python function that takes Tensor inputs and returns
841
a Tensor with a single element.
842
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
843
create_graph (bool, optional): If ``True``, the Hessian will be computed in
844
a differentiable manner. Note that when ``strict`` is ``False``, the result can not
845
require gradients or be disconnected from the inputs.
846
Defaults to ``False``.
847
strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input
848
such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
849
hessian for said inputs, which is the expected mathematical value.
850
Defaults to ``False``.
851
vectorize (bool, optional): This feature is experimental.
852
Please consider using :func:`torch.func.hessian`
853
instead if you are looking for something less experimental and more performant.
854
When computing the hessian, usually we invoke
855
``autograd.grad`` once per row of the hessian. If this flag is
856
``True``, we use the vmap prototype feature as the backend to
857
vectorize calls to ``autograd.grad`` so we only invoke it once
858
instead of once per row. This should lead to performance
859
improvements in many use cases, however, due to this feature
860
being incomplete, there may be performance cliffs. Please
861
use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
862
to show any performance warnings and file us issues if
863
warnings exist for your use case. Defaults to ``False``.
864
outer_jacobian_strategy (str, optional): The Hessian is computed by
865
computing the Jacobian of a Jacobian. The inner Jacobian is always
866
computed in reverse-mode AD. Setting strategy to ``"forward-mode"``
867
or ``"reverse-mode"`` determines whether the outer Jacobian will be
868
computed with forward or reverse mode AD. Currently, computing the outer
869
Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults
870
to ``"reverse-mode"``.
873
Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,
874
this will be a single Tensor containing the Hessian for the input.
875
If it is a tuple, then the Hessian will be a tuple of tuples where
876
``Hessian[i][j]`` will contain the Hessian of the ``i``\th input
877
and ``j``\th input with size the sum of the size of the ``i``\th input plus
878
the size of the ``j``\th input. ``Hessian[i][j]`` will have the same
879
dtype and device as the corresponding ``i``\th input.
883
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
884
>>> def pow_reducer(x):
885
... return x.pow(3).sum()
886
>>> inputs = torch.rand(2, 2)
887
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
888
>>> hessian(pow_reducer, inputs)
889
tensor([[[[5.2265, 0.0000],
898
>>> hessian(pow_reducer, inputs, create_graph=True)
899
tensor([[[[5.2265, 0.0000],
906
[0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
909
>>> def pow_adder_reducer(x, y):
910
... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
911
>>> inputs = (torch.rand(2), torch.rand(2))
912
>>> hessian(pow_adder_reducer, inputs)
922
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
923
assert outer_jacobian_strategy in (
926
), 'Expected strategy to be either "forward-mode" or "reverse-mode".'
928
def ensure_single_output_function(*inp):
930
is_out_tuple, t_out = _as_tuple(
931
out, "outputs of the user-provided function", "hessian"
933
_check_requires_grad(t_out, "outputs", strict=strict)
935
if is_out_tuple or not isinstance(out, torch.Tensor):
937
"The function given to hessian should return a single Tensor"
940
if out.nelement() != 1:
942
"The Tensor returned by the function given to hessian should contain a single element"
948
if outer_jacobian_strategy == "forward-mode":
949
# _grad_preprocess requires create_graph=True and input to require_grad
950
# or else the input will be detached
951
inp = tuple(t.requires_grad_(True) for t in inp)
952
jac = jacobian(ensure_single_output_function, inp, create_graph=True)
953
_check_requires_grad(jac, "jacobian", strict=strict)
959
create_graph=create_graph,
962
strategy=outer_jacobian_strategy,
964
return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))
967
def vhp(func, inputs, v=None, create_graph=False, strict=False):
968
r"""Compute the dot product between vector ``v`` and Hessian of a given scalar function at a specified point.
971
func (function): a Python function that takes Tensor inputs and returns
972
a Tensor with a single element.
973
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
974
v (tuple of Tensors or Tensor): The vector for which the vector Hessian
975
product is computed. Must be the same size as the input of
976
``func``. This argument is optional when ``func``'s input contains
977
a single element and (if it is not provided) will be set as a
978
Tensor containing a single ``1``.
979
create_graph (bool, optional): If ``True``, both the output and result
980
will be computed in a differentiable way. Note that when ``strict``
981
is ``False``, the result can not require gradients or be
982
disconnected from the inputs.
983
Defaults to ``False``.
984
strict (bool, optional): If ``True``, an error will be raised when we
985
detect that there exists an input such that all the outputs are
986
independent of it. If ``False``, we return a Tensor of zeros as the
987
vhp for said inputs, which is the expected mathematical value.
988
Defaults to ``False``.
991
output (tuple): tuple with:
992
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
994
vhp (tuple of Tensors or Tensor): result of the dot product with the
995
same shape as the inputs.
999
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
1000
>>> def pow_reducer(x):
1001
... return x.pow(3).sum()
1002
>>> inputs = torch.rand(2, 2)
1003
>>> v = torch.ones(2, 2)
1004
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
1005
>>> vhp(pow_reducer, inputs, v)
1007
tensor([[1.0689, 1.2431],
1009
>>> vhp(pow_reducer, inputs, v, create_graph=True)
1010
(tensor(0.5591, grad_fn=<SumBackward0>),
1011
tensor([[1.0689, 1.2431],
1012
[3.0989, 4.4456]], grad_fn=<MulBackward0>))
1013
>>> def pow_adder_reducer(x, y):
1014
... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
1015
>>> inputs = (torch.rand(2), torch.rand(2))
1016
>>> v = (torch.zeros(2), torch.ones(2))
1017
>>> vhp(pow_adder_reducer, inputs, v)
1022
with torch.enable_grad():
1023
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp")
1024
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
1027
_, v = _as_tuple(v, "v", "vhp")
1028
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
1029
_validate_v(v, inputs, is_inputs_tuple)
1031
if len(inputs) != 1 or inputs[0].nelement() != 1:
1033
"The vector v can only be None if the input to the user-provided function "
1034
"is a single Tensor with a single element."
1036
outputs = func(*inputs)
1037
is_outputs_tuple, outputs = _as_tuple(
1038
outputs, "outputs of the user-provided function", "vhp"
1040
_check_requires_grad(outputs, "outputs", strict=strict)
1042
if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
1044
"The function given to vhp should return a single Tensor"
1047
if outputs[0].nelement() != 1:
1049
"The Tensor returned by the function given to vhp should contain a single element"
1052
jac = _autograd_grad(outputs, inputs, create_graph=True)
1053
_check_requires_grad(jac, "jacobian", strict=strict)
1055
enable_grad = True if create_graph else torch.is_grad_enabled()
1056
with torch.set_grad_enabled(enable_grad):
1057
grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)
1058
vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back")
1060
outputs = _grad_postprocess(outputs, create_graph)
1061
vhp = _grad_postprocess(vhp, create_graph)
1063
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
1064
vhp, is_inputs_tuple
1068
def hvp(func, inputs, v=None, create_graph=False, strict=False):
1069
r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point.
1072
func (function): a Python function that takes Tensor inputs and returns
1073
a Tensor with a single element.
1074
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
1075
v (tuple of Tensors or Tensor): The vector for which the Hessian vector
1076
product is computed. Must be the same size as the input of
1077
``func``. This argument is optional when ``func``'s input contains
1078
a single element and (if it is not provided) will be set as a
1079
Tensor containing a single ``1``.
1080
create_graph (bool, optional): If ``True``, both the output and result will be
1081
computed in a differentiable way. Note that when ``strict`` is
1082
``False``, the result can not require gradients or be disconnected
1083
from the inputs. Defaults to ``False``.
1084
strict (bool, optional): If ``True``, an error will be raised when we
1085
detect that there exists an input such that all the outputs are
1086
independent of it. If ``False``, we return a Tensor of zeros as the
1087
hvp for said inputs, which is the expected mathematical value.
1088
Defaults to ``False``.
1090
output (tuple): tuple with:
1091
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
1093
hvp (tuple of Tensors or Tensor): result of the dot product with
1094
the same shape as the inputs.
1098
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
1099
>>> def pow_reducer(x):
1100
... return x.pow(3).sum()
1101
>>> inputs = torch.rand(2, 2)
1102
>>> v = torch.ones(2, 2)
1103
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
1104
>>> hvp(pow_reducer, inputs, v)
1106
tensor([[2.0239, 1.6456],
1109
>>> hvp(pow_reducer, inputs, v, create_graph=True)
1110
(tensor(0.1448, grad_fn=<SumBackward0>),
1111
tensor([[2.0239, 1.6456],
1112
[2.4988, 1.4310]], grad_fn=<MulBackward0>))
1115
>>> def pow_adder_reducer(x, y):
1116
... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
1117
>>> inputs = (torch.rand(2), torch.rand(2))
1118
>>> v = (torch.zeros(2), torch.ones(2))
1119
>>> hvp(pow_adder_reducer, inputs, v)
1126
This function is significantly slower than `vhp` due to backward mode AD constraints.
1127
If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you
1128
know that your function satisfies this condition, you should use vhp instead that is
1129
much faster with the current implementation.
1132
with torch.enable_grad():
1133
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp")
1134
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
1137
_, v = _as_tuple(v, "v", "hvp")
1138
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
1139
_validate_v(v, inputs, is_inputs_tuple)
1141
if len(inputs) != 1 or inputs[0].nelement() != 1:
1143
"The vector v can only be None if the input to the user-provided function "
1144
"is a single Tensor with a single element."
1146
outputs = func(*inputs)
1147
is_outputs_tuple, outputs = _as_tuple(
1148
outputs, "outputs of the user-provided function", "hvp"
1150
_check_requires_grad(outputs, "outputs", strict=strict)
1152
if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
1154
"The function given to hvp should return a single Tensor"
1157
if outputs[0].nelement() != 1:
1159
"The Tensor returned by the function given to hvp should contain a single element"
1162
jac = _autograd_grad(outputs, inputs, create_graph=True)
1163
_check_requires_grad(jac, "jacobian", strict=strict)
1165
grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs)
1167
double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)
1168
_check_requires_grad(jac, "hessian", strict=strict)
1170
enable_grad = True if create_graph else torch.is_grad_enabled()
1171
with torch.set_grad_enabled(enable_grad):
1172
grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)
1173
hvp = _fill_in_zeros(
1174
grad_res, inputs, strict, create_graph, "double_back_trick"
1177
outputs = _grad_postprocess(outputs, create_graph)
1178
hvp = _grad_postprocess(hvp, create_graph)
1180
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
1181
hvp, is_inputs_tuple