pytorch

Форк
0
/
eager_transforms.py 
1640 строк · 68.0 Кб
1
# mypy: ignore-errors
2

3
# Copyright (c) Facebook, Inc. and its affiliates.
4
# All rights reserved.
5
#
6
# This source code is licensed under the BSD-style license found in the
7
# LICENSE file in the root directory of this source tree.
8

9
from typing import Callable, Union, Tuple, List, Any, Optional
10
import torch
11
from functools import partial, wraps
12
import contextlib
13
from torch.utils._pytree import (
14
    tree_flatten,
15
    tree_unflatten,
16
    tree_map,
17
    tree_map_only,
18
    tree_map_,
19
    treespec_pprint,
20
)
21
from torch.utils import _pytree as pytree
22
from torch.fx.experimental import const_fold
23
from torch.fx.experimental.proxy_tensor import make_fx
24
import torch.autograd.forward_ad as fwAD
25
from torch._subclasses.functional_tensor import FunctionalTensor
26

27
from .vmap import doesnt_support_saved_tensors_hooks, get_chunk_sizes
28
from .apis import vmap
29

30
from torch._C._functorch import (
31
    _wrap_for_grad,
32
    _unwrap_for_grad,
33
    _grad_increment_nesting,
34
    _grad_decrement_nesting,
35
    _jvp_increment_nesting,
36
    _jvp_decrement_nesting,
37
    _wrap_functional_tensor,
38
    _unwrap_functional_tensor,
39
    _func_decrement_nesting,
40
    _func_increment_nesting,
41
    _assert_wrapped_functional,
42
    _propagate_functional_input_mutation,
43
    set_inplace_requires_grad_allowed,
44
    get_inplace_requires_grad_allowed,
45
)
46
from torch._functorch.utils import exposed_in, argnums_t
47

48

49
def lazy_dynamo_disable(func):
50
    import torch._dynamo
51
    return torch._dynamo.disable(func)
52

53
@contextlib.contextmanager
54
def enable_inplace_requires_grad(enabled):
55
    prev_state = get_inplace_requires_grad_allowed()
56
    set_inplace_requires_grad_allowed(enabled)
57
    try:
58
        yield
59
    finally:
60
        set_inplace_requires_grad_allowed(prev_state)
61

62

63
def _vjp_treespec_compare(primals_out, cotangents):
64
    # Revert this once #116264 gets fixed
65
    _, primals_out_spec = tree_flatten(primals_out)
66
    _, cotangents_spec = tree_flatten(cotangents)
67
    # Dynamo fails to trace operator.ne below. To bypass this limitation, this
68
    # function is not inlined.
69
    if primals_out_spec != cotangents_spec:
70
        raise RuntimeError(
71
            f'Expected pytree structure of cotangents to be the same '
72
            f'as pytree structure of outputs to the function. '
73
            f'cotangents: {treespec_pprint(cotangents_spec)}, '
74
            f'primal output: {treespec_pprint(primals_out_spec)}')
75

76

77
def _set_tensor_requires_grad(x):
78
    # avoid graph-break on x.requires_grad_()
79
    # https://github.com/pytorch/pytorch/pull/110053
80
    return x.requires_grad_()
81

82
def _create_differentiable(inps, level=None):
83
    def create_differentiable(x):
84
        if isinstance(x, torch.Tensor):
85
            with enable_inplace_requires_grad(True):
86
                return _set_tensor_requires_grad(x)
87
        raise ValueError(f'Thing passed to transform API must be Tensor, '
88
                         f'got {type(x)}')
89
    return tree_map(create_differentiable, inps)
90

91

92
def _undo_create_differentiable(inps, level=None):
93
    def unwrap_tensors(x):
94
        if isinstance(x, torch.Tensor):
95
            return _unwrap_for_grad(x, level)
96
        # TODO: Remove the following hack for namedtuples
97
        if isinstance(x, tuple):
98
            return tree_map(unwrap_tensors, tuple(x))
99

100
        raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
101

102
    return tree_map(unwrap_tensors, inps)
103

104

105
def _is_differentiable(maybe_tensor):
106
    if not isinstance(maybe_tensor, torch.Tensor):
107
        return False
108
    return maybe_tensor.requires_grad
109

110

111
def _any_differentiable(tensor_or_tuple_of_tensors):
112
    flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
113
    return any(tuple(map(_is_differentiable, flat_args)))
114

115

116
def _wrap_tensor_for_grad(maybe_tensor, level):
117
    if not isinstance(maybe_tensor, torch.Tensor):
118
        return maybe_tensor
119
    return _wrap_for_grad(maybe_tensor, level)
120

121

122
def _wrap_all_tensors(tensor_pytree, level):
123
    return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
124

125

126
def _as_tuple(val):
127
    if isinstance(val, tuple):
128
        return val
129
    return (val,)
130

131
# Version of autograd.grad that handles outputs that don't depend on inputs
132

133

134
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
135
    if grad_outputs is None:
136
        diff_outputs = tuple(out for out in outputs if out.requires_grad)
137
    else:
138
        result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
139
        if len(result) == 0:
140
            diff_outputs, grad_outputs = (), ()
141
        else:
142
            diff_outputs, grad_outputs = zip(*result)
143
    if len(diff_outputs) == 0:
144
        return tuple(torch.zeros_like(inp) for inp in inputs)
145
    grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
146
                                      retain_graph=retain_graph,
147
                                      create_graph=create_graph,
148
                                      allow_unused=True)
149
    grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
150
                        for gi, inp in zip(grad_inputs, inputs))
151
    return grad_inputs
152

153
# NOTE [grad and vjp interaction with no_grad]
154
#
155
# def f(x):
156
#   with torch.no_grad():
157
#     c = x ** 2
158
#   return x - c
159
#
160
# The thing to consider is if enable_grad is on/off before grad gets called.
161
#
162
# Case 1: enable_grad is on.
163
# grad(f)(x)
164
# In this case, `grad` should respect the inner torch.no_grad.
165
#
166
# Case 2: enable_grad is off
167
# with torch.no_grad():
168
#   grad(f)(x)
169
# In this case, `grad` should respect the inner torch.no_grad, but not the
170
# outer one. This is because `grad` is a "function transform": its result
171
# should not depend on the result of a context manager outside of `f`.
172
#
173
# This gives us the following desired behavior:
174
# - (nested) grad transforms must obey torch.no_grad inside them
175
# - (nested) grad transforms should not obey torch.no_grad outside them
176
#
177
# To achieve this behavior, upon entering grad/vjp:
178
# - we save the current ("previous") is_grad_enabled (*)
179
# - we unconditionally enable grad.
180
#
181
# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
182
# off the stack:
183
# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
184
#   active, all subsequent grad transforms must obey it).
185
# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
186
#   then we temporarily restore the previous `is_grad_enabled`. This is
187
#   because we're crossing the boundary from a `grad` outside the
188
#   no_grad to a `grad` inside the no_grad.
189
#
190
# NB: vjp has some interesting behavior because the vjp's callable can be called
191
# under a different grad_mode than the forward computation...
192
#
193
# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but
194
# it respects c10::AutoFwGradMode. We've implemented the same logic for
195
# our jvp transform (it will have special handling if FwGradMode is disabled).
196

197

198
# How do we increment and decrement the nesting? I don't think we can.
199
@exposed_in("torch.func")
200
def vjp(func: Callable, *primals, has_aux: bool = False):
201
    """
202
    Standing for the vector-Jacobian product, returns a tuple containing the
203
    results of ``func`` applied to ``primals`` and a function that, when
204
    given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with
205
    respect to ``primals`` times ``cotangents``.
206

207
    Args:
208
        func (Callable): A Python function that takes one or more arguments. Must
209
            return one or more Tensors.
210
        primals (Tensors): Positional arguments to ``func`` that must all be
211
            Tensors. The returned function will also be computing the
212
            derivative with respect to these arguments
213
        has_aux (bool): Flag indicating that ``func`` returns a
214
            ``(output, aux)`` tuple where the first element is the output of
215
            the function to be differentiated and the second element is
216
            other auxiliary objects that will not be differentiated.
217
            Default: False.
218

219
    Returns:
220
        Returns a ``(output, vjp_fn)`` tuple containing the output of ``func``
221
        applied to ``primals`` and a function that computes the vjp of
222
        ``func`` with respect to all ``primals`` using the cotangents passed
223
        to the returned function. If ``has_aux is True``, then instead returns a
224
        ``(output, vjp_fn, aux)`` tuple.
225
        The returned ``vjp_fn`` function will return a tuple of each VJP.
226

227
    When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
228

229
        >>> x = torch.randn([5])
230
        >>> f = lambda x: x.sin().sum()
231
        >>> (_, vjpfunc) = torch.func.vjp(f, x)
232
        >>> grad = vjpfunc(torch.tensor(1.))[0]
233
        >>> assert torch.allclose(grad, torch.func.grad(f)(x))
234

235
    However, :func:`vjp` can support functions with multiple outputs by
236
    passing in the cotangents for each of the outputs
237

238
        >>> x = torch.randn([5])
239
        >>> f = lambda x: (x.sin(), x.cos())
240
        >>> (_, vjpfunc) = torch.func.vjp(f, x)
241
        >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
242
        >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
243

244
    :func:`vjp` can even support outputs being Python structs
245

246
        >>> x = torch.randn([5])
247
        >>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
248
        >>> (_, vjpfunc) = torch.func.vjp(f, x)
249
        >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
250
        >>> vjps = vjpfunc(cotangents)
251
        >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
252

253
    The function returned by :func:`vjp` will compute the partials with
254
    respect to each of the ``primals``
255

256
        >>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
257
        >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
258
        >>> cotangents = torch.randn([5, 5])
259
        >>> vjps = vjpfunc(cotangents)
260
        >>> assert len(vjps) == 2
261
        >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
262
        >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
263

264
    ``primals`` are the positional arguments for ``f``. All kwargs use their
265
    default value
266

267
        >>> x = torch.randn([5])
268
        >>> def f(x, scale=4.):
269
        >>>   return x * scale
270
        >>>
271
        >>> (_, vjpfunc) = torch.func.vjp(f, x)
272
        >>> vjps = vjpfunc(torch.ones_like(x))
273
        >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
274

275
    .. note::
276
        Using PyTorch ``torch.no_grad`` together with ``vjp``.
277
        Case 1: Using ``torch.no_grad`` inside a function:
278

279
            >>> def f(x):
280
            >>>     with torch.no_grad():
281
            >>>         c = x ** 2
282
            >>>     return x - c
283

284
        In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
285

286
        Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
287

288
            >>> # xdoctest: +SKIP(failing)
289
            >>> with torch.no_grad():
290
            >>>     vjp(f)(x)
291

292
        In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
293
        outer one. This is because ``vjp`` is a "function transform": its result
294
        should not depend on the result of a context manager outside of ``f``.
295
    """
296
    return _vjp_with_argnums(func, *primals, has_aux=has_aux)
297

298

299
@contextlib.contextmanager
300
def grad_increment_nesting():
301
    try:
302
        grad_level = _grad_increment_nesting()
303
        yield grad_level
304
    finally:
305
        _grad_decrement_nesting()
306

307

308
@doesnt_support_saved_tensors_hooks
309
def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False):
310
    # This is the same function as vjp but also accepts an argnums argument
311
    # All args are the same as vjp except for the added argument
312
    # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
313
    #         If None, computes the gradients with respect to all inputs (used for vjp). Default: None
314
    #
315
    # WARN: Users should NOT call this function directly and should just be calling vjp.
316
    # It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers.
317
    #
318
    # NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev
319
    #
320
    # Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs
321
    # for only the primal elements given by argnums.
322
    with grad_increment_nesting() as level:
323
        # See NOTE [grad and vjp interaction with no_grad]
324
        with torch.enable_grad():
325
            primals = _wrap_all_tensors(primals, level)
326
            # Note for the reviewer: This is extremely odd but it passes the
327
            # assertion "len(self.block_stack) == 1" on symbolic_convert.py
328
            # The equivalent "if argnums is None" fails for some reason
329
            if not isinstance(argnums, int) and not argnums:
330
                diff_primals = _create_differentiable(primals, level)
331
            else:
332
                diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
333
                tree_map_(partial(_create_differentiable, level=level), diff_primals)
334
            primals_out = func(*primals)
335

336
            if has_aux:
337
                if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
338
                    raise RuntimeError(
339
                        "vjp(f, *primals): output of function f should be a tuple: (output, aux) "
340
                        "if has_aux is True"
341
                    )
342
                primals_out, aux = primals_out
343
                aux = _undo_create_differentiable(aux, level)
344

345
            flat_primals_out, primals_out_spec = tree_flatten(primals_out)
346
            assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)')
347
            flat_diff_primals, primals_spec = tree_flatten(diff_primals)
348
            results = _undo_create_differentiable(primals_out, level)
349

350
            for primal_out in flat_primals_out:
351
                assert isinstance(primal_out, torch.Tensor)
352
                if primal_out.is_floating_point() or primal_out.is_complex():
353
                    continue
354
                raise RuntimeError("vjp(f, ...): All outputs of f must be "
355
                                   "floating-point or complex Tensors, got Tensor "
356
                                   f"with dtype {primal_out.dtype}")
357

358
        def wrapper(cotangents, retain_graph=True, create_graph=None):
359
            if create_graph is None:
360
                create_graph = torch.is_grad_enabled()
361
            flat_cotangents, cotangents_spec = tree_flatten(cotangents)
362
            _vjp_treespec_compare(primals_out, cotangents)
363
            result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
364
                                    retain_graph=retain_graph, create_graph=create_graph)
365
            return tree_unflatten(result, primals_spec)
366

367
    if has_aux:
368
        return results, wrapper, aux
369
    else:
370
        return results, wrapper
371

372

373
def _safe_zero_index(x):
374
    assert len(x) == 1
375
    return x[0]
376

377
# jacrev and jacfwd don't support complex functions
378
# Helper function to throw appropriate error.
379
def error_if_complex(func_name, args, is_input):
380
    flat_args = pytree.tree_leaves(args)
381
    for idx, arg in enumerate(flat_args):
382
        if isinstance(arg, torch.Tensor) and arg.dtype.is_complex:
383
            input_or_output = ("inputs" if is_input else "outputs")
384
            err_msg = (f"{func_name}: Expected all {input_or_output} "
385
                       f"to be real but received complex tensor at flattened input idx: {idx}")
386
            raise RuntimeError(err_msg)
387

388
@exposed_in("torch.func")
389
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
390
           chunk_size: Optional[int] = None,
391
           _preallocate_and_copy=False):
392
    """
393
    Computes the Jacobian of ``func`` with respect to the arg(s) at index
394
    ``argnum`` using reverse mode autodiff
395

396
    .. note::
397
        Using :attr:`chunk_size=1` is equivalent to computing the jacobian
398
        row-by-row with a for-loop i.e. the constraints of :func:`vmap` are
399
        not applicable.
400

401
    Args:
402
        func (function): A Python function that takes one or more arguments,
403
            one of which must be a Tensor, and returns one or more Tensors
404
        argnums (int or Tuple[int]): Optional, integer or tuple of integers,
405
            saying which arguments to get the Jacobian with respect to.
406
            Default: 0.
407
        has_aux (bool): Flag indicating that ``func`` returns a
408
            ``(output, aux)`` tuple where the first element is the output of
409
            the function to be differentiated and the second element is
410
            auxiliary objects that will not be differentiated.
411
            Default: False.
412
        chunk_size (None or int): If None (default), use the maximum chunk size
413
            (equivalent to doing a single vmap over vjp to compute the jacobian).
414
            If 1, then compute the jacobian row-by-row with a for-loop.
415
            If not None, then compute the jacobian :attr:`chunk_size` rows at a time
416
            (equivalent to doing multiple vmap over vjp). If you run into memory issues computing
417
            the jacobian, please try to specify a non-None chunk_size.
418

419
    Returns:
420
        Returns a function that takes in the same inputs as ``func`` and
421
        returns the Jacobian of ``func`` with respect to the arg(s) at
422
        ``argnums``. If ``has_aux is True``, then the returned function
423
        instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
424
        is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
425

426
    A basic usage with a pointwise, unary operation will give a diagonal array
427
    as the Jacobian
428

429
        >>> from torch.func import jacrev
430
        >>> x = torch.randn(5)
431
        >>> jacobian = jacrev(torch.sin)(x)
432
        >>> expected = torch.diag(torch.cos(x))
433
        >>> assert torch.allclose(jacobian, expected)
434

435
    If you would like to compute the output of the function as well as the
436
    jacobian of the function, use the ``has_aux`` flag to return the output
437
    as an auxiliary object:
438

439
        >>> from torch.func import jacrev
440
        >>> x = torch.randn(5)
441
        >>>
442
        >>> def f(x):
443
        >>>   return x.sin()
444
        >>>
445
        >>> def g(x):
446
        >>>   result = f(x)
447
        >>>   return result, result
448
        >>>
449
        >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x)
450
        >>> assert torch.allclose(f_x, f(x))
451

452
    :func:`jacrev` can be composed with vmap to produce batched
453
    Jacobians:
454

455
        >>> from torch.func import jacrev, vmap
456
        >>> x = torch.randn(64, 5)
457
        >>> jacobian = vmap(jacrev(torch.sin))(x)
458
        >>> assert jacobian.shape == (64, 5, 5)
459

460
    Additionally, :func:`jacrev` can be composed with itself to produce
461
    Hessians
462

463
        >>> from torch.func import jacrev
464
        >>> def f(x):
465
        >>>   return x.sin().sum()
466
        >>>
467
        >>> x = torch.randn(5)
468
        >>> hessian = jacrev(jacrev(f))(x)
469
        >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
470

471
    By default, :func:`jacrev` computes the Jacobian with respect to the first
472
    input. However, it can compute the Jacboian with respect to a different
473
    argument by using ``argnums``:
474

475
        >>> from torch.func import jacrev
476
        >>> def f(x, y):
477
        >>>   return x + y ** 2
478
        >>>
479
        >>> x, y = torch.randn(5), torch.randn(5)
480
        >>> jacobian = jacrev(f, argnums=1)(x, y)
481
        >>> expected = torch.diag(2 * y)
482
        >>> assert torch.allclose(jacobian, expected)
483

484
    Additionally, passing a tuple to ``argnums`` will compute the Jacobian
485
    with respect to multiple arguments
486

487
        >>> from torch.func import jacrev
488
        >>> def f(x, y):
489
        >>>   return x + y ** 2
490
        >>>
491
        >>> x, y = torch.randn(5), torch.randn(5)
492
        >>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
493
        >>> expectedX = torch.diag(torch.ones_like(x))
494
        >>> expectedY = torch.diag(2 * y)
495
        >>> assert torch.allclose(jacobian[0], expectedX)
496
        >>> assert torch.allclose(jacobian[1], expectedY)
497

498
    .. note::
499
        Using PyTorch ``torch.no_grad`` together with ``jacrev``.
500
        Case 1: Using ``torch.no_grad`` inside a function:
501

502
            >>> def f(x):
503
            >>>     with torch.no_grad():
504
            >>>         c = x ** 2
505
            >>>     return x - c
506

507
        In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
508

509
        Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
510

511
            >>> with torch.no_grad():
512
            >>>     jacrev(f)(x)
513

514
        In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
515
        outer one. This is because ``jacrev`` is a "function transform": its result
516
        should not depend on the result of a context manager outside of ``f``.
517
    """
518
    if not (chunk_size is None or chunk_size > 0):
519
        raise ValueError("jacrev: `chunk_size` should be greater than 0.")
520

521
    @wraps(func)
522
    def wrapper_fn(*args):
523
        error_if_complex("jacrev", args, is_input=True)
524
        vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
525
        if has_aux:
526
            output, vjp_fn, aux = vjp_out
527
        else:
528
            output, vjp_fn = vjp_out
529

530
        # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
531
        flat_output, output_spec = tree_flatten(output)
532

533
        error_if_complex("jacrev", flat_output, is_input=False)
534

535
        # NB: vjp already checks that all outputs are tensors
536
        # Step 1: Construct grad_outputs by splitting the standard basis
537
        flat_output_numels = tuple(out.numel() for out in flat_output)
538

539
        primals = _slice_argnums(args, argnums)
540
        flat_primals, primals_spec = tree_flatten(primals)
541

542
        def compute_jacobian_stacked():
543
            # Helper function to compute chunked Jacobian
544
            # The intermediate chunked calculation are only
545
            # scoped at this function level.
546
            chunked_results = []
547
            for flat_basis_chunk in _chunked_standard_basis_for_(flat_output,
548
                                                                 flat_output_numels,
549
                                                                 chunk_size=chunk_size):
550
                if chunk_size == 1:
551
                    # sanity check.
552
                    for t in flat_basis_chunk:
553
                        assert t.size(0) == 1
554

555
                    flat_basis_chunk = tree_map(lambda t: torch.squeeze(t, 0), flat_basis_chunk)
556

557
                basis = tree_unflatten(flat_basis_chunk, output_spec)
558

559
                if chunk_size == 1:
560
                    # Behaviour with `chunk_size=1` is same as `for-loop`
561
                    # i.e. user shouldn't deal with the limitations of vmap.
562
                    chunked_result = vjp_fn(basis)
563
                else:  # chunk_size is None or chunk_size != 1
564
                    chunked_result = vmap(vjp_fn)(basis)
565

566
                flat_results = pytree.tree_leaves(chunked_result)
567

568
                if chunk_size == 1:
569
                    flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results)
570

571
                chunked_results.append(flat_results)
572

573
            if len(chunked_results) == 1:
574
                # Short-circuit if we used a single chunk
575
                return chunked_results[0]
576

577
            # Concatenate chunks.
578
            flat_results = []
579
            # Iterate and concat the jacobians of different
580
            # inputs.
581
            for idx in range(len(flat_primals)):
582
                r = tuple(r_[idx] for r_ in chunked_results)
583
                flat_results.append(torch.cat(r, 0))
584

585
            return flat_results
586

587
        def compute_jacobian_preallocate_and_copy():
588
            # Helper function to compute chunked Jacobian
589
            # The intermediate chunked calculation are only
590
            # scoped at this function level.
591
            out_vec_size = sum(flat_output_numels)
592

593
            # Don't pre-allocate if we have a single chunk.
594
            if not (chunk_size is None or chunk_size >= out_vec_size):
595
                stacked_results = [primal.new_zeros(out_vec_size, *primal.shape) for primal in flat_primals]
596

597
            for idx, flat_basis_chunk in enumerate(_chunked_standard_basis_for_(flat_output,
598
                                                                                flat_output_numels,
599
                                                                                chunk_size=chunk_size)):
600
                if chunk_size == 1:
601
                    # sanity check.
602
                    for t in flat_basis_chunk:
603
                        assert t.size(0) == 1
604

605
                    flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk]
606

607
                basis = tree_unflatten(flat_basis_chunk, output_spec)
608

609
                if chunk_size == 1:
610
                    # Behaviour with `chunk_size=1` is same as `for-loop`
611
                    # i.e. user shouldn't deal with the limitations of vmap.
612
                    chunked_result = vjp_fn(basis)
613
                else:  # chunk_size is None or chunk_size != 1
614
                    chunked_result = vmap(vjp_fn)(basis)
615

616
                flat_results = pytree.tree_leaves(chunked_result)
617

618
                # Short-circuit if we have a single chunk.
619
                if chunk_size is None or chunk_size >= out_vec_size:
620
                    if chunk_size == 1:  # and out_vec_size == 1
621
                        # Since we squeezed the output dim
622
                        flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results)
623
                    return flat_results
624

625
                for r, sr in zip(flat_results, stacked_results):
626
                    sr[idx * chunk_size: (idx + 1) * chunk_size].copy_(r)
627

628
            return stacked_results
629

630
        if _preallocate_and_copy:
631
            flat_jacobians_per_input = compute_jacobian_preallocate_and_copy()
632
        else:
633
            flat_jacobians_per_input = compute_jacobian_stacked()
634

635
        # Step 2: The returned jacobian is one big tensor per input. In this step,
636
        # we split each Tensor by output.
637
        flat_jacobians_per_input = [result.split(flat_output_numels, dim=0) for result in flat_jacobians_per_input]
638
        flat_input_flat_output = [
639
            tuple(split.view(out.shape + primal.shape)
640
                  for split, out in zip(splits, flat_output))
641
            for splits, primal in zip(flat_jacobians_per_input, flat_primals)
642
        ]
643

644
        # Step 3: Right now, `jacobian` is a List[List[Tensor]].
645
        # The outer List corresponds to the number of primals,
646
        # the inner List corresponds to the number of outputs.
647
        # We need to:
648
        # a. Exchange the order of the outer List and inner List
649
        # b. tree_unflatten the inner Lists (which correspond to the primals)
650
        # c. handle the argnums=int case
651
        # d. tree_unflatten the outer List (which corresponds to the outputs)
652
        flat_output_flat_input = tuple(zip(*flat_input_flat_output))
653

654
        flat_output_input = tuple(tree_unflatten(flat_input, primals_spec)
655
                                  for flat_input in flat_output_flat_input)
656

657
        if isinstance(argnums, int):
658
            flat_output_input = tuple(_safe_zero_index(flat_input)
659
                                      for flat_input in flat_output_input)
660
        output_input = tree_unflatten(flat_output_input, output_spec)
661
        if has_aux:
662
            return output_input, aux
663
        return output_input
664
    return wrapper_fn
665

666
# NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
667
#
668
# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
669
# It turns out we can compute the jacobian of this function with a single
670
# call to autograd.grad by using vmap over the correct grad_outputs.
671
#
672
# Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
673
# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
674
#
675
# To get the first row of the jacobian, we call
676
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
677
# To get the 2nd row of the jacobian, we call
678
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
679
# and so on.
680
#
681
# Using vmap, we can vectorize all 4 of these computations into one by
682
# passing the standard basis for R^4 as the grad_output.
683
# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
684
#
685
# Now, how do we compute the jacobian *without stacking the output*?
686
# We can just split the standard basis across the outputs. So to
687
# compute the jacobian of f(x), we'd use
688
# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
689
# The grad_outputs looks like the following:
690
# ( torch.tensor([[1, 0, 0],
691
#                 [0, 1, 0],
692
#                 [0, 0, 1],
693
#                 [0, 0, 0]]),
694
#   torch.tensor([[0],
695
#                 [0],
696
#                 [0],
697
#                 [1]]) )
698
#
699
# But we're not done yet!
700
# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
701
# returns a Tensor of shape [4, 3]. We have to remember to split the
702
# jacobian of shape [4, 3] into two:
703
# - one of shape [3, 3] for the first output
704
# - one of shape [   3] for the second output
705

706

707
def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
708
    # This function:
709
    # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
710
    # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
711
    # - Each chunk corresponds to one tensor. The chunk has the same dtype and
712
    #   device as the tensor
713
    #
714
    # For example, with tensor_numels = [1, 2, 1], this function returns:
715
    # ( tensor([[1],     tensor([[0, 0],      tensor([[0],
716
    #           [0],             [1, 0],              [0],
717
    #           [0],             [0, 1],              [0],
718
    #           [0]])  ,         [0, 0]])  ,          [1]])  )
719
    #
720
    # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
721
    # Precondition: tensors always has at least one element.
722
    #
723
    # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
724
    # for context behind this function.
725
    # NOTE: Argument `chunk_size` is used to generate chunked basis instead of
726
    #       one huge basis matrix. `chunk_size` dictates the maximum size of the
727
    #       basis matrix along dim=0.
728
    assert len(tensors) == len(tensor_numels)
729
    assert len(tensors) > 0
730
    assert chunk_size is None or chunk_size > 0
731
    total_numel = sum(tensor_numels)
732
    if chunk_size and chunk_size < total_numel:
733
        chunk_numels = get_chunk_sizes(total_numel, chunk_size)
734
    else:  # chunk_size is None or chunk_size >= total_numel
735
        chunk_size = total_numel
736
        chunk_numels = [total_numel]
737

738
    diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
739

740
    for chunk_idx, total_numel in enumerate(chunk_numels):
741
        chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
742
                       for tensor, tensor_numel in zip(tensors, tensor_numels))
743

744
        for chunk, diag_start_idx in zip(chunks, diag_start_indices):
745
            chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1)
746
        chunks = tuple(chunk.view(total_numel, *tensor.shape)
747
                       for chunk, tensor in zip(chunks, tensors))
748
        yield chunks
749

750
def _construct_standard_basis_for(tensors, tensor_numels):
751
    for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
752
        return basis
753

754

755
def _validate_and_wrap_argnum(argnum, num_args):
756
    if not isinstance(argnum, int):
757
        raise RuntimeError(f'argnum must be int, got: {type(argnum)}')
758
    if argnum >= 0 and argnum < num_args:
759
        return argnum
760
    if argnum < 0 and argnum >= -num_args:
761
        return argnum + num_args
762
    raise RuntimeError(f'Got argnum={argnum}, but only {num_args} positional inputs')
763

764

765
def _check_unique_non_empty(argnums):
766
    if isinstance(argnums, tuple):
767
        if len(argnums) == 0:
768
            raise RuntimeError("argnums must be non-empty")
769
        if len(set(argnums)) != len(argnums):
770
            raise RuntimeError(f"argnums elements must be unique, got {argnums}")
771

772

773
def _replace_args(old_args, new_args, argnums):
774
    if isinstance(argnums, int):
775
        if len(new_args) != 1:
776
            raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')
777
        return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)))
778
    if isinstance(argnums, tuple):
779
        if len(new_args) != len(argnums):
780
            raise RuntimeError(
781
                "new_args should have the same size as argnums. "
782
                f"Argnums size {len(argnums)}, new_args size {len(new_args)}")
783

784
        def get_right_elem(i):
785
            return new_args[argnums.index(i)] if i in argnums else old_args[i]
786

787
        return tuple(get_right_elem(i) for i in range(len(old_args)))
788
    raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
789

790

791
def _validate_and_wrap_argnums(argnums, num_args):
792
    if isinstance(argnums, int):
793
        return _validate_and_wrap_argnum(argnums, num_args)
794
    if isinstance(argnums, tuple):
795
        return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums)
796
    raise AssertionError("Should never get here")
797

798

799
def _slice_argnums(args, argnums, as_tuple=True):
800
    if not isinstance(argnums, int) and not isinstance(argnums, tuple):
801
        raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
802
    argnums = _validate_and_wrap_argnums(argnums, len(args))
803
    _check_unique_non_empty(argnums)
804
    if isinstance(argnums, int):
805
        if as_tuple:
806
            return (args[argnums],)
807
        else:
808
            return args[argnums]
809
    return tuple(args[i] for i in argnums)
810

811

812
JVP_NESTING = 0
813

814

815
@contextlib.contextmanager
816
def noop():
817
    yield
818

819

820
def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:
821
    if not isinstance(elts, tuple):
822
        raise RuntimeError(
823
            f'{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}')
824
    for elt in elts:
825
        if isinstance(elt, torch.Tensor):
826
            continue
827
        raise RuntimeError(
828
            f'{api}: Expected {argname} to be a tuple of Tensors, got '
829
            f'a tuple with an element of type {type(elt)}')
830
    if len(elts) == 0:
831
        raise RuntimeError(
832
            f'{api}: Expected {argname} to be a non-empty tuple of Tensors.')
833

834

835
def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:
836
    if (len(output) == 1 and output[0] is None) or len(output) < 1:
837
        raise RuntimeError(
838
            f'{api}: Expected f to be a function that has non-empty output (got output = {output})'
839
        )
840
    for o in output:
841
        if not isinstance(o, torch.Tensor):
842
            raise RuntimeError(
843
                f'{api}: expected f(*primals) to return only tensors'
844
                f', got unsupported type {type(o)}'
845
            )
846

847

848
def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
849
    if isinstance(output, torch.Tensor):
850
        return
851
    if not isinstance(output, tuple):
852
        raise RuntimeError(
853
            f'{api}: Expected output of f to be a Tensor or Tensors, got '
854
            f'{type(output)}')
855
    if len(output) == 0:
856
        raise RuntimeError(
857
            f'{api}: Expected output of f to be a non-empty tuple of Tensors.')
858
    for out in output:
859
        if isinstance(out, torch.Tensor):
860
            continue
861
        raise RuntimeError(
862
            f'{api}: Expected output of f to be a Tensor or Tensors, got '
863
            f'{type(out)} as an output')
864

865

866
def assert_non_empty_list_of_tensors(output: List[torch.Tensor], api: str, argname: str) -> None:
867
    if len(output) == 0:
868
        raise RuntimeError(
869
            f'{api}: Expected {argname} to contain at least one Tensor.')
870
    for out in output:
871
        if isinstance(out, torch.Tensor):
872
            continue
873
        raise RuntimeError(
874
            f'{api}: Expected {argname} to only contain Tensors, got '
875
            f'{type(out)}')
876

877

878
jvp_str = 'jvp(f, primals, tangents)'
879

880

881
def safe_unpack_dual(dual, strict):
882
    if not isinstance(dual, torch.Tensor):
883
        raise RuntimeError(
884
            f'{jvp_str}: expected f(*args) to return only tensors'
885
            f', got unsupported type {type(dual)}'
886
        )
887

888
    primal, tangent = fwAD.unpack_dual(dual)
889
    if tangent is None:
890
        if strict:
891
            raise RuntimeError(
892
                'jvp(f, primals, tangents, strict=True): '
893
                'The output of f is independent of '
894
                'the inputs. This is not allowed with strict=True.')
895
        tangent = torch.zeros_like(primal)
896
    return primal, tangent
897

898

899
@exposed_in("torch.func")
900
def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):
901
    """
902
    Standing for the Jacobian-vector product, returns a tuple containing
903
    the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
904
    ``primals``" times ``tangents``. This is also known as forward-mode autodiff.
905

906
    Args:
907
        func (function): A Python function that takes one or more arguments,
908
            one of which must be a Tensor, and returns one or more Tensors
909
        primals (Tensors): Positional arguments to ``func`` that must all be
910
            Tensors. The returned function will also be computing the
911
            derivative with respect to these arguments
912
        tangents (Tensors): The "vector" for which Jacobian-vector-product is
913
            computed. Must be the same structure and sizes as the inputs to
914
            ``func``.
915
        has_aux (bool): Flag indicating that ``func`` returns a
916
            ``(output, aux)`` tuple where the first element is the output of
917
            the function to be differentiated and the second element is
918
            other auxiliary objects that will not be differentiated.
919
            Default: False.
920

921
    Returns:
922
        Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
923
        evaluated at ``primals`` and the Jacobian-vector product.
924
        If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple.
925

926
    .. note::
927
        You may see this API error out with "forward-mode AD not implemented
928
        for operator X". If so, please file a bug report and we will prioritize it.
929

930
    jvp is useful when you wish to compute gradients of a function R^1 -> R^N
931

932
        >>> from torch.func import jvp
933
        >>> x = torch.randn([])
934
        >>> f = lambda x: x * torch.tensor([1., 2., 3])
935
        >>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
936
        >>> assert torch.allclose(value, f(x))
937
        >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
938

939
    :func:`jvp` can support functions with multiple inputs by passing in the
940
    tangents for each of the inputs
941

942
         >>> from torch.func import jvp
943
         >>> x = torch.randn(5)
944
         >>> y = torch.randn(5)
945
         >>> f = lambda x, y: (x * y)
946
         >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
947
         >>> assert torch.allclose(output, x + y)
948

949
    """
950

951
    return _jvp_with_argnums(func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux)
952

953

954
@doesnt_support_saved_tensors_hooks
955
def _jvp_with_argnums(func: Callable, primals: Any, tangents: Any, argnums: Optional[argnums_t], *,
956
                      strict: bool = False, has_aux: bool):
957
    # This is the same function as jvp but also accepts an argnums argument
958
    # Most args are the same as jvp except for the added argument
959
    # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
960
    #         If None, computes the gradients with respect to all inputs (used for jvp). Default: None
961
    # Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is
962
    # given by argnums
963
    #
964
    # WARN: Users should NOT call this function directly and should just be calling jvp.
965
    # It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers.
966
    #
967
    # NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd
968
    #
969
    # Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to
970
    # the primals given by argnums
971
    if not isinstance(primals, tuple):
972
        raise RuntimeError(
973
            f'{jvp_str}: Expected primals to be a tuple. '
974
            f'E.g. it should be valid to call f(*primals).')
975
    diff_args = primals if argnums is None else _slice_argnums(primals, argnums)
976
    flat_primals, primals_spec = tree_flatten(diff_args)
977
    flat_tangents, tangents_spec = tree_flatten(tangents)
978
    if primals_spec != tangents_spec:
979
        raise RuntimeError(
980
            f'{jvp_str}: Expected primals and tangents to have the same python '
981
            f'structure. For example, if primals is a tuple of 3 tensors, '
982
            f'tangents also must be. Got primals with structure {primals_spec} '
983
            f'and tangents with structure {tangents_spec}')
984
    assert_non_empty_list_of_tensors(flat_primals, jvp_str, 'primals')
985
    assert_non_empty_list_of_tensors(flat_tangents, jvp_str, 'tangents')
986

987
    level = _jvp_increment_nesting()
988
    try:
989
        global JVP_NESTING
990
        JVP_NESTING += 1
991
        with fwAD._set_fwd_grad_enabled(True):
992
            ctx = fwAD.dual_level if JVP_NESTING == 1 else noop
993
            with ctx():
994
                flat_duals = tuple(fwAD.make_dual(p, t)
995
                                   for p, t in zip(flat_primals, flat_tangents))
996
                duals = tree_unflatten(flat_duals, primals_spec)
997
                if argnums is not None:
998
                    primals = _wrap_all_tensors(primals, level)
999
                    duals = _replace_args(primals, duals, argnums)
1000
                result_duals = func(*duals)
1001
                if has_aux:
1002
                    if not (isinstance(result_duals, tuple) and len(result_duals) == 2):
1003
                        raise RuntimeError(
1004
                            f"{jvp_str}: output of function f should be a tuple: (output, aux) "
1005
                            "if has_aux is True"
1006
                        )
1007
                    result_duals, aux = result_duals
1008
                    aux = _undo_create_differentiable(aux, level)
1009

1010
                result_duals, spec = tree_flatten(result_duals)
1011
                assert_non_empty_tensor_output(result_duals, jvp_str)
1012

1013
                primals_out, tangents_out = \
1014
                    zip(*[safe_unpack_dual(dual, strict) for dual in result_duals])
1015
                primals_out = tree_map(
1016
                    partial(_undo_create_differentiable, level=level), primals_out)
1017
                tangents_out = tree_map(
1018
                    partial(_undo_create_differentiable, level=level), tangents_out)
1019

1020
                primals_out_unflatten = tree_unflatten(primals_out, spec)
1021
                tangents_out_unflatten = tree_unflatten(tangents_out, spec)
1022
                if has_aux:
1023
                    return primals_out_unflatten, tangents_out_unflatten, aux
1024

1025
                return primals_out_unflatten, tangents_out_unflatten
1026
    finally:
1027
        _jvp_decrement_nesting()
1028
        JVP_NESTING -= 1
1029

1030

1031
def safe_unflatten(tensor, dim, shape):
1032
    if len(shape) == 0:
1033
        assert tensor.shape[dim] == 1
1034
        return tensor.squeeze(dim)
1035
    return tensor.unflatten(dim, shape)
1036

1037

1038
@exposed_in("torch.func")
1039
def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):
1040
    """
1041
    Computes the Jacobian of ``func`` with respect to the arg(s) at index
1042
    ``argnum`` using forward-mode autodiff
1043

1044
    Args:
1045
        func (function): A Python function that takes one or more arguments,
1046
            one of which must be a Tensor, and returns one or more Tensors
1047
        argnums (int or Tuple[int]): Optional, integer or tuple of integers,
1048
            saying which arguments to get the Jacobian with respect to.
1049
            Default: 0.
1050
        has_aux (bool): Flag indicating that ``func`` returns a
1051
            ``(output, aux)`` tuple where the first element is the output of
1052
            the function to be differentiated and the second element is
1053
            auxiliary objects that will not be differentiated.
1054
            Default: False.
1055
        randomness(str): Flag indicating what type of randomness to use.
1056
            See :func:`vmap` for more detail. Allowed: "different", "same", "error".
1057
            Default: "error"
1058

1059
    Returns:
1060
        Returns a function that takes in the same inputs as ``func`` and
1061
        returns the Jacobian of ``func`` with respect to the arg(s) at
1062
        ``argnums``. If ``has_aux is True``, then the returned function
1063
        instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
1064
        is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
1065

1066
    .. note::
1067
        You may see this API error out with "forward-mode AD not implemented
1068
        for operator X". If so, please file a bug report and we will prioritize it.
1069
        An alternative is to use :func:`jacrev`, which has better operator coverage.
1070

1071
    A basic usage with a pointwise, unary operation will give a diagonal array
1072
    as the Jacobian
1073

1074
        >>> from torch.func import jacfwd
1075
        >>> x = torch.randn(5)
1076
        >>> jacobian = jacfwd(torch.sin)(x)
1077
        >>> expected = torch.diag(torch.cos(x))
1078
        >>> assert torch.allclose(jacobian, expected)
1079

1080
    :func:`jacfwd` can be composed with vmap to produce batched
1081
    Jacobians:
1082

1083
        >>> from torch.func import jacfwd, vmap
1084
        >>> x = torch.randn(64, 5)
1085
        >>> jacobian = vmap(jacfwd(torch.sin))(x)
1086
        >>> assert jacobian.shape == (64, 5, 5)
1087

1088
    If you would like to compute the output of the function as well as the
1089
    jacobian of the function, use the ``has_aux`` flag to return the output
1090
    as an auxiliary object:
1091

1092
        >>> from torch.func import jacfwd
1093
        >>> x = torch.randn(5)
1094
        >>>
1095
        >>> def f(x):
1096
        >>>   return x.sin()
1097
        >>>
1098
        >>> def g(x):
1099
        >>>   result = f(x)
1100
        >>>   return result, result
1101
        >>>
1102
        >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x)
1103
        >>> assert torch.allclose(f_x, f(x))
1104

1105
    Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
1106
    to produce Hessians
1107

1108
        >>> from torch.func import jacfwd, jacrev
1109
        >>> def f(x):
1110
        >>>   return x.sin().sum()
1111
        >>>
1112
        >>> x = torch.randn(5)
1113
        >>> hessian = jacfwd(jacrev(f))(x)
1114
        >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
1115

1116
    By default, :func:`jacfwd` computes the Jacobian with respect to the first
1117
    input. However, it can compute the Jacboian with respect to a different
1118
    argument by using ``argnums``:
1119

1120
        >>> from torch.func import jacfwd
1121
        >>> def f(x, y):
1122
        >>>   return x + y ** 2
1123
        >>>
1124
        >>> x, y = torch.randn(5), torch.randn(5)
1125
        >>> jacobian = jacfwd(f, argnums=1)(x, y)
1126
        >>> expected = torch.diag(2 * y)
1127
        >>> assert torch.allclose(jacobian, expected)
1128

1129
    Additionally, passing a tuple to ``argnums`` will compute the Jacobian
1130
    with respect to multiple arguments
1131

1132
        >>> from torch.func import jacfwd
1133
        >>> def f(x, y):
1134
        >>>   return x + y ** 2
1135
        >>>
1136
        >>> x, y = torch.randn(5), torch.randn(5)
1137
        >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
1138
        >>> expectedX = torch.diag(torch.ones_like(x))
1139
        >>> expectedY = torch.diag(2 * y)
1140
        >>> assert torch.allclose(jacobian[0], expectedX)
1141
        >>> assert torch.allclose(jacobian[1], expectedY)
1142

1143
    """
1144
    @wraps(func)
1145
    def wrapper_fn(*args):
1146
        error_if_complex("jacfwd", args, is_input=True)
1147
        primals = args if argnums is None else _slice_argnums(args, argnums)
1148
        flat_primals, primals_spec = tree_flatten(primals)
1149
        flat_primals_numels = tuple(p.numel() for p in flat_primals)
1150
        flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
1151
        basis = tree_unflatten(flat_basis, primals_spec)
1152

1153
        def push_jvp(basis):
1154
            output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux)
1155
            # output[0] is the output of `func(*args)`
1156
            error_if_complex("jacfwd", output[0], is_input=False)
1157
            if has_aux:
1158
                _, jvp_out, aux = output
1159
                return jvp_out, aux
1160
            _, jvp_out = output
1161
            return jvp_out
1162

1163
        results = vmap(push_jvp, randomness=randomness)(basis)
1164
        if has_aux:
1165
            results, aux = results
1166
            # aux is in the standard basis format, e.g. NxN matrix
1167
            # We need to fetch the first element as original `func` output
1168
            flat_aux, aux_spec = tree_flatten(aux)
1169
            flat_aux = [value[0] for value in flat_aux]
1170
            aux = tree_unflatten(flat_aux, aux_spec)
1171

1172
        jac_outs, spec = tree_flatten(results)
1173
        # Most probably below output check can never raise an error
1174
        # as jvp should test the output before
1175
        # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')
1176

1177
        jac_outs_ins = tuple(
1178
            tuple(
1179
                safe_unflatten(jac_out_in, -1, primal.shape)
1180
                for primal, jac_out_in in
1181
                zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1))
1182
            )
1183
            for jac_out in jac_outs
1184
        )
1185
        jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins)
1186

1187
        if isinstance(argnums, int):
1188
            jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
1189
        if has_aux:
1190
            return tree_unflatten(jac_outs_ins, spec), aux
1191
        return tree_unflatten(jac_outs_ins, spec)
1192
    return wrapper_fn
1193

1194

1195
@exposed_in("torch.func")
1196
def hessian(func, argnums=0):
1197
    """
1198
    Computes the Hessian of ``func`` with respect to the arg(s) at index
1199
    ``argnum`` via a forward-over-reverse strategy.
1200

1201
    The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
1202
    a good default for good performance. It is possible to compute Hessians
1203
    through other compositions of :func:`jacfwd` and :func:`jacrev` like
1204
    ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
1205

1206
    Args:
1207
        func (function): A Python function that takes one or more arguments,
1208
            one of which must be a Tensor, and returns one or more Tensors
1209
        argnums (int or Tuple[int]): Optional, integer or tuple of integers,
1210
            saying which arguments to get the Hessian with respect to.
1211
            Default: 0.
1212

1213
    Returns:
1214
        Returns a function that takes in the same inputs as ``func`` and
1215
        returns the Hessian of ``func`` with respect to the arg(s) at
1216
        ``argnums``.
1217

1218
    .. note::
1219
        You may see this API error out with "forward-mode AD not implemented
1220
        for operator X". If so, please file a bug report and we will prioritize it.
1221
        An alternative is to use ``jacrev(jacrev(func))``, which has better
1222
        operator coverage.
1223

1224
    A basic usage with a R^N -> R^1 function gives a N x N Hessian:
1225

1226
        >>> from torch.func import hessian
1227
        >>> def f(x):
1228
        >>>   return x.sin().sum()
1229
        >>>
1230
        >>> x = torch.randn(5)
1231
        >>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
1232
        >>> assert torch.allclose(hess, torch.diag(-x.sin()))
1233

1234
    """
1235
    return jacfwd(jacrev(func, argnums), argnums)
1236

1237

1238
@doesnt_support_saved_tensors_hooks
1239
def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable:
1240
    with grad_increment_nesting() as level:
1241
        output, aux, grad_input = None, None, None
1242
        # See NOTE [grad and vjp interaction with no_grad]
1243
        with torch.enable_grad():
1244
            args = _wrap_all_tensors(args, level)
1245
            kwargs = _wrap_all_tensors(kwargs, level)
1246
            diff_args = _slice_argnums(args, argnums, as_tuple=False)
1247
            tree_map_(partial(_create_differentiable, level=level), diff_args)
1248

1249
            output = func(*args, **kwargs)
1250
            if has_aux:
1251
                if not (isinstance(output, tuple) and len(output) == 2):
1252
                    raise RuntimeError(
1253
                        "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
1254
                        "if has_aux is True"
1255
                    )
1256
                output, aux = output
1257

1258
            if not isinstance(output, torch.Tensor):
1259
                raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
1260
                                   f'to return a Tensor, got {type(output)}')
1261
            if output.dim() != 0:
1262
                raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
1263
                                   'to return a scalar Tensor, got tensor with '
1264
                                   f'{output.dim()} dims. Maybe you wanted to '
1265
                                   'use the vjp or jacrev APIs instead?')
1266

1267
            flat_diff_args, spec = tree_flatten(diff_args)
1268

1269
            # NB: need create_graph so that backward pass isn't run in no_grad mode
1270
            flat_outputs = _as_tuple(output)
1271
            flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
1272
            grad_input = tree_unflatten(flat_grad_input, spec)
1273

1274
            grad_input = _undo_create_differentiable(grad_input, level)
1275
            output = _undo_create_differentiable(output, level)
1276
            if has_aux:
1277
                aux = _undo_create_differentiable(aux, level)
1278

1279
        if has_aux:
1280
            return grad_input, (output, aux)
1281
        return grad_input, output
1282

1283

1284
def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
1285
    results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
1286
    if has_aux:
1287
        grad, (_, aux) = results
1288
        return grad, aux
1289
    grad, _ = results
1290
    return grad
1291

1292
def _maybe_wrap_functional_tensor(maybe_tensor, level, *, _python_functionalize: bool = False):
1293
    if not isinstance(maybe_tensor, torch.Tensor):
1294
        return maybe_tensor
1295
    wrapped = _wrap_functional_tensor(maybe_tensor, level)
1296
    _assert_wrapped_functional(maybe_tensor, wrapped)
1297
    if _python_functionalize:
1298
        out = FunctionalTensor(wrapped)
1299
        torch._mirror_autograd_meta_to(maybe_tensor, out)
1300
        return out
1301
    return wrapped
1302

1303

1304
def _wrap_all_tensors_to_functional(tensor_pytree, level, *, _python_functionalize: bool = False):
1305
    return tree_map(partial(lambda x: _maybe_wrap_functional_tensor(
1306
        x, level, _python_functionalize=_python_functionalize)), tensor_pytree)
1307

1308

1309
def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool):
1310
    if not isinstance(maybe_tensor, torch.Tensor):
1311
        return maybe_tensor
1312
    if isinstance(maybe_tensor, FunctionalTensor):
1313
        maybe_tensor = maybe_tensor.elem
1314

1315
    if not torch._is_functional_tensor(maybe_tensor):
1316
        # If it's not a functional tensor, just return it.
1317
        # This can happen if we functionalize a fn that returns a global,
1318
        # which was never wrapped properly.
1319
        return maybe_tensor
1320
    # Sync any pending updates on the output tensor
1321
    torch._sync(maybe_tensor)
1322
    return _unwrap_functional_tensor(maybe_tensor, reapply_views)
1323

1324

1325
def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool):
1326
    return tree_map(lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), tensor_pytree)
1327

1328

1329
@exposed_in("torch.func")
1330
def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
1331
    """
1332
    functionalize is a transform that can be used to remove (intermediate)
1333
    mutations and aliasing from a function, while preserving the function's
1334
    semantics.
1335

1336
    ``functionalize(func)`` returns a new function with the same semantics
1337
    as ``func``, but with all intermediate mutations removed.
1338
    Every inplace operation performed on an intermediate tensor:
1339
    ``intermediate.foo_()``
1340
    gets replaced by its out-of-place equivalent:
1341
    ``intermediate_updated = intermediate.foo()``.
1342

1343
    functionalize is useful for shipping a pytorch program off to
1344
    backends or compilers that aren't able to easily represent
1345
    mutations or aliasing operators.
1346

1347
    Args:
1348
        func (Callable): A Python function that takes one or more arguments.
1349
        remove (str): An optional string argument, that takes on either
1350
            the value 'mutations' or 'mutations_and_views'.
1351
            If 'mutations' is passed in then all mutating operators
1352
            will be replaced with their non-mutating equivalents.
1353
            If 'mutations_and_views' is passed in, then additionally, all aliasing
1354
            operators will be replaced with their non-aliasing equivalents.
1355
            Default: 'mutations'.
1356

1357
    Returns:
1358
        Returns a new "functionalized" function. It takes the same inputs as
1359
        ``func``, and has the same behavior, but any mutations
1360
        (and optionally aliasing) performed on intermediate tensors
1361
        in the function will be removed.
1362

1363
    functionalize will also remove mutations (and views) that were performed on function inputs.
1364
    However to preserve semantics, functionalize will "fix up" the mutations after
1365
    the transform has finished running, by detecting if any tensor inputs "should have"
1366
    been mutated, and copying the new data back to the inputs if necessary.
1367

1368

1369
    Example::
1370

1371
        >>> # xdoctest: +SKIP
1372
        >>> import torch
1373
        >>> from torch.fx.experimental.proxy_tensor import make_fx
1374
        >>> from torch.func import functionalize
1375
        >>>
1376
        >>> # A function that uses mutations and views, but only on intermediate tensors.
1377
        >>> def f(a):
1378
        ...     b = a + 1
1379
        ...     c = b.view(-1)
1380
        ...     c.add_(1)
1381
        ...     return b
1382
        ...
1383
        >>> inpt = torch.randn(2)
1384
        >>>
1385
        >>> out1 = f(inpt)
1386
        >>> out2 = functionalize(f)(inpt)
1387
        >>>
1388
        >>> # semantics are the same (outputs are equivalent)
1389
        >>> print(torch.allclose(out1, out2))
1390
        True
1391
        >>>
1392
        >>> f_traced = make_fx(f)(inpt)
1393
        >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
1394
        >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1395
        >>>
1396
        >>> print(f_traced.code)
1397

1398

1399

1400
        def forward(self, a_1):
1401
            add = torch.ops.aten.add(a_1, 1);  a_1 = None
1402
            view = torch.ops.aten.view(add, [-1])
1403
            add_ = torch.ops.aten.add_(view, 1);  view = None
1404
            return add
1405

1406
        >>> print(f_no_mutations_traced.code)
1407

1408

1409

1410
        def forward(self, a_1):
1411
            add = torch.ops.aten.add(a_1, 1);  a_1 = None
1412
            view = torch.ops.aten.view(add, [-1]);  add = None
1413
            add_1 = torch.ops.aten.add(view, 1);  view = None
1414
            view_1 = torch.ops.aten.view(add_1, [2]);  add_1 = None
1415
            return view_1
1416

1417
        >>> print(f_no_mutations_and_views_traced.code)
1418

1419

1420

1421
        def forward(self, a_1):
1422
            add = torch.ops.aten.add(a_1, 1);  a_1 = None
1423
            view_copy = torch.ops.aten.view_copy(add, [-1]);  add = None
1424
            add_1 = torch.ops.aten.add(view_copy, 1);  view_copy = None
1425
            view_copy_1 = torch.ops.aten.view_copy(add_1, [2]);  add_1 = None
1426
            return view_copy_1
1427

1428

1429
        >>> # A function that mutates its input tensor
1430
        >>> def f(a):
1431
        ...     b = a.view(-1)
1432
        ...     b.add_(1)
1433
        ...     return a
1434
        ...
1435
        >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1436
        >>> #
1437
        >>> # All mutations and views have been removed,
1438
        >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
1439
        >>> # after the function has completed.
1440
        >>> print(f_no_mutations_and_views_traced.code)
1441

1442

1443

1444
        def forward(self, a_1):
1445
            view_copy = torch.ops.aten.view_copy(a_1, [-1])
1446
            add = torch.ops.aten.add(view_copy, 1);  view_copy = None
1447
            view_copy_1 = torch.ops.aten.view_copy(add, [2]);  add = None
1448
            copy_ = torch.ops.aten.copy_(a_1, view_copy_1);  a_1 = None
1449
            return view_copy_1
1450

1451

1452
    There are a few "failure modes" for functionalize that are worth calling out:
1453
      (1) Like other torch.func transforms, `functionalize()` doesn't work with functions
1454
          that directly use `.backward()`. The same is true for torch.autograd.grad.
1455
          If you want to use autograd, you can compute gradients directly
1456
          with `functionalize(grad(f))`.
1457
      (2) Like other torch.func transforms, `functionalize()` doesn't work with global state.
1458
          If you call `functionalize(f)` on a function that takes views / mutations of
1459
          non-local state, functionalization will simply no-op and pass the view/mutation
1460
          calls directly to the backend.
1461
          One way to work around this is is to ensure that any non-local state creation
1462
          is wrapped into a larger function, which you then call functionalize on.
1463
      (3) `resize_()` has some limitations: functionalize will only work on programs
1464
          that use resize_()` as long as the tensor being resized is not a view.
1465
      (4) `as_strided()` has some limitations: functionalize will not work on
1466
          `as_strided()` calls that result in tensors with overlapping memory.
1467

1468

1469
    Finally, a helpful mental model for understanding functionalization is that
1470
    most user pytorch programs are writing with the public torch API.
1471
    When executed, torch operators are generally decomposed into
1472
    our internal C++ "ATen" API.
1473
    The logic for functionalization happens entirely at the level of ATen.
1474
    Functionalization knows how to take every aliasing operator in ATen,
1475
    and map it to its non-aliasing equivalent
1476
    (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``),
1477
    and how to take every mutating operator in ATen,
1478
    and map it to its non-mutating equivalent
1479
    (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``),
1480
    while tracking aliases and mutations out-of-line to know when to fix things up.
1481
    Information about which ATen operators are aliasing or mutating all comes from
1482
    https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.
1483
    """
1484
    if remove == 'mutations':
1485
        reapply_views = True
1486
    elif remove == 'mutations_and_views':
1487
        reapply_views = False
1488
    else:
1489
        raise RuntimeError(
1490
            f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}."
1491
            " Valid options are:\n"
1492
            "     remove='mutations': all inplace and out= operators will be removed from the program, and replaced"
1493
            " with their out-of-place equivalents.\n"
1494
            "     remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be"
1495
            " replaced with their non-aliasing counterparts, {view}_copy.\n"
1496
        )
1497

1498
    @doesnt_support_saved_tensors_hooks
1499
    @wraps(func)
1500
    def wrapped(*args, **kwargs):
1501
        try:
1502
            func_level = _func_increment_nesting(reapply_views)
1503
            func_args = _wrap_all_tensors_to_functional(args, func_level)
1504
            func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)
1505

1506
            flattened_unwrapped_args = pytree.arg_tree_leaves(*args)
1507
            flattened_wrapped_args = pytree.arg_tree_leaves(*func_args)
1508
            flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs)
1509
            flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs)
1510

1511
            func_outputs = func(*func_args, **func_kwargs)
1512
            outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views)
1513
            flat_outputs, func_out_spec = tree_flatten(outputs)
1514

1515
            for a in flattened_wrapped_args + flattened_wrapped_kwargs:
1516
                if isinstance(a, torch.Tensor):
1517
                    # Call sync_() on the inputs, to ensure that any pending mutations have been applied.
1518
                    torch._sync(a)
1519

1520
            # And if any mutations were applied to the inputs, we need to propagate them back to the user.
1521
            for unwrapped, wrapped in zip(flattened_unwrapped_args, flattened_wrapped_args):
1522
                if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):
1523
                    _propagate_functional_input_mutation(unwrapped, wrapped)
1524
            for unwrapped, wrapped in zip(flattened_unwrapped_kwargs, flattened_wrapped_kwargs):
1525
                if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):
1526
                    _propagate_functional_input_mutation(unwrapped, wrapped)
1527

1528
            return outputs
1529
        finally:
1530
            _func_decrement_nesting()
1531
    return wrapped
1532

1533
@exposed_in("torch.func")
1534
def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
1535
    '''
1536
    Returns the value of ``func`` at ``primals`` and linear approximation
1537
    at ``primals``.
1538

1539
    Args:
1540
        func (Callable): A Python function that takes one or more arguments.
1541
        primals (Tensors): Positional arguments to ``func`` that must all be
1542
            Tensors. These are the values at which the function is linearly approximated.
1543

1544
    Returns:
1545
        Returns a ``(output, jvp_fn)`` tuple containing the output of ``func``
1546
        applied to ``primals`` and a function that computes the jvp of
1547
        ``func`` evaluated at ``primals``.
1548

1549
    linearize is useful if jvp is to be computed multiple times at ``primals``. However,
1550
    to achieve this, linearize saves intermediate computation and has higher memory requirements
1551
    than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient
1552
    to compute vmap(jvp) instead of using linearize.
1553

1554
    .. note::
1555
        linearize evaluates ``func`` twice. Please file an issue for an implementation
1556
        with a single evaluation.
1557

1558
    Example::
1559
        >>> import torch
1560
        >>> from torch.func import linearize
1561
        >>> def fn(x):
1562
        ...     return x.sin()
1563
        ...
1564
        >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
1565
        >>> jvp_fn(torch.ones(3, 3))
1566
        tensor([[1., 1., 1.],
1567
                [1., 1., 1.],
1568
                [1., 1., 1.]])
1569
        >>>
1570

1571
    '''
1572
    # Note: We evaluate `fn` twice.
1573
    # Once for returning the output and other while
1574
    # tracing the graph.
1575
    # If this becomes a bottle-neck, we should update
1576
    # make_fx such that it also returns the output.
1577

1578
    output = func(*primals)
1579
    _, output_spec = tree_flatten(output)
1580

1581
    flat_primals, primals_argspec = tree_flatten(primals)
1582

1583
    # tangents for tracing
1584
    flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals)
1585

1586
    # function to trace
1587
    def trace_fn(flat_tangents):
1588
        with fwAD.dual_level():
1589
            flat_duals = tuple(fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents))
1590
            duals = tree_unflatten(flat_duals, primals_argspec)
1591
            output = func(*duals)
1592
            tangents = tree_map_only(torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output)
1593

1594
        return tangents
1595

1596
    jvp_graph = make_fx(trace_fn)(flat_tangents)
1597
    const_folded_jvp_graph = const_fold.split_const_subgraphs(jvp_graph)
1598

1599
    # Hold only the meta-data regarding the primals.
1600
    flat_primals_shape = tuple(p.shape for p in flat_primals)
1601
    flat_primals_device = tuple(p.device for p in flat_primals)
1602
    flat_primals_dtype = tuple(p.dtype for p in flat_primals)
1603

1604
    def forward_ad_checks(flat_tangents):
1605
        for idx, t in enumerate(flat_tangents):
1606
            if t.shape != flat_primals_shape[idx]:
1607
                msg = (f"tangent:{idx} with shape {t.shape} in flattened "
1608
                       f"pytree doesn't match the shape {flat_primals_shape[idx]} "
1609
                       "of the corresponding primal.")
1610
                raise RuntimeError(msg)
1611

1612
            if t.device != flat_primals_device[idx]:
1613
                msg = (f"tangent:{idx} with device {t.device} in flattened "
1614
                       f"pytree doesn't match the device {flat_primals_device[idx]} "
1615
                       "of the corresponding primal.")
1616
                raise RuntimeError(msg)
1617

1618
            if t.dtype != flat_primals_dtype[idx]:
1619
                msg = (f"tangent:{idx} with dtype {t.dtype} in flattened "
1620
                       f"pytree doesn't match the dtype {flat_primals_dtype[idx]} "
1621
                       "of the corresponding primal.")
1622
                raise RuntimeError(msg)
1623

1624
    # jvp_fn : callable to return
1625
    #   It takes care of checking the argspec of tangents,
1626
    #   calling the folded fx graph and unflattening fx graph output
1627
    def jvp_fn(*tangents):
1628
        flat_tangents, tangent_argspec = tree_flatten(tangents)
1629
        if tangent_argspec != primals_argspec:
1630
            raise RuntimeError(f"Expected the tangents {tangent_argspec} to have "
1631
                               f"the same argspec as the primals {primals_argspec}")
1632

1633
        forward_ad_checks(flat_tangents)
1634

1635
        flat_output = const_folded_jvp_graph(*flat_tangents)
1636
        # const folded graph can return flat output,
1637
        # so transform output.
1638
        return tree_unflatten(flat_output, output_spec)
1639

1640
    return output, jvp_fn
1641

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.