pytorch

Форк
0
/
clip_grad.py 
151 строка · 7.0 Кб
1
import warnings
2
import functools
3
from typing import Union, Iterable, List, Dict, Tuple, Optional, cast
4

5
import torch
6
from torch import Tensor
7
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support, _device_has_foreach_support
8

9
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
10

11
__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_']
12

13
def _no_grad(func):
14
    """
15
    This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
16
    clip_grad_norm_ and clip_grad_value_ themselves.
17
    """
18
    def _no_grad_wrapper(*args, **kwargs):
19
        with torch.no_grad():
20
            return func(*args, **kwargs)
21
    functools.update_wrapper(_no_grad_wrapper, func)
22
    return _no_grad_wrapper
23

24
@_no_grad
25
def clip_grad_norm_(
26
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
27
        error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
28
    r"""Clip the gradient norm of an iterable of parameters.
29

30
    The norm is computed over all gradients together, as if they were
31
    concatenated into a single vector. Gradients are modified in-place.
32

33
    Args:
34
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
35
            single Tensor that will have gradients normalized
36
        max_norm (float): max norm of the gradients
37
        norm_type (float): type of the used p-norm. Can be ``'inf'`` for
38
            infinity norm.
39
        error_if_nonfinite (bool): if True, an error is thrown if the total
40
            norm of the gradients from :attr:`parameters` is ``nan``,
41
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)
42
        foreach (bool): use the faster foreach-based implementation.
43
            If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
44
            fall back to the slow implementation for other device types.
45
            Default: ``None``
46

47
    Returns:
48
        Total norm of the parameter gradients (viewed as a single vector).
49
    """
50
    if isinstance(parameters, torch.Tensor):
51
        parameters = [parameters]
52
    grads = [p.grad for p in parameters if p.grad is not None]
53
    max_norm = float(max_norm)
54
    norm_type = float(norm_type)
55
    if len(grads) == 0:
56
        return torch.tensor(0.)
57
    first_device = grads[0].device
58
    grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \
59
        = _group_tensors_by_device_and_dtype([grads])  # type: ignore[assignment]
60

61
    norms: List[Tensor] = []
62
    for ((device, _), ([device_grads], _)) in grouped_grads.items():  # type: ignore[assignment]
63
        if (
64
            (foreach is None and _has_foreach_support(device_grads, device))
65
            or (foreach and _device_has_foreach_support(device))
66
        ):
67
            norms.extend(torch._foreach_norm(device_grads, norm_type))
68
        elif foreach:
69
            raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
70
        else:
71
            norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
72

73
    total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
74

75
    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
76
        raise RuntimeError(
77
            f'The total norm of order {norm_type} for gradients from '
78
            '`parameters` is non-finite, so it cannot be clipped. To disable '
79
            'this error and scale the gradients by the non-finite norm anyway, '
80
            'set `error_if_nonfinite=False`')
81
    clip_coef = max_norm / (total_norm + 1e-6)
82
    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
83
    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
84
    # when the gradients do not reside in CPU memory.
85
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
86
    for ((device, _), ([device_grads], _)) in grouped_grads.items():  # type: ignore[assignment]
87
        if (
88
            (foreach is None and _has_foreach_support(device_grads, device))
89
            or (foreach and _device_has_foreach_support(device))
90
        ):
91
            torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
92
        elif foreach:
93
            raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
94
        else:
95
            clip_coef_clamped_device = clip_coef_clamped.to(device)
96
            for g in device_grads:
97
                g.mul_(clip_coef_clamped_device)
98

99
    return total_norm
100

101

102
def clip_grad_norm(
103
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.,
104
        error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
105
    r"""Clip the gradient norm of an iterable of parameters.
106

107
    .. warning::
108
        This method is now deprecated in favor of
109
        :func:`torch.nn.utils.clip_grad_norm_`.
110
    """
111
    warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
112
                  "of torch.nn.utils.clip_grad_norm_.", stacklevel=2)
113
    return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
114

115

116
@_no_grad
117
def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None:
118
    r"""Clip the gradients of an iterable of parameters at specified value.
119

120
    Gradients are modified in-place.
121

122
    Args:
123
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
124
            single Tensor that will have gradients normalized
125
        clip_value (float): maximum allowed value of the gradients.
126
            The gradients are clipped in the range
127
            :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
128
        foreach (bool): use the faster foreach-based implementation
129
            If ``None``, use the foreach implementation for CUDA and CPU native tensors and
130
            silently fall back to the slow implementation for other device types.
131
            Default: ``None``
132
    """
133
    if isinstance(parameters, torch.Tensor):
134
        parameters = [parameters]
135
    clip_value = float(clip_value)
136

137
    grads = [p.grad for p in parameters if p.grad is not None]
138
    grouped_grads = _group_tensors_by_device_and_dtype([grads])
139

140
    for ((device, _), ([grads], _)) in grouped_grads.items():  # type: ignore[assignment]
141
        if (
142
            (foreach is None and _has_foreach_support(cast(List[Tensor], grads), device=device))
143
            or (foreach and _device_has_foreach_support(device))
144
        ):
145
            torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value)
146
            torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value)
147
        elif foreach:
148
            raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
149
        else:
150
            for grad in grads:
151
                cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)
152

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.