optuna

Форк
0
/
test_study.py 
1625 строк · 55.3 Кб
1
from __future__ import annotations
2

3
from concurrent.futures import as_completed
4
from concurrent.futures import ThreadPoolExecutor
5
import copy
6
import multiprocessing
7
import pickle
8
import platform
9
import threading
10
import time
11
from typing import Any
12
from typing import Callable
13
from unittest.mock import Mock
14
from unittest.mock import patch
15
import uuid
16
import warnings
17

18
import _pytest.capture
19
import pytest
20

21
from optuna import copy_study
22
from optuna import create_study
23
from optuna import create_trial
24
from optuna import delete_study
25
from optuna import distributions
26
from optuna import get_all_study_names
27
from optuna import get_all_study_summaries
28
from optuna import load_study
29
from optuna import logging
30
from optuna import Study
31
from optuna import Trial
32
from optuna import TrialPruned
33
from optuna.exceptions import DuplicatedStudyError
34
from optuna.exceptions import ExperimentalWarning
35
from optuna.study import StudyDirection
36
from optuna.study.study import _SYSTEM_ATTR_METRIC_NAMES
37
from optuna.testing.objectives import fail_objective
38
from optuna.testing.storages import STORAGE_MODES
39
from optuna.testing.storages import StorageSupplier
40
from optuna.trial import FrozenTrial
41
from optuna.trial import TrialState
42

43

44
CallbackFuncType = Callable[[Study, FrozenTrial], None]
45

46

47
def func(trial: Trial) -> float:
48
    x = trial.suggest_float("x", -10.0, 10.0)
49
    y = trial.suggest_float("y", 20, 30, log=True)
50
    z = trial.suggest_categorical("z", (-1.0, 1.0))
51
    return (x - 2) ** 2 + (y - 25) ** 2 + z
52

53

54
class Func:
55
    def __init__(self, sleep_sec: float | None = None) -> None:
56
        self.n_calls = 0
57
        self.sleep_sec = sleep_sec
58
        self.lock = threading.Lock()
59

60
    def __call__(self, trial: Trial) -> float:
61
        with self.lock:
62
            self.n_calls += 1
63

64
        # Sleep for testing parallelism.
65
        if self.sleep_sec is not None:
66
            time.sleep(self.sleep_sec)
67

68
        value = func(trial)
69
        check_params(trial.params)
70
        return value
71

72

73
def check_params(params: dict[str, Any]) -> None:
74
    assert sorted(params.keys()) == ["x", "y", "z"]
75

76

77
def check_value(value: float | None) -> None:
78
    assert isinstance(value, float)
79
    assert -1.0 <= value <= 12.0**2 + 5.0**2 + 1.0
80

81

82
def check_frozen_trial(frozen_trial: FrozenTrial) -> None:
83
    if frozen_trial.state == TrialState.COMPLETE:
84
        check_params(frozen_trial.params)
85
        check_value(frozen_trial.value)
86

87

88
def check_study(study: Study) -> None:
89
    for trial in study.trials:
90
        check_frozen_trial(trial)
91

92
    assert not study._is_multi_objective()
93

94
    complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
95
    if len(complete_trials) == 0:
96
        with pytest.raises(ValueError):
97
            study.best_params
98
        with pytest.raises(ValueError):
99
            study.best_value
100
        with pytest.raises(ValueError):
101
            study.best_trial
102
    else:
103
        check_params(study.best_params)
104
        check_value(study.best_value)
105
        check_frozen_trial(study.best_trial)
106

107

108
def stop_objective(threshold_number: int) -> Callable[[Trial], float]:
109
    def objective(trial: Trial) -> float:
110
        if trial.number >= threshold_number:
111
            trial.study.stop()
112

113
        return trial.number
114

115
    return objective
116

117

118
def test_optimize_trivial_in_memory_new() -> None:
119
    study = create_study()
120
    study.optimize(func, n_trials=10)
121
    check_study(study)
122

123

124
def test_optimize_trivial_in_memory_resume() -> None:
125
    study = create_study()
126
    study.optimize(func, n_trials=10)
127
    study.optimize(func, n_trials=10)
128
    check_study(study)
129

130

131
def test_optimize_trivial_rdb_resume_study() -> None:
132
    study = create_study(storage="sqlite:///:memory:")
133
    study.optimize(func, n_trials=10)
134
    check_study(study)
135

136

137
def test_optimize_with_direction() -> None:
138
    study = create_study(direction="minimize")
139
    study.optimize(func, n_trials=10)
140
    assert study.direction == StudyDirection.MINIMIZE
141
    check_study(study)
142

143
    study = create_study(direction="maximize")
144
    study.optimize(func, n_trials=10)
145
    assert study.direction == StudyDirection.MAXIMIZE
146
    check_study(study)
147

148
    with pytest.raises(ValueError):
149
        create_study(direction="test")
150

151

152
@pytest.mark.parametrize("n_trials", (0, 1, 20))
153
@pytest.mark.parametrize("n_jobs", (1, 2, -1))
154
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
155
def test_optimize_parallel(n_trials: int, n_jobs: int, storage_mode: str) -> None:
156
    f = Func()
157

158
    with StorageSupplier(storage_mode) as storage:
159
        study = create_study(storage=storage)
160
        study.optimize(f, n_trials=n_trials, n_jobs=n_jobs)
161
        assert f.n_calls == len(study.trials) == n_trials
162
        check_study(study)
163

164

165
def test_optimize_with_thread_pool_executor() -> None:
166
    def objective(t: Trial) -> float:
167
        return t.suggest_float("x", -10, 10)
168

169
    study = create_study()
170
    with ThreadPoolExecutor(max_workers=5) as pool:
171
        for _ in range(10):
172
            pool.submit(study.optimize, objective, n_trials=10)
173
    assert len(study.trials) == 100
174

175

176
@pytest.mark.parametrize("n_trials", (0, 1, 20, None))
177
@pytest.mark.parametrize("n_jobs", (1, 2, -1))
178
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
179
def test_optimize_parallel_timeout(n_trials: int, n_jobs: int, storage_mode: str) -> None:
180
    sleep_sec = 0.1
181
    timeout_sec = 1.0
182
    f = Func(sleep_sec=sleep_sec)
183

184
    with StorageSupplier(storage_mode) as storage:
185
        study = create_study(storage=storage)
186
        study.optimize(f, n_trials=n_trials, n_jobs=n_jobs, timeout=timeout_sec)
187

188
        assert f.n_calls == len(study.trials)
189

190
        if n_trials is not None:
191
            assert f.n_calls <= n_trials
192

193
        # A thread can process at most (timeout_sec / sleep_sec + 1) trials.
194
        n_jobs_actual = n_jobs if n_jobs != -1 else multiprocessing.cpu_count()
195
        max_calls = (timeout_sec / sleep_sec + 1) * n_jobs_actual
196
        assert f.n_calls <= max_calls
197

198
        check_study(study)
199

200

201
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
202
def test_optimize_with_catch(storage_mode: str) -> None:
203
    with StorageSupplier(storage_mode) as storage:
204
        study = create_study(storage=storage)
205

206
        # Test default exceptions.
207
        with pytest.raises(ValueError):
208
            study.optimize(fail_objective, n_trials=20)
209
        assert len(study.trials) == 1
210
        assert all(trial.state == TrialState.FAIL for trial in study.trials)
211

212
        # Test acceptable exception.
213
        study.optimize(fail_objective, n_trials=20, catch=(ValueError,))
214
        assert len(study.trials) == 21
215
        assert all(trial.state == TrialState.FAIL for trial in study.trials)
216

217
        # Test trial with unacceptable exception.
218
        with pytest.raises(ValueError):
219
            study.optimize(fail_objective, n_trials=20, catch=(ArithmeticError,))
220
        assert len(study.trials) == 22
221
        assert all(trial.state == TrialState.FAIL for trial in study.trials)
222

223

224
@pytest.mark.parametrize("catch", [ValueError, (ValueError,), [ValueError], {ValueError}])
225
def test_optimize_with_catch_valid_type(catch: Any) -> None:
226
    study = create_study()
227
    study.optimize(fail_objective, n_trials=20, catch=catch)
228

229

230
@pytest.mark.parametrize("catch", [None, 1])
231
def test_optimize_with_catch_invalid_type(catch: Any) -> None:
232
    study = create_study()
233

234
    with pytest.raises(TypeError):
235
        study.optimize(fail_objective, n_trials=20, catch=catch)
236

237

238
@pytest.mark.parametrize("n_jobs", (2, -1))
239
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
240
def test_optimize_with_reseeding(n_jobs: int, storage_mode: str) -> None:
241
    f = Func()
242

243
    with StorageSupplier(storage_mode) as storage:
244
        study = create_study(storage=storage)
245
        sampler = study.sampler
246
        with patch.object(sampler, "reseed_rng", wraps=sampler.reseed_rng) as mock_object:
247
            study.optimize(f, n_trials=1, n_jobs=2)
248
            assert mock_object.call_count == 1
249

250

251
def test_call_another_study_optimize_in_optimize() -> None:
252
    def inner_objective(t: Trial) -> float:
253
        return t.suggest_float("x", -10, 10)
254

255
    def objective(t: Trial) -> float:
256
        inner_study = create_study()
257
        inner_study.enqueue_trial({"x": t.suggest_int("initial_point", -10, 10)})
258
        inner_study.optimize(inner_objective, n_trials=10)
259
        return inner_study.best_value
260

261
    study = create_study()
262
    study.optimize(objective, n_trials=10)
263
    assert len(study.trials) == 10
264

265

266
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
267
def test_study_set_and_get_user_attrs(storage_mode: str) -> None:
268
    with StorageSupplier(storage_mode) as storage:
269
        study = create_study(storage=storage)
270

271
        study.set_user_attr("dataset", "MNIST")
272
        assert study.user_attrs["dataset"] == "MNIST"
273

274

275
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
276
def test_trial_set_and_get_user_attrs(storage_mode: str) -> None:
277
    def f(trial: Trial) -> float:
278
        trial.set_user_attr("train_accuracy", 1)
279
        assert trial.user_attrs["train_accuracy"] == 1
280
        return 0.0
281

282
    with StorageSupplier(storage_mode) as storage:
283
        study = create_study(storage=storage)
284
        study.optimize(f, n_trials=1)
285
        frozen_trial = study.trials[0]
286
        assert frozen_trial.user_attrs["train_accuracy"] == 1
287

288

289
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
290
@pytest.mark.parametrize("include_best_trial", [True, False])
291
def test_get_all_study_summaries(storage_mode: str, include_best_trial: bool) -> None:
292
    with StorageSupplier(storage_mode) as storage:
293
        study = create_study(storage=storage)
294
        study.optimize(func, n_trials=5)
295

296
        summaries = get_all_study_summaries(study._storage, include_best_trial)
297
        summary = [s for s in summaries if s._study_id == study._study_id][0]
298

299
        assert summary.study_name == study.study_name
300
        assert summary.n_trials == 5
301
        if include_best_trial:
302
            assert summary.best_trial is not None
303
        else:
304
            assert summary.best_trial is None
305

306

307
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
308
def test_get_all_study_summaries_with_no_trials(storage_mode: str) -> None:
309
    with StorageSupplier(storage_mode) as storage:
310
        study = create_study(storage=storage)
311

312
        summaries = get_all_study_summaries(study._storage)
313
        summary = [s for s in summaries if s._study_id == study._study_id][0]
314

315
        assert summary.study_name == study.study_name
316
        assert summary.n_trials == 0
317
        assert summary.datetime_start is None
318

319

320
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
321
def test_get_all_study_names(storage_mode: str) -> None:
322
    with StorageSupplier(storage_mode) as storage:
323
        n_studies = 5
324

325
        studies = [create_study(storage=storage) for _ in range(n_studies)]
326
        study_names = get_all_study_names(storage)
327

328
        assert len(study_names) == n_studies
329
        for study, study_name in zip(studies, study_names):
330
            assert study_name == study.study_name
331

332

333
def test_study_pickle() -> None:
334
    study_1 = create_study()
335
    study_1.optimize(func, n_trials=10)
336
    check_study(study_1)
337
    assert len(study_1.trials) == 10
338
    dumped_bytes = pickle.dumps(study_1)
339

340
    study_2 = pickle.loads(dumped_bytes)
341
    check_study(study_2)
342
    assert len(study_2.trials) == 10
343

344
    study_2.optimize(func, n_trials=10)
345
    check_study(study_2)
346
    assert len(study_2.trials) == 20
347

348

349
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
350
def test_create_study(storage_mode: str) -> None:
351
    with StorageSupplier(storage_mode) as storage:
352
        # Test creating a new study.
353
        study = create_study(storage=storage, load_if_exists=False)
354

355
        # Test `load_if_exists=True` with existing study.
356
        create_study(study_name=study.study_name, storage=storage, load_if_exists=True)
357

358
        with pytest.raises(DuplicatedStudyError):
359
            create_study(study_name=study.study_name, storage=storage, load_if_exists=False)
360

361

362
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
363
def test_load_study(storage_mode: str) -> None:
364
    with StorageSupplier(storage_mode) as storage:
365
        if storage is None:
366
            # `InMemoryStorage` can not be used with `load_study` function.
367
            return
368

369
        study_name = str(uuid.uuid4())
370

371
        with pytest.raises(KeyError):
372
            # Test loading an unexisting study.
373
            load_study(study_name=study_name, storage=storage)
374

375
        # Create a new study.
376
        created_study = create_study(study_name=study_name, storage=storage)
377

378
        # Test loading an existing study.
379
        loaded_study = load_study(study_name=study_name, storage=storage)
380
        assert created_study._study_id == loaded_study._study_id
381

382

383
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
384
def test_load_study_study_name_none(storage_mode: str) -> None:
385
    with StorageSupplier(storage_mode) as storage:
386
        if storage is None:
387
            # `InMemoryStorage` can not be used with `load_study` function.
388
            return
389

390
        study_name = str(uuid.uuid4())
391

392
        _ = create_study(study_name=study_name, storage=storage)
393

394
        loaded_study = load_study(study_name=None, storage=storage)
395

396
        assert loaded_study.study_name == study_name
397

398
        study_name = str(uuid.uuid4())
399

400
        _ = create_study(study_name=study_name, storage=storage)
401

402
        # Ambiguous study.
403
        with pytest.raises(ValueError):
404
            load_study(study_name=None, storage=storage)
405

406

407
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
408
def test_delete_study(storage_mode: str) -> None:
409
    with StorageSupplier(storage_mode) as storage:
410
        # Test deleting a non-existing study.
411
        with pytest.raises(KeyError):
412
            delete_study(study_name="invalid-study-name", storage=storage)
413

414
        # Test deleting an existing study.
415
        study = create_study(storage=storage, load_if_exists=False)
416
        delete_study(study_name=study.study_name, storage=storage)
417

418
        # Test failed to delete the study which is already deleted.
419
        with pytest.raises(KeyError):
420
            delete_study(study_name=study.study_name, storage=storage)
421

422

423
@pytest.mark.parametrize("from_storage_mode", STORAGE_MODES)
424
@pytest.mark.parametrize("to_storage_mode", STORAGE_MODES)
425
def test_copy_study(from_storage_mode: str, to_storage_mode: str) -> None:
426
    with StorageSupplier(from_storage_mode) as from_storage, StorageSupplier(
427
        to_storage_mode
428
    ) as to_storage:
429
        from_study = create_study(storage=from_storage, directions=["maximize", "minimize"])
430
        from_study._storage.set_study_system_attr(from_study._study_id, "foo", "bar")
431
        from_study.set_user_attr("baz", "qux")
432
        from_study.optimize(
433
            lambda t: (t.suggest_float("x0", 0, 1), t.suggest_float("x1", 0, 1)), n_trials=3
434
        )
435

436
        copy_study(
437
            from_study_name=from_study.study_name,
438
            from_storage=from_storage,
439
            to_storage=to_storage,
440
        )
441

442
        to_study = load_study(study_name=from_study.study_name, storage=to_storage)
443

444
        assert to_study.study_name == from_study.study_name
445
        assert to_study.directions == from_study.directions
446
        to_study_system_attrs = to_study._storage.get_study_system_attrs(to_study._study_id)
447
        from_study_system_attrs = from_study._storage.get_study_system_attrs(from_study._study_id)
448
        assert to_study_system_attrs == from_study_system_attrs
449
        assert to_study.user_attrs == from_study.user_attrs
450
        assert len(to_study.trials) == len(from_study.trials)
451

452

453
@pytest.mark.parametrize("from_storage_mode", STORAGE_MODES)
454
@pytest.mark.parametrize("to_storage_mode", STORAGE_MODES)
455
def test_copy_study_to_study_name(from_storage_mode: str, to_storage_mode: str) -> None:
456
    with StorageSupplier(from_storage_mode) as from_storage, StorageSupplier(
457
        to_storage_mode
458
    ) as to_storage:
459
        from_study = create_study(study_name="foo", storage=from_storage)
460
        _ = create_study(study_name="foo", storage=to_storage)
461

462
        with pytest.raises(DuplicatedStudyError):
463
            copy_study(
464
                from_study_name=from_study.study_name,
465
                from_storage=from_storage,
466
                to_storage=to_storage,
467
            )
468

469
        copy_study(
470
            from_study_name=from_study.study_name,
471
            from_storage=from_storage,
472
            to_storage=to_storage,
473
            to_study_name="bar",
474
        )
475

476
        _ = load_study(study_name="bar", storage=to_storage)
477

478

479
def test_nested_optimization() -> None:
480
    def objective(trial: Trial) -> float:
481
        with pytest.raises(RuntimeError):
482
            trial.study.optimize(lambda _: 0.0, n_trials=1)
483

484
        return 1.0
485

486
    study = create_study()
487
    study.optimize(objective, n_trials=10, catch=())
488

489

490
def test_stop_in_objective() -> None:
491
    # Test stopping the optimization: it should stop once the trial number reaches 4.
492
    study = create_study()
493
    study.optimize(stop_objective(4), n_trials=10)
494
    assert len(study.trials) == 5
495

496
    # Test calling `optimize` again: it should stop once the trial number reaches 11.
497
    study.optimize(stop_objective(11), n_trials=10)
498
    assert len(study.trials) == 12
499

500

501
def test_stop_in_callback() -> None:
502
    def callback(study: Study, trial: FrozenTrial) -> None:
503
        if trial.number >= 4:
504
            study.stop()
505

506
    # Test stopping the optimization inside a callback.
507
    study = create_study()
508
    study.optimize(lambda _: 1.0, n_trials=10, callbacks=[callback])
509
    assert len(study.trials) == 5
510

511

512
def test_stop_n_jobs() -> None:
513
    def callback(study: Study, trial: FrozenTrial) -> None:
514
        if trial.number >= 4:
515
            study.stop()
516

517
    study = create_study()
518
    study.optimize(lambda _: 1.0, n_trials=None, callbacks=[callback], n_jobs=2)
519
    assert 5 <= len(study.trials) <= 6
520

521

522
def test_stop_outside_optimize() -> None:
523
    # Test stopping outside the optimization: it should raise `RuntimeError`.
524
    study = create_study()
525
    with pytest.raises(RuntimeError):
526
        study.stop()
527

528
    # Test calling `optimize` after the `RuntimeError` is caught.
529
    study.optimize(lambda _: 1.0, n_trials=1)
530

531

532
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
533
def test_add_trial(storage_mode: str) -> None:
534
    with StorageSupplier(storage_mode) as storage:
535
        study = create_study(storage=storage)
536
        assert len(study.trials) == 0
537

538
        trial = create_trial(value=0.8)
539
        study.add_trial(trial)
540
        assert len(study.trials) == 1
541
        assert study.trials[0].number == 0
542
        assert study.best_value == 0.8
543

544

545
def test_add_trial_invalid_values_length() -> None:
546
    study = create_study()
547
    trial = create_trial(values=[0, 0])
548
    with pytest.raises(ValueError):
549
        study.add_trial(trial)
550

551
    study = create_study(directions=["minimize", "minimize"])
552
    trial = create_trial(value=0)
553
    with pytest.raises(ValueError):
554
        study.add_trial(trial)
555

556

557
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
558
def test_add_trials(storage_mode: str) -> None:
559
    with StorageSupplier(storage_mode) as storage:
560
        study = create_study(storage=storage)
561
        assert len(study.trials) == 0
562

563
        study.add_trials([])
564
        assert len(study.trials) == 0
565

566
        trials = [create_trial(value=i) for i in range(3)]
567
        study.add_trials(trials)
568
        assert len(study.trials) == 3
569
        for i, trial in enumerate(study.trials):
570
            assert trial.number == i
571
            assert trial.value == i
572

573
        other_study = create_study(storage=storage)
574
        other_study.add_trials(study.trials)
575
        assert len(other_study.trials) == 3
576
        for i, trial in enumerate(other_study.trials):
577
            assert trial.number == i
578
            assert trial.value == i
579

580

581
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
582
def test_enqueue_trial_properly_sets_param_values(storage_mode: str) -> None:
583
    with StorageSupplier(storage_mode) as storage:
584
        study = create_study(storage=storage)
585
        assert len(study.trials) == 0
586

587
        study.enqueue_trial(params={"x": -5, "y": 5})
588
        study.enqueue_trial(params={"x": -1, "y": 0})
589

590
        def objective(trial: Trial) -> float:
591
            x = trial.suggest_int("x", -10, 10)
592
            y = trial.suggest_int("y", -10, 10)
593
            return x**2 + y**2
594

595
        study.optimize(objective, n_trials=2)
596
        t0 = study.trials[0]
597
        assert t0.params["x"] == -5
598
        assert t0.params["y"] == 5
599

600
        t1 = study.trials[1]
601
        assert t1.params["x"] == -1
602
        assert t1.params["y"] == 0
603

604

605
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
606
def test_enqueue_trial_with_unfixed_parameters(storage_mode: str) -> None:
607
    with StorageSupplier(storage_mode) as storage:
608
        study = create_study(storage=storage)
609
        assert len(study.trials) == 0
610

611
        study.enqueue_trial(params={"x": -5})
612

613
        def objective(trial: Trial) -> float:
614
            x = trial.suggest_int("x", -10, 10)
615
            y = trial.suggest_int("y", -10, 10)
616
            return x**2 + y**2
617

618
        study.optimize(objective, n_trials=1)
619
        t = study.trials[0]
620
        assert t.params["x"] == -5
621
        assert -10 <= t.params["y"] <= 10
622

623

624
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
625
def test_enqueue_trial_properly_sets_user_attr(storage_mode: str) -> None:
626
    with StorageSupplier(storage_mode) as storage:
627
        study = create_study(storage=storage)
628
        assert len(study.trials) == 0
629

630
        study.enqueue_trial(params={"x": -5, "y": 5}, user_attrs={"is_optimal": False})
631
        study.enqueue_trial(params={"x": 0, "y": 0}, user_attrs={"is_optimal": True})
632

633
        def objective(trial: Trial) -> float:
634
            x = trial.suggest_int("x", -10, 10)
635
            y = trial.suggest_int("y", -10, 10)
636
            return x**2 + y**2
637

638
        study.optimize(objective, n_trials=2)
639
        t0 = study.trials[0]
640
        assert t0.user_attrs == {"is_optimal": False}
641

642
        t1 = study.trials[1]
643
        assert t1.user_attrs == {"is_optimal": True}
644

645

646
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
647
def test_enqueue_trial_with_non_dict_parameters(storage_mode: str) -> None:
648
    with StorageSupplier(storage_mode) as storage:
649
        study = create_study(storage=storage)
650
        assert len(study.trials) == 0
651

652
        with pytest.raises(TypeError):
653
            study.enqueue_trial(params=[17, 12])  # type: ignore[arg-type]
654

655

656
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
657
def test_enqueue_trial_with_out_of_range_parameters(storage_mode: str) -> None:
658
    fixed_value = 11
659

660
    with StorageSupplier(storage_mode) as storage:
661
        study = create_study(storage=storage)
662
        assert len(study.trials) == 0
663

664
        study.enqueue_trial(params={"x": fixed_value})
665

666
        def objective(trial: Trial) -> float:
667
            return trial.suggest_int("x", -10, 10)
668

669
        with pytest.warns(UserWarning):
670
            study.optimize(objective, n_trials=1)
671
        t = study.trials[0]
672
        assert t.params["x"] == fixed_value
673

674
    # Internal logic might differ when distribution contains a single element.
675
    # Test it explicitly.
676
    with StorageSupplier(storage_mode) as storage:
677
        study = create_study(storage=storage)
678
        assert len(study.trials) == 0
679

680
        study.enqueue_trial(params={"x": fixed_value})
681

682
        def objective(trial: Trial) -> float:
683
            return trial.suggest_int("x", 1, 1)  # Single element.
684

685
        with pytest.warns(UserWarning):
686
            study.optimize(objective, n_trials=1)
687
        t = study.trials[0]
688
        assert t.params["x"] == fixed_value
689

690

691
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
692
def test_enqueue_trial_skips_existing_finished(storage_mode: str) -> None:
693
    with StorageSupplier(storage_mode) as storage:
694
        study = create_study(storage=storage)
695
        assert len(study.trials) == 0
696

697
        def objective(trial: Trial) -> float:
698
            x = trial.suggest_int("x", -10, 10)
699
            y = trial.suggest_int("y", -10, 10)
700
            return x**2 + y**2
701

702
        study.enqueue_trial({"x": -5, "y": 5})
703
        study.optimize(objective, n_trials=1)
704

705
        t0 = study.trials[0]
706
        assert t0.params["x"] == -5
707
        assert t0.params["y"] == 5
708

709
        before_enqueue = len(study.trials)
710
        study.enqueue_trial({"x": -5, "y": 5}, skip_if_exists=True)
711
        after_enqueue = len(study.trials)
712
        assert before_enqueue == after_enqueue
713

714

715
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
716
def test_enqueue_trial_skips_existing_waiting(storage_mode: str) -> None:
717
    with StorageSupplier(storage_mode) as storage:
718
        study = create_study(storage=storage)
719
        assert len(study.trials) == 0
720

721
        def objective(trial: Trial) -> float:
722
            x = trial.suggest_int("x", -10, 10)
723
            y = trial.suggest_int("y", -10, 10)
724
            return x**2 + y**2
725

726
        study.enqueue_trial({"x": -5, "y": 5})
727
        before_enqueue = len(study.trials)
728
        study.enqueue_trial({"x": -5, "y": 5}, skip_if_exists=True)
729
        after_enqueue = len(study.trials)
730
        assert before_enqueue == after_enqueue
731

732
        study.optimize(objective, n_trials=1)
733
        t0 = study.trials[0]
734
        assert t0.params["x"] == -5
735
        assert t0.params["y"] == 5
736

737

738
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
739
@pytest.mark.parametrize(
740
    "new_params", [{"x": -5, "y": 5, "z": 5}, {"x": -5}, {"x": -5, "z": 5}, {"x": -5, "y": 6}]
741
)
742
def test_enqueue_trial_skip_existing_allows_unfixed(
743
    storage_mode: str, new_params: dict[str, int]
744
) -> None:
745
    with StorageSupplier(storage_mode) as storage:
746
        study = create_study(storage=storage)
747
        assert len(study.trials) == 0
748

749
        def objective(trial: Trial) -> float:
750
            x = trial.suggest_int("x", -10, 10)
751
            y = trial.suggest_int("y", -10, 10)
752
            if trial.number == 1:
753
                z = trial.suggest_int("z", -10, 10)
754
                return x**2 + y**2 + z**2
755
            return x**2 + y**2
756

757
        study.enqueue_trial({"x": -5, "y": 5})
758
        study.optimize(objective, n_trials=1)
759
        t0 = study.trials[0]
760
        assert t0.params["x"] == -5
761
        assert t0.params["y"] == 5
762

763
        study.enqueue_trial(new_params, skip_if_exists=True)
764
        study.optimize(objective, n_trials=1)
765

766
        unfixed_params = {"x", "y", "z"} - set(new_params)
767
        t1 = study.trials[1]
768
        assert all(t1.params[k] == new_params[k] for k in new_params)
769
        assert all(-10 <= t1.params[k] <= 10 for k in unfixed_params)
770

771

772
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
773
@pytest.mark.parametrize(
774
    "param", ["foo", 1, 1.1, 1e17, 1e-17, float("inf"), float("-inf"), float("nan"), None]
775
)
776
def test_enqueue_trial_skip_existing_handles_common_types(storage_mode: str, param: Any) -> None:
777
    with StorageSupplier(storage_mode) as storage:
778
        study = create_study(storage=storage)
779
        study.enqueue_trial({"x": param})
780
        before_enqueue = len(study.trials)
781
        study.enqueue_trial({"x": param}, skip_if_exists=True)
782
        after_enqueue = len(study.trials)
783
        assert before_enqueue == after_enqueue
784

785

786
@patch("optuna.study._optimize.gc.collect")
787
def test_optimize_with_gc(collect_mock: Mock) -> None:
788
    study = create_study()
789
    study.optimize(func, n_trials=10, gc_after_trial=True)
790
    check_study(study)
791
    assert collect_mock.call_count == 10
792

793

794
@patch("optuna.study._optimize.gc.collect")
795
def test_optimize_without_gc(collect_mock: Mock) -> None:
796
    study = create_study()
797
    study.optimize(func, n_trials=10, gc_after_trial=False)
798
    check_study(study)
799
    assert collect_mock.call_count == 0
800

801

802
@pytest.mark.parametrize("n_jobs", [1, 2])
803
def test_optimize_with_progbar(n_jobs: int, capsys: _pytest.capture.CaptureFixture) -> None:
804
    study = create_study()
805
    study.optimize(lambda _: 1.0, n_trials=10, n_jobs=n_jobs, show_progress_bar=True)
806
    _, err = capsys.readouterr()
807

808
    # Search for progress bar elements in stderr.
809
    assert "Best trial: 0" in err
810
    assert "Best value: 1" in err
811
    assert "10/10" in err
812
    if platform.system() != "Windows":
813
        # Skip this assertion because the progress bar sometimes stops at 99% on Windows.
814
        assert "100%" in err
815

816

817
@pytest.mark.parametrize("n_jobs", [1, 2])
818
def test_optimize_without_progbar(n_jobs: int, capsys: _pytest.capture.CaptureFixture) -> None:
819
    study = create_study()
820
    study.optimize(lambda _: 1.0, n_trials=10, n_jobs=n_jobs)
821
    _, err = capsys.readouterr()
822

823
    assert "Best trial: 0" not in err
824
    assert "Best value: 1" not in err
825
    assert "10/10" not in err
826
    if platform.system() != "Windows":
827
        # Skip this assertion because the progress bar sometimes stops at 99% on Windows.
828
        assert "100%" not in err
829

830

831
def test_optimize_with_progbar_timeout(capsys: _pytest.capture.CaptureFixture) -> None:
832
    study = create_study()
833
    study.optimize(lambda _: 1.0, timeout=2.0, show_progress_bar=True)
834
    _, err = capsys.readouterr()
835

836
    assert "Best trial: 0" in err
837
    assert "Best value: 1" in err
838
    assert "00:02/00:02" in err
839
    if platform.system() != "Windows":
840
        # Skip this assertion because the progress bar sometimes stops at 99% on Windows.
841
        assert "100%" in err
842

843

844
def test_optimize_with_progbar_parallel_timeout(capsys: _pytest.capture.CaptureFixture) -> None:
845
    study = create_study()
846
    with pytest.warns(
847
        UserWarning, match="The timeout-based progress bar is not supported with n_jobs != 1."
848
    ):
849
        study.optimize(lambda _: 1.0, timeout=2.0, show_progress_bar=True, n_jobs=2)
850
    _, err = capsys.readouterr()
851

852
    # Testing for a character that forms progress bar borders.
853
    assert "|" not in err
854

855

856
@pytest.mark.parametrize(
857
    "timeout,expected",
858
    [
859
        (59.0, "/00:59"),
860
        (60.0, "/01:00"),
861
        (60.0 * 60, "/1:00:00"),
862
        (60.0 * 60 * 24, "/24:00:00"),
863
        (60.0 * 60 * 24 * 10, "/240:00:00"),
864
    ],
865
)
866
def test_optimize_with_progbar_timeout_formats(
867
    timeout: float, expected: str, capsys: _pytest.capture.CaptureFixture
868
) -> None:
869
    study = create_study()
870
    study.optimize(stop_objective(5), timeout=timeout, show_progress_bar=True)
871
    _, err = capsys.readouterr()
872
    assert expected in err
873

874

875
@pytest.mark.parametrize("n_jobs", [1, 2])
876
def test_optimize_without_progbar_timeout(
877
    n_jobs: int, capsys: _pytest.capture.CaptureFixture
878
) -> None:
879
    study = create_study()
880
    study.optimize(lambda _: 1.0, timeout=2.0, n_jobs=n_jobs)
881
    _, err = capsys.readouterr()
882

883
    assert "Best trial: 0" not in err
884
    assert "Best value: 1.0" not in err
885
    assert "00:02/00:02" not in err
886
    if platform.system() != "Windows":
887
        # Skip this assertion because the progress bar sometimes stops at 99% on Windows.
888
        assert "100%" not in err
889

890

891
@pytest.mark.parametrize("n_jobs", [1, 2])
892
def test_optimize_progbar_n_trials_prioritized(
893
    n_jobs: int, capsys: _pytest.capture.CaptureFixture
894
) -> None:
895
    study = create_study()
896
    study.optimize(lambda _: 1.0, n_trials=10, n_jobs=n_jobs, timeout=10.0, show_progress_bar=True)
897
    _, err = capsys.readouterr()
898

899
    assert "Best trial: 0" in err
900
    assert "Best value: 1" in err
901
    assert "10/10" in err
902
    if platform.system() != "Windows":
903
        # Skip this assertion because the progress bar sometimes stops at 99% on Windows.
904
        assert "100%" in err
905
    assert "it" in err
906

907

908
@pytest.mark.parametrize("n_jobs", [1, 2])
909
def test_optimize_without_progbar_n_trials_prioritized(
910
    n_jobs: int, capsys: _pytest.capture.CaptureFixture
911
) -> None:
912
    study = create_study()
913
    study.optimize(lambda _: 1.0, n_trials=10, n_jobs=n_jobs, timeout=10.0)
914
    _, err = capsys.readouterr()
915

916
    # Testing for a character that forms progress bar borders.
917
    assert "|" not in err
918

919

920
@pytest.mark.parametrize("n_jobs", [1, 2])
921
def test_optimize_progbar_no_constraints(
922
    n_jobs: int, capsys: _pytest.capture.CaptureFixture
923
) -> None:
924
    study = create_study()
925
    with warnings.catch_warnings():
926
        warnings.simplefilter("ignore", category=UserWarning)
927
        study.optimize(stop_objective(5), n_jobs=n_jobs, show_progress_bar=True)
928
    _, err = capsys.readouterr()
929

930
    # We can't simply test if stderr is empty, since we're not sure
931
    # what else could write to it. Instead, we are testing for a character
932
    # that forms progress bar borders.
933
    assert "|" not in err
934

935

936
@pytest.mark.parametrize("n_jobs", [1, 2])
937
def test_optimize_without_progbar_no_constraints(
938
    n_jobs: int, capsys: _pytest.capture.CaptureFixture
939
) -> None:
940
    study = create_study()
941
    study.optimize(stop_objective(5), n_jobs=n_jobs)
942
    _, err = capsys.readouterr()
943

944
    # Testing for a character that forms progress bar borders.
945
    assert "|" not in err
946

947

948
@pytest.mark.parametrize("n_jobs", [1, 4])
949
def test_callbacks(n_jobs: int) -> None:
950
    lock = threading.Lock()
951

952
    def with_lock(f: CallbackFuncType) -> CallbackFuncType:
953
        def callback(study: Study, trial: FrozenTrial) -> None:
954
            with lock:
955
                f(study, trial)
956

957
        return callback
958

959
    study = create_study()
960

961
    def objective(trial: Trial) -> float:
962
        return trial.suggest_int("x", 1, 1)
963

964
    # Empty callback list.
965
    study.optimize(objective, callbacks=[], n_trials=10, n_jobs=n_jobs)
966

967
    # One callback.
968
    values = []
969
    callbacks = [with_lock(lambda study, trial: values.append(trial.value))]
970
    study.optimize(objective, callbacks=callbacks, n_trials=10, n_jobs=n_jobs)
971
    assert values == [1] * 10
972

973
    # Two callbacks.
974
    values = []
975
    params = []
976
    callbacks = [
977
        with_lock(lambda study, trial: values.append(trial.value)),
978
        with_lock(lambda study, trial: params.append(trial.params)),
979
    ]
980
    study.optimize(objective, callbacks=callbacks, n_trials=10, n_jobs=n_jobs)
981
    assert values == [1] * 10
982
    assert params == [{"x": 1}] * 10
983

984
    # If a trial is failed with an exception and the exception is caught by the study,
985
    # callbacks are invoked.
986
    states = []
987
    callbacks = [with_lock(lambda study, trial: states.append(trial.state))]
988
    study.optimize(
989
        lambda t: 1 / 0,
990
        callbacks=callbacks,
991
        n_trials=10,
992
        n_jobs=n_jobs,
993
        catch=(ZeroDivisionError,),
994
    )
995
    assert states == [TrialState.FAIL] * 10
996

997
    # If a trial is failed with an exception and the exception isn't caught by the study,
998
    # callbacks aren't invoked.
999
    states = []
1000
    callbacks = [with_lock(lambda study, trial: states.append(trial.state))]
1001
    with pytest.raises(ZeroDivisionError):
1002
        study.optimize(lambda t: 1 / 0, callbacks=callbacks, n_trials=10, n_jobs=n_jobs, catch=())
1003
    assert states == []
1004

1005

1006
def test_optimize_infinite_budget_progbar() -> None:
1007
    def terminate_study(study: Study, trial: FrozenTrial) -> None:
1008
        study.stop()
1009

1010
    study = create_study()
1011

1012
    with pytest.warns(UserWarning):
1013
        study.optimize(
1014
            func, n_trials=None, timeout=None, show_progress_bar=True, callbacks=[terminate_study]
1015
        )
1016

1017

1018
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
1019
def test_get_trials(storage_mode: str) -> None:
1020
    with StorageSupplier(storage_mode) as storage:
1021
        study = create_study(storage=storage)
1022
        study.optimize(lambda t: t.suggest_int("x", 1, 5), n_trials=5)
1023

1024
        with patch("copy.deepcopy", wraps=copy.deepcopy) as mock_object:
1025
            trials0 = study.get_trials(deepcopy=False)
1026
            assert mock_object.call_count == 0
1027
            assert len(trials0) == 5
1028

1029
            trials1 = study.get_trials(deepcopy=True)
1030
            assert mock_object.call_count > 0
1031
            assert trials0 == trials1
1032

1033
            # `study.trials` is equivalent to `study.get_trials(deepcopy=True)`.
1034
            old_count = mock_object.call_count
1035
            trials2 = study.trials
1036
            assert mock_object.call_count > old_count
1037
            assert trials0 == trials2
1038

1039

1040
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
1041
def test_get_trials_state_option(storage_mode: str) -> None:
1042
    with StorageSupplier(storage_mode) as storage:
1043
        study = create_study(storage=storage)
1044

1045
        def objective(trial: Trial) -> float:
1046
            if trial.number == 0:
1047
                return 0.0  # TrialState.COMPLETE.
1048
            elif trial.number == 1:
1049
                return 0.0  # TrialState.COMPLETE.
1050
            elif trial.number == 2:
1051
                raise TrialPruned  # TrialState.PRUNED.
1052
            else:
1053
                assert False
1054

1055
        study.optimize(objective, n_trials=3)
1056

1057
        trials = study.get_trials(states=None)
1058
        assert len(trials) == 3
1059

1060
        trials = study.get_trials(states=(TrialState.COMPLETE,))
1061
        assert len(trials) == 2
1062
        assert all(t.state == TrialState.COMPLETE for t in trials)
1063

1064
        trials = study.get_trials(states=(TrialState.COMPLETE, TrialState.PRUNED))
1065
        assert len(trials) == 3
1066
        assert all(t.state in (TrialState.COMPLETE, TrialState.PRUNED) for t in trials)
1067

1068
        trials = study.get_trials(states=())
1069
        assert len(trials) == 0
1070

1071
        other_states = [
1072
            s for s in list(TrialState) if s != TrialState.COMPLETE and s != TrialState.PRUNED
1073
        ]
1074
        for s in other_states:
1075
            trials = study.get_trials(states=(s,))
1076
            assert len(trials) == 0
1077

1078

1079
def test_log_completed_trial(capsys: _pytest.capture.CaptureFixture) -> None:
1080
    # We need to reconstruct our default handler to properly capture stderr.
1081
    logging._reset_library_root_logger()
1082
    logging.set_verbosity(logging.INFO)
1083

1084
    study = create_study()
1085
    study.optimize(lambda _: 1.0, n_trials=1)
1086
    _, err = capsys.readouterr()
1087
    assert "Trial 0" in err
1088

1089
    logging.set_verbosity(logging.WARNING)
1090
    study.optimize(lambda _: 1.0, n_trials=1)
1091
    _, err = capsys.readouterr()
1092
    assert "Trial 1" not in err
1093

1094
    logging.set_verbosity(logging.DEBUG)
1095
    study.optimize(lambda _: 1.0, n_trials=1)
1096
    _, err = capsys.readouterr()
1097
    assert "Trial 2" in err
1098

1099

1100
def test_log_completed_trial_skip_storage_access() -> None:
1101
    study = create_study()
1102

1103
    # Create a trial to retrieve it as the `study.best_trial`.
1104
    study.optimize(lambda _: 0.0, n_trials=1)
1105
    frozen_trial = study.best_trial
1106

1107
    storage = study._storage
1108

1109
    with patch.object(storage, "get_best_trial", wraps=storage.get_best_trial) as mock_object:
1110
        study._log_completed_trial(frozen_trial)
1111
        assert mock_object.call_count == 1
1112

1113
    logging.set_verbosity(logging.WARNING)
1114
    with patch.object(storage, "get_best_trial", wraps=storage.get_best_trial) as mock_object:
1115
        study._log_completed_trial(frozen_trial)
1116
        assert mock_object.call_count == 0
1117

1118
    logging.set_verbosity(logging.DEBUG)
1119
    with patch.object(storage, "get_best_trial", wraps=storage.get_best_trial) as mock_object:
1120
        study._log_completed_trial(frozen_trial)
1121
        assert mock_object.call_count == 1
1122

1123

1124
def test_create_study_with_multi_objectives() -> None:
1125
    study = create_study(directions=["maximize"])
1126
    assert study.direction == StudyDirection.MAXIMIZE
1127
    assert not study._is_multi_objective()
1128

1129
    study = create_study(directions=["maximize", "minimize"])
1130
    assert study.directions == [StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE]
1131
    assert study._is_multi_objective()
1132

1133
    with pytest.raises(ValueError):
1134
        # Empty `direction` isn't allowed.
1135
        _ = create_study(directions=[])
1136

1137
    with pytest.raises(ValueError):
1138
        _ = create_study(direction="minimize", directions=["maximize"])
1139

1140
    with pytest.raises(ValueError):
1141
        _ = create_study(direction="minimize", directions=[])
1142

1143

1144
def test_create_study_with_direction_object() -> None:
1145
    study = create_study(direction=StudyDirection.MAXIMIZE)
1146
    assert study.direction == StudyDirection.MAXIMIZE
1147

1148
    study = create_study(directions=[StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE])
1149
    assert study.directions == [StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE]
1150

1151

1152
@pytest.mark.parametrize("n_objectives", [2, 3])
1153
def test_optimize_with_multi_objectives(n_objectives: int) -> None:
1154
    directions = ["minimize" for _ in range(n_objectives)]
1155
    study = create_study(directions=directions)
1156

1157
    def objective(trial: Trial) -> list[float]:
1158
        return [trial.suggest_float("v{}".format(i), 0, 5) for i in range(n_objectives)]
1159

1160
    study.optimize(objective, n_trials=10)
1161

1162
    assert len(study.trials) == 10
1163

1164
    for trial in study.trials:
1165
        assert trial.values
1166
        assert len(trial.values) == n_objectives
1167

1168

1169
def test_best_trials() -> None:
1170
    study = create_study(directions=["minimize", "maximize"])
1171
    study.optimize(lambda t: [2, 2], n_trials=1)
1172
    study.optimize(lambda t: [1, 1], n_trials=1)
1173
    study.optimize(lambda t: [3, 1], n_trials=1)
1174
    assert {tuple(t.values) for t in study.best_trials} == {(1, 1), (2, 2)}
1175

1176

1177
def test_wrong_n_objectives() -> None:
1178
    n_objectives = 2
1179
    directions = ["minimize" for _ in range(n_objectives)]
1180
    study = create_study(directions=directions)
1181

1182
    def objective(trial: Trial) -> list[float]:
1183
        return [trial.suggest_float("v{}".format(i), 0, 5) for i in range(n_objectives + 1)]
1184

1185
    study.optimize(objective, n_trials=10)
1186

1187
    for trial in study.trials:
1188
        assert trial.state is TrialState.FAIL
1189

1190

1191
def test_ask() -> None:
1192
    study = create_study()
1193

1194
    trial = study.ask()
1195
    assert isinstance(trial, Trial)
1196

1197

1198
def test_ask_enqueue_trial() -> None:
1199
    study = create_study()
1200

1201
    study.enqueue_trial({"x": 0.5}, user_attrs={"memo": "this is memo"})
1202

1203
    trial = study.ask()
1204
    assert trial.suggest_float("x", 0, 1) == 0.5
1205
    assert trial.user_attrs == {"memo": "this is memo"}
1206

1207

1208
def test_ask_fixed_search_space() -> None:
1209
    fixed_distributions = {
1210
        "x": distributions.FloatDistribution(0, 1),
1211
        "y": distributions.CategoricalDistribution(["bacon", "spam"]),
1212
    }
1213

1214
    study = create_study()
1215
    trial = study.ask(fixed_distributions=fixed_distributions)
1216

1217
    params = trial.params
1218
    assert len(trial.params) == 2
1219
    assert 0 <= params["x"] < 1
1220
    assert params["y"] in ["bacon", "spam"]
1221

1222

1223
# Deprecated distributions are internally converted to corresponding distributions.
1224
@pytest.mark.filterwarnings("ignore::FutureWarning")
1225
def test_ask_distribution_conversion() -> None:
1226
    fixed_distributions = {
1227
        "ud": distributions.UniformDistribution(low=0, high=10),
1228
        "dud": distributions.DiscreteUniformDistribution(low=0, high=10, q=2),
1229
        "lud": distributions.LogUniformDistribution(low=1, high=10),
1230
        "id": distributions.IntUniformDistribution(low=0, high=10),
1231
        "idd": distributions.IntUniformDistribution(low=0, high=10, step=2),
1232
        "ild": distributions.IntLogUniformDistribution(low=1, high=10),
1233
    }
1234

1235
    study = create_study()
1236

1237
    with pytest.warns(
1238
        FutureWarning,
1239
        match="See https://github.com/optuna/optuna/issues/2941",
1240
    ) as record:
1241
        trial = study.ask(fixed_distributions=fixed_distributions)
1242
        assert len(record) == 6
1243

1244
    expected_distributions = {
1245
        "ud": distributions.FloatDistribution(low=0, high=10, log=False, step=None),
1246
        "dud": distributions.FloatDistribution(low=0, high=10, log=False, step=2),
1247
        "lud": distributions.FloatDistribution(low=1, high=10, log=True, step=None),
1248
        "id": distributions.IntDistribution(low=0, high=10, log=False, step=1),
1249
        "idd": distributions.IntDistribution(low=0, high=10, log=False, step=2),
1250
        "ild": distributions.IntDistribution(low=1, high=10, log=True, step=1),
1251
    }
1252

1253
    assert trial.distributions == expected_distributions
1254

1255

1256
# It confirms that ask doesn't convert non-deprecated distributions.
1257
def test_ask_distribution_conversion_noop() -> None:
1258
    fixed_distributions = {
1259
        "ud": distributions.FloatDistribution(low=0, high=10, log=False, step=None),
1260
        "dud": distributions.FloatDistribution(low=0, high=10, log=False, step=2),
1261
        "lud": distributions.FloatDistribution(low=1, high=10, log=True, step=None),
1262
        "id": distributions.IntDistribution(low=0, high=10, log=False, step=1),
1263
        "idd": distributions.IntDistribution(low=0, high=10, log=False, step=2),
1264
        "ild": distributions.IntDistribution(low=1, high=10, log=True, step=1),
1265
        "cd": distributions.CategoricalDistribution(choices=["a", "b", "c"]),
1266
    }
1267

1268
    study = create_study()
1269

1270
    trial = study.ask(fixed_distributions=fixed_distributions)
1271

1272
    # Check fixed_distributions doesn't change.
1273
    assert trial.distributions == fixed_distributions
1274

1275

1276
def test_tell() -> None:
1277
    study = create_study()
1278
    assert len(study.trials) == 0
1279

1280
    trial = study.ask()
1281
    assert len(study.trials) == 1
1282
    assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 0
1283

1284
    study.tell(trial, 1.0)
1285
    assert len(study.trials) == 1
1286
    assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 1
1287

1288
    study.tell(study.ask(), [1.0])
1289
    assert len(study.trials) == 2
1290
    assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 2
1291

1292
    # `trial` could be int.
1293
    study.tell(study.ask().number, 1.0)
1294
    assert len(study.trials) == 3
1295
    assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 3
1296

1297
    # Inf is supported as values.
1298
    study.tell(study.ask(), float("inf"))
1299
    assert len(study.trials) == 4
1300
    assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 4
1301

1302
    study.tell(study.ask(), state=TrialState.PRUNED)
1303
    assert len(study.trials) == 5
1304
    assert len(study.get_trials(states=(TrialState.PRUNED,))) == 1
1305

1306
    study.tell(study.ask(), state=TrialState.FAIL)
1307
    assert len(study.trials) == 6
1308
    assert len(study.get_trials(states=(TrialState.FAIL,))) == 1
1309

1310

1311
def test_tell_pruned() -> None:
1312
    study = create_study()
1313

1314
    study.tell(study.ask(), state=TrialState.PRUNED)
1315
    assert study.trials[-1].value is None
1316
    assert study.trials[-1].state == TrialState.PRUNED
1317

1318
    # Store the last intermediates as value.
1319
    trial = study.ask()
1320
    trial.report(2.0, step=1)
1321
    study.tell(trial, state=TrialState.PRUNED)
1322
    assert study.trials[-1].value == 2.0
1323
    assert study.trials[-1].state == TrialState.PRUNED
1324

1325
    # Inf is also supported as a value.
1326
    trial = study.ask()
1327
    trial.report(float("inf"), step=1)
1328
    study.tell(trial, state=TrialState.PRUNED)
1329
    assert study.trials[-1].value == float("inf")
1330
    assert study.trials[-1].state == TrialState.PRUNED
1331

1332
    # NaN is not supported as a value.
1333
    trial = study.ask()
1334
    trial.report(float("nan"), step=1)
1335
    study.tell(trial, state=TrialState.PRUNED)
1336
    assert study.trials[-1].value is None
1337
    assert study.trials[-1].state == TrialState.PRUNED
1338

1339

1340
def test_tell_automatically_fail() -> None:
1341
    study = create_study()
1342

1343
    # Check invalid values, e.g. str cannot be cast to float.
1344
    with pytest.warns(UserWarning):
1345
        study.tell(study.ask(), "a")  # type: ignore
1346
        assert len(study.trials) == 1
1347
        assert study.trials[-1].state == TrialState.FAIL
1348
        assert study.trials[-1].values is None
1349

1350
    # Check invalid values, e.g. `None` that cannot be cast to float.
1351
    with pytest.warns(UserWarning):
1352
        study.tell(study.ask(), None)
1353
        assert len(study.trials) == 2
1354
        assert study.trials[-1].state == TrialState.FAIL
1355
        assert study.trials[-1].values is None
1356

1357
    # Check number of values.
1358
    with pytest.warns(UserWarning):
1359
        study.tell(study.ask(), [])
1360
        assert len(study.trials) == 3
1361
        assert study.trials[-1].state == TrialState.FAIL
1362
        assert study.trials[-1].values is None
1363

1364
    # Check wrong number of values, e.g. two values for single direction.
1365
    with pytest.warns(UserWarning):
1366
        study.tell(study.ask(), [1.0, 2.0])
1367
        assert len(study.trials) == 4
1368
        assert study.trials[-1].state == TrialState.FAIL
1369
        assert study.trials[-1].values is None
1370

1371
    # Both state and values are not specified.
1372
    with pytest.warns(UserWarning):
1373
        study.tell(study.ask())
1374
        assert len(study.trials) == 5
1375
        assert study.trials[-1].state == TrialState.FAIL
1376
        assert study.trials[-1].values is None
1377

1378
    # Nan is not supported.
1379
    with pytest.warns(UserWarning):
1380
        study.tell(study.ask(), float("nan"))
1381
        assert len(study.trials) == 6
1382
        assert study.trials[-1].state == TrialState.FAIL
1383
        assert study.trials[-1].values is None
1384

1385

1386
def test_tell_multi_objective() -> None:
1387
    study = create_study(directions=["minimize", "maximize"])
1388
    study.tell(study.ask(), [1.0, 2.0])
1389
    assert len(study.trials) == 1
1390

1391

1392
def test_tell_multi_objective_automatically_fail() -> None:
1393
    # Number of values doesn't match the length of directions.
1394
    study = create_study(directions=["minimize", "maximize"])
1395

1396
    with pytest.warns(UserWarning):
1397
        study.tell(study.ask(), [])
1398
        assert len(study.trials) == 1
1399
        assert study.trials[-1].state == TrialState.FAIL
1400
        assert study.trials[-1].values is None
1401

1402
    with pytest.warns(UserWarning):
1403
        study.tell(study.ask(), [1.0])
1404
        assert len(study.trials) == 2
1405
        assert study.trials[-1].state == TrialState.FAIL
1406
        assert study.trials[-1].values is None
1407

1408
    with pytest.warns(UserWarning):
1409
        study.tell(study.ask(), [1.0, 2.0, 3.0])
1410
        assert len(study.trials) == 3
1411
        assert study.trials[-1].state == TrialState.FAIL
1412
        assert study.trials[-1].values is None
1413

1414
    with pytest.warns(UserWarning):
1415
        study.tell(study.ask(), [1.0, None])  # type: ignore
1416
        assert len(study.trials) == 4
1417
        assert study.trials[-1].state == TrialState.FAIL
1418
        assert study.trials[-1].values is None
1419

1420
    with pytest.warns(UserWarning):
1421
        study.tell(study.ask(), [None, None])  # type: ignore
1422
        assert len(study.trials) == 5
1423
        assert study.trials[-1].state == TrialState.FAIL
1424
        assert study.trials[-1].values is None
1425

1426
    with pytest.warns(UserWarning):
1427
        study.tell(study.ask(), 1.0)
1428
        assert len(study.trials) == 6
1429
        assert study.trials[-1].state == TrialState.FAIL
1430
        assert study.trials[-1].values is None
1431

1432

1433
def test_tell_invalid() -> None:
1434
    study = create_study()
1435

1436
    # Missing values for completions.
1437
    with pytest.raises(ValueError):
1438
        study.tell(study.ask(), state=TrialState.COMPLETE)
1439

1440
    # Invalid values for completions.
1441
    with pytest.raises(ValueError):
1442
        study.tell(study.ask(), "a", state=TrialState.COMPLETE)  # type: ignore
1443

1444
    with pytest.raises(ValueError):
1445
        study.tell(study.ask(), None, state=TrialState.COMPLETE)
1446

1447
    with pytest.raises(ValueError):
1448
        study.tell(study.ask(), [], state=TrialState.COMPLETE)
1449

1450
    with pytest.raises(ValueError):
1451
        study.tell(study.ask(), [1.0, 2.0], state=TrialState.COMPLETE)
1452

1453
    with pytest.raises(ValueError):
1454
        study.tell(study.ask(), float("nan"), state=TrialState.COMPLETE)
1455

1456
    # `state` must be None or finished state.
1457
    with pytest.raises(ValueError):
1458
        study.tell(study.ask(), state=TrialState.RUNNING)
1459

1460
    # `state` must be None or finished state.
1461
    with pytest.raises(ValueError):
1462
        study.tell(study.ask(), state=TrialState.WAITING)
1463

1464
    # `value` must be None for `TrialState.PRUNED`.
1465
    with pytest.raises(ValueError):
1466
        study.tell(study.ask(), values=1, state=TrialState.PRUNED)
1467

1468
    # `value` must be None for `TrialState.FAIL`.
1469
    with pytest.raises(ValueError):
1470
        study.tell(study.ask(), values=1, state=TrialState.FAIL)
1471

1472
    # Trial that has not been asked for cannot be told.
1473
    with pytest.raises(ValueError):
1474
        study.tell(study.ask().number + 1, 1.0)
1475

1476
    # Waiting trial cannot be told.
1477
    with pytest.raises(ValueError):
1478
        study.enqueue_trial({})
1479
        study.tell(study.trials[-1].number, 1.0)
1480

1481
    # It must be Trial or int for trial.
1482
    with pytest.raises(TypeError):
1483
        study.tell("1", 1.0)  # type: ignore
1484

1485

1486
def test_tell_duplicate_tell() -> None:
1487
    study = create_study()
1488

1489
    trial = study.ask()
1490
    study.tell(trial, 1.0)
1491

1492
    # Should not panic when passthrough is enabled.
1493
    study.tell(trial, 1.0, skip_if_finished=True)
1494

1495
    with pytest.raises(ValueError):
1496
        study.tell(trial, 1.0, skip_if_finished=False)
1497

1498

1499
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
1500
def test_enqueued_trial_datetime_start(storage_mode: str) -> None:
1501
    with StorageSupplier(storage_mode) as storage:
1502
        study = create_study(storage=storage)
1503

1504
        def objective(trial: Trial) -> float:
1505
            time.sleep(1)
1506
            x = trial.suggest_int("x", -10, 10)
1507
            return x
1508

1509
        study.enqueue_trial(params={"x": 1})
1510
        assert study.trials[0].datetime_start is None
1511

1512
        study.optimize(objective, n_trials=1)
1513
        assert study.trials[0].datetime_start is not None
1514

1515

1516
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
1517
def test_study_summary_datetime_start_calculation(storage_mode: str) -> None:
1518
    with StorageSupplier(storage_mode) as storage:
1519

1520
        def objective(trial: Trial) -> float:
1521
            x = trial.suggest_int("x", -10, 10)
1522
            return x
1523

1524
        # StudySummary datetime_start tests.
1525
        study = create_study(storage=storage)
1526
        study.enqueue_trial(params={"x": 1})
1527

1528
        # Study summary with only enqueued trials should have null datetime_start.
1529
        summaries = get_all_study_summaries(study._storage, include_best_trial=True)
1530
        assert summaries[0].datetime_start is None
1531

1532
        # Study summary with completed trials should have nonnull datetime_start.
1533
        study.optimize(objective, n_trials=1)
1534
        study.enqueue_trial(params={"x": 1}, skip_if_exists=False)
1535
        summaries = get_all_study_summaries(study._storage, include_best_trial=True)
1536
        assert summaries[0].datetime_start is not None
1537

1538

1539
def _process_tell(study: Study, trial: Trial | int, values: float) -> None:
1540
    study.tell(trial, values)
1541

1542

1543
def test_tell_from_another_process() -> None:
1544
    pool = multiprocessing.Pool()
1545

1546
    with StorageSupplier("sqlite") as storage:
1547
        # Create a study and ask for a new trial.
1548
        study = create_study(storage=storage)
1549
        trial0 = study.ask()
1550

1551
        # Test normal behaviour.
1552
        pool.starmap(_process_tell, [(study, trial0, 1.2)])
1553

1554
        assert len(study.trials) == 1
1555
        assert study.best_trial.state == TrialState.COMPLETE
1556
        assert study.best_value == 1.2
1557

1558
        # Test study.tell using trial number.
1559
        trial = study.ask()
1560
        pool.starmap(_process_tell, [(study, trial.number, 1.5)])
1561

1562
        assert len(study.trials) == 2
1563
        assert study.best_trial.state == TrialState.COMPLETE
1564
        assert study.best_value == 1.2
1565

1566
        # Should fail because the trial0 is already finished.
1567
        with pytest.raises(ValueError):
1568
            pool.starmap(_process_tell, [(study, trial0, 1.2)])
1569

1570

1571
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
1572
def test_pop_waiting_trial_thread_safe(storage_mode: str) -> None:
1573
    if "sqlite" == storage_mode or "cached_sqlite" == storage_mode:
1574
        pytest.skip("study._pop_waiting_trial is not thread-safe on SQLite3")
1575

1576
    num_enqueued = 10
1577
    with StorageSupplier(storage_mode) as storage:
1578
        study = create_study(storage=storage)
1579
        for i in range(num_enqueued):
1580
            study.enqueue_trial({"i": i})
1581

1582
        trial_id_set = set()
1583
        with ThreadPoolExecutor(10) as pool:
1584
            futures = []
1585
            for i in range(num_enqueued):
1586
                future = pool.submit(study._pop_waiting_trial_id)
1587
                futures.append(future)
1588

1589
            for future in as_completed(futures):
1590
                trial_id_set.add(future.result())
1591
        assert len(trial_id_set) == num_enqueued
1592

1593

1594
def test_set_metric_names() -> None:
1595
    metric_names = ["v0", "v1"]
1596
    study = create_study(directions=["minimize", "minimize"])
1597
    study.set_metric_names(metric_names)
1598

1599
    got_metric_names = study._storage.get_study_system_attrs(study._study_id).get(
1600
        _SYSTEM_ATTR_METRIC_NAMES
1601
    )
1602
    assert got_metric_names is not None
1603
    assert metric_names == got_metric_names
1604

1605

1606
def test_set_metric_names_experimental_warning() -> None:
1607
    study = create_study()
1608
    with pytest.warns(ExperimentalWarning):
1609
        study.set_metric_names(["v0"])
1610

1611

1612
def test_set_invalid_metric_names() -> None:
1613
    metric_names = ["v0", "v1", "v2"]
1614
    study = create_study(directions=["minimize", "minimize"])
1615
    with pytest.raises(ValueError):
1616
        study.set_metric_names(metric_names)
1617

1618

1619
def test_get_metric_names() -> None:
1620
    study = create_study()
1621
    assert study.metric_names is None
1622
    study.set_metric_names(["v0"])
1623
    assert study.metric_names == ["v0"]
1624
    study.set_metric_names(["v1"])
1625
    assert study.metric_names == ["v1"]
1626

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.