3
``torch.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions.
5
It requires minimal changes to the existing code - you only need to declare :class:`Tensor` s
6
for which gradients should be computed with the ``requires_grad=True`` keyword.
7
As of now, we only support autograd for floating point :class:`Tensor` types (
8
half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
12
from typing import cast, List, Optional, Sequence, Tuple, Union
15
from torch import _vmap_internals
16
from torch.overrides import handle_torch_function, has_torch_function, is_tensor_like
17
from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge
19
from . import forward_ad, functional, graph
20
from .anomaly_mode import detect_anomaly, set_detect_anomaly
21
from .function import Function, NestedIOFunction
22
from .grad_mode import (
23
_force_original_view_tracking,
24
_unsafe_preserve_version_counter,
29
set_multithreading_enabled,
31
from .gradcheck import gradcheck, gradgradcheck
32
from .graph import _engine_run_backward
33
from .variable import Variable
51
"set_multithreading_enabled",
55
_OptionalTensor = Optional[torch.Tensor]
56
_ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
60
output: Union[torch.Tensor, graph.GradientEdge],
62
is_grads_batched: bool,
63
) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
66
from torch.nested._internal.nested_tensor import NestedTensor
68
if isinstance(output, graph.GradientEdge):
71
raise RuntimeError("Batched grads are not supported with GradientEdge")
72
out_metadata = output.node._input_metadata[output.output_nr]
73
return torch.Size(out_metadata.shape), grad.shape
75
if output.is_nested and not isinstance(output, NestedTensor):
77
raise RuntimeError("Batched grads are not supported with Nested Tensor.")
78
out_shape = output._nested_tensor_size()
79
grad_shape = grad._nested_tensor_size()
81
return out_shape, grad_shape
83
reg_out_shape = output.shape
84
reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
85
return reg_out_shape, reg_grad_shape
89
outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
90
grads: Sequence[_OptionalTensor],
91
is_grads_batched: bool,
92
) -> Tuple[_OptionalTensor, ...]:
93
new_grads: List[_OptionalTensor] = []
94
for out, grad in zip(outputs, grads):
95
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
99
if isinstance(out, graph.GradientEdge):
100
out_metadata = out.node._input_metadata[out.output_nr]
101
out_size = torch.Size(out_metadata.shape)
102
out_dtype = out_metadata.dtype
103
out_device = out_metadata.device
104
out_is_nested = out_metadata.is_nested_tensor
105
if out_metadata.is_cpp_nested_tensor:
107
"C++ NestedTensor are not supported with GradientEdge"
109
out_is_cpp_nested = False
112
from torch.nested._internal.nested_tensor import NestedTensor
114
assert isinstance(out, torch.Tensor)
115
out_dtype = out.dtype
116
out_is_nested = out.is_nested
117
out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor)
118
if not out_is_cpp_nested:
121
if isinstance(grad, torch.Tensor):
122
from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
124
first_grad = grad if not is_grads_batched else grad[0]
129
if out_is_cpp_nested:
130
assert isinstance(out, torch.Tensor)
131
shape_matches = torch.is_same_size(out, first_grad)
136
assert out_size is not None
137
shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
139
if not shape_matches:
140
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
141
out_shape, grad_shape = _calculate_shape(
142
out, first_grad, is_grads_batched
146
"If `is_grads_batched=True`, we interpret the first "
147
"dimension of each grad_output as the batch dimension. "
148
"The sizes of the remaining dimensions are expected to match "
149
"the shape of corresponding output, but a mismatch "
150
"was detected: grad_output["
151
+ str(grads.index(grad))
152
+ "] has a shape of "
155
+ str(outputs.index(out))
156
+ "] has a shape of "
159
"If you only want some tensors in `grad_output` to be considered "
160
"batched, consider using vmap."
164
"Mismatch in shape: grad_output["
165
+ str(grads.index(grad))
166
+ "] has a shape of "
169
+ str(outputs.index(out))
170
+ "] has a shape of "
174
if out_dtype.is_complex != grad.dtype.is_complex:
176
"For complex Tensors, both grad_output and output"
177
" are required to have the same dtype."
178
" Mismatch in dtype: grad_output["
179
+ str(grads.index(grad))
180
+ "] has a dtype of "
183
+ str(outputs.index(out))
184
+ "] has a dtype of "
188
new_grads.append(grad)
190
if isinstance(out, graph.GradientEdge) or out.requires_grad:
191
if isinstance(out, graph.GradientEdge):
192
assert out_size is not None
193
out_numel_is_1 = all(o == 1 for o in out_size)
195
assert isinstance(out, torch.Tensor)
196
out_numel_is_1 = out.numel() == 1
197
if not out_numel_is_1:
199
"grad can be implicitly created only for scalar outputs"
201
if not out_dtype.is_floating_point:
203
"grad can be implicitly created only for real scalar outputs"
204
f" but got {out_dtype}"
206
raise RuntimeError(msg)
207
if isinstance(out, graph.GradientEdge):
208
assert out_size is not None
209
assert out_device is not None
218
assert isinstance(out, torch.Tensor)
220
torch.ones_like(out, memory_format=torch.preserve_format)
223
new_grads.append(None)
226
"gradients can be either Tensors or None, but got "
227
+ type(grad).__name__
229
return tuple(new_grads)
232
def _tensor_or_tensors_to_tuple(
233
tensors: Optional[_TensorOrTensors], length: int
234
) -> Tuple[_OptionalTensor, ...]:
236
return (None,) * length
237
if isinstance(tensors, torch.Tensor):
239
return tuple(tensors)
243
tensors: _TensorOrTensors,
244
grad_tensors: Optional[_TensorOrTensors] = None,
245
retain_graph: Optional[bool] = None,
246
create_graph: bool = False,
247
grad_variables: Optional[_TensorOrTensors] = None,
248
inputs: Optional[_TensorOrTensorsOrGradEdge] = None,
250
r"""Compute the sum of gradients of given tensors with respect to graph leaves.
252
The graph is differentiated using the chain rule. If any of ``tensors``
253
are non-scalar (i.e. their data has more than one element) and require
254
gradient, then the Jacobian-vector product would be computed, in this
255
case the function additionally requires specifying ``grad_tensors``.
256
It should be a sequence of matching length, that contains the "vector"
257
in the Jacobian-vector product, usually the gradient of the differentiated
258
function w.r.t. corresponding tensors (``None`` is an acceptable value for
259
all tensors that don't need gradient tensors).
261
This function accumulates gradients in the leaves - you might need to zero
262
``.grad`` attributes or set them to ``None`` before calling it.
263
See :ref:`Default gradient layouts<default-grad-layouts>`
264
for details on the memory layout of accumulated gradients.
267
Using this method with ``create_graph=True`` will create a reference cycle
268
between the parameter and its gradient which can cause a memory leak.
269
We recommend using ``autograd.grad`` when creating the graph to avoid this.
270
If you have to use this function, make sure to reset the ``.grad`` fields of your
271
parameters to ``None`` after use to break the cycle and avoid the leak.
275
If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
276
in a user-specified CUDA stream context, see
277
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
281
When ``inputs`` are provided and a given input is not a leaf,
282
the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
283
It is an implementation detail on which the user should not rely.
284
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
287
tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
289
grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
290
the Jacobian-vector product, usually gradients w.r.t. each element of
291
corresponding tensors. None values can be specified for scalar Tensors or
292
ones that don't require grad. If a None value would be acceptable for all
293
grad_tensors, then this argument is optional.
294
retain_graph (bool, optional): If ``False``, the graph used to compute the grad
295
will be freed. Note that in nearly all cases setting this option to ``True``
296
is not needed and often can be worked around in a much more efficient
297
way. Defaults to the value of ``create_graph``.
298
create_graph (bool, optional): If ``True``, graph of the derivative will
299
be constructed, allowing to compute higher order derivative products.
300
Defaults to ``False``.
301
inputs (Sequence[Tensor] or Tensor or Sequence[GradientEdge], optional): Inputs w.r.t. which the gradient
302
be will accumulated into ``.grad``. All other Tensors will be ignored. If
303
not provided, the gradient is accumulated into all the leaf Tensors that
304
were used to compute the :attr:`tensors`.
306
if torch._C._are_functorch_transforms_active():
308
"backward() called inside a functorch transform. This is not "
309
"supported, please use functorch.grad or functorch.vjp instead "
310
"or call backward() outside of functorch transforms."
313
if grad_variables is not None:
315
"`grad_variables` is deprecated. Use `grad_tensors` instead.",
319
if grad_tensors is None:
320
grad_tensors = grad_variables
323
"`grad_tensors` and `grad_variables` (deprecated) "
324
"arguments both passed to `backward()`. Please only "
325
"use `grad_tensors`."
327
if inputs is not None and len(inputs) == 0:
328
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
330
tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
333
if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
335
if inputs is not None
339
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
340
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
341
if retain_graph is None:
342
retain_graph = create_graph
347
_engine_run_backward(
353
allow_unreachable=True,
354
accumulate_grad=True,
359
outputs: _TensorOrTensorsOrGradEdge,
360
inputs: _TensorOrTensorsOrGradEdge,
361
grad_outputs: Optional[_TensorOrTensors] = None,
362
retain_graph: Optional[bool] = None,
363
create_graph: bool = False,
364
only_inputs: bool = True,
365
allow_unused: Optional[bool] = None,
366
is_grads_batched: bool = False,
367
materialize_grads: bool = False,
368
) -> Tuple[torch.Tensor, ...]:
369
r"""Compute and return the sum of gradients of outputs with respect to the inputs.
371
``grad_outputs`` should be a sequence of length matching ``output``
372
containing the "vector" in vector-Jacobian product, usually the pre-computed
373
gradients w.r.t. each of the outputs. If an output doesn't require_grad,
374
then the gradient can be ``None``).
378
If you run any forward ops, create ``grad_outputs``, and/or call ``grad``
379
in a user-specified CUDA stream context, see
380
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
384
``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``).
385
To accumulate gradient for other parts of the graph, please use
386
``torch.autograd.backward``.
389
outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function.
390
inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be
391
returned (and not accumulated into ``.grad``).
392
grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
393
Usually gradients w.r.t. each output. None values can be specified for scalar
394
Tensors or ones that don't require grad. If a None value would be acceptable
395
for all grad_tensors, then this argument is optional. Default: None.
396
retain_graph (bool, optional): If ``False``, the graph used to compute the grad
397
will be freed. Note that in nearly all cases setting this option to ``True``
398
is not needed and often can be worked around in a much more efficient
399
way. Defaults to the value of ``create_graph``.
400
create_graph (bool, optional): If ``True``, graph of the derivative will
401
be constructed, allowing to compute higher order derivative products.
403
allow_unused (Optional[bool], optional): If ``False``, specifying inputs
404
that were not used when computing outputs (and therefore their grad is
405
always zero) is an error. Defaults to the value of ``materialize_grads``.
406
is_grads_batched (bool, optional): If ``True``, the first dimension of each
407
tensor in ``grad_outputs`` will be interpreted as the batch dimension.
408
Instead of computing a single vector-Jacobian product, we compute a
409
batch of vector-Jacobian products for each "vector" in the batch.
410
We use the vmap prototype feature as the backend to vectorize calls
411
to the autograd engine so that this computation can be performed in a
412
single call. This should lead to performance improvements when compared
413
to manually looping and performing backward multiple times. Note that
414
due to this feature being experimental, there may be performance
415
cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
416
to show any performance warnings and file an issue on github if warnings exist
417
for your use case. Defaults to ``False``.
418
materialize_grads (bool, optional): If ``True``, set the gradient for unused inputs
419
to zero instead of None. This is useful when computing higher-order derivatives.
420
If ``materialize_grads`` is ``True`` and ``allow_unused`` is ``False``, an error
421
will be raised. Defaults to ``False``.
424
if materialize_grads and allow_unused is False:
426
"Expected allow_unused to be True or not passed when materialize_grads=True, "
427
"but got: allow_unused=False."
429
if allow_unused is None:
430
allow_unused = materialize_grads
431
if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge):
433
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
436
outputs = tuple(outputs)
437
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
438
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
440
inputs = tuple(inputs)
441
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
442
t_inputs = tuple(i for i in inputs if is_tensor_like(i))
443
overridable_args = t_outputs + t_inputs
444
if has_torch_function(overridable_args):
445
return handle_torch_function(
450
grad_outputs=grad_outputs,
451
retain_graph=retain_graph,
452
create_graph=create_graph,
453
only_inputs=only_inputs,
454
allow_unused=allow_unused,
455
is_grads_batched=is_grads_batched,
456
materialize_grads=materialize_grads,
461
"only_inputs argument is deprecated and is ignored now "
462
"(defaults to True). To accumulate gradient for other "
463
"parts of the graph, please use torch.autograd.backward.",
468
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
469
grad_outputs_ = _make_grads(
470
outputs, grad_outputs_, is_grads_batched=is_grads_batched
473
if retain_graph is None:
474
retain_graph = create_graph
482
return _engine_run_backward(
489
accumulate_grad=False,
492
result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
496
result = _engine_run_backward(
503
accumulate_grad=False,
505
if materialize_grads:
507
result[i] is None and not is_tensor_like(inputs[i])
508
for i in range(len(inputs))
511
"materialize_grads cannot be used when the given input is a GradientEdge"
515
if output is not None
516
else torch.zeros_like(input, requires_grad=True)
517
for (output, input) in zip(result, inputs)
535
def _is_checkpoint_valid():
536
return Variable._execution_engine.is_checkpoint_valid()
539
def variable(*args, **kwargs):
541
"torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead"
549
variable.Variable = Variable
551
if not torch._C._autograd_init():
552
raise RuntimeError("autograd initialization failed")
555
from torch._C._autograd import (
558
_disable_profiler_legacy,
560
_enable_profiler_legacy,
561
_enable_record_function,
565
_pop_saved_tensors_default_hooks,
569
_push_saved_tensors_default_hooks,
570
_record_function_with_args_enter,
571
_record_function_with_args_exit,
572
_set_empty_test_observer,
573
_supported_activities,
574
_toggle_collection_dynamic,
580
from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
582
from . import profiler
585
def _register_py_tensor_class_for_device(device, cls):
586
if not isinstance(cls, type):
587
raise RuntimeError("cls isn't a typeinfo object")
588
torch._C._register_py_class_for_device(device, cls)
591
is_multithreading_enabled = torch._C._is_multithreading_enabled
593
is_multithreading_enabled, "Returns True if multithreading is currently enabled."
596
is_view_replay_enabled = torch._C._is_view_replay_enabled
598
is_view_replay_enabled, "Returns True if view-replay is currently enabled."