pytorch

Форк
0
/
functional.py 
1182 строки · 51.1 Кб
1
from typing import List, Tuple
2

3
import torch
4
from torch._vmap_internals import _vmap
5
from . import forward_ad as fwAD
6

7
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
8

9
# Utility functions
10

11

12
def _as_tuple_nocheck(x):
13
    if isinstance(x, tuple):
14
        return x
15
    elif isinstance(x, list):
16
        return tuple(x)
17
    else:
18
        return (x,)
19

20

21
def _as_tuple(inp, arg_name=None, fn_name=None):
22
    # Ensures that inp is a tuple of Tensors
23
    # Returns whether or not the original inp was a tuple and the tupled version of the input
24
    if arg_name is None and fn_name is None:
25
        return _as_tuple_nocheck(inp)
26

27
    is_inp_tuple = True
28
    if not isinstance(inp, tuple):
29
        inp = (inp,)
30
        is_inp_tuple = False
31

32
    for i, el in enumerate(inp):
33
        if not isinstance(el, torch.Tensor):
34
            if is_inp_tuple:
35
                raise TypeError(
36
                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
37
                    f" value at index {i} has type {type(el)}."
38
                )
39
            else:
40
                raise TypeError(
41
                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
42
                    f" given {arg_name} has type {type(el)}."
43
                )
44

45
    return is_inp_tuple, inp
46

47

48
def _tuple_postprocess(res, to_unpack):
49
    # Unpacks a potentially nested tuple of Tensors
50
    # to_unpack should be a single boolean or a tuple of two booleans.
51
    # It is used to:
52
    # - invert _as_tuple when res should match the inp given to _as_tuple
53
    # - optionally remove nesting of two tuples created by multiple calls to _as_tuple
54
    if isinstance(to_unpack, tuple):
55
        assert len(to_unpack) == 2
56
        if not to_unpack[1]:
57
            res = tuple(el[0] for el in res)
58
        if not to_unpack[0]:
59
            res = res[0]
60
    else:
61
        if not to_unpack:
62
            res = res[0]
63
    return res
64

65

66
def _grad_preprocess(inputs, create_graph, need_graph):
67
    # Preprocess the inputs to make sure they require gradient
68
    # inputs is a tuple of Tensors to preprocess
69
    # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
70
    # need_graph specifies if we internally want gradients to flow back to the Tensors in res
71
    # Note that we *always* create a new Tensor object to be able to see the difference between
72
    # inputs given as arguments and the same Tensors automatically captured by the user function.
73
    # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
74
    res = []
75
    for inp in inputs:
76
        if create_graph and inp.requires_grad:
77
            # Create at least a new Tensor object in a differentiable way
78
            if not inp.is_sparse:
79
                # Use .view_as() to get a shallow copy
80
                res.append(inp.view_as(inp))
81
            else:
82
                # We cannot use view for sparse Tensors so we clone
83
                res.append(inp.clone())
84
        else:
85
            res.append(inp.detach().requires_grad_(need_graph))
86
    return tuple(res)
87

88

89
def _grad_postprocess(inputs, create_graph):
90
    # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
91
    # request it.
92
    if isinstance(inputs[0], torch.Tensor):
93
        if not create_graph:
94
            return tuple(inp.detach() for inp in inputs)
95
        else:
96
            return inputs
97
    else:
98
        return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
99

100

101
def _validate_v(v, other, is_other_tuple):
102
    # This assumes that other is the correct shape, and v should match
103
    # Both are assumed to be tuples of Tensors
104
    if len(other) != len(v):
105
        if is_other_tuple:
106
            raise RuntimeError(
107
                f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
108
            )
109
        else:
110
            raise RuntimeError("The given v should contain a single Tensor.")
111

112
    for idx, (el_v, el_other) in enumerate(zip(v, other)):
113
        if el_v.size() != el_other.size():
114
            prepend = ""
115
            if is_other_tuple:
116
                prepend = f"Entry {idx} in "
117
            raise RuntimeError(
118
                f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}."
119
            )
120

121

122
def _check_requires_grad(inputs, input_type, strict):
123
    # Used to make all the necessary checks to raise nice errors in strict mode.
124
    if not strict:
125
        return
126

127
    if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
128
        raise RuntimeError("Invalid input_type to _check_requires_grad")
129
    for i, inp in enumerate(inputs):
130
        if inp is None:
131
            # This can only be reached for grad_inputs.
132
            raise RuntimeError(
133
                f"The output of the user-provided function is independent of input {i}."
134
                " This is not allowed in strict mode."
135
            )
136
        if not inp.requires_grad:
137
            if input_type == "hessian":
138
                raise RuntimeError(
139
                    f"The hessian of the user-provided function with respect to input {i}"
140
                    " is independent of the input. This is not allowed in strict mode."
141
                    " You should ensure that your function is thrice differentiable and that"
142
                    " the hessian depends on the inputs."
143
                )
144
            elif input_type == "jacobian":
145
                raise RuntimeError(
146
                    "While computing the hessian, found that the jacobian of the user-provided"
147
                    f" function with respect to input {i} is independent of the input. This is not"
148
                    " allowed in strict mode. You should ensure that your function is twice"
149
                    " differentiable and that the jacobian depends on the inputs (this would be"
150
                    " violated by a linear function for example)."
151
                )
152
            elif input_type == "grad_inputs":
153
                raise RuntimeError(
154
                    f"The gradient with respect to input {i} is independent of the inputs of the"
155
                    " user-provided function. This is not allowed in strict mode."
156
                )
157
            else:
158
                raise RuntimeError(
159
                    f"Output {i} of the user-provided function does not require gradients."
160
                    " The outputs must be computed in a differentiable manner from the input"
161
                    " when running in strict mode."
162
                )
163

164

165
def _autograd_grad(
166
    outputs,
167
    inputs,
168
    grad_outputs=None,
169
    create_graph=False,
170
    retain_graph=None,
171
    is_grads_batched=False,
172
):
173
    # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
174
    # This has the extra constraint that inputs has to be a tuple
175
    assert isinstance(outputs, tuple)
176
    if grad_outputs is None:
177
        grad_outputs = (None,) * len(outputs)
178
    assert isinstance(grad_outputs, tuple)
179
    assert len(outputs) == len(grad_outputs)
180

181
    new_outputs: Tuple[torch.Tensor, ...] = tuple()
182
    new_grad_outputs: Tuple[torch.Tensor, ...] = tuple()
183
    for out, grad_out in zip(outputs, grad_outputs):
184
        if out is not None and out.requires_grad:
185
            new_outputs += (out,)
186
            new_grad_outputs += (grad_out,)
187

188
    if len(new_outputs) == 0:
189
        # No differentiable output, we don't need to call the autograd engine
190
        return (None,) * len(inputs)
191
    else:
192
        return torch.autograd.grad(
193
            new_outputs,
194
            inputs,
195
            new_grad_outputs,
196
            allow_unused=True,
197
            create_graph=create_graph,
198
            retain_graph=retain_graph,
199
            is_grads_batched=is_grads_batched,
200
        )
201

202

203
def _fill_in_zeros(grads, refs, strict, create_graph, stage):
204
    # Used to detect None in the grads and depending on the flags, either replace them
205
    # with Tensors full of 0s of the appropriate size based on the refs or raise an error.
206
    # strict and create graph allow us to detect when it is appropriate to raise an error
207
    # stage gives us information of which backward call we consider to give good error message
208
    if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
209
        raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
210

211
    res: Tuple[torch.Tensor, ...] = tuple()
212
    for i, grads_i in enumerate(grads):
213
        if grads_i is None:
214
            if strict:
215
                if stage == "back":
216
                    raise RuntimeError(
217
                        "The output of the user-provided function is independent of "
218
                        f"input {i}. This is not allowed in strict mode."
219
                    )
220
                elif stage == "back_trick":
221
                    raise RuntimeError(
222
                        f"The gradient with respect to the input is independent of entry {i}"
223
                        " in the grad_outputs when using the double backward trick to compute"
224
                        " forward mode gradients. This is not allowed in strict mode."
225
                    )
226
                elif stage == "double_back":
227
                    raise RuntimeError(
228
                        "The jacobian of the user-provided function is independent of "
229
                        f"input {i}. This is not allowed in strict mode."
230
                    )
231
                else:
232
                    raise RuntimeError(
233
                        "The hessian of the user-provided function is independent of "
234
                        f"entry {i} in the grad_jacobian. This is not allowed in strict "
235
                        "mode as it prevents from using the double backward trick to "
236
                        "replace forward mode AD."
237
                    )
238

239
            grads_i = torch.zeros_like(refs[i])
240
        else:
241
            if strict and create_graph and not grads_i.requires_grad:
242
                if "double" not in stage:
243
                    raise RuntimeError(
244
                        "The jacobian of the user-provided function is independent of "
245
                        f"input {i}. This is not allowed in strict mode when create_graph=True."
246
                    )
247
                else:
248
                    raise RuntimeError(
249
                        "The hessian of the user-provided function is independent of "
250
                        f"input {i}. This is not allowed in strict mode when create_graph=True."
251
                    )
252

253
        res += (grads_i,)
254

255
    return res
256

257

258
# Public API
259

260

261
def vjp(func, inputs, v=None, create_graph=False, strict=False):
262
    r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
263

264
    Args:
265
        func (function): a Python function that takes Tensor inputs and returns
266
            a tuple of Tensors or a Tensor.
267
        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
268
        v (tuple of Tensors or Tensor): The vector for which the vector
269
            Jacobian product is computed.  Must be the same size as the output
270
            of ``func``. This argument is optional when the output of ``func``
271
            contains a single element and (if it is not provided) will be set
272
            as a Tensor containing a single ``1``.
273
        create_graph (bool, optional): If ``True``, both the output and result
274
            will be computed in a differentiable way. Note that when ``strict``
275
            is ``False``, the result can not require gradients or be
276
            disconnected from the inputs.  Defaults to ``False``.
277
        strict (bool, optional): If ``True``, an error will be raised when we
278
            detect that there exists an input such that all the outputs are
279
            independent of it. If ``False``, we return a Tensor of zeros as the
280
            vjp for said inputs, which is the expected mathematical value.
281
            Defaults to ``False``.
282

283
    Returns:
284
        output (tuple): tuple with:
285
            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
286

287
            vjp (tuple of Tensors or Tensor): result of the dot product with
288
            the same shape as the inputs.
289

290
    Example:
291

292
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
293
        >>> def exp_reducer(x):
294
        ...     return x.exp().sum(dim=1)
295
        >>> inputs = torch.rand(4, 4)
296
        >>> v = torch.ones(4)
297
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
298
        >>> vjp(exp_reducer, inputs, v)
299
        (tensor([5.7817, 7.2458, 5.7830, 6.7782]),
300
         tensor([[1.4458, 1.3962, 1.3042, 1.6354],
301
                [2.1288, 1.0652, 1.5483, 2.5035],
302
                [2.2046, 1.1292, 1.1432, 1.3059],
303
                [1.3225, 1.6652, 1.7753, 2.0152]]))
304

305
        >>> vjp(exp_reducer, inputs, v, create_graph=True)
306
        (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),
307
         tensor([[1.4458, 1.3962, 1.3042, 1.6354],
308
                [2.1288, 1.0652, 1.5483, 2.5035],
309
                [2.2046, 1.1292, 1.1432, 1.3059],
310
                [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
311

312
        >>> def adder(x, y):
313
        ...     return 2 * x + 3 * y
314
        >>> inputs = (torch.rand(2), torch.rand(2))
315
        >>> v = torch.ones(2)
316
        >>> vjp(adder, inputs, v)
317
        (tensor([2.4225, 2.3340]),
318
         (tensor([2., 2.]), tensor([3., 3.])))
319
    """
320
    with torch.enable_grad():
321
        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
322
        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
323

324
        outputs = func(*inputs)
325
        is_outputs_tuple, outputs = _as_tuple(
326
            outputs, "outputs of the user-provided function", "vjp"
327
        )
328
        _check_requires_grad(outputs, "outputs", strict=strict)
329

330
        if v is not None:
331
            _, v = _as_tuple(v, "v", "vjp")
332
            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
333
            _validate_v(v, outputs, is_outputs_tuple)
334
        else:
335
            if len(outputs) != 1 or outputs[0].nelement() != 1:
336
                raise RuntimeError(
337
                    "The vector v can only be None if the "
338
                    "user-provided function returns "
339
                    "a single Tensor with a single element."
340
                )
341

342
    enable_grad = True if create_graph else torch.is_grad_enabled()
343
    with torch.set_grad_enabled(enable_grad):
344
        grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
345
        vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
346

347
    # Cleanup objects and return them to the user
348
    outputs = _grad_postprocess(outputs, create_graph)
349
    vjp = _grad_postprocess(vjp, create_graph)
350

351
    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
352
        vjp, is_inputs_tuple
353
    )
354

355

356
def jvp(func, inputs, v=None, create_graph=False, strict=False):
357
    r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
358

359
    Args:
360
        func (function): a Python function that takes Tensor inputs and returns
361
            a tuple of Tensors or a Tensor.
362
        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
363
        v (tuple of Tensors or Tensor): The vector for which the Jacobian
364
            vector product is computed. Must be the same size as the input of
365
            ``func``. This argument is optional when the input to ``func``
366
            contains a single element and (if it is not provided) will be set
367
            as a Tensor containing a single ``1``.
368
        create_graph (bool, optional): If ``True``, both the output and result
369
            will be computed in a differentiable way. Note that when ``strict``
370
            is ``False``, the result can not require gradients or be
371
            disconnected from the inputs.  Defaults to ``False``.
372
        strict (bool, optional): If ``True``, an error will be raised when we
373
            detect that there exists an input such that all the outputs are
374
            independent of it. If ``False``, we return a Tensor of zeros as the
375
            jvp for said inputs, which is the expected mathematical value.
376
            Defaults to ``False``.
377

378
    Returns:
379
        output (tuple): tuple with:
380
            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
381

382
            jvp (tuple of Tensors or Tensor): result of the dot product with
383
            the same shape as the output.
384

385
    Note:
386
        ``autograd.functional.jvp`` computes the jvp by using the backward of
387
        the backward (sometimes called the double backwards trick). This is not
388
        the most performant way of computing the jvp. Please consider using
389
        :func:`torch.func.jvp` or the
390
        :ref:`low-level forward-mode AD API <forward-mode-ad>` instead.
391

392
    Example:
393

394
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
395
        >>> def exp_reducer(x):
396
        ...     return x.exp().sum(dim=1)
397
        >>> inputs = torch.rand(4, 4)
398
        >>> v = torch.ones(4, 4)
399
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
400
        >>> jvp(exp_reducer, inputs, v)
401
        (tensor([6.3090, 4.6742, 7.9114, 8.2106]),
402
         tensor([6.3090, 4.6742, 7.9114, 8.2106]))
403

404
        >>> jvp(exp_reducer, inputs, v, create_graph=True)
405
        (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>),
406
         tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))
407

408
        >>> def adder(x, y):
409
        ...     return 2 * x + 3 * y
410
        >>> inputs = (torch.rand(2), torch.rand(2))
411
        >>> v = (torch.ones(2), torch.ones(2))
412
        >>> jvp(adder, inputs, v)
413
        (tensor([2.2399, 2.5005]),
414
         tensor([5., 5.]))
415

416
    """
417
    with torch.enable_grad():
418
        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
419
        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
420

421
        if v is not None:
422
            _, v = _as_tuple(v, "v", "jvp")
423
            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
424
            _validate_v(v, inputs, is_inputs_tuple)
425
        else:
426
            if len(inputs) != 1 or inputs[0].nelement() != 1:
427
                raise RuntimeError(
428
                    "The vector v can only be None if the input to "
429
                    "the user-provided function is a single Tensor "
430
                    "with a single element."
431
                )
432

433
        outputs = func(*inputs)
434
        is_outputs_tuple, outputs = _as_tuple(
435
            outputs, "outputs of the user-provided function", "jvp"
436
        )
437
        _check_requires_grad(outputs, "outputs", strict=strict)
438
        # The backward is linear so the value of grad_outputs is not important as
439
        # it won't appear in the double backward graph. We only need to ensure that
440
        # it does not contain inf or nan.
441
        grad_outputs = tuple(
442
            torch.zeros_like(out, requires_grad=True) for out in outputs
443
        )
444

445
        grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
446
        _check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
447

448
    if create_graph:
449
        with torch.enable_grad():
450
            grad_res = _autograd_grad(
451
                grad_inputs, grad_outputs, v, create_graph=create_graph
452
            )
453
            jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
454
    else:
455
        grad_res = _autograd_grad(
456
            grad_inputs, grad_outputs, v, create_graph=create_graph
457
        )
458
        jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
459

460
    # Cleanup objects and return them to the user
461
    outputs = _grad_postprocess(outputs, create_graph)
462
    jvp = _grad_postprocess(jvp, create_graph)
463

464
    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
465
        jvp, is_outputs_tuple
466
    )
467

468

469
def _construct_standard_basis_for(
470
    tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]
471
) -> Tuple[torch.Tensor, ...]:
472
    # This function:
473
    # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
474
    # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
475
    # - Each chunk corresponds to one tensor. The chunk has the same dtype and
476
    #   device as the tensor
477
    #
478
    # For example, with tensor_numels = [1, 2, 1], this function returns:
479
    # ( tensor([[1],     tensor([[0, 0],      tensor([[0],
480
    #           [0],             [1, 0],              [0],
481
    #           [0],             [0, 1],              [0],
482
    #           [0]])  ,         [0, 0]])  ,          [1]])  )
483
    #
484
    # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
485
    # Precondition: tensors always has at least one element.
486
    #
487
    # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
488
    # for context behind this function. All the pre-conditions are guarded for
489
    # in torch.autograd.functional.jacobian.
490
    assert len(tensors) == len(tensor_numels)
491
    assert len(tensors) > 0
492
    total_numel = sum(tensor_numels)
493
    chunks = tuple(
494
        tensor.new_zeros(total_numel, tensor_numel)
495
        for tensor, tensor_numel in zip(tensors, tensor_numels)
496
    )
497
    diag_start_idx = 0
498
    for chunk, numel in zip(chunks, tensor_numels):
499
        chunk.diagonal(diag_start_idx).fill_(1)
500
        diag_start_idx -= numel
501
    return chunks
502

503

504
def _jacfwd(func, inputs, strict=False, vectorize=False):
505
    if strict:
506
        raise RuntimeError(
507
            "torch.autograd.functional.jacobian: `strict=True` "
508
            'and `strategy="forward-mode"` are not supported together (yet). '
509
            "Please either set `strict=False` or "
510
            '`strategy="reverse-mode"`.'
511
        )
512
    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
513
    output_info = []
514

515
    if vectorize:
516
        # See NOTE: [Computing jacobian with vmap and grad for multiple outputs]
517
        input_numels = tuple(input.numel() for input in inputs)
518

519
        # Step 1: Prepare tangents
520
        tangents = _construct_standard_basis_for(inputs, input_numels)
521

522
        # Step 2: Compute vmap over computation with dual tensors
523
        def jvp(tangents):
524
            with fwAD.dual_level():
525
                dual_inputs = tuple(
526
                    fwAD.make_dual(input, tangent.view_as(input))
527
                    for input, tangent in zip(inputs, tangents)
528
                )
529
                _is_outputs_tuple, dual_outputs = _as_tuple(
530
                    func(*dual_inputs), "outputs"
531
                )
532
                output_info.append(_is_outputs_tuple)
533
                jv = []
534
                primal_outs = []
535
                for dual_out in dual_outputs:
536
                    primal, tangent = fwAD.unpack_dual(dual_out)
537
                    primal_outs.append(primal)
538
                    if tangent is not None:
539
                        jv.append(tangent)
540
                    else:
541
                        jv.append(torch.zeros_like(primal))
542
                output_info.append(primal_outs)
543
                return tuple(jv)
544

545
        outputs_before_split = _vmap(jvp)(tangents)
546
        is_outputs_tuple, outputs = output_info
547
        # Step 3: for each of the output tangents, split along dim 0
548
        jacobian_input_output = []
549
        for jac_output_i, output_i in zip(outputs_before_split, outputs):
550
            jacobian_output_i_output = []
551
            for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs):
552
                # We need to transpose the Jacobian because in forward AD, the
553
                # batch dimension represents that of the inputs
554
                jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape(
555
                    (*output_i.shape, *input_j.shape)
556
                )  # noqa: C409
557

558
                jacobian_output_i_output.append(jacobian_input_i_output_j)
559
            jacobian_input_output.append(jacobian_output_i_output)
560

561
        # Omit [Step 4] because everything is already transposed w/ forward AD
562
        return _tuple_postprocess(
563
            jacobian_input_output, (is_outputs_tuple, is_inputs_tuple)
564
        )
565
    else:
566
        raise NotImplementedError(
567
            "Computing Jacobian using forward-AD or forward-over-reverse Hessian is"
568
            "only implemented for `vectorize=True`."
569
        )
570

571

572
def jacobian(
573
    func,
574
    inputs,
575
    create_graph=False,
576
    strict=False,
577
    vectorize=False,
578
    strategy="reverse-mode",
579
):
580
    r"""Compute the Jacobian of a given function.
581

582
    Args:
583
        func (function): a Python function that takes Tensor inputs and returns
584
            a tuple of Tensors or a Tensor.
585
        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
586
        create_graph (bool, optional): If ``True``, the Jacobian will be
587
            computed in a differentiable manner. Note that when ``strict`` is
588
            ``False``, the result can not require gradients or be disconnected
589
            from the inputs.  Defaults to ``False``.
590
        strict (bool, optional): If ``True``, an error will be raised when we
591
            detect that there exists an input such that all the outputs are
592
            independent of it. If ``False``, we return a Tensor of zeros as the
593
            jacobian for said inputs, which is the expected mathematical value.
594
            Defaults to ``False``.
595
        vectorize (bool, optional): This feature is experimental.
596
            Please consider using :func:`torch.func.jacrev` or
597
            :func:`torch.func.jacfwd` instead if you are looking for something
598
            less experimental and more performant.
599
            When computing the jacobian, usually we invoke
600
            ``autograd.grad`` once per row of the jacobian. If this flag is
601
            ``True``, we perform only a single ``autograd.grad`` call with
602
            ``batched_grad=True`` which uses the vmap prototype feature.
603
            Though this should lead to performance improvements in many cases,
604
            because this feature is still experimental, there may be performance
605
            cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for
606
            more information.
607
        strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to
608
            determine whether the Jacobian will be computed with forward or reverse
609
            mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``.
610
            Defaults to ``"reverse-mode"``. If ``func`` has more outputs than
611
            inputs, ``"forward-mode"`` tends to be more performant. Otherwise,
612
            prefer to use ``"reverse-mode"``.
613

614
    Returns:
615
        Jacobian (Tensor or nested tuple of Tensors): if there is a single
616
        input and output, this will be a single Tensor containing the
617
        Jacobian for the linearized inputs and output. If one of the two is
618
        a tuple, then the Jacobian will be a tuple of Tensors. If both of
619
        them are tuples, then the Jacobian will be a tuple of tuple of
620
        Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the
621
        ``i``\th output and ``j``\th input and will have as size the
622
        concatenation of the sizes of the corresponding output and the
623
        corresponding input and will have same dtype and device as the
624
        corresponding input. If strategy is ``forward-mode``, the dtype will be
625
        that of the output; otherwise, the input.
626

627
    Example:
628

629
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
630
        >>> def exp_reducer(x):
631
        ...     return x.exp().sum(dim=1)
632
        >>> inputs = torch.rand(2, 2)
633
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
634
        >>> jacobian(exp_reducer, inputs)
635
        tensor([[[1.4917, 2.4352],
636
                 [0.0000, 0.0000]],
637
                [[0.0000, 0.0000],
638
                 [2.4369, 2.3799]]])
639

640
        >>> jacobian(exp_reducer, inputs, create_graph=True)
641
        tensor([[[1.4917, 2.4352],
642
                 [0.0000, 0.0000]],
643
                [[0.0000, 0.0000],
644
                 [2.4369, 2.3799]]], grad_fn=<ViewBackward>)
645

646
        >>> def exp_adder(x, y):
647
        ...     return 2 * x.exp() + 3 * y
648
        >>> inputs = (torch.rand(2), torch.rand(2))
649
        >>> jacobian(exp_adder, inputs)
650
        (tensor([[2.8052, 0.0000],
651
                [0.0000, 3.3963]]),
652
         tensor([[3., 0.],
653
                 [0., 3.]]))
654
    """
655
    assert strategy in ("forward-mode", "reverse-mode"), (
656
        'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your '
657
        'function has more outputs than inputs, "forward-mode" tends to be more performant. '
658
        'Otherwise, prefer to use "reverse-mode".'
659
    )
660
    if strategy == "forward-mode":
661
        if create_graph:
662
            raise NotImplementedError(
663
                "torch.autograd.functional.jacobian: `create_graph=True` "
664
                'and `strategy="forward-mode"` are not supported together (yet). '
665
                "Please either set `create_graph=False` or "
666
                '`strategy="reverse-mode"`.'
667
            )
668
        return _jacfwd(func, inputs, strict, vectorize)
669

670
    with torch.enable_grad():
671
        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
672
        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
673

674
        outputs = func(*inputs)
675
        is_outputs_tuple, outputs = _as_tuple(
676
            outputs, "outputs of the user-provided function", "jacobian"
677
        )
678
        _check_requires_grad(outputs, "outputs", strict=strict)
679

680
        if vectorize:
681
            if strict:
682
                raise RuntimeError(
683
                    "torch.autograd.functional.jacobian: `strict=True` "
684
                    "and `vectorized=True` are not supported together. "
685
                    "Please either set `strict=False` or "
686
                    "`vectorize=False`."
687
                )
688
            # NOTE: [Computing jacobian with vmap and grad for multiple outputs]
689
            #
690
            # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
691
            # It turns out we can compute the jacobian of this function with a single
692
            # call to autograd.grad by using vmap over the correct grad_outputs.
693
            #
694
            # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
695
            # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
696
            #
697
            # To get the first row of the jacobian, we call
698
            # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
699
            # To get the 2nd row of the jacobian, we call
700
            # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
701
            # and so on.
702
            #
703
            # Using vmap, we can vectorize all 4 of these computations into one by
704
            # passing the standard basis for R^4 as the grad_output.
705
            # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
706
            #
707
            # Now, how do we compute the jacobian *without stacking the output*?
708
            # We can just split the standard basis across the outputs. So to
709
            # compute the jacobian of f(x), we'd use
710
            # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
711
            # The grad_outputs looks like the following:
712
            # ( torch.tensor([[1, 0, 0],
713
            #                 [0, 1, 0],
714
            #                 [0, 0, 1],
715
            #                 [0, 0, 0]]),
716
            #   torch.tensor([[0],
717
            #                 [0],
718
            #                 [0],
719
            #                 [1]]) )
720
            #
721
            # But we're not done yet!
722
            # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
723
            # returns a Tensor of shape [4, 3]. We have to remember to split the
724
            # jacobian of shape [4, 3] into two:
725
            # - one of shape [3, 3] for the first output
726
            # - one of shape [   3] for the second output
727

728
            # Step 1: Construct grad_outputs by splitting the standard basis
729
            output_numels = tuple(output.numel() for output in outputs)
730
            grad_outputs = _construct_standard_basis_for(outputs, output_numels)
731
            flat_outputs = tuple(output.reshape(-1) for output in outputs)
732

733
            # Step 2: Call vmap + autograd.grad
734
            def vjp(grad_output):
735
                vj = list(
736
                    _autograd_grad(
737
                        flat_outputs,
738
                        inputs,
739
                        grad_output,
740
                        create_graph=create_graph,
741
                        is_grads_batched=True,
742
                    )
743
                )
744
                for el_idx, vj_el in enumerate(vj):
745
                    if vj_el is not None:
746
                        continue
747
                    vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand(
748
                        (sum(output_numels),) + inputs[el_idx].shape
749
                    )
750
                return tuple(vj)
751

752
            jacobians_of_flat_output = vjp(grad_outputs)
753

754
            # Step 3: The returned jacobian is one big tensor per input. In this step,
755
            # we split each Tensor by output.
756
            jacobian_input_output = []
757
            for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs):
758
                jacobian_input_i_output = []
759
                for jac, output_j in zip(
760
                    jac_input_i.split(output_numels, dim=0), outputs
761
                ):
762
                    jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)
763
                    jacobian_input_i_output.append(jacobian_input_i_output_j)
764
                jacobian_input_output.append(jacobian_input_i_output)
765

766
            # Step 4: Right now, `jacobian` is a List[List[Tensor]].
767
            # The outer List corresponds to the number of inputs,
768
            # the inner List corresponds to the number of outputs.
769
            # We need to exchange the order of these and convert to tuples
770
            # before returning.
771
            jacobian_output_input = tuple(zip(*jacobian_input_output))
772

773
            jacobian_output_input = _grad_postprocess(
774
                jacobian_output_input, create_graph
775
            )
776
            return _tuple_postprocess(
777
                jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)
778
            )
779

780
        jacobian: Tuple[torch.Tensor, ...] = tuple()
781

782
        for i, out in enumerate(outputs):
783
            # mypy complains that expression and variable have different types due to the empty list
784
            jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs)))  # type: ignore[assignment]
785
            for j in range(out.nelement()):
786
                vj = _autograd_grad(
787
                    (out.reshape(-1)[j],),
788
                    inputs,
789
                    retain_graph=True,
790
                    create_graph=create_graph,
791
                )
792

793
                for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(
794
                    zip(jac_i, vj, inputs)
795
                ):
796
                    if vj_el is not None:
797
                        if strict and create_graph and not vj_el.requires_grad:
798
                            msg = (
799
                                "The jacobian of the user-provided function is "
800
                                f"independent of input {i}. This is not allowed in "
801
                                "strict mode when create_graph=True."
802
                            )
803
                            raise RuntimeError(msg)
804
                        jac_i_el.append(vj_el)
805
                    else:
806
                        if strict:
807
                            msg = (
808
                                f"Output {i} of the user-provided function is "
809
                                f"independent of input {el_idx}. This is not allowed in "
810
                                "strict mode."
811
                            )
812
                            raise RuntimeError(msg)
813
                        jac_i_el.append(torch.zeros_like(inp_el))
814

815
            jacobian += (
816
                tuple(
817
                    torch.stack(jac_i_el, dim=0).view(
818
                        out.size() + inputs[el_idx].size()  # type: ignore[operator]
819
                    )
820
                    for (el_idx, jac_i_el) in enumerate(jac_i)
821
                ),
822
            )
823

824
        jacobian = _grad_postprocess(jacobian, create_graph)
825

826
        return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
827

828

829
def hessian(
830
    func,
831
    inputs,
832
    create_graph=False,
833
    strict=False,
834
    vectorize=False,
835
    outer_jacobian_strategy="reverse-mode",
836
):
837
    r"""Compute the Hessian of a given scalar function.
838

839
    Args:
840
        func (function): a Python function that takes Tensor inputs and returns
841
            a Tensor with a single element.
842
        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
843
        create_graph (bool, optional): If ``True``, the Hessian will be computed in
844
            a differentiable manner. Note that when ``strict`` is ``False``, the result can not
845
            require gradients or be disconnected from the inputs.
846
            Defaults to ``False``.
847
        strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input
848
            such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
849
            hessian for said inputs, which is the expected mathematical value.
850
            Defaults to ``False``.
851
        vectorize (bool, optional): This feature is experimental.
852
            Please consider using :func:`torch.func.hessian`
853
            instead if you are looking for something less experimental and more performant.
854
            When computing the hessian, usually we invoke
855
            ``autograd.grad`` once per row of the hessian. If this flag is
856
            ``True``, we use the vmap prototype feature as the backend to
857
            vectorize calls to ``autograd.grad`` so we only invoke it once
858
            instead of once per row. This should lead to performance
859
            improvements in many use cases, however, due to this feature
860
            being incomplete, there may be performance cliffs. Please
861
            use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
862
            to show any performance warnings and file us issues if
863
            warnings exist for your use case. Defaults to ``False``.
864
        outer_jacobian_strategy (str, optional): The Hessian is computed by
865
            computing the Jacobian of a Jacobian. The inner Jacobian is always
866
            computed in reverse-mode AD. Setting strategy to ``"forward-mode"``
867
            or ``"reverse-mode"`` determines whether the outer Jacobian will be
868
            computed with forward or reverse mode AD. Currently, computing the outer
869
            Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults
870
            to ``"reverse-mode"``.
871

872
    Returns:
873
        Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,
874
        this will be a single Tensor containing the Hessian for the input.
875
        If it is a tuple, then the Hessian will be a tuple of tuples where
876
        ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input
877
        and ``j``\th input with size the sum of the size of the ``i``\th input plus
878
        the size of the ``j``\th input. ``Hessian[i][j]`` will have the same
879
        dtype and device as the corresponding ``i``\th input.
880

881
    Example:
882

883
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
884
        >>> def pow_reducer(x):
885
        ...     return x.pow(3).sum()
886
        >>> inputs = torch.rand(2, 2)
887
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
888
        >>> hessian(pow_reducer, inputs)
889
        tensor([[[[5.2265, 0.0000],
890
                  [0.0000, 0.0000]],
891
                 [[0.0000, 4.8221],
892
                  [0.0000, 0.0000]]],
893
                [[[0.0000, 0.0000],
894
                  [1.9456, 0.0000]],
895
                 [[0.0000, 0.0000],
896
                  [0.0000, 3.2550]]]])
897

898
        >>> hessian(pow_reducer, inputs, create_graph=True)
899
        tensor([[[[5.2265, 0.0000],
900
                  [0.0000, 0.0000]],
901
                 [[0.0000, 4.8221],
902
                  [0.0000, 0.0000]]],
903
                [[[0.0000, 0.0000],
904
                  [1.9456, 0.0000]],
905
                 [[0.0000, 0.0000],
906
                  [0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
907

908

909
        >>> def pow_adder_reducer(x, y):
910
        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
911
        >>> inputs = (torch.rand(2), torch.rand(2))
912
        >>> hessian(pow_adder_reducer, inputs)
913
        ((tensor([[4., 0.],
914
                  [0., 4.]]),
915
          tensor([[0., 0.],
916
                  [0., 0.]])),
917
         (tensor([[0., 0.],
918
                  [0., 0.]]),
919
          tensor([[6., 0.],
920
                  [0., 6.]])))
921
    """
922
    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
923
    assert outer_jacobian_strategy in (
924
        "forward-mode",
925
        "reverse-mode",
926
    ), 'Expected strategy to be either "forward-mode" or "reverse-mode".'
927

928
    def ensure_single_output_function(*inp):
929
        out = func(*inp)
930
        is_out_tuple, t_out = _as_tuple(
931
            out, "outputs of the user-provided function", "hessian"
932
        )
933
        _check_requires_grad(t_out, "outputs", strict=strict)
934

935
        if is_out_tuple or not isinstance(out, torch.Tensor):
936
            raise RuntimeError(
937
                "The function given to hessian should return a single Tensor"
938
            )
939

940
        if out.nelement() != 1:
941
            raise RuntimeError(
942
                "The Tensor returned by the function given to hessian should contain a single element"
943
            )
944

945
        return out.squeeze()
946

947
    def jac_func(*inp):
948
        if outer_jacobian_strategy == "forward-mode":
949
            # _grad_preprocess requires create_graph=True and input to require_grad
950
            # or else the input will be detached
951
            inp = tuple(t.requires_grad_(True) for t in inp)
952
        jac = jacobian(ensure_single_output_function, inp, create_graph=True)
953
        _check_requires_grad(jac, "jacobian", strict=strict)
954
        return jac
955

956
    res = jacobian(
957
        jac_func,
958
        inputs,
959
        create_graph=create_graph,
960
        strict=strict,
961
        vectorize=vectorize,
962
        strategy=outer_jacobian_strategy,
963
    )
964
    return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))
965

966

967
def vhp(func, inputs, v=None, create_graph=False, strict=False):
968
    r"""Compute the dot product between vector ``v`` and Hessian of a  given scalar function at a specified point.
969

970
    Args:
971
        func (function): a Python function that takes Tensor inputs and returns
972
            a Tensor with a single element.
973
        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
974
        v (tuple of Tensors or Tensor): The vector for which the vector Hessian
975
            product is computed. Must be the same size as the input of
976
            ``func``. This argument is optional when ``func``'s input contains
977
            a single element and (if it is not provided) will be set as a
978
            Tensor containing a single ``1``.
979
        create_graph (bool, optional): If ``True``, both the output and result
980
            will be computed in a differentiable way. Note that when ``strict``
981
            is ``False``, the result can not require gradients or be
982
            disconnected from the inputs.
983
            Defaults to ``False``.
984
        strict (bool, optional): If ``True``, an error will be raised when we
985
            detect that there exists an input such that all the outputs are
986
            independent of it. If ``False``, we return a Tensor of zeros as the
987
            vhp for said inputs, which is the expected mathematical value.
988
            Defaults to ``False``.
989

990
    Returns:
991
        output (tuple): tuple with:
992
            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
993

994
            vhp (tuple of Tensors or Tensor): result of the dot product with the
995
            same shape as the inputs.
996

997
    Example:
998

999
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
1000
        >>> def pow_reducer(x):
1001
        ...     return x.pow(3).sum()
1002
        >>> inputs = torch.rand(2, 2)
1003
        >>> v = torch.ones(2, 2)
1004
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1005
        >>> vhp(pow_reducer, inputs, v)
1006
        (tensor(0.5591),
1007
         tensor([[1.0689, 1.2431],
1008
                 [3.0989, 4.4456]]))
1009
        >>> vhp(pow_reducer, inputs, v, create_graph=True)
1010
        (tensor(0.5591, grad_fn=<SumBackward0>),
1011
         tensor([[1.0689, 1.2431],
1012
                 [3.0989, 4.4456]], grad_fn=<MulBackward0>))
1013
        >>> def pow_adder_reducer(x, y):
1014
        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
1015
        >>> inputs = (torch.rand(2), torch.rand(2))
1016
        >>> v = (torch.zeros(2), torch.ones(2))
1017
        >>> vhp(pow_adder_reducer, inputs, v)
1018
        (tensor(4.8053),
1019
         (tensor([0., 0.]),
1020
          tensor([6., 6.])))
1021
    """
1022
    with torch.enable_grad():
1023
        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp")
1024
        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
1025

1026
        if v is not None:
1027
            _, v = _as_tuple(v, "v", "vhp")
1028
            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
1029
            _validate_v(v, inputs, is_inputs_tuple)
1030
        else:
1031
            if len(inputs) != 1 or inputs[0].nelement() != 1:
1032
                raise RuntimeError(
1033
                    "The vector v can only be None if the input to the user-provided function "
1034
                    "is a single Tensor with a single element."
1035
                )
1036
        outputs = func(*inputs)
1037
        is_outputs_tuple, outputs = _as_tuple(
1038
            outputs, "outputs of the user-provided function", "vhp"
1039
        )
1040
        _check_requires_grad(outputs, "outputs", strict=strict)
1041

1042
        if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
1043
            raise RuntimeError(
1044
                "The function given to vhp should return a single Tensor"
1045
            )
1046

1047
        if outputs[0].nelement() != 1:
1048
            raise RuntimeError(
1049
                "The Tensor returned by the function given to vhp should contain a single element"
1050
            )
1051

1052
        jac = _autograd_grad(outputs, inputs, create_graph=True)
1053
        _check_requires_grad(jac, "jacobian", strict=strict)
1054

1055
    enable_grad = True if create_graph else torch.is_grad_enabled()
1056
    with torch.set_grad_enabled(enable_grad):
1057
        grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)
1058
        vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back")
1059

1060
    outputs = _grad_postprocess(outputs, create_graph)
1061
    vhp = _grad_postprocess(vhp, create_graph)
1062

1063
    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
1064
        vhp, is_inputs_tuple
1065
    )
1066

1067

1068
def hvp(func, inputs, v=None, create_graph=False, strict=False):
1069
    r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point.
1070

1071
    Args:
1072
        func (function): a Python function that takes Tensor inputs and returns
1073
            a Tensor with a single element.
1074
        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
1075
        v (tuple of Tensors or Tensor): The vector for which the Hessian vector
1076
            product is computed. Must be the same size as the input of
1077
            ``func``. This argument is optional when ``func``'s input contains
1078
            a single element and (if it is not provided) will be set as a
1079
            Tensor containing a single ``1``.
1080
        create_graph (bool, optional): If ``True``, both the output and result will be
1081
            computed in a differentiable way. Note that when ``strict`` is
1082
            ``False``, the result can not require gradients or be disconnected
1083
            from the inputs.  Defaults to ``False``.
1084
        strict (bool, optional): If ``True``, an error will be raised when we
1085
            detect that there exists an input such that all the outputs are
1086
            independent of it. If ``False``, we return a Tensor of zeros as the
1087
            hvp for said inputs, which is the expected mathematical value.
1088
            Defaults to ``False``.
1089
    Returns:
1090
        output (tuple): tuple with:
1091
            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
1092

1093
            hvp (tuple of Tensors or Tensor): result of the dot product with
1094
            the same shape as the inputs.
1095

1096
    Example:
1097

1098
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
1099
        >>> def pow_reducer(x):
1100
        ...     return x.pow(3).sum()
1101
        >>> inputs = torch.rand(2, 2)
1102
        >>> v = torch.ones(2, 2)
1103
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1104
        >>> hvp(pow_reducer, inputs, v)
1105
        (tensor(0.1448),
1106
         tensor([[2.0239, 1.6456],
1107
                 [2.4988, 1.4310]]))
1108

1109
        >>> hvp(pow_reducer, inputs, v, create_graph=True)
1110
        (tensor(0.1448, grad_fn=<SumBackward0>),
1111
         tensor([[2.0239, 1.6456],
1112
                 [2.4988, 1.4310]], grad_fn=<MulBackward0>))
1113

1114

1115
        >>> def pow_adder_reducer(x, y):
1116
        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
1117
        >>> inputs = (torch.rand(2), torch.rand(2))
1118
        >>> v = (torch.zeros(2), torch.ones(2))
1119
        >>> hvp(pow_adder_reducer, inputs, v)
1120
        (tensor(2.3030),
1121
         (tensor([0., 0.]),
1122
          tensor([6., 6.])))
1123

1124
    Note:
1125

1126
        This function is significantly slower than `vhp` due to backward mode AD constraints.
1127
        If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you
1128
        know that your function satisfies this condition, you should use vhp instead that is
1129
        much faster with the current implementation.
1130

1131
    """
1132
    with torch.enable_grad():
1133
        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp")
1134
        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
1135

1136
        if v is not None:
1137
            _, v = _as_tuple(v, "v", "hvp")
1138
            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
1139
            _validate_v(v, inputs, is_inputs_tuple)
1140
        else:
1141
            if len(inputs) != 1 or inputs[0].nelement() != 1:
1142
                raise RuntimeError(
1143
                    "The vector v can only be None if the input to the user-provided function "
1144
                    "is a single Tensor with a single element."
1145
                )
1146
        outputs = func(*inputs)
1147
        is_outputs_tuple, outputs = _as_tuple(
1148
            outputs, "outputs of the user-provided function", "hvp"
1149
        )
1150
        _check_requires_grad(outputs, "outputs", strict=strict)
1151

1152
        if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
1153
            raise RuntimeError(
1154
                "The function given to hvp should return a single Tensor"
1155
            )
1156

1157
        if outputs[0].nelement() != 1:
1158
            raise RuntimeError(
1159
                "The Tensor returned by the function given to hvp should contain a single element"
1160
            )
1161

1162
        jac = _autograd_grad(outputs, inputs, create_graph=True)
1163
        _check_requires_grad(jac, "jacobian", strict=strict)
1164

1165
        grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs)
1166

1167
        double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)
1168
        _check_requires_grad(jac, "hessian", strict=strict)
1169

1170
    enable_grad = True if create_graph else torch.is_grad_enabled()
1171
    with torch.set_grad_enabled(enable_grad):
1172
        grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)
1173
        hvp = _fill_in_zeros(
1174
            grad_res, inputs, strict, create_graph, "double_back_trick"
1175
        )
1176

1177
    outputs = _grad_postprocess(outputs, create_graph)
1178
    hvp = _grad_postprocess(hvp, create_graph)
1179

1180
    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
1181
        hvp, is_inputs_tuple
1182
    )
1183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.