pytorch

Форк
0
/
test_optim.py 
338 строк · 11.5 Кб
1
# Owner(s): ["module: optimizer"]
2

3
import torch
4
from torch.optim import (
5
    Adadelta,
6
    Adagrad,
7
    Adam,
8
    Adamax,
9
    AdamW,
10
    ASGD,
11
    NAdam,
12
    RAdam,
13
    RMSprop,
14
    Rprop,
15
    SGD,
16
)
17
from torch.testing._internal.common_utils import (
18
    gradcheck,
19
    load_tests,
20
    skipIfTorchDynamo,
21
    TestCase,
22
)
23

24

25
# load_tests from common_utils is used to automatically filter tests for
26
# sharding on sandcastle. This line silences flake warnings
27
load_tests = load_tests
28

29

30
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
31
    # Ignored is the list of values in `opt_differentiable_state`, we do this
32
    # for `gradcheck` to correctly track the state tensors as function inputs
33
    # because otherwise it can't unpack the values in the `opt_differentiable_state`
34
    # dict
35
    p = p.clone()
36
    p.grad = grad
37
    opt_differentiable_state = {
38
        k: v.clone() if isinstance(v, torch.Tensor) else v
39
        for k, v in opt_differentiable_state.items()
40
    }
41
    opt = opt_class([p], **kwargs)
42
    opt.state[p].update(opt_differentiable_state)
43
    opt.step()
44
    return (p,) + tuple(
45
        v
46
        for v in opt.state[p].values()
47
        if isinstance(v, torch.Tensor) and v.requires_grad
48
    )
49

50

51
@skipIfTorchDynamo("Differentiable optimizers not supported")
52
class TestDifferentiableOptimizer(TestCase):
53
    def test_sgd(self):
54
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
55
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
56
        mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
57
        state = {"momentum_buffer": mbuff}
58
        gradcheck(
59
            _diff_fn,
60
            (
61
                p,
62
                grad,
63
                state,
64
                SGD,
65
                {"lr": 0.9, "differentiable": True},
66
                *state.values(),
67
            ),
68
        )
69

70
    def test_adam(self):
71
        state = {}
72
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
73
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
74
        # `step` is not a continuous variable (even though we define it as a float)
75
        # and so it shouldn't require gradients.
76
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
77
        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
78
        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
79
        state["max_exp_avg_sq"] = torch.rand(
80
            10, requires_grad=True, dtype=torch.float64
81
        )
82

83
        gradcheck(
84
            _diff_fn,
85
            (
86
                p,
87
                grad,
88
                state,
89
                Adam,
90
                {"lr": 0.9, "differentiable": True, "amsgrad": True},
91
                *state.values(),
92
            ),
93
        )
94

95
    def test_rmsprop(self):
96
        state = {}
97
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
98
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
99
        state["step"] = torch.zeros((), dtype=torch.float64)
100
        state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
101
        state["momentum_buffer"] = torch.rand(
102
            10, requires_grad=True, dtype=torch.float64
103
        )
104
        # This can cause issues with large values and nan due to sqrt ops
105
        state["grad_avg"] = 1e-2 * torch.rand(
106
            10, requires_grad=True, dtype=torch.float64
107
        )
108
        gradcheck(
109
            _diff_fn,
110
            (
111
                p,
112
                grad,
113
                state,
114
                RMSprop,
115
                {
116
                    "lr": 0.9,
117
                    "maximize": True,
118
                    "momentum": 0.9,
119
                    "differentiable": True,
120
                    "centered": True,
121
                    "weight_decay": 0.1,
122
                },
123
                *state.values(),
124
            ),
125
        )
126

127
    def test_adadelta(self):
128
        state = {}
129
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
130
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
131
        # `step` is not a continuous variable (even though we define it as a float)
132
        # and so it shouldn't require gradients.
133
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
134
        state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
135
        state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
136
        gradcheck(
137
            _diff_fn,
138
            (
139
                p,
140
                grad,
141
                state,
142
                Adadelta,
143
                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
144
                *state.values(),
145
            ),
146
        )
147

148
    def test_adagrad(self):
149
        state = {}
150
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
151
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
152
        # `step` is not a continuous variable (even though we define it as a float)
153
        # and so it shouldn't require gradients.
154
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
155
        state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
156
        gradcheck(
157
            _diff_fn,
158
            (
159
                p,
160
                grad,
161
                state,
162
                Adagrad,
163
                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
164
                *state.values(),
165
            ),
166
        )
167

168
    def test_adamax(self):
169
        state = {}
170
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
171
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
172
        # `step` is not a continuous variable (even though we define it as a float)
173
        # and so it shouldn't require gradients.
174
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
175
        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
176
        state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
177
        gradcheck(
178
            _diff_fn,
179
            (
180
                p,
181
                grad,
182
                state,
183
                Adamax,
184
                {"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
185
                *state.values(),
186
            ),
187
        )
188

189
    @skipIfTorchDynamo(
190
        "The inplace mu update fails with dynamo, "
191
        "since this is only happening when differentiable is enabled, skipping for now"
192
    )
193
    def test_asgd(self):
194
        state = {}
195
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
196
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
197
        # `step` `eta` & `mu` are not continuous variables (even though we define them as floats)
198
        # and so they shouldn't require gradients.
199
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
200
        state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
201
        state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
202
        state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
203

204
        gradcheck(
205
            _diff_fn,
206
            (
207
                p,
208
                grad,
209
                state,
210
                ASGD,
211
                {"lr": 0.9, "differentiable": True},
212
                *state.values(),
213
            ),
214
        )
215

216
    def test_rprop(self):
217
        state = {}
218
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
219
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
220
        # `step` is not a continuous variable (even though we define it as a float)
221
        # and so it shouldn't require gradients.
222
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
223
        state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
224
        state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
225

226
        gradcheck(
227
            _diff_fn,
228
            (
229
                p,
230
                grad,
231
                state,
232
                Rprop,
233
                {"lr": 0.9, "differentiable": True},
234
                *state.values(),
235
            ),
236
        )
237

238
    def test_adamw(self):
239
        state = {}
240
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
241
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
242
        # `step` is not a continuous variable (even though we define it as a float)
243
        # and so it shouldn't require gradients.
244
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
245
        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
246
        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
247
        state["max_exp_avg_sq"] = torch.rand(
248
            10, requires_grad=True, dtype=torch.float64
249
        )
250

251
        gradcheck(
252
            _diff_fn,
253
            (
254
                p,
255
                grad,
256
                state,
257
                AdamW,
258
                {"lr": 0.9, "differentiable": True, "amsgrad": True},
259
                *state.values(),
260
            ),
261
        )
262

263
    def test_nadam(self):
264
        state = {}
265
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
266
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
267
        # `step` is not a continuous variable (even though we define it as a float)
268
        # and so it shouldn't require gradients.
269
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
270
        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
271
        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
272
        state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
273

274
        gradcheck(
275
            _diff_fn,
276
            (
277
                p,
278
                grad,
279
                state,
280
                NAdam,
281
                {"lr": 0.9, "differentiable": True},
282
                *state.values(),
283
            ),
284
        )
285

286
        gradcheck(
287
            _diff_fn,
288
            (
289
                p,
290
                grad,
291
                state,
292
                NAdam,
293
                {"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True},
294
                *state.values(),
295
            ),
296
        )
297

298
    def test_radam(self):
299
        state = {}
300
        p = torch.rand(10, requires_grad=True, dtype=torch.float64)
301
        grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
302
        # `step` is not a continuous variable (even though we define it as a float)
303
        # and so it shouldn't require gradients.
304
        state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
305
        state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
306
        state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
307

308
        gradcheck(
309
            _diff_fn,
310
            (
311
                p,
312
                grad,
313
                state,
314
                RAdam,
315
                {"lr": 0.9, "differentiable": True},
316
                *state.values(),
317
            ),
318
        )
319
        gradcheck(
320
            _diff_fn,
321
            (
322
                p,
323
                grad,
324
                state,
325
                RAdam,
326
                {
327
                    "lr": 0.9,
328
                    "weight_decay": 0.1,
329
                    "decoupled_weight_decay": True,
330
                    "differentiable": True,
331
                },
332
                *state.values(),
333
            ),
334
        )
335

336

337
if __name__ == "__main__":
338
    print("These tests should be run through test/test_optim.py instead")
339

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.