3
from collections import namedtuple
8
from .grad_mode import _DecoratorContextManager
24
def enter_dual_level():
25
r"""Enter a new forward grad level.
27
This level can be used to make and unpack dual Tensors to compute
30
This function also updates the current level that is used by default
31
by the other functions in this API.
34
new_level = torch._C._enter_dual_level()
35
if new_level != _current_level + 1:
37
"Entering a new forward AD level but the current level "
38
"is not valid. Make sure you did not modified it directly."
40
_current_level = new_level
44
def exit_dual_level(*, level=None):
45
r"""Exit a forward grad level.
47
This function deletes all the gradients associated with this
48
level. Only deleting the latest entered level is allowed.
50
This function also updates the current level that is used by default
51
by the other functions in this API.
55
level = _current_level
56
if level != _current_level:
58
"Trying to exit a forward AD level that was not the last one "
59
"that was created. This is not supported."
61
torch._C._exit_dual_level(level=level)
62
_current_level = level - 1
65
def _maybe_load_decompositions():
66
if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__:
67
from torch._decomp import decompositions_for_jvp
70
def make_dual(tensor, tangent, *, level=None):
71
r"""Associate a tensor value with its tangent to create a "dual tensor" for forward AD gradient computation.
73
The result is a new tensor aliased to :attr:`tensor` with :attr:`tangent` embedded
74
as an attribute as-is if it has the same storage layout or copied otherwise.
75
The tangent attribute can be recovered with :func:`unpack_dual`.
77
This function is backward differentiable.
79
Given a function `f` whose jacobian is `J`, it allows one to compute the Jacobian-vector product (`jvp`)
80
between `J` and a given vector `v` as follows.
84
>>> # xdoctest: +SKIP("Undefined variables")
85
>>> with dual_level():
86
... inp = make_dual(x, v)
88
... y, jvp = unpack_dual(out)
90
Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
91
for detailed steps on how to use this API.
110
_maybe_load_decompositions()
113
level = _current_level
117
"Trying to create a dual Tensor for forward AD but no level "
118
"exists, make sure to enter_dual_level() first."
120
if not (tensor.is_floating_point() or tensor.is_complex()):
122
f"Expected primal to be floating point or complex, but got: {tensor.dtype}"
124
if not (tangent.is_floating_point() or tangent.is_complex()):
126
f"Expected tangent to be floating point or complex, but got: {tangent.dtype}"
129
return torch._VF._make_dual(tensor, tangent, level=level)
132
_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])
135
class UnpackedDualTensor(_UnpackedDualTensor):
136
r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
138
See :func:`unpack_dual` for more details.
143
def unpack_dual(tensor, *, level=None):
144
r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient.
146
The result is a namedtuple ``(primal, tangent)`` where ``primal`` is a view of
147
:attr:`tensor`'s primal and ``tangent`` is :attr:`tensor`'s tangent as-is.
148
Neither of these tensors can be dual tensor of level :attr:`level`.
150
This function is backward differentiable.
154
>>> # xdoctest: +SKIP("Undefined variables")
155
>>> with dual_level():
156
... inp = make_dual(x, x_t)
158
... y, jvp = unpack_dual(out)
159
... jvp = unpack_dual(out).tangent
161
Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
162
for detailed steps on how to use this API.
165
level = _current_level
168
return UnpackedDualTensor(tensor, None)
170
primal, dual = torch._VF._unpack_dual(tensor, level=level)
172
return UnpackedDualTensor(primal, dual)
175
class dual_level(_DecoratorContextManager):
176
r"""Context-manager for forward AD, where all forward AD computation must occur within the ``dual_level`` context.
180
The ``dual_level`` context appropriately enters and exit the dual level to
181
controls the current forward AD level, which is used by default by the other
182
functions in this API.
184
We currently don't plan to support nested ``dual_level`` contexts, however, so
185
only a single forward AD level is supported. To compute higher-order
186
forward grads, one can use :func:`torch.func.jvp`.
190
>>> # xdoctest: +SKIP("Undefined variables")
191
>>> x = torch.tensor([1])
192
>>> x_t = torch.tensor([1])
193
>>> with dual_level():
194
... inp = make_dual(x, x_t)
195
... # Do computations with inp
196
... out = your_fn(inp)
197
... _, grad = unpack_dual(out)
200
>>> # After exiting the level, the grad is deleted
201
>>> _, grad_after = unpack_dual(out)
205
Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
206
for detailed steps on how to use this API.
210
return enter_dual_level()
212
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
217
_is_fwd_grad_enabled = torch._C._is_fwd_grad_enabled
222
class _set_fwd_grad_enabled(_DecoratorContextManager):
223
def __init__(self, mode: bool) -> None:
224
self.prev = _is_fwd_grad_enabled()
225
torch._C._set_fwd_grad_enabled(mode)
227
def __enter__(self) -> None:
230
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
231
torch._C._set_fwd_grad_enabled(self.prev)