pytorch
1640 строк · 68.0 Кб
1# mypy: ignore-errors
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9from typing import Callable, Union, Tuple, List, Any, Optional10import torch11from functools import partial, wraps12import contextlib13from torch.utils._pytree import (14tree_flatten,15tree_unflatten,16tree_map,17tree_map_only,18tree_map_,19treespec_pprint,20)
21from torch.utils import _pytree as pytree22from torch.fx.experimental import const_fold23from torch.fx.experimental.proxy_tensor import make_fx24import torch.autograd.forward_ad as fwAD25from torch._subclasses.functional_tensor import FunctionalTensor26
27from .vmap import doesnt_support_saved_tensors_hooks, get_chunk_sizes28from .apis import vmap29
30from torch._C._functorch import (31_wrap_for_grad,32_unwrap_for_grad,33_grad_increment_nesting,34_grad_decrement_nesting,35_jvp_increment_nesting,36_jvp_decrement_nesting,37_wrap_functional_tensor,38_unwrap_functional_tensor,39_func_decrement_nesting,40_func_increment_nesting,41_assert_wrapped_functional,42_propagate_functional_input_mutation,43set_inplace_requires_grad_allowed,44get_inplace_requires_grad_allowed,45)
46from torch._functorch.utils import exposed_in, argnums_t47
48
49def lazy_dynamo_disable(func):50import torch._dynamo51return torch._dynamo.disable(func)52
53@contextlib.contextmanager54def enable_inplace_requires_grad(enabled):55prev_state = get_inplace_requires_grad_allowed()56set_inplace_requires_grad_allowed(enabled)57try:58yield59finally:60set_inplace_requires_grad_allowed(prev_state)61
62
63def _vjp_treespec_compare(primals_out, cotangents):64# Revert this once #116264 gets fixed65_, primals_out_spec = tree_flatten(primals_out)66_, cotangents_spec = tree_flatten(cotangents)67# Dynamo fails to trace operator.ne below. To bypass this limitation, this68# function is not inlined.69if primals_out_spec != cotangents_spec:70raise RuntimeError(71f'Expected pytree structure of cotangents to be the same '72f'as pytree structure of outputs to the function. '73f'cotangents: {treespec_pprint(cotangents_spec)}, '74f'primal output: {treespec_pprint(primals_out_spec)}')75
76
77def _set_tensor_requires_grad(x):78# avoid graph-break on x.requires_grad_()79# https://github.com/pytorch/pytorch/pull/11005380return x.requires_grad_()81
82def _create_differentiable(inps, level=None):83def create_differentiable(x):84if isinstance(x, torch.Tensor):85with enable_inplace_requires_grad(True):86return _set_tensor_requires_grad(x)87raise ValueError(f'Thing passed to transform API must be Tensor, '88f'got {type(x)}')89return tree_map(create_differentiable, inps)90
91
92def _undo_create_differentiable(inps, level=None):93def unwrap_tensors(x):94if isinstance(x, torch.Tensor):95return _unwrap_for_grad(x, level)96# TODO: Remove the following hack for namedtuples97if isinstance(x, tuple):98return tree_map(unwrap_tensors, tuple(x))99
100raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")101
102return tree_map(unwrap_tensors, inps)103
104
105def _is_differentiable(maybe_tensor):106if not isinstance(maybe_tensor, torch.Tensor):107return False108return maybe_tensor.requires_grad109
110
111def _any_differentiable(tensor_or_tuple_of_tensors):112flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)113return any(tuple(map(_is_differentiable, flat_args)))114
115
116def _wrap_tensor_for_grad(maybe_tensor, level):117if not isinstance(maybe_tensor, torch.Tensor):118return maybe_tensor119return _wrap_for_grad(maybe_tensor, level)120
121
122def _wrap_all_tensors(tensor_pytree, level):123return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)124
125
126def _as_tuple(val):127if isinstance(val, tuple):128return val129return (val,)130
131# Version of autograd.grad that handles outputs that don't depend on inputs
132
133
134def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):135if grad_outputs is None:136diff_outputs = tuple(out for out in outputs if out.requires_grad)137else:138result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)139if len(result) == 0:140diff_outputs, grad_outputs = (), ()141else:142diff_outputs, grad_outputs = zip(*result)143if len(diff_outputs) == 0:144return tuple(torch.zeros_like(inp) for inp in inputs)145grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,146retain_graph=retain_graph,147create_graph=create_graph,148allow_unused=True)149grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi150for gi, inp in zip(grad_inputs, inputs))151return grad_inputs152
153# NOTE [grad and vjp interaction with no_grad]
154#
155# def f(x):
156# with torch.no_grad():
157# c = x ** 2
158# return x - c
159#
160# The thing to consider is if enable_grad is on/off before grad gets called.
161#
162# Case 1: enable_grad is on.
163# grad(f)(x)
164# In this case, `grad` should respect the inner torch.no_grad.
165#
166# Case 2: enable_grad is off
167# with torch.no_grad():
168# grad(f)(x)
169# In this case, `grad` should respect the inner torch.no_grad, but not the
170# outer one. This is because `grad` is a "function transform": its result
171# should not depend on the result of a context manager outside of `f`.
172#
173# This gives us the following desired behavior:
174# - (nested) grad transforms must obey torch.no_grad inside them
175# - (nested) grad transforms should not obey torch.no_grad outside them
176#
177# To achieve this behavior, upon entering grad/vjp:
178# - we save the current ("previous") is_grad_enabled (*)
179# - we unconditionally enable grad.
180#
181# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
182# off the stack:
183# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
184# active, all subsequent grad transforms must obey it).
185# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
186# then we temporarily restore the previous `is_grad_enabled`. This is
187# because we're crossing the boundary from a `grad` outside the
188# no_grad to a `grad` inside the no_grad.
189#
190# NB: vjp has some interesting behavior because the vjp's callable can be called
191# under a different grad_mode than the forward computation...
192#
193# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but
194# it respects c10::AutoFwGradMode. We've implemented the same logic for
195# our jvp transform (it will have special handling if FwGradMode is disabled).
196
197
198# How do we increment and decrement the nesting? I don't think we can.
199@exposed_in("torch.func")200def vjp(func: Callable, *primals, has_aux: bool = False):201"""202Standing for the vector-Jacobian product, returns a tuple containing the
203results of ``func`` applied to ``primals`` and a function that, when
204given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with
205respect to ``primals`` times ``cotangents``.
206
207Args:
208func (Callable): A Python function that takes one or more arguments. Must
209return one or more Tensors.
210primals (Tensors): Positional arguments to ``func`` that must all be
211Tensors. The returned function will also be computing the
212derivative with respect to these arguments
213has_aux (bool): Flag indicating that ``func`` returns a
214``(output, aux)`` tuple where the first element is the output of
215the function to be differentiated and the second element is
216other auxiliary objects that will not be differentiated.
217Default: False.
218
219Returns:
220Returns a ``(output, vjp_fn)`` tuple containing the output of ``func``
221applied to ``primals`` and a function that computes the vjp of
222``func`` with respect to all ``primals`` using the cotangents passed
223to the returned function. If ``has_aux is True``, then instead returns a
224``(output, vjp_fn, aux)`` tuple.
225The returned ``vjp_fn`` function will return a tuple of each VJP.
226
227When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
228
229>>> x = torch.randn([5])
230>>> f = lambda x: x.sin().sum()
231>>> (_, vjpfunc) = torch.func.vjp(f, x)
232>>> grad = vjpfunc(torch.tensor(1.))[0]
233>>> assert torch.allclose(grad, torch.func.grad(f)(x))
234
235However, :func:`vjp` can support functions with multiple outputs by
236passing in the cotangents for each of the outputs
237
238>>> x = torch.randn([5])
239>>> f = lambda x: (x.sin(), x.cos())
240>>> (_, vjpfunc) = torch.func.vjp(f, x)
241>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
242>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
243
244:func:`vjp` can even support outputs being Python structs
245
246>>> x = torch.randn([5])
247>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
248>>> (_, vjpfunc) = torch.func.vjp(f, x)
249>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
250>>> vjps = vjpfunc(cotangents)
251>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
252
253The function returned by :func:`vjp` will compute the partials with
254respect to each of the ``primals``
255
256>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
257>>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
258>>> cotangents = torch.randn([5, 5])
259>>> vjps = vjpfunc(cotangents)
260>>> assert len(vjps) == 2
261>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
262>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
263
264``primals`` are the positional arguments for ``f``. All kwargs use their
265default value
266
267>>> x = torch.randn([5])
268>>> def f(x, scale=4.):
269>>> return x * scale
270>>>
271>>> (_, vjpfunc) = torch.func.vjp(f, x)
272>>> vjps = vjpfunc(torch.ones_like(x))
273>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
274
275.. note::
276Using PyTorch ``torch.no_grad`` together with ``vjp``.
277Case 1: Using ``torch.no_grad`` inside a function:
278
279>>> def f(x):
280>>> with torch.no_grad():
281>>> c = x ** 2
282>>> return x - c
283
284In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
285
286Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
287
288>>> # xdoctest: +SKIP(failing)
289>>> with torch.no_grad():
290>>> vjp(f)(x)
291
292In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
293outer one. This is because ``vjp`` is a "function transform": its result
294should not depend on the result of a context manager outside of ``f``.
295"""
296return _vjp_with_argnums(func, *primals, has_aux=has_aux)297
298
299@contextlib.contextmanager300def grad_increment_nesting():301try:302grad_level = _grad_increment_nesting()303yield grad_level304finally:305_grad_decrement_nesting()306
307
308@doesnt_support_saved_tensors_hooks
309def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False):310# This is the same function as vjp but also accepts an argnums argument311# All args are the same as vjp except for the added argument312# argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.313# If None, computes the gradients with respect to all inputs (used for vjp). Default: None314#315# WARN: Users should NOT call this function directly and should just be calling vjp.316# It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers.317#318# NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev319#320# Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs321# for only the primal elements given by argnums.322with grad_increment_nesting() as level:323# See NOTE [grad and vjp interaction with no_grad]324with torch.enable_grad():325primals = _wrap_all_tensors(primals, level)326# Note for the reviewer: This is extremely odd but it passes the327# assertion "len(self.block_stack) == 1" on symbolic_convert.py328# The equivalent "if argnums is None" fails for some reason329if not isinstance(argnums, int) and not argnums:330diff_primals = _create_differentiable(primals, level)331else:332diff_primals = _slice_argnums(primals, argnums, as_tuple=False)333tree_map_(partial(_create_differentiable, level=level), diff_primals)334primals_out = func(*primals)335
336if has_aux:337if not (isinstance(primals_out, tuple) and len(primals_out) == 2):338raise RuntimeError(339"vjp(f, *primals): output of function f should be a tuple: (output, aux) "340"if has_aux is True"341)342primals_out, aux = primals_out343aux = _undo_create_differentiable(aux, level)344
345flat_primals_out, primals_out_spec = tree_flatten(primals_out)346assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)')347flat_diff_primals, primals_spec = tree_flatten(diff_primals)348results = _undo_create_differentiable(primals_out, level)349
350for primal_out in flat_primals_out:351assert isinstance(primal_out, torch.Tensor)352if primal_out.is_floating_point() or primal_out.is_complex():353continue354raise RuntimeError("vjp(f, ...): All outputs of f must be "355"floating-point or complex Tensors, got Tensor "356f"with dtype {primal_out.dtype}")357
358def wrapper(cotangents, retain_graph=True, create_graph=None):359if create_graph is None:360create_graph = torch.is_grad_enabled()361flat_cotangents, cotangents_spec = tree_flatten(cotangents)362_vjp_treespec_compare(primals_out, cotangents)363result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,364retain_graph=retain_graph, create_graph=create_graph)365return tree_unflatten(result, primals_spec)366
367if has_aux:368return results, wrapper, aux369else:370return results, wrapper371
372
373def _safe_zero_index(x):374assert len(x) == 1375return x[0]376
377# jacrev and jacfwd don't support complex functions
378# Helper function to throw appropriate error.
379def error_if_complex(func_name, args, is_input):380flat_args = pytree.tree_leaves(args)381for idx, arg in enumerate(flat_args):382if isinstance(arg, torch.Tensor) and arg.dtype.is_complex:383input_or_output = ("inputs" if is_input else "outputs")384err_msg = (f"{func_name}: Expected all {input_or_output} "385f"to be real but received complex tensor at flattened input idx: {idx}")386raise RuntimeError(err_msg)387
388@exposed_in("torch.func")389def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,390chunk_size: Optional[int] = None,391_preallocate_and_copy=False):392"""393Computes the Jacobian of ``func`` with respect to the arg(s) at index
394``argnum`` using reverse mode autodiff
395
396.. note::
397Using :attr:`chunk_size=1` is equivalent to computing the jacobian
398row-by-row with a for-loop i.e. the constraints of :func:`vmap` are
399not applicable.
400
401Args:
402func (function): A Python function that takes one or more arguments,
403one of which must be a Tensor, and returns one or more Tensors
404argnums (int or Tuple[int]): Optional, integer or tuple of integers,
405saying which arguments to get the Jacobian with respect to.
406Default: 0.
407has_aux (bool): Flag indicating that ``func`` returns a
408``(output, aux)`` tuple where the first element is the output of
409the function to be differentiated and the second element is
410auxiliary objects that will not be differentiated.
411Default: False.
412chunk_size (None or int): If None (default), use the maximum chunk size
413(equivalent to doing a single vmap over vjp to compute the jacobian).
414If 1, then compute the jacobian row-by-row with a for-loop.
415If not None, then compute the jacobian :attr:`chunk_size` rows at a time
416(equivalent to doing multiple vmap over vjp). If you run into memory issues computing
417the jacobian, please try to specify a non-None chunk_size.
418
419Returns:
420Returns a function that takes in the same inputs as ``func`` and
421returns the Jacobian of ``func`` with respect to the arg(s) at
422``argnums``. If ``has_aux is True``, then the returned function
423instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
424is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
425
426A basic usage with a pointwise, unary operation will give a diagonal array
427as the Jacobian
428
429>>> from torch.func import jacrev
430>>> x = torch.randn(5)
431>>> jacobian = jacrev(torch.sin)(x)
432>>> expected = torch.diag(torch.cos(x))
433>>> assert torch.allclose(jacobian, expected)
434
435If you would like to compute the output of the function as well as the
436jacobian of the function, use the ``has_aux`` flag to return the output
437as an auxiliary object:
438
439>>> from torch.func import jacrev
440>>> x = torch.randn(5)
441>>>
442>>> def f(x):
443>>> return x.sin()
444>>>
445>>> def g(x):
446>>> result = f(x)
447>>> return result, result
448>>>
449>>> jacobian_f, f_x = jacrev(g, has_aux=True)(x)
450>>> assert torch.allclose(f_x, f(x))
451
452:func:`jacrev` can be composed with vmap to produce batched
453Jacobians:
454
455>>> from torch.func import jacrev, vmap
456>>> x = torch.randn(64, 5)
457>>> jacobian = vmap(jacrev(torch.sin))(x)
458>>> assert jacobian.shape == (64, 5, 5)
459
460Additionally, :func:`jacrev` can be composed with itself to produce
461Hessians
462
463>>> from torch.func import jacrev
464>>> def f(x):
465>>> return x.sin().sum()
466>>>
467>>> x = torch.randn(5)
468>>> hessian = jacrev(jacrev(f))(x)
469>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
470
471By default, :func:`jacrev` computes the Jacobian with respect to the first
472input. However, it can compute the Jacboian with respect to a different
473argument by using ``argnums``:
474
475>>> from torch.func import jacrev
476>>> def f(x, y):
477>>> return x + y ** 2
478>>>
479>>> x, y = torch.randn(5), torch.randn(5)
480>>> jacobian = jacrev(f, argnums=1)(x, y)
481>>> expected = torch.diag(2 * y)
482>>> assert torch.allclose(jacobian, expected)
483
484Additionally, passing a tuple to ``argnums`` will compute the Jacobian
485with respect to multiple arguments
486
487>>> from torch.func import jacrev
488>>> def f(x, y):
489>>> return x + y ** 2
490>>>
491>>> x, y = torch.randn(5), torch.randn(5)
492>>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
493>>> expectedX = torch.diag(torch.ones_like(x))
494>>> expectedY = torch.diag(2 * y)
495>>> assert torch.allclose(jacobian[0], expectedX)
496>>> assert torch.allclose(jacobian[1], expectedY)
497
498.. note::
499Using PyTorch ``torch.no_grad`` together with ``jacrev``.
500Case 1: Using ``torch.no_grad`` inside a function:
501
502>>> def f(x):
503>>> with torch.no_grad():
504>>> c = x ** 2
505>>> return x - c
506
507In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
508
509Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
510
511>>> with torch.no_grad():
512>>> jacrev(f)(x)
513
514In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
515outer one. This is because ``jacrev`` is a "function transform": its result
516should not depend on the result of a context manager outside of ``f``.
517"""
518if not (chunk_size is None or chunk_size > 0):519raise ValueError("jacrev: `chunk_size` should be greater than 0.")520
521@wraps(func)522def wrapper_fn(*args):523error_if_complex("jacrev", args, is_input=True)524vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)525if has_aux:526output, vjp_fn, aux = vjp_out527else:528output, vjp_fn = vjp_out529
530# See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]531flat_output, output_spec = tree_flatten(output)532
533error_if_complex("jacrev", flat_output, is_input=False)534
535# NB: vjp already checks that all outputs are tensors536# Step 1: Construct grad_outputs by splitting the standard basis537flat_output_numels = tuple(out.numel() for out in flat_output)538
539primals = _slice_argnums(args, argnums)540flat_primals, primals_spec = tree_flatten(primals)541
542def compute_jacobian_stacked():543# Helper function to compute chunked Jacobian544# The intermediate chunked calculation are only545# scoped at this function level.546chunked_results = []547for flat_basis_chunk in _chunked_standard_basis_for_(flat_output,548flat_output_numels,549chunk_size=chunk_size):550if chunk_size == 1:551# sanity check.552for t in flat_basis_chunk:553assert t.size(0) == 1554
555flat_basis_chunk = tree_map(lambda t: torch.squeeze(t, 0), flat_basis_chunk)556
557basis = tree_unflatten(flat_basis_chunk, output_spec)558
559if chunk_size == 1:560# Behaviour with `chunk_size=1` is same as `for-loop`561# i.e. user shouldn't deal with the limitations of vmap.562chunked_result = vjp_fn(basis)563else: # chunk_size is None or chunk_size != 1564chunked_result = vmap(vjp_fn)(basis)565
566flat_results = pytree.tree_leaves(chunked_result)567
568if chunk_size == 1:569flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results)570
571chunked_results.append(flat_results)572
573if len(chunked_results) == 1:574# Short-circuit if we used a single chunk575return chunked_results[0]576
577# Concatenate chunks.578flat_results = []579# Iterate and concat the jacobians of different580# inputs.581for idx in range(len(flat_primals)):582r = tuple(r_[idx] for r_ in chunked_results)583flat_results.append(torch.cat(r, 0))584
585return flat_results586
587def compute_jacobian_preallocate_and_copy():588# Helper function to compute chunked Jacobian589# The intermediate chunked calculation are only590# scoped at this function level.591out_vec_size = sum(flat_output_numels)592
593# Don't pre-allocate if we have a single chunk.594if not (chunk_size is None or chunk_size >= out_vec_size):595stacked_results = [primal.new_zeros(out_vec_size, *primal.shape) for primal in flat_primals]596
597for idx, flat_basis_chunk in enumerate(_chunked_standard_basis_for_(flat_output,598flat_output_numels,599chunk_size=chunk_size)):600if chunk_size == 1:601# sanity check.602for t in flat_basis_chunk:603assert t.size(0) == 1604
605flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk]606
607basis = tree_unflatten(flat_basis_chunk, output_spec)608
609if chunk_size == 1:610# Behaviour with `chunk_size=1` is same as `for-loop`611# i.e. user shouldn't deal with the limitations of vmap.612chunked_result = vjp_fn(basis)613else: # chunk_size is None or chunk_size != 1614chunked_result = vmap(vjp_fn)(basis)615
616flat_results = pytree.tree_leaves(chunked_result)617
618# Short-circuit if we have a single chunk.619if chunk_size is None or chunk_size >= out_vec_size:620if chunk_size == 1: # and out_vec_size == 1621# Since we squeezed the output dim622flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results)623return flat_results624
625for r, sr in zip(flat_results, stacked_results):626sr[idx * chunk_size: (idx + 1) * chunk_size].copy_(r)627
628return stacked_results629
630if _preallocate_and_copy:631flat_jacobians_per_input = compute_jacobian_preallocate_and_copy()632else:633flat_jacobians_per_input = compute_jacobian_stacked()634
635# Step 2: The returned jacobian is one big tensor per input. In this step,636# we split each Tensor by output.637flat_jacobians_per_input = [result.split(flat_output_numels, dim=0) for result in flat_jacobians_per_input]638flat_input_flat_output = [639tuple(split.view(out.shape + primal.shape)640for split, out in zip(splits, flat_output))641for splits, primal in zip(flat_jacobians_per_input, flat_primals)642]643
644# Step 3: Right now, `jacobian` is a List[List[Tensor]].645# The outer List corresponds to the number of primals,646# the inner List corresponds to the number of outputs.647# We need to:648# a. Exchange the order of the outer List and inner List649# b. tree_unflatten the inner Lists (which correspond to the primals)650# c. handle the argnums=int case651# d. tree_unflatten the outer List (which corresponds to the outputs)652flat_output_flat_input = tuple(zip(*flat_input_flat_output))653
654flat_output_input = tuple(tree_unflatten(flat_input, primals_spec)655for flat_input in flat_output_flat_input)656
657if isinstance(argnums, int):658flat_output_input = tuple(_safe_zero_index(flat_input)659for flat_input in flat_output_input)660output_input = tree_unflatten(flat_output_input, output_spec)661if has_aux:662return output_input, aux663return output_input664return wrapper_fn665
666# NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
667#
668# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
669# It turns out we can compute the jacobian of this function with a single
670# call to autograd.grad by using vmap over the correct grad_outputs.
671#
672# Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
673# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
674#
675# To get the first row of the jacobian, we call
676# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
677# To get the 2nd row of the jacobian, we call
678# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
679# and so on.
680#
681# Using vmap, we can vectorize all 4 of these computations into one by
682# passing the standard basis for R^4 as the grad_output.
683# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
684#
685# Now, how do we compute the jacobian *without stacking the output*?
686# We can just split the standard basis across the outputs. So to
687# compute the jacobian of f(x), we'd use
688# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
689# The grad_outputs looks like the following:
690# ( torch.tensor([[1, 0, 0],
691# [0, 1, 0],
692# [0, 0, 1],
693# [0, 0, 0]]),
694# torch.tensor([[0],
695# [0],
696# [0],
697# [1]]) )
698#
699# But we're not done yet!
700# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
701# returns a Tensor of shape [4, 3]. We have to remember to split the
702# jacobian of shape [4, 3] into two:
703# - one of shape [3, 3] for the first output
704# - one of shape [ 3] for the second output
705
706
707def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):708# This function:709# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.710# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.711# - Each chunk corresponds to one tensor. The chunk has the same dtype and712# device as the tensor713#714# For example, with tensor_numels = [1, 2, 1], this function returns:715# ( tensor([[1], tensor([[0, 0], tensor([[0],716# [0], [1, 0], [0],717# [0], [0, 1], [0],718# [0]]) , [0, 0]]) , [1]]) )719#720# Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)721# Precondition: tensors always has at least one element.722#723# See NOTE: [Computing jacobian with vmap and grad for multiple tensors]724# for context behind this function.725# NOTE: Argument `chunk_size` is used to generate chunked basis instead of726# one huge basis matrix. `chunk_size` dictates the maximum size of the727# basis matrix along dim=0.728assert len(tensors) == len(tensor_numels)729assert len(tensors) > 0730assert chunk_size is None or chunk_size > 0731total_numel = sum(tensor_numels)732if chunk_size and chunk_size < total_numel:733chunk_numels = get_chunk_sizes(total_numel, chunk_size)734else: # chunk_size is None or chunk_size >= total_numel735chunk_size = total_numel736chunk_numels = [total_numel]737
738diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())739
740for chunk_idx, total_numel in enumerate(chunk_numels):741chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)742for tensor, tensor_numel in zip(tensors, tensor_numels))743
744for chunk, diag_start_idx in zip(chunks, diag_start_indices):745chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1)746chunks = tuple(chunk.view(total_numel, *tensor.shape)747for chunk, tensor in zip(chunks, tensors))748yield chunks749
750def _construct_standard_basis_for(tensors, tensor_numels):751for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):752return basis753
754
755def _validate_and_wrap_argnum(argnum, num_args):756if not isinstance(argnum, int):757raise RuntimeError(f'argnum must be int, got: {type(argnum)}')758if argnum >= 0 and argnum < num_args:759return argnum760if argnum < 0 and argnum >= -num_args:761return argnum + num_args762raise RuntimeError(f'Got argnum={argnum}, but only {num_args} positional inputs')763
764
765def _check_unique_non_empty(argnums):766if isinstance(argnums, tuple):767if len(argnums) == 0:768raise RuntimeError("argnums must be non-empty")769if len(set(argnums)) != len(argnums):770raise RuntimeError(f"argnums elements must be unique, got {argnums}")771
772
773def _replace_args(old_args, new_args, argnums):774if isinstance(argnums, int):775if len(new_args) != 1:776raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')777return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)))778if isinstance(argnums, tuple):779if len(new_args) != len(argnums):780raise RuntimeError(781"new_args should have the same size as argnums. "782f"Argnums size {len(argnums)}, new_args size {len(new_args)}")783
784def get_right_elem(i):785return new_args[argnums.index(i)] if i in argnums else old_args[i]786
787return tuple(get_right_elem(i) for i in range(len(old_args)))788raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')789
790
791def _validate_and_wrap_argnums(argnums, num_args):792if isinstance(argnums, int):793return _validate_and_wrap_argnum(argnums, num_args)794if isinstance(argnums, tuple):795return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums)796raise AssertionError("Should never get here")797
798
799def _slice_argnums(args, argnums, as_tuple=True):800if not isinstance(argnums, int) and not isinstance(argnums, tuple):801raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')802argnums = _validate_and_wrap_argnums(argnums, len(args))803_check_unique_non_empty(argnums)804if isinstance(argnums, int):805if as_tuple:806return (args[argnums],)807else:808return args[argnums]809return tuple(args[i] for i in argnums)810
811
812JVP_NESTING = 0813
814
815@contextlib.contextmanager816def noop():817yield818
819
820def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:821if not isinstance(elts, tuple):822raise RuntimeError(823f'{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}')824for elt in elts:825if isinstance(elt, torch.Tensor):826continue827raise RuntimeError(828f'{api}: Expected {argname} to be a tuple of Tensors, got '829f'a tuple with an element of type {type(elt)}')830if len(elts) == 0:831raise RuntimeError(832f'{api}: Expected {argname} to be a non-empty tuple of Tensors.')833
834
835def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:836if (len(output) == 1 and output[0] is None) or len(output) < 1:837raise RuntimeError(838f'{api}: Expected f to be a function that has non-empty output (got output = {output})'839)840for o in output:841if not isinstance(o, torch.Tensor):842raise RuntimeError(843f'{api}: expected f(*primals) to return only tensors'844f', got unsupported type {type(o)}'845)846
847
848def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:849if isinstance(output, torch.Tensor):850return851if not isinstance(output, tuple):852raise RuntimeError(853f'{api}: Expected output of f to be a Tensor or Tensors, got '854f'{type(output)}')855if len(output) == 0:856raise RuntimeError(857f'{api}: Expected output of f to be a non-empty tuple of Tensors.')858for out in output:859if isinstance(out, torch.Tensor):860continue861raise RuntimeError(862f'{api}: Expected output of f to be a Tensor or Tensors, got '863f'{type(out)} as an output')864
865
866def assert_non_empty_list_of_tensors(output: List[torch.Tensor], api: str, argname: str) -> None:867if len(output) == 0:868raise RuntimeError(869f'{api}: Expected {argname} to contain at least one Tensor.')870for out in output:871if isinstance(out, torch.Tensor):872continue873raise RuntimeError(874f'{api}: Expected {argname} to only contain Tensors, got '875f'{type(out)}')876
877
878jvp_str = 'jvp(f, primals, tangents)'879
880
881def safe_unpack_dual(dual, strict):882if not isinstance(dual, torch.Tensor):883raise RuntimeError(884f'{jvp_str}: expected f(*args) to return only tensors'885f', got unsupported type {type(dual)}'886)887
888primal, tangent = fwAD.unpack_dual(dual)889if tangent is None:890if strict:891raise RuntimeError(892'jvp(f, primals, tangents, strict=True): '893'The output of f is independent of '894'the inputs. This is not allowed with strict=True.')895tangent = torch.zeros_like(primal)896return primal, tangent897
898
899@exposed_in("torch.func")900def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):901"""902Standing for the Jacobian-vector product, returns a tuple containing
903the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
904``primals``" times ``tangents``. This is also known as forward-mode autodiff.
905
906Args:
907func (function): A Python function that takes one or more arguments,
908one of which must be a Tensor, and returns one or more Tensors
909primals (Tensors): Positional arguments to ``func`` that must all be
910Tensors. The returned function will also be computing the
911derivative with respect to these arguments
912tangents (Tensors): The "vector" for which Jacobian-vector-product is
913computed. Must be the same structure and sizes as the inputs to
914``func``.
915has_aux (bool): Flag indicating that ``func`` returns a
916``(output, aux)`` tuple where the first element is the output of
917the function to be differentiated and the second element is
918other auxiliary objects that will not be differentiated.
919Default: False.
920
921Returns:
922Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
923evaluated at ``primals`` and the Jacobian-vector product.
924If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple.
925
926.. note::
927You may see this API error out with "forward-mode AD not implemented
928for operator X". If so, please file a bug report and we will prioritize it.
929
930jvp is useful when you wish to compute gradients of a function R^1 -> R^N
931
932>>> from torch.func import jvp
933>>> x = torch.randn([])
934>>> f = lambda x: x * torch.tensor([1., 2., 3])
935>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
936>>> assert torch.allclose(value, f(x))
937>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
938
939:func:`jvp` can support functions with multiple inputs by passing in the
940tangents for each of the inputs
941
942>>> from torch.func import jvp
943>>> x = torch.randn(5)
944>>> y = torch.randn(5)
945>>> f = lambda x, y: (x * y)
946>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
947>>> assert torch.allclose(output, x + y)
948
949"""
950
951return _jvp_with_argnums(func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux)952
953
954@doesnt_support_saved_tensors_hooks
955def _jvp_with_argnums(func: Callable, primals: Any, tangents: Any, argnums: Optional[argnums_t], *,956strict: bool = False, has_aux: bool):957# This is the same function as jvp but also accepts an argnums argument958# Most args are the same as jvp except for the added argument959# argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.960# If None, computes the gradients with respect to all inputs (used for jvp). Default: None961# Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is962# given by argnums963#964# WARN: Users should NOT call this function directly and should just be calling jvp.965# It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers.966#967# NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd968#969# Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to970# the primals given by argnums971if not isinstance(primals, tuple):972raise RuntimeError(973f'{jvp_str}: Expected primals to be a tuple. '974f'E.g. it should be valid to call f(*primals).')975diff_args = primals if argnums is None else _slice_argnums(primals, argnums)976flat_primals, primals_spec = tree_flatten(diff_args)977flat_tangents, tangents_spec = tree_flatten(tangents)978if primals_spec != tangents_spec:979raise RuntimeError(980f'{jvp_str}: Expected primals and tangents to have the same python '981f'structure. For example, if primals is a tuple of 3 tensors, '982f'tangents also must be. Got primals with structure {primals_spec} '983f'and tangents with structure {tangents_spec}')984assert_non_empty_list_of_tensors(flat_primals, jvp_str, 'primals')985assert_non_empty_list_of_tensors(flat_tangents, jvp_str, 'tangents')986
987level = _jvp_increment_nesting()988try:989global JVP_NESTING990JVP_NESTING += 1991with fwAD._set_fwd_grad_enabled(True):992ctx = fwAD.dual_level if JVP_NESTING == 1 else noop993with ctx():994flat_duals = tuple(fwAD.make_dual(p, t)995for p, t in zip(flat_primals, flat_tangents))996duals = tree_unflatten(flat_duals, primals_spec)997if argnums is not None:998primals = _wrap_all_tensors(primals, level)999duals = _replace_args(primals, duals, argnums)1000result_duals = func(*duals)1001if has_aux:1002if not (isinstance(result_duals, tuple) and len(result_duals) == 2):1003raise RuntimeError(1004f"{jvp_str}: output of function f should be a tuple: (output, aux) "1005"if has_aux is True"1006)1007result_duals, aux = result_duals1008aux = _undo_create_differentiable(aux, level)1009
1010result_duals, spec = tree_flatten(result_duals)1011assert_non_empty_tensor_output(result_duals, jvp_str)1012
1013primals_out, tangents_out = \1014zip(*[safe_unpack_dual(dual, strict) for dual in result_duals])1015primals_out = tree_map(1016partial(_undo_create_differentiable, level=level), primals_out)1017tangents_out = tree_map(1018partial(_undo_create_differentiable, level=level), tangents_out)1019
1020primals_out_unflatten = tree_unflatten(primals_out, spec)1021tangents_out_unflatten = tree_unflatten(tangents_out, spec)1022if has_aux:1023return primals_out_unflatten, tangents_out_unflatten, aux1024
1025return primals_out_unflatten, tangents_out_unflatten1026finally:1027_jvp_decrement_nesting()1028JVP_NESTING -= 11029
1030
1031def safe_unflatten(tensor, dim, shape):1032if len(shape) == 0:1033assert tensor.shape[dim] == 11034return tensor.squeeze(dim)1035return tensor.unflatten(dim, shape)1036
1037
1038@exposed_in("torch.func")1039def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):1040"""1041Computes the Jacobian of ``func`` with respect to the arg(s) at index
1042``argnum`` using forward-mode autodiff
1043
1044Args:
1045func (function): A Python function that takes one or more arguments,
1046one of which must be a Tensor, and returns one or more Tensors
1047argnums (int or Tuple[int]): Optional, integer or tuple of integers,
1048saying which arguments to get the Jacobian with respect to.
1049Default: 0.
1050has_aux (bool): Flag indicating that ``func`` returns a
1051``(output, aux)`` tuple where the first element is the output of
1052the function to be differentiated and the second element is
1053auxiliary objects that will not be differentiated.
1054Default: False.
1055randomness(str): Flag indicating what type of randomness to use.
1056See :func:`vmap` for more detail. Allowed: "different", "same", "error".
1057Default: "error"
1058
1059Returns:
1060Returns a function that takes in the same inputs as ``func`` and
1061returns the Jacobian of ``func`` with respect to the arg(s) at
1062``argnums``. If ``has_aux is True``, then the returned function
1063instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
1064is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
1065
1066.. note::
1067You may see this API error out with "forward-mode AD not implemented
1068for operator X". If so, please file a bug report and we will prioritize it.
1069An alternative is to use :func:`jacrev`, which has better operator coverage.
1070
1071A basic usage with a pointwise, unary operation will give a diagonal array
1072as the Jacobian
1073
1074>>> from torch.func import jacfwd
1075>>> x = torch.randn(5)
1076>>> jacobian = jacfwd(torch.sin)(x)
1077>>> expected = torch.diag(torch.cos(x))
1078>>> assert torch.allclose(jacobian, expected)
1079
1080:func:`jacfwd` can be composed with vmap to produce batched
1081Jacobians:
1082
1083>>> from torch.func import jacfwd, vmap
1084>>> x = torch.randn(64, 5)
1085>>> jacobian = vmap(jacfwd(torch.sin))(x)
1086>>> assert jacobian.shape == (64, 5, 5)
1087
1088If you would like to compute the output of the function as well as the
1089jacobian of the function, use the ``has_aux`` flag to return the output
1090as an auxiliary object:
1091
1092>>> from torch.func import jacfwd
1093>>> x = torch.randn(5)
1094>>>
1095>>> def f(x):
1096>>> return x.sin()
1097>>>
1098>>> def g(x):
1099>>> result = f(x)
1100>>> return result, result
1101>>>
1102>>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x)
1103>>> assert torch.allclose(f_x, f(x))
1104
1105Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
1106to produce Hessians
1107
1108>>> from torch.func import jacfwd, jacrev
1109>>> def f(x):
1110>>> return x.sin().sum()
1111>>>
1112>>> x = torch.randn(5)
1113>>> hessian = jacfwd(jacrev(f))(x)
1114>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
1115
1116By default, :func:`jacfwd` computes the Jacobian with respect to the first
1117input. However, it can compute the Jacboian with respect to a different
1118argument by using ``argnums``:
1119
1120>>> from torch.func import jacfwd
1121>>> def f(x, y):
1122>>> return x + y ** 2
1123>>>
1124>>> x, y = torch.randn(5), torch.randn(5)
1125>>> jacobian = jacfwd(f, argnums=1)(x, y)
1126>>> expected = torch.diag(2 * y)
1127>>> assert torch.allclose(jacobian, expected)
1128
1129Additionally, passing a tuple to ``argnums`` will compute the Jacobian
1130with respect to multiple arguments
1131
1132>>> from torch.func import jacfwd
1133>>> def f(x, y):
1134>>> return x + y ** 2
1135>>>
1136>>> x, y = torch.randn(5), torch.randn(5)
1137>>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
1138>>> expectedX = torch.diag(torch.ones_like(x))
1139>>> expectedY = torch.diag(2 * y)
1140>>> assert torch.allclose(jacobian[0], expectedX)
1141>>> assert torch.allclose(jacobian[1], expectedY)
1142
1143"""
1144@wraps(func)1145def wrapper_fn(*args):1146error_if_complex("jacfwd", args, is_input=True)1147primals = args if argnums is None else _slice_argnums(args, argnums)1148flat_primals, primals_spec = tree_flatten(primals)1149flat_primals_numels = tuple(p.numel() for p in flat_primals)1150flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)1151basis = tree_unflatten(flat_basis, primals_spec)1152
1153def push_jvp(basis):1154output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux)1155# output[0] is the output of `func(*args)`1156error_if_complex("jacfwd", output[0], is_input=False)1157if has_aux:1158_, jvp_out, aux = output1159return jvp_out, aux1160_, jvp_out = output1161return jvp_out1162
1163results = vmap(push_jvp, randomness=randomness)(basis)1164if has_aux:1165results, aux = results1166# aux is in the standard basis format, e.g. NxN matrix1167# We need to fetch the first element as original `func` output1168flat_aux, aux_spec = tree_flatten(aux)1169flat_aux = [value[0] for value in flat_aux]1170aux = tree_unflatten(flat_aux, aux_spec)1171
1172jac_outs, spec = tree_flatten(results)1173# Most probably below output check can never raise an error1174# as jvp should test the output before1175# assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')1176
1177jac_outs_ins = tuple(1178tuple(1179safe_unflatten(jac_out_in, -1, primal.shape)1180for primal, jac_out_in in1181zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1))1182)1183for jac_out in jac_outs1184)1185jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins)1186
1187if isinstance(argnums, int):1188jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)1189if has_aux:1190return tree_unflatten(jac_outs_ins, spec), aux1191return tree_unflatten(jac_outs_ins, spec)1192return wrapper_fn1193
1194
1195@exposed_in("torch.func")1196def hessian(func, argnums=0):1197"""1198Computes the Hessian of ``func`` with respect to the arg(s) at index
1199``argnum`` via a forward-over-reverse strategy.
1200
1201The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
1202a good default for good performance. It is possible to compute Hessians
1203through other compositions of :func:`jacfwd` and :func:`jacrev` like
1204``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
1205
1206Args:
1207func (function): A Python function that takes one or more arguments,
1208one of which must be a Tensor, and returns one or more Tensors
1209argnums (int or Tuple[int]): Optional, integer or tuple of integers,
1210saying which arguments to get the Hessian with respect to.
1211Default: 0.
1212
1213Returns:
1214Returns a function that takes in the same inputs as ``func`` and
1215returns the Hessian of ``func`` with respect to the arg(s) at
1216``argnums``.
1217
1218.. note::
1219You may see this API error out with "forward-mode AD not implemented
1220for operator X". If so, please file a bug report and we will prioritize it.
1221An alternative is to use ``jacrev(jacrev(func))``, which has better
1222operator coverage.
1223
1224A basic usage with a R^N -> R^1 function gives a N x N Hessian:
1225
1226>>> from torch.func import hessian
1227>>> def f(x):
1228>>> return x.sin().sum()
1229>>>
1230>>> x = torch.randn(5)
1231>>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x)
1232>>> assert torch.allclose(hess, torch.diag(-x.sin()))
1233
1234"""
1235return jacfwd(jacrev(func, argnums), argnums)1236
1237
1238@doesnt_support_saved_tensors_hooks
1239def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable:1240with grad_increment_nesting() as level:1241output, aux, grad_input = None, None, None1242# See NOTE [grad and vjp interaction with no_grad]1243with torch.enable_grad():1244args = _wrap_all_tensors(args, level)1245kwargs = _wrap_all_tensors(kwargs, level)1246diff_args = _slice_argnums(args, argnums, as_tuple=False)1247tree_map_(partial(_create_differentiable, level=level), diff_args)1248
1249output = func(*args, **kwargs)1250if has_aux:1251if not (isinstance(output, tuple) and len(output) == 2):1252raise RuntimeError(1253"grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "1254"if has_aux is True"1255)1256output, aux = output1257
1258if not isinstance(output, torch.Tensor):1259raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '1260f'to return a Tensor, got {type(output)}')1261if output.dim() != 0:1262raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '1263'to return a scalar Tensor, got tensor with '1264f'{output.dim()} dims. Maybe you wanted to '1265'use the vjp or jacrev APIs instead?')1266
1267flat_diff_args, spec = tree_flatten(diff_args)1268
1269# NB: need create_graph so that backward pass isn't run in no_grad mode1270flat_outputs = _as_tuple(output)1271flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)1272grad_input = tree_unflatten(flat_grad_input, spec)1273
1274grad_input = _undo_create_differentiable(grad_input, level)1275output = _undo_create_differentiable(output, level)1276if has_aux:1277aux = _undo_create_differentiable(aux, level)1278
1279if has_aux:1280return grad_input, (output, aux)1281return grad_input, output1282
1283
1284def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):1285results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)1286if has_aux:1287grad, (_, aux) = results1288return grad, aux1289grad, _ = results1290return grad1291
1292def _maybe_wrap_functional_tensor(maybe_tensor, level, *, _python_functionalize: bool = False):1293if not isinstance(maybe_tensor, torch.Tensor):1294return maybe_tensor1295wrapped = _wrap_functional_tensor(maybe_tensor, level)1296_assert_wrapped_functional(maybe_tensor, wrapped)1297if _python_functionalize:1298out = FunctionalTensor(wrapped)1299torch._mirror_autograd_meta_to(maybe_tensor, out)1300return out1301return wrapped1302
1303
1304def _wrap_all_tensors_to_functional(tensor_pytree, level, *, _python_functionalize: bool = False):1305return tree_map(partial(lambda x: _maybe_wrap_functional_tensor(1306x, level, _python_functionalize=_python_functionalize)), tensor_pytree)1307
1308
1309def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool):1310if not isinstance(maybe_tensor, torch.Tensor):1311return maybe_tensor1312if isinstance(maybe_tensor, FunctionalTensor):1313maybe_tensor = maybe_tensor.elem1314
1315if not torch._is_functional_tensor(maybe_tensor):1316# If it's not a functional tensor, just return it.1317# This can happen if we functionalize a fn that returns a global,1318# which was never wrapped properly.1319return maybe_tensor1320# Sync any pending updates on the output tensor1321torch._sync(maybe_tensor)1322return _unwrap_functional_tensor(maybe_tensor, reapply_views)1323
1324
1325def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool):1326return tree_map(lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), tensor_pytree)1327
1328
1329@exposed_in("torch.func")1330def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:1331"""1332functionalize is a transform that can be used to remove (intermediate)
1333mutations and aliasing from a function, while preserving the function's
1334semantics.
1335
1336``functionalize(func)`` returns a new function with the same semantics
1337as ``func``, but with all intermediate mutations removed.
1338Every inplace operation performed on an intermediate tensor:
1339``intermediate.foo_()``
1340gets replaced by its out-of-place equivalent:
1341``intermediate_updated = intermediate.foo()``.
1342
1343functionalize is useful for shipping a pytorch program off to
1344backends or compilers that aren't able to easily represent
1345mutations or aliasing operators.
1346
1347Args:
1348func (Callable): A Python function that takes one or more arguments.
1349remove (str): An optional string argument, that takes on either
1350the value 'mutations' or 'mutations_and_views'.
1351If 'mutations' is passed in then all mutating operators
1352will be replaced with their non-mutating equivalents.
1353If 'mutations_and_views' is passed in, then additionally, all aliasing
1354operators will be replaced with their non-aliasing equivalents.
1355Default: 'mutations'.
1356
1357Returns:
1358Returns a new "functionalized" function. It takes the same inputs as
1359``func``, and has the same behavior, but any mutations
1360(and optionally aliasing) performed on intermediate tensors
1361in the function will be removed.
1362
1363functionalize will also remove mutations (and views) that were performed on function inputs.
1364However to preserve semantics, functionalize will "fix up" the mutations after
1365the transform has finished running, by detecting if any tensor inputs "should have"
1366been mutated, and copying the new data back to the inputs if necessary.
1367
1368
1369Example::
1370
1371>>> # xdoctest: +SKIP
1372>>> import torch
1373>>> from torch.fx.experimental.proxy_tensor import make_fx
1374>>> from torch.func import functionalize
1375>>>
1376>>> # A function that uses mutations and views, but only on intermediate tensors.
1377>>> def f(a):
1378... b = a + 1
1379... c = b.view(-1)
1380... c.add_(1)
1381... return b
1382...
1383>>> inpt = torch.randn(2)
1384>>>
1385>>> out1 = f(inpt)
1386>>> out2 = functionalize(f)(inpt)
1387>>>
1388>>> # semantics are the same (outputs are equivalent)
1389>>> print(torch.allclose(out1, out2))
1390True
1391>>>
1392>>> f_traced = make_fx(f)(inpt)
1393>>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
1394>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1395>>>
1396>>> print(f_traced.code)
1397
1398
1399
1400def forward(self, a_1):
1401add = torch.ops.aten.add(a_1, 1); a_1 = None
1402view = torch.ops.aten.view(add, [-1])
1403add_ = torch.ops.aten.add_(view, 1); view = None
1404return add
1405
1406>>> print(f_no_mutations_traced.code)
1407
1408
1409
1410def forward(self, a_1):
1411add = torch.ops.aten.add(a_1, 1); a_1 = None
1412view = torch.ops.aten.view(add, [-1]); add = None
1413add_1 = torch.ops.aten.add(view, 1); view = None
1414view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None
1415return view_1
1416
1417>>> print(f_no_mutations_and_views_traced.code)
1418
1419
1420
1421def forward(self, a_1):
1422add = torch.ops.aten.add(a_1, 1); a_1 = None
1423view_copy = torch.ops.aten.view_copy(add, [-1]); add = None
1424add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None
1425view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None
1426return view_copy_1
1427
1428
1429>>> # A function that mutates its input tensor
1430>>> def f(a):
1431... b = a.view(-1)
1432... b.add_(1)
1433... return a
1434...
1435>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1436>>> #
1437>>> # All mutations and views have been removed,
1438>>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
1439>>> # after the function has completed.
1440>>> print(f_no_mutations_and_views_traced.code)
1441
1442
1443
1444def forward(self, a_1):
1445view_copy = torch.ops.aten.view_copy(a_1, [-1])
1446add = torch.ops.aten.add(view_copy, 1); view_copy = None
1447view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None
1448copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None
1449return view_copy_1
1450
1451
1452There are a few "failure modes" for functionalize that are worth calling out:
1453(1) Like other torch.func transforms, `functionalize()` doesn't work with functions
1454that directly use `.backward()`. The same is true for torch.autograd.grad.
1455If you want to use autograd, you can compute gradients directly
1456with `functionalize(grad(f))`.
1457(2) Like other torch.func transforms, `functionalize()` doesn't work with global state.
1458If you call `functionalize(f)` on a function that takes views / mutations of
1459non-local state, functionalization will simply no-op and pass the view/mutation
1460calls directly to the backend.
1461One way to work around this is is to ensure that any non-local state creation
1462is wrapped into a larger function, which you then call functionalize on.
1463(3) `resize_()` has some limitations: functionalize will only work on programs
1464that use resize_()` as long as the tensor being resized is not a view.
1465(4) `as_strided()` has some limitations: functionalize will not work on
1466`as_strided()` calls that result in tensors with overlapping memory.
1467
1468
1469Finally, a helpful mental model for understanding functionalization is that
1470most user pytorch programs are writing with the public torch API.
1471When executed, torch operators are generally decomposed into
1472our internal C++ "ATen" API.
1473The logic for functionalization happens entirely at the level of ATen.
1474Functionalization knows how to take every aliasing operator in ATen,
1475and map it to its non-aliasing equivalent
1476(e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``),
1477and how to take every mutating operator in ATen,
1478and map it to its non-mutating equivalent
1479(e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``),
1480while tracking aliases and mutations out-of-line to know when to fix things up.
1481Information about which ATen operators are aliasing or mutating all comes from
1482https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.
1483"""
1484if remove == 'mutations':1485reapply_views = True1486elif remove == 'mutations_and_views':1487reapply_views = False1488else:1489raise RuntimeError(1490f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}."1491" Valid options are:\n"1492" remove='mutations': all inplace and out= operators will be removed from the program, and replaced"1493" with their out-of-place equivalents.\n"1494" remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be"1495" replaced with their non-aliasing counterparts, {view}_copy.\n"1496)1497
1498@doesnt_support_saved_tensors_hooks1499@wraps(func)1500def wrapped(*args, **kwargs):1501try:1502func_level = _func_increment_nesting(reapply_views)1503func_args = _wrap_all_tensors_to_functional(args, func_level)1504func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)1505
1506flattened_unwrapped_args = pytree.arg_tree_leaves(*args)1507flattened_wrapped_args = pytree.arg_tree_leaves(*func_args)1508flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs)1509flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs)1510
1511func_outputs = func(*func_args, **func_kwargs)1512outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views)1513flat_outputs, func_out_spec = tree_flatten(outputs)1514
1515for a in flattened_wrapped_args + flattened_wrapped_kwargs:1516if isinstance(a, torch.Tensor):1517# Call sync_() on the inputs, to ensure that any pending mutations have been applied.1518torch._sync(a)1519
1520# And if any mutations were applied to the inputs, we need to propagate them back to the user.1521for unwrapped, wrapped in zip(flattened_unwrapped_args, flattened_wrapped_args):1522if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):1523_propagate_functional_input_mutation(unwrapped, wrapped)1524for unwrapped, wrapped in zip(flattened_unwrapped_kwargs, flattened_wrapped_kwargs):1525if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):1526_propagate_functional_input_mutation(unwrapped, wrapped)1527
1528return outputs1529finally:1530_func_decrement_nesting()1531return wrapped1532
1533@exposed_in("torch.func")1534def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:1535'''1536Returns the value of ``func`` at ``primals`` and linear approximation
1537at ``primals``.
1538
1539Args:
1540func (Callable): A Python function that takes one or more arguments.
1541primals (Tensors): Positional arguments to ``func`` that must all be
1542Tensors. These are the values at which the function is linearly approximated.
1543
1544Returns:
1545Returns a ``(output, jvp_fn)`` tuple containing the output of ``func``
1546applied to ``primals`` and a function that computes the jvp of
1547``func`` evaluated at ``primals``.
1548
1549linearize is useful if jvp is to be computed multiple times at ``primals``. However,
1550to achieve this, linearize saves intermediate computation and has higher memory requirements
1551than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient
1552to compute vmap(jvp) instead of using linearize.
1553
1554.. note::
1555linearize evaluates ``func`` twice. Please file an issue for an implementation
1556with a single evaluation.
1557
1558Example::
1559>>> import torch
1560>>> from torch.func import linearize
1561>>> def fn(x):
1562... return x.sin()
1563...
1564>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
1565>>> jvp_fn(torch.ones(3, 3))
1566tensor([[1., 1., 1.],
1567[1., 1., 1.],
1568[1., 1., 1.]])
1569>>>
1570
1571'''
1572# Note: We evaluate `fn` twice.1573# Once for returning the output and other while1574# tracing the graph.1575# If this becomes a bottle-neck, we should update1576# make_fx such that it also returns the output.1577
1578output = func(*primals)1579_, output_spec = tree_flatten(output)1580
1581flat_primals, primals_argspec = tree_flatten(primals)1582
1583# tangents for tracing1584flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals)1585
1586# function to trace1587def trace_fn(flat_tangents):1588with fwAD.dual_level():1589flat_duals = tuple(fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents))1590duals = tree_unflatten(flat_duals, primals_argspec)1591output = func(*duals)1592tangents = tree_map_only(torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output)1593
1594return tangents1595
1596jvp_graph = make_fx(trace_fn)(flat_tangents)1597const_folded_jvp_graph = const_fold.split_const_subgraphs(jvp_graph)1598
1599# Hold only the meta-data regarding the primals.1600flat_primals_shape = tuple(p.shape for p in flat_primals)1601flat_primals_device = tuple(p.device for p in flat_primals)1602flat_primals_dtype = tuple(p.dtype for p in flat_primals)1603
1604def forward_ad_checks(flat_tangents):1605for idx, t in enumerate(flat_tangents):1606if t.shape != flat_primals_shape[idx]:1607msg = (f"tangent:{idx} with shape {t.shape} in flattened "1608f"pytree doesn't match the shape {flat_primals_shape[idx]} "1609"of the corresponding primal.")1610raise RuntimeError(msg)1611
1612if t.device != flat_primals_device[idx]:1613msg = (f"tangent:{idx} with device {t.device} in flattened "1614f"pytree doesn't match the device {flat_primals_device[idx]} "1615"of the corresponding primal.")1616raise RuntimeError(msg)1617
1618if t.dtype != flat_primals_dtype[idx]:1619msg = (f"tangent:{idx} with dtype {t.dtype} in flattened "1620f"pytree doesn't match the dtype {flat_primals_dtype[idx]} "1621"of the corresponding primal.")1622raise RuntimeError(msg)1623
1624# jvp_fn : callable to return1625# It takes care of checking the argspec of tangents,1626# calling the folded fx graph and unflattening fx graph output1627def jvp_fn(*tangents):1628flat_tangents, tangent_argspec = tree_flatten(tangents)1629if tangent_argspec != primals_argspec:1630raise RuntimeError(f"Expected the tangents {tangent_argspec} to have "1631f"the same argspec as the primals {primals_argspec}")1632
1633forward_ad_checks(flat_tangents)1634
1635flat_output = const_folded_jvp_graph(*flat_tangents)1636# const folded graph can return flat output,1637# so transform output.1638return tree_unflatten(flat_output, output_spec)1639
1640return output, jvp_fn1641