pytorch

Форк
0
/
sparse_adam.py 
161 строка · 7.4 Кб
1
import torch
2
from . import _functional as F
3
from .optimizer import Optimizer, _maximize_doc
4

5
__all__ = ['SparseAdam']
6

7
class SparseAdam(Optimizer):
8
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, maximize: bool = False):
9
        if not 0.0 < lr:
10
            raise ValueError(f"Invalid learning rate: {lr}")
11
        if not 0.0 < eps:
12
            raise ValueError(f"Invalid epsilon value: {eps}")
13
        if not 0.0 <= betas[0] < 1.0:
14
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
15
        if not 0.0 <= betas[1] < 1.0:
16
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
17

18
        defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
19
        super().__init__(params, defaults)
20

21
        sparse_params = []
22
        complex_params = []
23
        for index, param_group in enumerate(self.param_groups):
24
            assert isinstance(param_group, dict), f"param_groups must be a list of dicts, but got {type(param_group)}"
25
            # given param group, convert given params to a list first before iterating
26
            for d_index, d_param in enumerate(param_group['params']):
27
                if d_param.is_sparse:
28
                    sparse_params.append([index, d_index])
29
                if d_param.is_complex():
30
                    complex_params.append([index, d_index])
31
        if sparse_params:
32
            raise ValueError(
33
                f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors"
34
            )
35
        if complex_params:
36
            raise ValueError(
37
                f"Complex params at indices {complex_params}: SparseAdam does not support complex parameters"
38
            )
39

40

41
    @torch.no_grad()
42
    def step(self, closure=None):
43
        """Perform a single optimization step.
44

45
        Args:
46
            closure (Callable, optional): A closure that reevaluates the model
47
                and returns the loss.
48
        """
49
        loss = None
50
        if closure is not None:
51
            with torch.enable_grad():
52
                loss = closure()
53

54
        for group in self.param_groups:
55
            params_with_grad = []
56
            grads = []
57
            exp_avgs = []
58
            exp_avg_sqs = []
59
            state_steps = []
60
            eps = group['eps']
61
            lr = group['lr']
62
            beta1, beta2 = group['betas']
63
            maximize = group.get('maximize', False)
64

65
            for p in group['params']:
66
                if p.grad is not None:
67
                    params_with_grad.append(p)
68
                    if not p.grad.is_sparse:
69
                        raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
70
                    grads.append(p.grad)
71

72
                    state = self.state[p]
73

74
                    # State initialization
75
                    if len(state) == 0:
76
                        state['step'] = 0
77
                        # Exponential moving average of gradient values
78
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
79
                        # Exponential moving average of squared gradient values
80
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
81

82
                    exp_avgs.append(state['exp_avg'])
83
                    exp_avg_sqs.append(state['exp_avg_sq'])
84

85
                    # update the steps for each param group update
86
                    state['step'] += 1
87
                    # record the step after step update
88
                    state_steps.append(state['step'])
89

90
            F.sparse_adam(params_with_grad,
91
                          grads,
92
                          exp_avgs,
93
                          exp_avg_sqs,
94
                          state_steps,
95
                          beta1=beta1,
96
                          beta2=beta2,
97
                          lr=group['lr'],
98
                          eps=group['eps'],
99
                          maximize=maximize)
100

101
        return loss
102

103
SparseAdam.__doc__ = fr"""SparseAdam implements a masked version of the Adam algorithm
104
    suitable for sparse gradients. Currently, due to implementation constraints (explained
105
    below), SparseAdam is only intended for a narrow subset of use cases, specifically
106
    parameters of a dense layout with gradients of a sparse layout. This occurs in a
107
    special case where the module backwards produces grads already in a sparse layout.
108
    One example NN module that behaves as such is ``nn.Embedding(sparse=True)``.
109

110
    SparseAdam approximates the Adam algorithm by masking out the parameter and moment
111
    updates corresponding to the zero values in the gradients. Whereas the Adam algorithm
112
    will update the first moment, the second moment, and the parameters based on all values
113
    of the gradients, SparseAdam only updates the moments and parameters corresponding
114
    to the non-zero values of the gradients.
115

116
    A simplified way of thinking about the `intended` implementation is as such:
117

118
    1. Create a mask of the non-zero values in the sparse gradients. For example,
119
       if your gradient looks like [0, 5, 0, 0, 9], the mask would be [0, 1, 0, 0, 1].
120
    2. Apply this mask over the running moments and do computation on only the
121
       non-zero values.
122
    3. Apply this mask over the parameters and only apply an update on non-zero values.
123

124
    In actuality, we use sparse layout Tensors to optimize this approximation, which means the
125
    more gradients that are masked by not being materialized, the more performant the optimization.
126
    Since we rely on using sparse layout tensors, we infer that any materialized value in the
127
    sparse layout is non-zero and we do NOT actually verify that all values are not zero!
128
    It is important to not conflate a semantically sparse tensor (a tensor where many
129
    of its values are zeros) with a sparse layout tensor (a tensor where ``.is_sparse``
130
    returns ``True``). The SparseAdam approximation is intended for `semantically` sparse
131
    tensors and the sparse layout is only a implementation detail. A clearer implementation
132
    would be to use MaskedTensors, but those are experimental.
133

134

135
    .. note::
136

137
        If you suspect your gradients are semantically sparse (but do not have sparse
138
        layout), this variant may not be the best for you. Ideally, you want to avoid
139
        materializing anything that is suspected to be sparse in the first place, since
140
        needing to convert all your grads from dense layout to sparse layout may outweigh
141
        the performance gain. Here, using Adam may be the best alternative, unless you
142
        can easily rig up your module to output sparse grads similar to
143
        ``nn.Embedding(sparse=True)``. If you insist on converting your grads, you can do
144
        so by manually overriding your parameters' ``.grad`` fields with their sparse
145
        equivalents before calling ``.step()``.
146

147

148
    Args:
149
        params (iterable): iterable of parameters to optimize or dicts defining
150
            parameter groups
151
        lr (float, optional): learning rate (default: 1e-3)
152
        betas (Tuple[float, float], optional): coefficients used for computing
153
            running averages of gradient and its square (default: (0.9, 0.999))
154
        eps (float, optional): term added to the denominator to improve
155
            numerical stability (default: 1e-8)
156
        {_maximize_doc}
157

158
    .. _Adam\: A Method for Stochastic Optimization:
159
        https://arxiv.org/abs/1412.6980
160

161
    """
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.