pytorch-image-models

Форк
0
/
test_optim.py 
742 строки · 24.1 Кб
1
""" Optimzier Tests
2

3
These tests were adapted from PyTorch' optimizer tests.
4

5
"""
6
import math
7
import pytest
8
import functools
9
from copy import deepcopy
10

11
import torch
12
from torch.testing._internal.common_utils import TestCase
13
from torch.nn import Parameter
14
from timm.scheduler import PlateauLRScheduler
15

16
from timm.optim import create_optimizer_v2
17

18
import importlib
19
import os
20

21
torch_backend = os.environ.get('TORCH_BACKEND')
22
if torch_backend is not None:
23
    importlib.import_module(torch_backend)
24
torch_device = os.environ.get('TORCH_DEVICE', 'cuda')
25

26
# HACK relying on internal PyTorch test functionality for comparisons that I don't want to write
27
torch_tc = TestCase()
28

29

30
def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors):
31
    weight = Parameter(weight)
32
    bias = Parameter(bias)
33
    input = Parameter(input)
34
    optimizer = constructor(weight, bias)
35
    schedulers = []
36
    for scheduler_constructor in scheduler_constructors:
37
        schedulers.append(scheduler_constructor(optimizer))
38

39
    # to check if the optimizer can be printed as a string
40
    optimizer.__repr__()
41

42
    def fn():
43
        optimizer.zero_grad()
44
        y = weight.mv(input)
45
        if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
46
            y = y.cuda(bias.get_device())
47
        loss = (y + bias).pow(2).sum()
48
        loss.backward()
49
        return loss
50

51
    initial_value = fn().item()
52
    for _i in range(200):
53
        for scheduler in schedulers:
54
            if isinstance(scheduler, PlateauLRScheduler):
55
                val_loss = fn()
56
                scheduler.step(val_loss)
57
            else:
58
                scheduler.step()
59
        optimizer.step(fn)
60

61
    assert fn().item() < initial_value
62

63

64
def _test_state_dict(weight, bias, input, constructor):
65
    weight = Parameter(weight)
66
    bias = Parameter(bias)
67
    input = Parameter(input)
68

69
    def fn_base(optimizer, weight, bias):
70
        optimizer.zero_grad()
71
        i = input_device if weight.device.type != 'cpu' else input
72
        loss = (weight.mv(i) + bias).pow(2).sum()
73
        loss.backward()
74
        return loss
75

76
    optimizer = constructor(weight, bias)
77
    fn = functools.partial(fn_base, optimizer, weight, bias)
78

79
    # Prime the optimizer
80
    for _i in range(20):
81
        optimizer.step(fn)
82
    # Clone the weights and construct new optimizer for them
83
    with torch.no_grad():
84
        weight_c = Parameter(weight.clone().detach())
85
        bias_c = Parameter(bias.clone().detach())
86
    optimizer_c = constructor(weight_c, bias_c)
87
    fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
88
    # Load state dict
89
    state_dict = deepcopy(optimizer.state_dict())
90
    state_dict_c = deepcopy(optimizer.state_dict())
91
    optimizer_c.load_state_dict(state_dict_c)
92

93
    # Run both optimizations in parallel
94
    for _i in range(20):
95
        optimizer.step(fn)
96
        optimizer_c.step(fn_c)
97
        torch_tc.assertEqual(weight, weight_c)
98
        torch_tc.assertEqual(bias, bias_c)
99
    # Make sure state dict is deterministic with equal but not identical parameters
100
    torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
101
    # Make sure repeated parameters have identical representation in state dict
102
    optimizer_c.param_groups.extend(optimizer_c.param_groups)
103
    torch_tc.assertEqual(optimizer.state_dict()['param_groups'][-1], optimizer_c.state_dict()['param_groups'][-1])
104

105
    # Check that state dict can be loaded even when we cast parameters
106
    # to a different type and move to a different device.
107
    if torch_device == 'cpu':
108
        return
109
    elif torch_device == 'cuda' and not torch.cuda.is_available():
110
        return
111

112
    with torch.no_grad():
113
        input_device = Parameter(input.clone().detach().float().to(torch_device))
114
        weight_device = Parameter(weight.clone().detach().to(torch_device))
115
        bias_device = Parameter(bias.clone().detach().to(torch_device))
116
    optimizer_device = constructor(weight_device, bias_device)
117
    fn_device = functools.partial(fn_base, optimizer_device, weight_device, bias_device)
118

119
    state_dict = deepcopy(optimizer.state_dict())
120
    state_dict_c = deepcopy(optimizer.state_dict())
121
    optimizer_device.load_state_dict(state_dict_c)
122

123
    # Make sure state dict wasn't modified
124
    torch_tc.assertEqual(state_dict, state_dict_c)
125

126
    for _i in range(20):
127
        optimizer.step(fn)
128
        optimizer_device.step(fn_device)
129
        torch_tc.assertEqual(weight, weight_device)
130
        torch_tc.assertEqual(bias, bias_device)
131

132
    # validate deepcopy() copies all public attributes
133
    def getPublicAttr(obj):
134
        return set(k for k in obj.__dict__ if not k.startswith('_'))
135

136
    assert getPublicAttr(optimizer) == getPublicAttr(deepcopy(optimizer))
137

138

139
def _test_basic_cases(constructor, scheduler_constructors=None):
140
    if scheduler_constructors is None:
141
        scheduler_constructors = []
142
    _test_state_dict(
143
        torch.randn(10, 5),
144
        torch.randn(10),
145
        torch.randn(5),
146
        constructor
147
    )
148
    _test_basic_cases_template(
149
        torch.randn(10, 5),
150
        torch.randn(10),
151
        torch.randn(5),
152
        constructor,
153
        scheduler_constructors
154
    )
155
    # non-contiguous parameters
156
    _test_basic_cases_template(
157
        torch.randn(10, 5, 2)[..., 0],
158
        torch.randn(10, 2)[..., 0],
159
        torch.randn(5),
160
        constructor,
161
        scheduler_constructors
162
    )
163
    # CUDA
164
    if torch_device == 'cpu':
165
        return
166
    elif torch_device == 'cuda' and not torch.cuda.is_available():
167
        return
168

169
    _test_basic_cases_template(
170
        torch.randn(10, 5).to(torch_device),
171
        torch.randn(10).to(torch_device),
172
        torch.randn(5).to(torch_device),
173
        constructor,
174
        scheduler_constructors
175
    )
176

177

178
def _test_model(optimizer, params, device=torch.device('cpu')):
179
    weight = torch.tensor(
180
        [[-0.2109, -0.4976], [-0.1413, -0.3420], [-0.2524, 0.6976]],
181
        device=device, requires_grad=True)
182
    bias = torch.tensor([-0.1085, -0.2979, 0.6892], device=device, requires_grad=True)
183
    weight2 = torch.tensor([[-0.0508, -0.3941, -0.2843]], device=device, requires_grad=True)
184
    bias2 = torch.tensor([-0.0711], device=device, requires_grad=True)
185
    input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device=device).reshape(3, 2)
186

187
    model = torch.nn.Sequential(torch.nn.Linear(2, 3),
188
                                torch.nn.Sigmoid(),
189
                                torch.nn.Linear(3, 1),
190
                                torch.nn.Sigmoid())
191
    model.to(device)
192

193
    pretrained_dict = model.state_dict()
194
    pretrained_dict['0.weight'] = weight
195
    pretrained_dict['0.bias'] = bias
196
    pretrained_dict['2.weight'] = weight2
197
    pretrained_dict['2.bias'] = bias2
198
    model.load_state_dict(pretrained_dict)
199

200
    optimizer = create_optimizer_v2(model, opt=optimizer, **params)
201

202
    prev_loss = float('inf')
203
    for i in range(20):
204
        optimizer.zero_grad()
205
        output = model(input)
206
        loss = output.sum()
207
        loss.backward()
208
        loss = loss.item()
209
        assert loss < prev_loss
210
        prev_loss = loss
211
        optimizer.step()
212

213

214
def rosenbrock(tensor):
215
    x, y = tensor
216
    return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
217

218

219
def drosenbrock(tensor):
220
    x, y = tensor
221
    return torch.tensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2)))
222

223

224
def _test_rosenbrock(constructor, scheduler_constructors=None):
225
    if scheduler_constructors is None:
226
        scheduler_constructors = []
227
    params_t = torch.tensor([1.5, 1.5])
228

229
    params = Parameter(params_t)
230
    optimizer = constructor([params])
231
    schedulers = []
232
    for scheduler_constructor in scheduler_constructors:
233
        schedulers.append(scheduler_constructor(optimizer))
234

235
    solution = torch.tensor([1, 1])
236
    initial_dist = params.clone().detach().dist(solution)
237

238
    def eval(params, w):
239
        # Depending on w, provide only the x or y gradient
240
        optimizer.zero_grad()
241
        loss = rosenbrock(params)
242
        loss.backward()
243
        grad = drosenbrock(params.clone().detach())
244
        # NB: We torture test the optimizer by returning an
245
        # uncoalesced sparse tensor
246
        if w:
247
            i = torch.LongTensor([[0, 0]])
248
            x = grad[0]
249
            v = torch.tensor([x / 4., x - x / 4.])
250
        else:
251
            i = torch.LongTensor([[1, 1]])
252
            y = grad[1]
253
            v = torch.tensor([y - y / 4., y / 4.])
254
        x = torch.sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype)
255
        with torch.no_grad():
256
            params.grad = x.to_dense()
257
        return loss
258

259
    for i in range(2000):
260
        # Do cyclic coordinate descent
261
        w = i % 2
262
        optimizer.step(functools.partial(eval, params, w))
263
        for scheduler in schedulers:
264
            if isinstance(scheduler, PlateauLRScheduler):
265
                scheduler.step(rosenbrock(params))
266
            else:
267
                scheduler.step()
268

269
    torch_tc.assertLessEqual(params.clone().detach().dist(solution), initial_dist)
270

271

272
def _build_params_dict(weight, bias, **kwargs):
273
    return [{'params': [weight]}, dict(params=[bias], **kwargs)]
274

275

276
def _build_params_dict_single(weight, bias, **kwargs):
277
    return [dict(params=bias, **kwargs)]
278

279

280
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
281
# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts
282
@pytest.mark.parametrize('optimizer', ['sgd'])
283
def test_sgd(optimizer):
284
    _test_basic_cases(
285
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
286
    )
287
    _test_basic_cases(
288
        lambda weight, bias: create_optimizer_v2(
289
            _build_params_dict(weight, bias, lr=1e-2),
290
            optimizer,
291
            lr=1e-3)
292
    )
293
    _test_basic_cases(
294
        lambda weight, bias: create_optimizer_v2(
295
            _build_params_dict_single(weight, bias, lr=1e-2),
296
            optimizer,
297
            lr=1e-3)
298
    )
299
    _test_basic_cases(
300
        lambda weight, bias: create_optimizer_v2(
301
            _build_params_dict_single(weight, bias, lr=1e-2), optimizer)
302
    )
303
    # _test_basic_cases(
304
    #     lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
305
    #     [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
306
    # )
307
    # _test_basic_cases(
308
    #     lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
309
    #     [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")]
310
    # )
311
    # _test_basic_cases(
312
    #     lambda weight, bias: optimizer([weight, bias], lr=1e-3),
313
    #     [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")]
314
    # )
315
    # _test_basic_cases(
316
    #     lambda weight, bias: optimizer([weight, bias], lr=1e-3),
317
    #     [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
318
    #      lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4)]
319
    # )
320
    # _test_basic_cases(
321
    #     lambda weight, bias: optimizer([weight, bias], lr=1e-3),
322
    #     [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
323
    #      lambda opt: ReduceLROnPlateau(opt)]
324
    # )
325
    # _test_basic_cases(
326
    #     lambda weight, bias: optimizer([weight, bias], lr=1e-3),
327
    #     [lambda opt: StepLR(opt, gamma=0.99, step_size=10),
328
    #      lambda opt: ExponentialLR(opt, gamma=0.99),
329
    #      lambda opt: ReduceLROnPlateau(opt)]
330
    # )
331
    _test_basic_cases(
332
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1)
333
    )
334
    _test_basic_cases(
335
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1, weight_decay=.1)
336
    )
337
    _test_rosenbrock(
338
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
339
    )
340
    _test_model(optimizer, dict(lr=1e-3))
341

342

343
@pytest.mark.parametrize('optimizer',  ['adamw', 'adam', 'nadam', 'adamax'])
344
def test_adam(optimizer):
345
    _test_basic_cases(
346
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
347
    )
348
    _test_basic_cases(
349
        lambda weight, bias: create_optimizer_v2(
350
            _build_params_dict(weight, bias, lr=3e-3),
351
            optimizer,
352
            lr=1e-3)
353
    )
354
    _test_basic_cases(
355
        lambda weight, bias: create_optimizer_v2(
356
            _build_params_dict_single(weight, bias, lr=3e-3),
357
            optimizer,
358
            lr=1e-3)
359
    )
360
    _test_rosenbrock(
361
        lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
362
    )
363
    _test_model(optimizer, dict(lr=5e-2))
364

365

366
@pytest.mark.parametrize('optimizer',  ['adabelief'])
367
def test_adabelief(optimizer):
368
    _test_basic_cases(
369
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
370
    )
371
    _test_basic_cases(
372
        lambda weight, bias: create_optimizer_v2(
373
            _build_params_dict(weight, bias, lr=3e-3),
374
            optimizer,
375
            lr=1e-3)
376
    )
377
    _test_basic_cases(
378
        lambda weight, bias: create_optimizer_v2(
379
            _build_params_dict_single(weight, bias, lr=3e-3),
380
            optimizer,
381
            lr=1e-3)
382
    )
383
    _test_basic_cases(
384
        lambda weight, bias: create_optimizer_v2(
385
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
386
    )
387
    _test_basic_cases(
388
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
389
    )
390
    _test_rosenbrock(
391
        lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
392
    )
393
    _test_model(optimizer, dict(lr=5e-2))
394

395

396
@pytest.mark.parametrize('optimizer',  ['radam', 'radabelief'])
397
def test_rectified(optimizer):
398
    _test_basic_cases(
399
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
400
    )
401
    _test_basic_cases(
402
        lambda weight, bias: create_optimizer_v2(
403
            _build_params_dict(weight, bias, lr=3e-3),
404
            optimizer,
405
            lr=1e-3)
406
    )
407
    _test_basic_cases(
408
        lambda weight, bias: create_optimizer_v2(
409
            _build_params_dict_single(weight, bias, lr=3e-3),
410
            optimizer,
411
            lr=1e-3)
412
    )
413
    _test_rosenbrock(
414
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
415
    )
416
    _test_model(optimizer, dict(lr=1e-3))
417

418

419
@pytest.mark.parametrize('optimizer',   ['adadelta', 'adagrad'])
420
def test_adaother(optimizer):
421
    _test_basic_cases(
422
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
423
    )
424
    _test_basic_cases(
425
        lambda weight, bias: create_optimizer_v2(
426
            _build_params_dict(weight, bias, lr=3e-3),
427
            optimizer,
428
            lr=1e-3)
429
    )
430
    _test_basic_cases(
431
        lambda weight, bias: create_optimizer_v2(
432
            _build_params_dict_single(weight, bias, lr=3e-3),
433
            optimizer,
434
            lr=1e-3)
435
    )
436
    _test_basic_cases(
437
        lambda weight, bias: create_optimizer_v2(
438
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
439
    )
440
    _test_basic_cases(
441
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
442
    )
443
    _test_rosenbrock(
444
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-1)
445
    )
446
    _test_model(optimizer, dict(lr=5e-2))
447

448

449
@pytest.mark.parametrize('optimizer',   ['adafactor'])
450
def test_adafactor(optimizer):
451
    _test_basic_cases(
452
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
453
    )
454
    _test_basic_cases(
455
        lambda weight, bias: create_optimizer_v2(
456
            _build_params_dict(weight, bias, lr=3e-3),
457
            optimizer,
458
            lr=1e-3)
459
    )
460
    _test_basic_cases(
461
        lambda weight, bias: create_optimizer_v2(
462
            _build_params_dict_single(weight, bias, lr=3e-3),
463
            optimizer,
464
            lr=1e-3)
465
    )
466
    _test_basic_cases(
467
        lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer)
468
    )
469
    _test_basic_cases(
470
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
471
    )
472
    _test_rosenbrock(
473
        lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
474
    )
475
    _test_model(optimizer, dict(lr=5e-2))
476

477

478
@pytest.mark.parametrize('optimizer',  ['lamb', 'lambc'])
479
def test_lamb(optimizer):
480
    _test_basic_cases(
481
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
482
    )
483
    _test_basic_cases(
484
        lambda weight, bias: create_optimizer_v2(
485
            _build_params_dict(weight, bias, lr=1e-3),
486
            optimizer,
487
            lr=1e-3)
488
    )
489
    _test_basic_cases(
490
        lambda weight, bias: create_optimizer_v2(
491
            _build_params_dict_single(weight, bias, lr=1e-3),
492
            optimizer,
493
            lr=1e-3)
494
    )
495
    _test_basic_cases(
496
        lambda weight, bias: create_optimizer_v2(
497
            _build_params_dict_single(weight, bias, lr=1e-3), optimizer)
498
    )
499
    _test_rosenbrock(
500
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
501
    )
502
    _test_model(optimizer, dict(lr=1e-3))
503

504

505
@pytest.mark.parametrize('optimizer',  ['lars', 'larc', 'nlars', 'nlarc'])
506
def test_lars(optimizer):
507
    _test_basic_cases(
508
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
509
    )
510
    _test_basic_cases(
511
        lambda weight, bias: create_optimizer_v2(
512
            _build_params_dict(weight, bias, lr=1e-3),
513
            optimizer,
514
            lr=1e-3)
515
    )
516
    _test_basic_cases(
517
        lambda weight, bias: create_optimizer_v2(
518
            _build_params_dict_single(weight, bias, lr=1e-3),
519
            optimizer,
520
            lr=1e-3)
521
    )
522
    _test_basic_cases(
523
        lambda weight, bias: create_optimizer_v2(
524
            _build_params_dict_single(weight, bias, lr=1e-3), optimizer)
525
    )
526
    _test_rosenbrock(
527
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
528
    )
529
    _test_model(optimizer, dict(lr=1e-3))
530

531

532
@pytest.mark.parametrize('optimizer',  ['madgrad', 'madgradw'])
533
def test_madgrad(optimizer):
534
    _test_basic_cases(
535
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
536
    )
537
    _test_basic_cases(
538
        lambda weight, bias: create_optimizer_v2(
539
            _build_params_dict(weight, bias, lr=3e-3),
540
            optimizer,
541
            lr=1e-3)
542
    )
543
    _test_basic_cases(
544
        lambda weight, bias: create_optimizer_v2(
545
            _build_params_dict_single(weight, bias, lr=3e-3),
546
            optimizer,
547
            lr=1e-3)
548
    )
549
    _test_basic_cases(
550
        lambda weight, bias: create_optimizer_v2(
551
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
552
    )
553
    _test_rosenbrock(
554
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
555
    )
556
    _test_model(optimizer, dict(lr=1e-2))
557

558

559
@pytest.mark.parametrize('optimizer',  ['novograd'])
560
def test_novograd(optimizer):
561
    _test_basic_cases(
562
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
563
    )
564
    _test_basic_cases(
565
        lambda weight, bias: create_optimizer_v2(
566
            _build_params_dict(weight, bias, lr=3e-3),
567
            optimizer,
568
            lr=1e-3)
569
    )
570
    _test_basic_cases(
571
        lambda weight, bias: create_optimizer_v2(
572
            _build_params_dict_single(weight, bias, lr=3e-3),
573
            optimizer,
574
            lr=1e-3)
575
    )
576
    _test_basic_cases(
577
        lambda weight, bias: create_optimizer_v2(
578
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
579
    )
580
    _test_rosenbrock(
581
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
582
    )
583
    _test_model(optimizer, dict(lr=1e-3))
584

585

586
@pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf'])
587
def test_rmsprop(optimizer):
588
    _test_basic_cases(
589
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
590
    )
591
    _test_basic_cases(
592
        lambda weight, bias: create_optimizer_v2(
593
            _build_params_dict(weight, bias, lr=3e-3),
594
            optimizer,
595
            lr=1e-3)
596
    )
597
    _test_basic_cases(
598
        lambda weight, bias: create_optimizer_v2(
599
            _build_params_dict_single(weight, bias, lr=3e-3),
600
            optimizer,
601
            lr=1e-3)
602
    )
603
    _test_basic_cases(
604
        lambda weight, bias: create_optimizer_v2(
605
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
606
    )
607
    _test_rosenbrock(
608
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
609
    )
610
    _test_model(optimizer, dict(lr=1e-2))
611

612

613
@pytest.mark.parametrize('optimizer', ['adamp'])
614
def test_adamp(optimizer):
615
    _test_basic_cases(
616
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
617
    )
618
    _test_basic_cases(
619
        lambda weight, bias: create_optimizer_v2(
620
            _build_params_dict(weight, bias, lr=3e-3),
621
            optimizer,
622
            lr=1e-3)
623
    )
624
    _test_basic_cases(
625
        lambda weight, bias: create_optimizer_v2(
626
            _build_params_dict_single(weight, bias, lr=3e-3),
627
            optimizer,
628
            lr=1e-3)
629
    )
630
    _test_basic_cases(
631
        lambda weight, bias: create_optimizer_v2(
632
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
633
    )
634
    _test_rosenbrock(
635
        lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
636
    )
637
    _test_model(optimizer, dict(lr=5e-2))
638

639

640
@pytest.mark.parametrize('optimizer', ['sgdp'])
641
def test_sgdp(optimizer):
642
    _test_basic_cases(
643
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
644
    )
645
    _test_basic_cases(
646
        lambda weight, bias: create_optimizer_v2(
647
            _build_params_dict(weight, bias, lr=3e-3),
648
            optimizer,
649
            lr=1e-3)
650
    )
651
    _test_basic_cases(
652
        lambda weight, bias: create_optimizer_v2(
653
            _build_params_dict_single(weight, bias, lr=3e-3),
654
            optimizer,
655
            lr=1e-3)
656
    )
657
    _test_basic_cases(
658
        lambda weight, bias: create_optimizer_v2(
659
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
660
    )
661
    _test_rosenbrock(
662
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
663
    )
664
    _test_model(optimizer, dict(lr=1e-3))
665

666

667
@pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum'])
668
def test_lookahead_sgd(optimizer):
669
    _test_basic_cases(
670
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
671
    )
672
    _test_basic_cases(
673
        lambda weight, bias: create_optimizer_v2(
674
            _build_params_dict(weight, bias, lr=3e-3),
675
            optimizer,
676
            lr=1e-3)
677
    )
678
    _test_basic_cases(
679
        lambda weight, bias: create_optimizer_v2(
680
            _build_params_dict_single(weight, bias, lr=3e-3),
681
            optimizer,
682
            lr=1e-3)
683
    )
684
    _test_basic_cases(
685
        lambda weight, bias: create_optimizer_v2(
686
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
687
    )
688
    _test_rosenbrock(
689
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
690
    )
691

692

693
@pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam'])
694
def test_lookahead_adam(optimizer):
695
    _test_basic_cases(
696
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
697
    )
698
    _test_basic_cases(
699
        lambda weight, bias: create_optimizer_v2(
700
            _build_params_dict(weight, bias, lr=3e-3),
701
            optimizer,
702
            lr=1e-3)
703
    )
704
    _test_basic_cases(
705
        lambda weight, bias: create_optimizer_v2(
706
            _build_params_dict_single(weight, bias, lr=3e-3),
707
            optimizer,
708
            lr=1e-3)
709
    )
710
    _test_basic_cases(
711
        lambda weight, bias: create_optimizer_v2(
712
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
713
    )
714
    _test_rosenbrock(
715
        lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
716
    )
717

718

719
@pytest.mark.parametrize('optimizer', ['lookahead_radam'])
720
def test_lookahead_radam(optimizer):
721
    _test_basic_cases(
722
        lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
723
    )
724
    _test_basic_cases(
725
        lambda weight, bias: create_optimizer_v2(
726
            _build_params_dict(weight, bias, lr=3e-3),
727
            optimizer,
728
            lr=1e-3)
729
    )
730
    _test_basic_cases(
731
        lambda weight, bias: create_optimizer_v2(
732
            _build_params_dict_single(weight, bias, lr=3e-3),
733
            optimizer,
734
            lr=1e-3)
735
    )
736
    _test_basic_cases(
737
        lambda weight, bias: create_optimizer_v2(
738
            _build_params_dict_single(weight, bias, lr=3e-3), optimizer)
739
    )
740
    _test_rosenbrock(
741
        lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
742
    )
743

744

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.