2
from . import _functional as F
3
from .optimizer import Optimizer, _maximize_doc
5
__all__ = ['SparseAdam']
7
class SparseAdam(Optimizer):
8
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, maximize: bool = False):
10
raise ValueError(f"Invalid learning rate: {lr}")
12
raise ValueError(f"Invalid epsilon value: {eps}")
13
if not 0.0 <= betas[0] < 1.0:
14
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
15
if not 0.0 <= betas[1] < 1.0:
16
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
18
defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
19
super().__init__(params, defaults)
23
for index, param_group in enumerate(self.param_groups):
24
assert isinstance(param_group, dict), f"param_groups must be a list of dicts, but got {type(param_group)}"
26
for d_index, d_param in enumerate(param_group['params']):
28
sparse_params.append([index, d_index])
29
if d_param.is_complex():
30
complex_params.append([index, d_index])
33
f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors"
37
f"Complex params at indices {complex_params}: SparseAdam does not support complex parameters"
42
def step(self, closure=None):
43
"""Perform a single optimization step.
46
closure (Callable, optional): A closure that reevaluates the model
50
if closure is not None:
51
with torch.enable_grad():
54
for group in self.param_groups:
62
beta1, beta2 = group['betas']
63
maximize = group.get('maximize', False)
65
for p in group['params']:
66
if p.grad is not None:
67
params_with_grad.append(p)
68
if not p.grad.is_sparse:
69
raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
78
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
80
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
82
exp_avgs.append(state['exp_avg'])
83
exp_avg_sqs.append(state['exp_avg_sq'])
88
state_steps.append(state['step'])
90
F.sparse_adam(params_with_grad,
103
SparseAdam.__doc__ = fr"""SparseAdam implements a masked version of the Adam algorithm
104
suitable for sparse gradients. Currently, due to implementation constraints (explained
105
below), SparseAdam is only intended for a narrow subset of use cases, specifically
106
parameters of a dense layout with gradients of a sparse layout. This occurs in a
107
special case where the module backwards produces grads already in a sparse layout.
108
One example NN module that behaves as such is ``nn.Embedding(sparse=True)``.
110
SparseAdam approximates the Adam algorithm by masking out the parameter and moment
111
updates corresponding to the zero values in the gradients. Whereas the Adam algorithm
112
will update the first moment, the second moment, and the parameters based on all values
113
of the gradients, SparseAdam only updates the moments and parameters corresponding
114
to the non-zero values of the gradients.
116
A simplified way of thinking about the `intended` implementation is as such:
118
1. Create a mask of the non-zero values in the sparse gradients. For example,
119
if your gradient looks like [0, 5, 0, 0, 9], the mask would be [0, 1, 0, 0, 1].
120
2. Apply this mask over the running moments and do computation on only the
122
3. Apply this mask over the parameters and only apply an update on non-zero values.
124
In actuality, we use sparse layout Tensors to optimize this approximation, which means the
125
more gradients that are masked by not being materialized, the more performant the optimization.
126
Since we rely on using sparse layout tensors, we infer that any materialized value in the
127
sparse layout is non-zero and we do NOT actually verify that all values are not zero!
128
It is important to not conflate a semantically sparse tensor (a tensor where many
129
of its values are zeros) with a sparse layout tensor (a tensor where ``.is_sparse``
130
returns ``True``). The SparseAdam approximation is intended for `semantically` sparse
131
tensors and the sparse layout is only a implementation detail. A clearer implementation
132
would be to use MaskedTensors, but those are experimental.
137
If you suspect your gradients are semantically sparse (but do not have sparse
138
layout), this variant may not be the best for you. Ideally, you want to avoid
139
materializing anything that is suspected to be sparse in the first place, since
140
needing to convert all your grads from dense layout to sparse layout may outweigh
141
the performance gain. Here, using Adam may be the best alternative, unless you
142
can easily rig up your module to output sparse grads similar to
143
``nn.Embedding(sparse=True)``. If you insist on converting your grads, you can do
144
so by manually overriding your parameters' ``.grad`` fields with their sparse
145
equivalents before calling ``.step()``.
149
params (iterable): iterable of parameters to optimize or dicts defining
151
lr (float, optional): learning rate (default: 1e-3)
152
betas (Tuple[float, float], optional): coefficients used for computing
153
running averages of gradient and its square (default: (0.9, 0.999))
154
eps (float, optional): term added to the denominator to improve
155
numerical stability (default: 1e-8)
158
.. _Adam\: A Method for Stochastic Optimization:
159
https://arxiv.org/abs/1412.6980