5
from torch.utils._contextlib import (
6
_DecoratorContextManager,
7
_NoParamDecoratorContextManager,
16
"set_multithreading_enabled",
20
class no_grad(_NoParamDecoratorContextManager):
21
r"""Context-manager that disables gradient calculation.
23
Disabling gradient calculation is useful for inference, when you are sure
24
that you will not call :meth:`Tensor.backward()`. It will reduce memory
25
consumption for computations that would otherwise have `requires_grad=True`.
27
In this mode, the result of every computation will have
28
`requires_grad=False`, even when the inputs have `requires_grad=True`.
29
There is an exception! All factory functions, or functions that create
30
a new Tensor and take a requires_grad kwarg, will NOT be affected by
33
This context manager is thread local; it will not affect computation
36
Also functions as a decorator.
39
No-grad is one of several mechanisms that can enable or
40
disable gradients locally see :ref:`locally-disable-grad-doc` for
41
more information on how they compare.
44
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
45
If you want to disable forward AD for a computation, you can unpack
50
>>> x = torch.tensor([1.], requires_grad=True)
51
>>> with torch.no_grad():
67
>>> # factory function exception
68
>>> with torch.no_grad():
69
... a = torch.nn.Parameter(torch.rand(10))
74
def __init__(self) -> None:
75
if not torch._jit_internal.is_scripting():
79
def __enter__(self) -> None:
80
self.prev = torch.is_grad_enabled()
81
torch.set_grad_enabled(False)
83
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
84
torch.set_grad_enabled(self.prev)
87
class enable_grad(_NoParamDecoratorContextManager):
88
r"""Context-manager that enables gradient calculation.
90
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
91
or :class:`~set_grad_enabled`.
93
This context manager is thread local; it will not affect computation
96
Also functions as a decorator.
99
enable_grad is one of several mechanisms that can enable or
100
disable gradients locally see :ref:`locally-disable-grad-doc` for
101
more information on how they compare.
104
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
107
>>> # xdoctest: +SKIP
108
>>> x = torch.tensor([1.], requires_grad=True)
109
>>> with torch.no_grad():
110
... with torch.enable_grad():
117
>>> @torch.enable_grad()
120
>>> with torch.no_grad():
124
>>> @torch.enable_grad
127
>>> with torch.no_grad():
134
def __enter__(self) -> None:
135
self.prev = torch.is_grad_enabled()
136
torch._C._set_grad_enabled(True)
138
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
139
torch._C._set_grad_enabled(self.prev)
142
class set_grad_enabled(_DecoratorContextManager):
143
r"""Context-manager that sets gradient calculation on or off.
145
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
146
It can be used as a context-manager or as a function.
148
This context manager is thread local; it will not affect computation
152
mode (bool): Flag whether to enable grad (``True``), or disable
153
(``False``). This can be used to conditionally enable
157
set_grad_enabled is one of several mechanisms that can enable or
158
disable gradients locally see :ref:`locally-disable-grad-doc` for
159
more information on how they compare.
162
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
165
>>> # xdoctest: +SKIP
166
>>> x = torch.tensor([1.], requires_grad=True)
168
>>> with torch.set_grad_enabled(is_train):
172
>>> _ = torch.set_grad_enabled(True)
176
>>> _ = torch.set_grad_enabled(False)
183
def __init__(self, mode: bool) -> None:
184
self.prev = torch.is_grad_enabled()
186
torch._C._set_grad_enabled(mode)
188
def __call__(self, orig_func: F) -> F:
189
torch._C._set_grad_enabled(self.prev)
190
return super().__call__(orig_func)
192
def __enter__(self) -> None:
193
torch._C._set_grad_enabled(self.mode)
195
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
196
torch._C._set_grad_enabled(self.prev)
198
def clone(self) -> "set_grad_enabled":
200
Create a copy of this class
202
return self.__class__(self.mode)
205
class inference_mode(_DecoratorContextManager):
206
r"""Context-manager that enables or disables inference mode.
208
InferenceMode is a new context manager analogous to :class:`~no_grad`
209
to be used when you are certain your operations will have no interactions
210
with autograd (e.g., model training). Code run under this mode gets better
211
performance by disabling view tracking and version counter bumps. Note that
212
unlike some other mechanisms that locally enable or disable grad,
213
entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
215
This context manager is thread local; it will not affect computation
218
Also functions as a decorator.
221
Inference mode is one of several mechanisms that can enable or
222
disable gradients locally see :ref:`locally-disable-grad-doc` for
223
more information on how they compare.
226
mode (bool or function): Either a boolean flag whether to enable or
227
disable inference mode or a Python function to decorate with
228
inference mode enabled
231
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
233
>>> x = torch.ones(1, 2, 3, requires_grad=True)
234
>>> with torch.inference_mode():
238
>>> # xdoctest: +SKIP("want string isnt quite right")
240
Traceback (most recent call last):
241
File "<stdin>", line 1, in <module>
242
RuntimeError: Inference tensors do not track version counter.
243
>>> @torch.inference_mode()
247
>>> out.requires_grad
249
>>> @torch.inference_mode
253
>>> out.requires_grad
258
def __init__(self, mode: bool = True) -> None:
259
if not torch._jit_internal.is_scripting():
263
def __new__(cls, mode=True):
264
if isinstance(mode, bool):
265
return super().__new__(cls)
268
def __enter__(self) -> None:
269
self._inference_mode_context = torch._C._InferenceMode(self.mode)
270
self._inference_mode_context.__enter__()
272
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
273
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
275
def clone(self) -> "inference_mode":
277
Create a copy of this class
279
return self.__class__(self.mode)
282
def _enter_inference_mode(mode):
283
mode_context = torch._C._InferenceMode(mode)
284
mode_context.__enter__()
288
def _exit_inference_mode(mode):
289
mode.__exit__(None, None, None)
292
class set_multithreading_enabled(_DecoratorContextManager):
293
r"""Context-manager that sets multithreaded backwards on or off.
295
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
296
It can be used as a context-manager or as a function.
298
This context manager is thread local; it will not affect computation
302
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
306
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
310
def __init__(self, mode: bool) -> None:
311
self.prev = torch._C._is_multithreading_enabled()
312
torch._C._set_multithreading_enabled(mode)
315
def __enter__(self) -> None:
318
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
319
torch._C._set_multithreading_enabled(self.prev)
321
def clone(self) -> "set_multithreading_enabled":
323
Create a copy of this class
325
return self.__class__(self.mode)
328
class _force_original_view_tracking(_DecoratorContextManager):
329
r"""Context-manager that sets whether or not to always enable view-replay in autograd.
331
``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
332
It can be used as a context-manager or as a function.
334
This context manager is thread local; it will not affect computation
337
When a tensor view is mutated, the autograd engine needs to decide whether or not
338
to regenerate the "updated view" by either replaying the chain of views from the updated base,
339
or with a single call to as_strided.
341
If set_view_replay_enabled is set to True, then autograd will always use view replay.
342
Otherwise, it will fall back to its existing logic.
345
mode (bool): Flag whether to enable view-replay (``True``), or disable
350
def __init__(self, mode: bool) -> None:
351
self.prev = torch._C._is_view_replay_enabled()
352
torch._C._set_view_replay_enabled(mode)
355
def __enter__(self) -> None:
358
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
359
torch._C._set_view_replay_enabled(self.prev)
362
return self.__class__(self.mode)
365
class _unsafe_preserve_version_counter(_DecoratorContextManager):
366
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
368
This context manager can lead to arbitrary silent-correctness issues in any other part of your code
369
(even the ones not touched directly by the context manager)!
371
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
372
This is generally important for correctness, as for example, mutating a tensor that autograd has saved
373
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
374
and error out in this situation.
376
However, there are rare instances where it might be useful to hide mutations from autograd. For example:
377
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
378
the tensor right before it is needed by autograd.
381
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
384
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
388
def __init__(self, tensor: torch.Tensor) -> None:
390
self.prev_version = tensor._version
392
def __enter__(self) -> None:
395
def __exit__(self, *args) -> None:
396
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)