3
from typing import Union, Iterable, List, Dict, Tuple, Optional, cast
6
from torch import Tensor
7
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support, _device_has_foreach_support
9
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
11
__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_']
15
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
16
clip_grad_norm_ and clip_grad_value_ themselves.
18
def _no_grad_wrapper(*args, **kwargs):
20
return func(*args, **kwargs)
21
functools.update_wrapper(_no_grad_wrapper, func)
22
return _no_grad_wrapper
26
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
27
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
28
r"""Clip the gradient norm of an iterable of parameters.
30
The norm is computed over all gradients together, as if they were
31
concatenated into a single vector. Gradients are modified in-place.
34
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
35
single Tensor that will have gradients normalized
36
max_norm (float): max norm of the gradients
37
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
39
error_if_nonfinite (bool): if True, an error is thrown if the total
40
norm of the gradients from :attr:`parameters` is ``nan``,
41
``inf``, or ``-inf``. Default: False (will switch to True in the future)
42
foreach (bool): use the faster foreach-based implementation.
43
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
44
fall back to the slow implementation for other device types.
48
Total norm of the parameter gradients (viewed as a single vector).
50
if isinstance(parameters, torch.Tensor):
51
parameters = [parameters]
52
grads = [p.grad for p in parameters if p.grad is not None]
53
max_norm = float(max_norm)
54
norm_type = float(norm_type)
56
return torch.tensor(0.)
57
first_device = grads[0].device
58
grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \
59
= _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]
61
norms: List[Tensor] = []
62
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
64
(foreach is None and _has_foreach_support(device_grads, device))
65
or (foreach and _device_has_foreach_support(device))
67
norms.extend(torch._foreach_norm(device_grads, norm_type))
69
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
71
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
73
total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
75
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
77
f'The total norm of order {norm_type} for gradients from '
78
'`parameters` is non-finite, so it cannot be clipped. To disable '
79
'this error and scale the gradients by the non-finite norm anyway, '
80
'set `error_if_nonfinite=False`')
81
clip_coef = max_norm / (total_norm + 1e-6)
82
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
83
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
84
# when the gradients do not reside in CPU memory.
85
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
86
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
88
(foreach is None and _has_foreach_support(device_grads, device))
89
or (foreach and _device_has_foreach_support(device))
91
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
93
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
95
clip_coef_clamped_device = clip_coef_clamped.to(device)
96
for g in device_grads:
97
g.mul_(clip_coef_clamped_device)
103
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.,
104
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
105
r"""Clip the gradient norm of an iterable of parameters.
108
This method is now deprecated in favor of
109
:func:`torch.nn.utils.clip_grad_norm_`.
111
warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
112
"of torch.nn.utils.clip_grad_norm_.", stacklevel=2)
113
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
117
def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None:
118
r"""Clip the gradients of an iterable of parameters at specified value.
120
Gradients are modified in-place.
123
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
124
single Tensor that will have gradients normalized
125
clip_value (float): maximum allowed value of the gradients.
126
The gradients are clipped in the range
127
:math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
128
foreach (bool): use the faster foreach-based implementation
129
If ``None``, use the foreach implementation for CUDA and CPU native tensors and
130
silently fall back to the slow implementation for other device types.
133
if isinstance(parameters, torch.Tensor):
134
parameters = [parameters]
135
clip_value = float(clip_value)
137
grads = [p.grad for p in parameters if p.grad is not None]
138
grouped_grads = _group_tensors_by_device_and_dtype([grads])
140
for ((device, _), ([grads], _)) in grouped_grads.items(): # type: ignore[assignment]
142
(foreach is None and _has_foreach_support(cast(List[Tensor], grads), device=device))
143
or (foreach and _device_has_foreach_support(device))
145
torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value)
146
torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value)
148
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
151
cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)