pytorch

Форк
0
/
grad_mode.py 
396 строк · 12.7 Кб
1
from typing import Any
2

3
import torch
4

5
from torch.utils._contextlib import (
6
    _DecoratorContextManager,
7
    _NoParamDecoratorContextManager,
8
    F,
9
)
10

11
__all__ = [
12
    "no_grad",
13
    "enable_grad",
14
    "set_grad_enabled",
15
    "inference_mode",
16
    "set_multithreading_enabled",
17
]
18

19

20
class no_grad(_NoParamDecoratorContextManager):
21
    r"""Context-manager that disables gradient calculation.
22

23
    Disabling gradient calculation is useful for inference, when you are sure
24
    that you will not call :meth:`Tensor.backward()`. It will reduce memory
25
    consumption for computations that would otherwise have `requires_grad=True`.
26

27
    In this mode, the result of every computation will have
28
    `requires_grad=False`, even when the inputs have `requires_grad=True`.
29
    There is an exception! All factory functions, or functions that create
30
    a new Tensor and take a requires_grad kwarg, will NOT be affected by
31
    this mode.
32

33
    This context manager is thread local; it will not affect computation
34
    in other threads.
35

36
    Also functions as a decorator.
37

38
    .. note::
39
        No-grad is one of several mechanisms that can enable or
40
        disable gradients locally see :ref:`locally-disable-grad-doc` for
41
        more information on how they compare.
42

43
    .. note::
44
        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
45
        If you want to disable forward AD for a computation, you can unpack
46
        your dual tensors.
47

48
    Example::
49
        >>> # xdoctest: +SKIP
50
        >>> x = torch.tensor([1.], requires_grad=True)
51
        >>> with torch.no_grad():
52
        ...     y = x * 2
53
        >>> y.requires_grad
54
        False
55
        >>> @torch.no_grad()
56
        ... def doubler(x):
57
        ...     return x * 2
58
        >>> z = doubler(x)
59
        >>> z.requires_grad
60
        False
61
        >>> @torch.no_grad
62
        ... def tripler(x):
63
        ...     return x * 3
64
        >>> z = tripler(x)
65
        >>> z.requires_grad
66
        False
67
        >>> # factory function exception
68
        >>> with torch.no_grad():
69
        ...     a = torch.nn.Parameter(torch.rand(10))
70
        >>> a.requires_grad
71
        True
72
    """
73

74
    def __init__(self) -> None:
75
        if not torch._jit_internal.is_scripting():
76
            super().__init__()
77
        self.prev = False
78

79
    def __enter__(self) -> None:
80
        self.prev = torch.is_grad_enabled()
81
        torch.set_grad_enabled(False)
82

83
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
84
        torch.set_grad_enabled(self.prev)
85

86

87
class enable_grad(_NoParamDecoratorContextManager):
88
    r"""Context-manager that enables gradient calculation.
89

90
    Enables gradient calculation, if it has been disabled via :class:`~no_grad`
91
    or :class:`~set_grad_enabled`.
92

93
    This context manager is thread local; it will not affect computation
94
    in other threads.
95

96
    Also functions as a decorator.
97

98
    .. note::
99
        enable_grad is one of several mechanisms that can enable or
100
        disable gradients locally see :ref:`locally-disable-grad-doc` for
101
        more information on how they compare.
102

103
    .. note::
104
        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
105

106
    Example::
107
        >>> # xdoctest: +SKIP
108
        >>> x = torch.tensor([1.], requires_grad=True)
109
        >>> with torch.no_grad():
110
        ...     with torch.enable_grad():
111
        ...         y = x * 2
112
        >>> y.requires_grad
113
        True
114
        >>> y.backward()
115
        >>> x.grad
116
        tensor([2.])
117
        >>> @torch.enable_grad()
118
        ... def doubler(x):
119
        ...     return x * 2
120
        >>> with torch.no_grad():
121
        ...     z = doubler(x)
122
        >>> z.requires_grad
123
        True
124
        >>> @torch.enable_grad
125
        ... def tripler(x):
126
        ...     return x * 3
127
        >>> with torch.no_grad():
128
        ...     z = tripler(x)
129
        >>> z.requires_grad
130
        True
131

132
    """
133

134
    def __enter__(self) -> None:
135
        self.prev = torch.is_grad_enabled()
136
        torch._C._set_grad_enabled(True)
137

138
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
139
        torch._C._set_grad_enabled(self.prev)
140

141

142
class set_grad_enabled(_DecoratorContextManager):
143
    r"""Context-manager that sets gradient calculation on or off.
144

145
    ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
146
    It can be used as a context-manager or as a function.
147

148
    This context manager is thread local; it will not affect computation
149
    in other threads.
150

151
    Args:
152
        mode (bool): Flag whether to enable grad (``True``), or disable
153
                     (``False``). This can be used to conditionally enable
154
                     gradients.
155

156
    .. note::
157
        set_grad_enabled is one of several mechanisms that can enable or
158
        disable gradients locally see :ref:`locally-disable-grad-doc` for
159
        more information on how they compare.
160

161
    .. note::
162
        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
163

164
    Example::
165
        >>> # xdoctest: +SKIP
166
        >>> x = torch.tensor([1.], requires_grad=True)
167
        >>> is_train = False
168
        >>> with torch.set_grad_enabled(is_train):
169
        ...     y = x * 2
170
        >>> y.requires_grad
171
        False
172
        >>> _ = torch.set_grad_enabled(True)
173
        >>> y = x * 2
174
        >>> y.requires_grad
175
        True
176
        >>> _ = torch.set_grad_enabled(False)
177
        >>> y = x * 2
178
        >>> y.requires_grad
179
        False
180

181
    """
182

183
    def __init__(self, mode: bool) -> None:
184
        self.prev = torch.is_grad_enabled()
185
        self.mode = mode
186
        torch._C._set_grad_enabled(mode)
187

188
    def __call__(self, orig_func: F) -> F:
189
        torch._C._set_grad_enabled(self.prev)
190
        return super().__call__(orig_func)
191

192
    def __enter__(self) -> None:
193
        torch._C._set_grad_enabled(self.mode)
194

195
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
196
        torch._C._set_grad_enabled(self.prev)
197

198
    def clone(self) -> "set_grad_enabled":
199
        r"""
200
        Create a copy of this class
201
        """
202
        return self.__class__(self.mode)
203

204

205
class inference_mode(_DecoratorContextManager):
206
    r"""Context-manager that enables or disables inference mode.
207

208
    InferenceMode is a new context manager analogous to :class:`~no_grad`
209
    to be used when you are certain your operations will have no interactions
210
    with autograd (e.g., model training). Code run under this mode gets better
211
    performance by disabling view tracking and version counter bumps. Note that
212
    unlike some other mechanisms that locally enable or disable grad,
213
    entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
214

215
    This context manager is thread local; it will not affect computation
216
    in other threads.
217

218
    Also functions as a decorator.
219

220
    .. note::
221
        Inference mode is one of several mechanisms that can enable or
222
        disable gradients locally see :ref:`locally-disable-grad-doc` for
223
        more information on how they compare.
224

225
    Args:
226
        mode (bool or function): Either a boolean flag whether to enable or
227
            disable inference mode or a Python function to decorate with
228
            inference mode enabled
229

230
    Example::
231
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
232
        >>> import torch
233
        >>> x = torch.ones(1, 2, 3, requires_grad=True)
234
        >>> with torch.inference_mode():
235
        ...     y = x * x
236
        >>> y.requires_grad
237
        False
238
        >>> # xdoctest: +SKIP("want string isnt quite right")
239
        >>> y._version
240
        Traceback (most recent call last):
241
        File "<stdin>", line 1, in <module>
242
        RuntimeError: Inference tensors do not track version counter.
243
        >>> @torch.inference_mode()
244
        ... def func(x):
245
        ...     return x * x
246
        >>> out = func(x)
247
        >>> out.requires_grad
248
        False
249
        >>> @torch.inference_mode
250
        ... def doubler(x):
251
        ...     return x * 2
252
        >>> out = doubler(x)
253
        >>> out.requires_grad
254
        False
255

256
    """
257

258
    def __init__(self, mode: bool = True) -> None:
259
        if not torch._jit_internal.is_scripting():
260
            super().__init__()
261
        self.mode = mode
262

263
    def __new__(cls, mode=True):
264
        if isinstance(mode, bool):
265
            return super().__new__(cls)
266
        return cls()(mode)
267

268
    def __enter__(self) -> None:
269
        self._inference_mode_context = torch._C._InferenceMode(self.mode)
270
        self._inference_mode_context.__enter__()
271

272
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
273
        self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
274

275
    def clone(self) -> "inference_mode":
276
        r"""
277
        Create a copy of this class
278
        """
279
        return self.__class__(self.mode)
280

281

282
def _enter_inference_mode(mode):
283
    mode_context = torch._C._InferenceMode(mode)
284
    mode_context.__enter__()
285
    return mode_context
286

287

288
def _exit_inference_mode(mode):
289
    mode.__exit__(None, None, None)
290

291

292
class set_multithreading_enabled(_DecoratorContextManager):
293
    r"""Context-manager that sets multithreaded backwards on or off.
294

295
    ``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
296
    It can be used as a context-manager or as a function.
297

298
    This context manager is thread local; it will not affect computation
299
    in other threads.
300

301
    Args:
302
        mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
303
                     (``False``).
304

305
    .. note::
306
        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
307

308
    """
309

310
    def __init__(self, mode: bool) -> None:
311
        self.prev = torch._C._is_multithreading_enabled()
312
        torch._C._set_multithreading_enabled(mode)
313
        self.mode = mode
314

315
    def __enter__(self) -> None:
316
        pass
317

318
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
319
        torch._C._set_multithreading_enabled(self.prev)
320

321
    def clone(self) -> "set_multithreading_enabled":
322
        r"""
323
        Create a copy of this class
324
        """
325
        return self.__class__(self.mode)
326

327

328
class _force_original_view_tracking(_DecoratorContextManager):
329
    r"""Context-manager that sets whether or not to always enable view-replay in autograd.
330

331
    ``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
332
    It can be used as a context-manager or as a function.
333

334
    This context manager is thread local; it will not affect computation
335
    in other threads.
336

337
    When a tensor view is mutated, the autograd engine needs to decide whether or not
338
    to regenerate the "updated view" by either replaying the chain of views from the updated base,
339
    or with a single call to as_strided.
340

341
    If set_view_replay_enabled is set to True, then autograd will always use view replay.
342
    Otherwise, it will fall back to its existing logic.
343

344
    Args:
345
        mode (bool): Flag whether to enable view-replay (``True``), or disable
346
                     (``False``).
347

348
    """
349

350
    def __init__(self, mode: bool) -> None:
351
        self.prev = torch._C._is_view_replay_enabled()
352
        torch._C._set_view_replay_enabled(mode)
353
        self.mode = mode
354

355
    def __enter__(self) -> None:
356
        pass
357

358
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
359
        torch._C._set_view_replay_enabled(self.prev)
360

361
    def clone(self):
362
        return self.__class__(self.mode)
363

364

365
class _unsafe_preserve_version_counter(_DecoratorContextManager):
366
    r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
367

368
    This context manager can lead to arbitrary silent-correctness issues in any other part of your code
369
    (even the ones not touched directly by the context manager)!
370

371
    Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
372
    This is generally important for correctness, as for example, mutating a tensor that autograd has saved
373
    for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
374
    and error out in this situation.
375

376
    However, there are rare instances where it might be useful to hide mutations from autograd. For example:
377
    if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
378
    the tensor right before it is needed by autograd.
379

380
    Args:
381
        tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
382

383
    .. note::
384
        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
385

386
    """
387

388
    def __init__(self, tensor: torch.Tensor) -> None:
389
        self.tensor = tensor
390
        self.prev_version = tensor._version
391

392
    def __enter__(self) -> None:
393
        pass
394

395
    def __exit__(self, *args) -> None:
396
        torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
397

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.