optuna
1625 строк · 55.3 Кб
1from __future__ import annotations2
3from concurrent.futures import as_completed4from concurrent.futures import ThreadPoolExecutor5import copy6import multiprocessing7import pickle8import platform9import threading10import time11from typing import Any12from typing import Callable13from unittest.mock import Mock14from unittest.mock import patch15import uuid16import warnings17
18import _pytest.capture19import pytest20
21from optuna import copy_study22from optuna import create_study23from optuna import create_trial24from optuna import delete_study25from optuna import distributions26from optuna import get_all_study_names27from optuna import get_all_study_summaries28from optuna import load_study29from optuna import logging30from optuna import Study31from optuna import Trial32from optuna import TrialPruned33from optuna.exceptions import DuplicatedStudyError34from optuna.exceptions import ExperimentalWarning35from optuna.study import StudyDirection36from optuna.study.study import _SYSTEM_ATTR_METRIC_NAMES37from optuna.testing.objectives import fail_objective38from optuna.testing.storages import STORAGE_MODES39from optuna.testing.storages import StorageSupplier40from optuna.trial import FrozenTrial41from optuna.trial import TrialState42
43
44CallbackFuncType = Callable[[Study, FrozenTrial], None]45
46
47def func(trial: Trial) -> float:48x = trial.suggest_float("x", -10.0, 10.0)49y = trial.suggest_float("y", 20, 30, log=True)50z = trial.suggest_categorical("z", (-1.0, 1.0))51return (x - 2) ** 2 + (y - 25) ** 2 + z52
53
54class Func:55def __init__(self, sleep_sec: float | None = None) -> None:56self.n_calls = 057self.sleep_sec = sleep_sec58self.lock = threading.Lock()59
60def __call__(self, trial: Trial) -> float:61with self.lock:62self.n_calls += 163
64# Sleep for testing parallelism.65if self.sleep_sec is not None:66time.sleep(self.sleep_sec)67
68value = func(trial)69check_params(trial.params)70return value71
72
73def check_params(params: dict[str, Any]) -> None:74assert sorted(params.keys()) == ["x", "y", "z"]75
76
77def check_value(value: float | None) -> None:78assert isinstance(value, float)79assert -1.0 <= value <= 12.0**2 + 5.0**2 + 1.080
81
82def check_frozen_trial(frozen_trial: FrozenTrial) -> None:83if frozen_trial.state == TrialState.COMPLETE:84check_params(frozen_trial.params)85check_value(frozen_trial.value)86
87
88def check_study(study: Study) -> None:89for trial in study.trials:90check_frozen_trial(trial)91
92assert not study._is_multi_objective()93
94complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))95if len(complete_trials) == 0:96with pytest.raises(ValueError):97study.best_params98with pytest.raises(ValueError):99study.best_value100with pytest.raises(ValueError):101study.best_trial102else:103check_params(study.best_params)104check_value(study.best_value)105check_frozen_trial(study.best_trial)106
107
108def stop_objective(threshold_number: int) -> Callable[[Trial], float]:109def objective(trial: Trial) -> float:110if trial.number >= threshold_number:111trial.study.stop()112
113return trial.number114
115return objective116
117
118def test_optimize_trivial_in_memory_new() -> None:119study = create_study()120study.optimize(func, n_trials=10)121check_study(study)122
123
124def test_optimize_trivial_in_memory_resume() -> None:125study = create_study()126study.optimize(func, n_trials=10)127study.optimize(func, n_trials=10)128check_study(study)129
130
131def test_optimize_trivial_rdb_resume_study() -> None:132study = create_study(storage="sqlite:///:memory:")133study.optimize(func, n_trials=10)134check_study(study)135
136
137def test_optimize_with_direction() -> None:138study = create_study(direction="minimize")139study.optimize(func, n_trials=10)140assert study.direction == StudyDirection.MINIMIZE141check_study(study)142
143study = create_study(direction="maximize")144study.optimize(func, n_trials=10)145assert study.direction == StudyDirection.MAXIMIZE146check_study(study)147
148with pytest.raises(ValueError):149create_study(direction="test")150
151
152@pytest.mark.parametrize("n_trials", (0, 1, 20))153@pytest.mark.parametrize("n_jobs", (1, 2, -1))154@pytest.mark.parametrize("storage_mode", STORAGE_MODES)155def test_optimize_parallel(n_trials: int, n_jobs: int, storage_mode: str) -> None:156f = Func()157
158with StorageSupplier(storage_mode) as storage:159study = create_study(storage=storage)160study.optimize(f, n_trials=n_trials, n_jobs=n_jobs)161assert f.n_calls == len(study.trials) == n_trials162check_study(study)163
164
165def test_optimize_with_thread_pool_executor() -> None:166def objective(t: Trial) -> float:167return t.suggest_float("x", -10, 10)168
169study = create_study()170with ThreadPoolExecutor(max_workers=5) as pool:171for _ in range(10):172pool.submit(study.optimize, objective, n_trials=10)173assert len(study.trials) == 100174
175
176@pytest.mark.parametrize("n_trials", (0, 1, 20, None))177@pytest.mark.parametrize("n_jobs", (1, 2, -1))178@pytest.mark.parametrize("storage_mode", STORAGE_MODES)179def test_optimize_parallel_timeout(n_trials: int, n_jobs: int, storage_mode: str) -> None:180sleep_sec = 0.1181timeout_sec = 1.0182f = Func(sleep_sec=sleep_sec)183
184with StorageSupplier(storage_mode) as storage:185study = create_study(storage=storage)186study.optimize(f, n_trials=n_trials, n_jobs=n_jobs, timeout=timeout_sec)187
188assert f.n_calls == len(study.trials)189
190if n_trials is not None:191assert f.n_calls <= n_trials192
193# A thread can process at most (timeout_sec / sleep_sec + 1) trials.194n_jobs_actual = n_jobs if n_jobs != -1 else multiprocessing.cpu_count()195max_calls = (timeout_sec / sleep_sec + 1) * n_jobs_actual196assert f.n_calls <= max_calls197
198check_study(study)199
200
201@pytest.mark.parametrize("storage_mode", STORAGE_MODES)202def test_optimize_with_catch(storage_mode: str) -> None:203with StorageSupplier(storage_mode) as storage:204study = create_study(storage=storage)205
206# Test default exceptions.207with pytest.raises(ValueError):208study.optimize(fail_objective, n_trials=20)209assert len(study.trials) == 1210assert all(trial.state == TrialState.FAIL for trial in study.trials)211
212# Test acceptable exception.213study.optimize(fail_objective, n_trials=20, catch=(ValueError,))214assert len(study.trials) == 21215assert all(trial.state == TrialState.FAIL for trial in study.trials)216
217# Test trial with unacceptable exception.218with pytest.raises(ValueError):219study.optimize(fail_objective, n_trials=20, catch=(ArithmeticError,))220assert len(study.trials) == 22221assert all(trial.state == TrialState.FAIL for trial in study.trials)222
223
224@pytest.mark.parametrize("catch", [ValueError, (ValueError,), [ValueError], {ValueError}])225def test_optimize_with_catch_valid_type(catch: Any) -> None:226study = create_study()227study.optimize(fail_objective, n_trials=20, catch=catch)228
229
230@pytest.mark.parametrize("catch", [None, 1])231def test_optimize_with_catch_invalid_type(catch: Any) -> None:232study = create_study()233
234with pytest.raises(TypeError):235study.optimize(fail_objective, n_trials=20, catch=catch)236
237
238@pytest.mark.parametrize("n_jobs", (2, -1))239@pytest.mark.parametrize("storage_mode", STORAGE_MODES)240def test_optimize_with_reseeding(n_jobs: int, storage_mode: str) -> None:241f = Func()242
243with StorageSupplier(storage_mode) as storage:244study = create_study(storage=storage)245sampler = study.sampler246with patch.object(sampler, "reseed_rng", wraps=sampler.reseed_rng) as mock_object:247study.optimize(f, n_trials=1, n_jobs=2)248assert mock_object.call_count == 1249
250
251def test_call_another_study_optimize_in_optimize() -> None:252def inner_objective(t: Trial) -> float:253return t.suggest_float("x", -10, 10)254
255def objective(t: Trial) -> float:256inner_study = create_study()257inner_study.enqueue_trial({"x": t.suggest_int("initial_point", -10, 10)})258inner_study.optimize(inner_objective, n_trials=10)259return inner_study.best_value260
261study = create_study()262study.optimize(objective, n_trials=10)263assert len(study.trials) == 10264
265
266@pytest.mark.parametrize("storage_mode", STORAGE_MODES)267def test_study_set_and_get_user_attrs(storage_mode: str) -> None:268with StorageSupplier(storage_mode) as storage:269study = create_study(storage=storage)270
271study.set_user_attr("dataset", "MNIST")272assert study.user_attrs["dataset"] == "MNIST"273
274
275@pytest.mark.parametrize("storage_mode", STORAGE_MODES)276def test_trial_set_and_get_user_attrs(storage_mode: str) -> None:277def f(trial: Trial) -> float:278trial.set_user_attr("train_accuracy", 1)279assert trial.user_attrs["train_accuracy"] == 1280return 0.0281
282with StorageSupplier(storage_mode) as storage:283study = create_study(storage=storage)284study.optimize(f, n_trials=1)285frozen_trial = study.trials[0]286assert frozen_trial.user_attrs["train_accuracy"] == 1287
288
289@pytest.mark.parametrize("storage_mode", STORAGE_MODES)290@pytest.mark.parametrize("include_best_trial", [True, False])291def test_get_all_study_summaries(storage_mode: str, include_best_trial: bool) -> None:292with StorageSupplier(storage_mode) as storage:293study = create_study(storage=storage)294study.optimize(func, n_trials=5)295
296summaries = get_all_study_summaries(study._storage, include_best_trial)297summary = [s for s in summaries if s._study_id == study._study_id][0]298
299assert summary.study_name == study.study_name300assert summary.n_trials == 5301if include_best_trial:302assert summary.best_trial is not None303else:304assert summary.best_trial is None305
306
307@pytest.mark.parametrize("storage_mode", STORAGE_MODES)308def test_get_all_study_summaries_with_no_trials(storage_mode: str) -> None:309with StorageSupplier(storage_mode) as storage:310study = create_study(storage=storage)311
312summaries = get_all_study_summaries(study._storage)313summary = [s for s in summaries if s._study_id == study._study_id][0]314
315assert summary.study_name == study.study_name316assert summary.n_trials == 0317assert summary.datetime_start is None318
319
320@pytest.mark.parametrize("storage_mode", STORAGE_MODES)321def test_get_all_study_names(storage_mode: str) -> None:322with StorageSupplier(storage_mode) as storage:323n_studies = 5324
325studies = [create_study(storage=storage) for _ in range(n_studies)]326study_names = get_all_study_names(storage)327
328assert len(study_names) == n_studies329for study, study_name in zip(studies, study_names):330assert study_name == study.study_name331
332
333def test_study_pickle() -> None:334study_1 = create_study()335study_1.optimize(func, n_trials=10)336check_study(study_1)337assert len(study_1.trials) == 10338dumped_bytes = pickle.dumps(study_1)339
340study_2 = pickle.loads(dumped_bytes)341check_study(study_2)342assert len(study_2.trials) == 10343
344study_2.optimize(func, n_trials=10)345check_study(study_2)346assert len(study_2.trials) == 20347
348
349@pytest.mark.parametrize("storage_mode", STORAGE_MODES)350def test_create_study(storage_mode: str) -> None:351with StorageSupplier(storage_mode) as storage:352# Test creating a new study.353study = create_study(storage=storage, load_if_exists=False)354
355# Test `load_if_exists=True` with existing study.356create_study(study_name=study.study_name, storage=storage, load_if_exists=True)357
358with pytest.raises(DuplicatedStudyError):359create_study(study_name=study.study_name, storage=storage, load_if_exists=False)360
361
362@pytest.mark.parametrize("storage_mode", STORAGE_MODES)363def test_load_study(storage_mode: str) -> None:364with StorageSupplier(storage_mode) as storage:365if storage is None:366# `InMemoryStorage` can not be used with `load_study` function.367return368
369study_name = str(uuid.uuid4())370
371with pytest.raises(KeyError):372# Test loading an unexisting study.373load_study(study_name=study_name, storage=storage)374
375# Create a new study.376created_study = create_study(study_name=study_name, storage=storage)377
378# Test loading an existing study.379loaded_study = load_study(study_name=study_name, storage=storage)380assert created_study._study_id == loaded_study._study_id381
382
383@pytest.mark.parametrize("storage_mode", STORAGE_MODES)384def test_load_study_study_name_none(storage_mode: str) -> None:385with StorageSupplier(storage_mode) as storage:386if storage is None:387# `InMemoryStorage` can not be used with `load_study` function.388return389
390study_name = str(uuid.uuid4())391
392_ = create_study(study_name=study_name, storage=storage)393
394loaded_study = load_study(study_name=None, storage=storage)395
396assert loaded_study.study_name == study_name397
398study_name = str(uuid.uuid4())399
400_ = create_study(study_name=study_name, storage=storage)401
402# Ambiguous study.403with pytest.raises(ValueError):404load_study(study_name=None, storage=storage)405
406
407@pytest.mark.parametrize("storage_mode", STORAGE_MODES)408def test_delete_study(storage_mode: str) -> None:409with StorageSupplier(storage_mode) as storage:410# Test deleting a non-existing study.411with pytest.raises(KeyError):412delete_study(study_name="invalid-study-name", storage=storage)413
414# Test deleting an existing study.415study = create_study(storage=storage, load_if_exists=False)416delete_study(study_name=study.study_name, storage=storage)417
418# Test failed to delete the study which is already deleted.419with pytest.raises(KeyError):420delete_study(study_name=study.study_name, storage=storage)421
422
423@pytest.mark.parametrize("from_storage_mode", STORAGE_MODES)424@pytest.mark.parametrize("to_storage_mode", STORAGE_MODES)425def test_copy_study(from_storage_mode: str, to_storage_mode: str) -> None:426with StorageSupplier(from_storage_mode) as from_storage, StorageSupplier(427to_storage_mode
428) as to_storage:429from_study = create_study(storage=from_storage, directions=["maximize", "minimize"])430from_study._storage.set_study_system_attr(from_study._study_id, "foo", "bar")431from_study.set_user_attr("baz", "qux")432from_study.optimize(433lambda t: (t.suggest_float("x0", 0, 1), t.suggest_float("x1", 0, 1)), n_trials=3434)435
436copy_study(437from_study_name=from_study.study_name,438from_storage=from_storage,439to_storage=to_storage,440)441
442to_study = load_study(study_name=from_study.study_name, storage=to_storage)443
444assert to_study.study_name == from_study.study_name445assert to_study.directions == from_study.directions446to_study_system_attrs = to_study._storage.get_study_system_attrs(to_study._study_id)447from_study_system_attrs = from_study._storage.get_study_system_attrs(from_study._study_id)448assert to_study_system_attrs == from_study_system_attrs449assert to_study.user_attrs == from_study.user_attrs450assert len(to_study.trials) == len(from_study.trials)451
452
453@pytest.mark.parametrize("from_storage_mode", STORAGE_MODES)454@pytest.mark.parametrize("to_storage_mode", STORAGE_MODES)455def test_copy_study_to_study_name(from_storage_mode: str, to_storage_mode: str) -> None:456with StorageSupplier(from_storage_mode) as from_storage, StorageSupplier(457to_storage_mode
458) as to_storage:459from_study = create_study(study_name="foo", storage=from_storage)460_ = create_study(study_name="foo", storage=to_storage)461
462with pytest.raises(DuplicatedStudyError):463copy_study(464from_study_name=from_study.study_name,465from_storage=from_storage,466to_storage=to_storage,467)468
469copy_study(470from_study_name=from_study.study_name,471from_storage=from_storage,472to_storage=to_storage,473to_study_name="bar",474)475
476_ = load_study(study_name="bar", storage=to_storage)477
478
479def test_nested_optimization() -> None:480def objective(trial: Trial) -> float:481with pytest.raises(RuntimeError):482trial.study.optimize(lambda _: 0.0, n_trials=1)483
484return 1.0485
486study = create_study()487study.optimize(objective, n_trials=10, catch=())488
489
490def test_stop_in_objective() -> None:491# Test stopping the optimization: it should stop once the trial number reaches 4.492study = create_study()493study.optimize(stop_objective(4), n_trials=10)494assert len(study.trials) == 5495
496# Test calling `optimize` again: it should stop once the trial number reaches 11.497study.optimize(stop_objective(11), n_trials=10)498assert len(study.trials) == 12499
500
501def test_stop_in_callback() -> None:502def callback(study: Study, trial: FrozenTrial) -> None:503if trial.number >= 4:504study.stop()505
506# Test stopping the optimization inside a callback.507study = create_study()508study.optimize(lambda _: 1.0, n_trials=10, callbacks=[callback])509assert len(study.trials) == 5510
511
512def test_stop_n_jobs() -> None:513def callback(study: Study, trial: FrozenTrial) -> None:514if trial.number >= 4:515study.stop()516
517study = create_study()518study.optimize(lambda _: 1.0, n_trials=None, callbacks=[callback], n_jobs=2)519assert 5 <= len(study.trials) <= 6520
521
522def test_stop_outside_optimize() -> None:523# Test stopping outside the optimization: it should raise `RuntimeError`.524study = create_study()525with pytest.raises(RuntimeError):526study.stop()527
528# Test calling `optimize` after the `RuntimeError` is caught.529study.optimize(lambda _: 1.0, n_trials=1)530
531
532@pytest.mark.parametrize("storage_mode", STORAGE_MODES)533def test_add_trial(storage_mode: str) -> None:534with StorageSupplier(storage_mode) as storage:535study = create_study(storage=storage)536assert len(study.trials) == 0537
538trial = create_trial(value=0.8)539study.add_trial(trial)540assert len(study.trials) == 1541assert study.trials[0].number == 0542assert study.best_value == 0.8543
544
545def test_add_trial_invalid_values_length() -> None:546study = create_study()547trial = create_trial(values=[0, 0])548with pytest.raises(ValueError):549study.add_trial(trial)550
551study = create_study(directions=["minimize", "minimize"])552trial = create_trial(value=0)553with pytest.raises(ValueError):554study.add_trial(trial)555
556
557@pytest.mark.parametrize("storage_mode", STORAGE_MODES)558def test_add_trials(storage_mode: str) -> None:559with StorageSupplier(storage_mode) as storage:560study = create_study(storage=storage)561assert len(study.trials) == 0562
563study.add_trials([])564assert len(study.trials) == 0565
566trials = [create_trial(value=i) for i in range(3)]567study.add_trials(trials)568assert len(study.trials) == 3569for i, trial in enumerate(study.trials):570assert trial.number == i571assert trial.value == i572
573other_study = create_study(storage=storage)574other_study.add_trials(study.trials)575assert len(other_study.trials) == 3576for i, trial in enumerate(other_study.trials):577assert trial.number == i578assert trial.value == i579
580
581@pytest.mark.parametrize("storage_mode", STORAGE_MODES)582def test_enqueue_trial_properly_sets_param_values(storage_mode: str) -> None:583with StorageSupplier(storage_mode) as storage:584study = create_study(storage=storage)585assert len(study.trials) == 0586
587study.enqueue_trial(params={"x": -5, "y": 5})588study.enqueue_trial(params={"x": -1, "y": 0})589
590def objective(trial: Trial) -> float:591x = trial.suggest_int("x", -10, 10)592y = trial.suggest_int("y", -10, 10)593return x**2 + y**2594
595study.optimize(objective, n_trials=2)596t0 = study.trials[0]597assert t0.params["x"] == -5598assert t0.params["y"] == 5599
600t1 = study.trials[1]601assert t1.params["x"] == -1602assert t1.params["y"] == 0603
604
605@pytest.mark.parametrize("storage_mode", STORAGE_MODES)606def test_enqueue_trial_with_unfixed_parameters(storage_mode: str) -> None:607with StorageSupplier(storage_mode) as storage:608study = create_study(storage=storage)609assert len(study.trials) == 0610
611study.enqueue_trial(params={"x": -5})612
613def objective(trial: Trial) -> float:614x = trial.suggest_int("x", -10, 10)615y = trial.suggest_int("y", -10, 10)616return x**2 + y**2617
618study.optimize(objective, n_trials=1)619t = study.trials[0]620assert t.params["x"] == -5621assert -10 <= t.params["y"] <= 10622
623
624@pytest.mark.parametrize("storage_mode", STORAGE_MODES)625def test_enqueue_trial_properly_sets_user_attr(storage_mode: str) -> None:626with StorageSupplier(storage_mode) as storage:627study = create_study(storage=storage)628assert len(study.trials) == 0629
630study.enqueue_trial(params={"x": -5, "y": 5}, user_attrs={"is_optimal": False})631study.enqueue_trial(params={"x": 0, "y": 0}, user_attrs={"is_optimal": True})632
633def objective(trial: Trial) -> float:634x = trial.suggest_int("x", -10, 10)635y = trial.suggest_int("y", -10, 10)636return x**2 + y**2637
638study.optimize(objective, n_trials=2)639t0 = study.trials[0]640assert t0.user_attrs == {"is_optimal": False}641
642t1 = study.trials[1]643assert t1.user_attrs == {"is_optimal": True}644
645
646@pytest.mark.parametrize("storage_mode", STORAGE_MODES)647def test_enqueue_trial_with_non_dict_parameters(storage_mode: str) -> None:648with StorageSupplier(storage_mode) as storage:649study = create_study(storage=storage)650assert len(study.trials) == 0651
652with pytest.raises(TypeError):653study.enqueue_trial(params=[17, 12]) # type: ignore[arg-type]654
655
656@pytest.mark.parametrize("storage_mode", STORAGE_MODES)657def test_enqueue_trial_with_out_of_range_parameters(storage_mode: str) -> None:658fixed_value = 11659
660with StorageSupplier(storage_mode) as storage:661study = create_study(storage=storage)662assert len(study.trials) == 0663
664study.enqueue_trial(params={"x": fixed_value})665
666def objective(trial: Trial) -> float:667return trial.suggest_int("x", -10, 10)668
669with pytest.warns(UserWarning):670study.optimize(objective, n_trials=1)671t = study.trials[0]672assert t.params["x"] == fixed_value673
674# Internal logic might differ when distribution contains a single element.675# Test it explicitly.676with StorageSupplier(storage_mode) as storage:677study = create_study(storage=storage)678assert len(study.trials) == 0679
680study.enqueue_trial(params={"x": fixed_value})681
682def objective(trial: Trial) -> float:683return trial.suggest_int("x", 1, 1) # Single element.684
685with pytest.warns(UserWarning):686study.optimize(objective, n_trials=1)687t = study.trials[0]688assert t.params["x"] == fixed_value689
690
691@pytest.mark.parametrize("storage_mode", STORAGE_MODES)692def test_enqueue_trial_skips_existing_finished(storage_mode: str) -> None:693with StorageSupplier(storage_mode) as storage:694study = create_study(storage=storage)695assert len(study.trials) == 0696
697def objective(trial: Trial) -> float:698x = trial.suggest_int("x", -10, 10)699y = trial.suggest_int("y", -10, 10)700return x**2 + y**2701
702study.enqueue_trial({"x": -5, "y": 5})703study.optimize(objective, n_trials=1)704
705t0 = study.trials[0]706assert t0.params["x"] == -5707assert t0.params["y"] == 5708
709before_enqueue = len(study.trials)710study.enqueue_trial({"x": -5, "y": 5}, skip_if_exists=True)711after_enqueue = len(study.trials)712assert before_enqueue == after_enqueue713
714
715@pytest.mark.parametrize("storage_mode", STORAGE_MODES)716def test_enqueue_trial_skips_existing_waiting(storage_mode: str) -> None:717with StorageSupplier(storage_mode) as storage:718study = create_study(storage=storage)719assert len(study.trials) == 0720
721def objective(trial: Trial) -> float:722x = trial.suggest_int("x", -10, 10)723y = trial.suggest_int("y", -10, 10)724return x**2 + y**2725
726study.enqueue_trial({"x": -5, "y": 5})727before_enqueue = len(study.trials)728study.enqueue_trial({"x": -5, "y": 5}, skip_if_exists=True)729after_enqueue = len(study.trials)730assert before_enqueue == after_enqueue731
732study.optimize(objective, n_trials=1)733t0 = study.trials[0]734assert t0.params["x"] == -5735assert t0.params["y"] == 5736
737
738@pytest.mark.parametrize("storage_mode", STORAGE_MODES)739@pytest.mark.parametrize(740"new_params", [{"x": -5, "y": 5, "z": 5}, {"x": -5}, {"x": -5, "z": 5}, {"x": -5, "y": 6}]741)
742def test_enqueue_trial_skip_existing_allows_unfixed(743storage_mode: str, new_params: dict[str, int]744) -> None:745with StorageSupplier(storage_mode) as storage:746study = create_study(storage=storage)747assert len(study.trials) == 0748
749def objective(trial: Trial) -> float:750x = trial.suggest_int("x", -10, 10)751y = trial.suggest_int("y", -10, 10)752if trial.number == 1:753z = trial.suggest_int("z", -10, 10)754return x**2 + y**2 + z**2755return x**2 + y**2756
757study.enqueue_trial({"x": -5, "y": 5})758study.optimize(objective, n_trials=1)759t0 = study.trials[0]760assert t0.params["x"] == -5761assert t0.params["y"] == 5762
763study.enqueue_trial(new_params, skip_if_exists=True)764study.optimize(objective, n_trials=1)765
766unfixed_params = {"x", "y", "z"} - set(new_params)767t1 = study.trials[1]768assert all(t1.params[k] == new_params[k] for k in new_params)769assert all(-10 <= t1.params[k] <= 10 for k in unfixed_params)770
771
772@pytest.mark.parametrize("storage_mode", STORAGE_MODES)773@pytest.mark.parametrize(774"param", ["foo", 1, 1.1, 1e17, 1e-17, float("inf"), float("-inf"), float("nan"), None]775)
776def test_enqueue_trial_skip_existing_handles_common_types(storage_mode: str, param: Any) -> None:777with StorageSupplier(storage_mode) as storage:778study = create_study(storage=storage)779study.enqueue_trial({"x": param})780before_enqueue = len(study.trials)781study.enqueue_trial({"x": param}, skip_if_exists=True)782after_enqueue = len(study.trials)783assert before_enqueue == after_enqueue784
785
786@patch("optuna.study._optimize.gc.collect")787def test_optimize_with_gc(collect_mock: Mock) -> None:788study = create_study()789study.optimize(func, n_trials=10, gc_after_trial=True)790check_study(study)791assert collect_mock.call_count == 10792
793
794@patch("optuna.study._optimize.gc.collect")795def test_optimize_without_gc(collect_mock: Mock) -> None:796study = create_study()797study.optimize(func, n_trials=10, gc_after_trial=False)798check_study(study)799assert collect_mock.call_count == 0800
801
802@pytest.mark.parametrize("n_jobs", [1, 2])803def test_optimize_with_progbar(n_jobs: int, capsys: _pytest.capture.CaptureFixture) -> None:804study = create_study()805study.optimize(lambda _: 1.0, n_trials=10, n_jobs=n_jobs, show_progress_bar=True)806_, err = capsys.readouterr()807
808# Search for progress bar elements in stderr.809assert "Best trial: 0" in err810assert "Best value: 1" in err811assert "10/10" in err812if platform.system() != "Windows":813# Skip this assertion because the progress bar sometimes stops at 99% on Windows.814assert "100%" in err815
816
817@pytest.mark.parametrize("n_jobs", [1, 2])818def test_optimize_without_progbar(n_jobs: int, capsys: _pytest.capture.CaptureFixture) -> None:819study = create_study()820study.optimize(lambda _: 1.0, n_trials=10, n_jobs=n_jobs)821_, err = capsys.readouterr()822
823assert "Best trial: 0" not in err824assert "Best value: 1" not in err825assert "10/10" not in err826if platform.system() != "Windows":827# Skip this assertion because the progress bar sometimes stops at 99% on Windows.828assert "100%" not in err829
830
831def test_optimize_with_progbar_timeout(capsys: _pytest.capture.CaptureFixture) -> None:832study = create_study()833study.optimize(lambda _: 1.0, timeout=2.0, show_progress_bar=True)834_, err = capsys.readouterr()835
836assert "Best trial: 0" in err837assert "Best value: 1" in err838assert "00:02/00:02" in err839if platform.system() != "Windows":840# Skip this assertion because the progress bar sometimes stops at 99% on Windows.841assert "100%" in err842
843
844def test_optimize_with_progbar_parallel_timeout(capsys: _pytest.capture.CaptureFixture) -> None:845study = create_study()846with pytest.warns(847UserWarning, match="The timeout-based progress bar is not supported with n_jobs != 1."848):849study.optimize(lambda _: 1.0, timeout=2.0, show_progress_bar=True, n_jobs=2)850_, err = capsys.readouterr()851
852# Testing for a character that forms progress bar borders.853assert "|" not in err854
855
856@pytest.mark.parametrize(857"timeout,expected",858[859(59.0, "/00:59"),860(60.0, "/01:00"),861(60.0 * 60, "/1:00:00"),862(60.0 * 60 * 24, "/24:00:00"),863(60.0 * 60 * 24 * 10, "/240:00:00"),864],865)
866def test_optimize_with_progbar_timeout_formats(867timeout: float, expected: str, capsys: _pytest.capture.CaptureFixture868) -> None:869study = create_study()870study.optimize(stop_objective(5), timeout=timeout, show_progress_bar=True)871_, err = capsys.readouterr()872assert expected in err873
874
875@pytest.mark.parametrize("n_jobs", [1, 2])876def test_optimize_without_progbar_timeout(877n_jobs: int, capsys: _pytest.capture.CaptureFixture878) -> None:879study = create_study()880study.optimize(lambda _: 1.0, timeout=2.0, n_jobs=n_jobs)881_, err = capsys.readouterr()882
883assert "Best trial: 0" not in err884assert "Best value: 1.0" not in err885assert "00:02/00:02" not in err886if platform.system() != "Windows":887# Skip this assertion because the progress bar sometimes stops at 99% on Windows.888assert "100%" not in err889
890
891@pytest.mark.parametrize("n_jobs", [1, 2])892def test_optimize_progbar_n_trials_prioritized(893n_jobs: int, capsys: _pytest.capture.CaptureFixture894) -> None:895study = create_study()896study.optimize(lambda _: 1.0, n_trials=10, n_jobs=n_jobs, timeout=10.0, show_progress_bar=True)897_, err = capsys.readouterr()898
899assert "Best trial: 0" in err900assert "Best value: 1" in err901assert "10/10" in err902if platform.system() != "Windows":903# Skip this assertion because the progress bar sometimes stops at 99% on Windows.904assert "100%" in err905assert "it" in err906
907
908@pytest.mark.parametrize("n_jobs", [1, 2])909def test_optimize_without_progbar_n_trials_prioritized(910n_jobs: int, capsys: _pytest.capture.CaptureFixture911) -> None:912study = create_study()913study.optimize(lambda _: 1.0, n_trials=10, n_jobs=n_jobs, timeout=10.0)914_, err = capsys.readouterr()915
916# Testing for a character that forms progress bar borders.917assert "|" not in err918
919
920@pytest.mark.parametrize("n_jobs", [1, 2])921def test_optimize_progbar_no_constraints(922n_jobs: int, capsys: _pytest.capture.CaptureFixture923) -> None:924study = create_study()925with warnings.catch_warnings():926warnings.simplefilter("ignore", category=UserWarning)927study.optimize(stop_objective(5), n_jobs=n_jobs, show_progress_bar=True)928_, err = capsys.readouterr()929
930# We can't simply test if stderr is empty, since we're not sure931# what else could write to it. Instead, we are testing for a character932# that forms progress bar borders.933assert "|" not in err934
935
936@pytest.mark.parametrize("n_jobs", [1, 2])937def test_optimize_without_progbar_no_constraints(938n_jobs: int, capsys: _pytest.capture.CaptureFixture939) -> None:940study = create_study()941study.optimize(stop_objective(5), n_jobs=n_jobs)942_, err = capsys.readouterr()943
944# Testing for a character that forms progress bar borders.945assert "|" not in err946
947
948@pytest.mark.parametrize("n_jobs", [1, 4])949def test_callbacks(n_jobs: int) -> None:950lock = threading.Lock()951
952def with_lock(f: CallbackFuncType) -> CallbackFuncType:953def callback(study: Study, trial: FrozenTrial) -> None:954with lock:955f(study, trial)956
957return callback958
959study = create_study()960
961def objective(trial: Trial) -> float:962return trial.suggest_int("x", 1, 1)963
964# Empty callback list.965study.optimize(objective, callbacks=[], n_trials=10, n_jobs=n_jobs)966
967# One callback.968values = []969callbacks = [with_lock(lambda study, trial: values.append(trial.value))]970study.optimize(objective, callbacks=callbacks, n_trials=10, n_jobs=n_jobs)971assert values == [1] * 10972
973# Two callbacks.974values = []975params = []976callbacks = [977with_lock(lambda study, trial: values.append(trial.value)),978with_lock(lambda study, trial: params.append(trial.params)),979]980study.optimize(objective, callbacks=callbacks, n_trials=10, n_jobs=n_jobs)981assert values == [1] * 10982assert params == [{"x": 1}] * 10983
984# If a trial is failed with an exception and the exception is caught by the study,985# callbacks are invoked.986states = []987callbacks = [with_lock(lambda study, trial: states.append(trial.state))]988study.optimize(989lambda t: 1 / 0,990callbacks=callbacks,991n_trials=10,992n_jobs=n_jobs,993catch=(ZeroDivisionError,),994)995assert states == [TrialState.FAIL] * 10996
997# If a trial is failed with an exception and the exception isn't caught by the study,998# callbacks aren't invoked.999states = []1000callbacks = [with_lock(lambda study, trial: states.append(trial.state))]1001with pytest.raises(ZeroDivisionError):1002study.optimize(lambda t: 1 / 0, callbacks=callbacks, n_trials=10, n_jobs=n_jobs, catch=())1003assert states == []1004
1005
1006def test_optimize_infinite_budget_progbar() -> None:1007def terminate_study(study: Study, trial: FrozenTrial) -> None:1008study.stop()1009
1010study = create_study()1011
1012with pytest.warns(UserWarning):1013study.optimize(1014func, n_trials=None, timeout=None, show_progress_bar=True, callbacks=[terminate_study]1015)1016
1017
1018@pytest.mark.parametrize("storage_mode", STORAGE_MODES)1019def test_get_trials(storage_mode: str) -> None:1020with StorageSupplier(storage_mode) as storage:1021study = create_study(storage=storage)1022study.optimize(lambda t: t.suggest_int("x", 1, 5), n_trials=5)1023
1024with patch("copy.deepcopy", wraps=copy.deepcopy) as mock_object:1025trials0 = study.get_trials(deepcopy=False)1026assert mock_object.call_count == 01027assert len(trials0) == 51028
1029trials1 = study.get_trials(deepcopy=True)1030assert mock_object.call_count > 01031assert trials0 == trials11032
1033# `study.trials` is equivalent to `study.get_trials(deepcopy=True)`.1034old_count = mock_object.call_count1035trials2 = study.trials1036assert mock_object.call_count > old_count1037assert trials0 == trials21038
1039
1040@pytest.mark.parametrize("storage_mode", STORAGE_MODES)1041def test_get_trials_state_option(storage_mode: str) -> None:1042with StorageSupplier(storage_mode) as storage:1043study = create_study(storage=storage)1044
1045def objective(trial: Trial) -> float:1046if trial.number == 0:1047return 0.0 # TrialState.COMPLETE.1048elif trial.number == 1:1049return 0.0 # TrialState.COMPLETE.1050elif trial.number == 2:1051raise TrialPruned # TrialState.PRUNED.1052else:1053assert False1054
1055study.optimize(objective, n_trials=3)1056
1057trials = study.get_trials(states=None)1058assert len(trials) == 31059
1060trials = study.get_trials(states=(TrialState.COMPLETE,))1061assert len(trials) == 21062assert all(t.state == TrialState.COMPLETE for t in trials)1063
1064trials = study.get_trials(states=(TrialState.COMPLETE, TrialState.PRUNED))1065assert len(trials) == 31066assert all(t.state in (TrialState.COMPLETE, TrialState.PRUNED) for t in trials)1067
1068trials = study.get_trials(states=())1069assert len(trials) == 01070
1071other_states = [1072s for s in list(TrialState) if s != TrialState.COMPLETE and s != TrialState.PRUNED1073]1074for s in other_states:1075trials = study.get_trials(states=(s,))1076assert len(trials) == 01077
1078
1079def test_log_completed_trial(capsys: _pytest.capture.CaptureFixture) -> None:1080# We need to reconstruct our default handler to properly capture stderr.1081logging._reset_library_root_logger()1082logging.set_verbosity(logging.INFO)1083
1084study = create_study()1085study.optimize(lambda _: 1.0, n_trials=1)1086_, err = capsys.readouterr()1087assert "Trial 0" in err1088
1089logging.set_verbosity(logging.WARNING)1090study.optimize(lambda _: 1.0, n_trials=1)1091_, err = capsys.readouterr()1092assert "Trial 1" not in err1093
1094logging.set_verbosity(logging.DEBUG)1095study.optimize(lambda _: 1.0, n_trials=1)1096_, err = capsys.readouterr()1097assert "Trial 2" in err1098
1099
1100def test_log_completed_trial_skip_storage_access() -> None:1101study = create_study()1102
1103# Create a trial to retrieve it as the `study.best_trial`.1104study.optimize(lambda _: 0.0, n_trials=1)1105frozen_trial = study.best_trial1106
1107storage = study._storage1108
1109with patch.object(storage, "get_best_trial", wraps=storage.get_best_trial) as mock_object:1110study._log_completed_trial(frozen_trial)1111assert mock_object.call_count == 11112
1113logging.set_verbosity(logging.WARNING)1114with patch.object(storage, "get_best_trial", wraps=storage.get_best_trial) as mock_object:1115study._log_completed_trial(frozen_trial)1116assert mock_object.call_count == 01117
1118logging.set_verbosity(logging.DEBUG)1119with patch.object(storage, "get_best_trial", wraps=storage.get_best_trial) as mock_object:1120study._log_completed_trial(frozen_trial)1121assert mock_object.call_count == 11122
1123
1124def test_create_study_with_multi_objectives() -> None:1125study = create_study(directions=["maximize"])1126assert study.direction == StudyDirection.MAXIMIZE1127assert not study._is_multi_objective()1128
1129study = create_study(directions=["maximize", "minimize"])1130assert study.directions == [StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE]1131assert study._is_multi_objective()1132
1133with pytest.raises(ValueError):1134# Empty `direction` isn't allowed.1135_ = create_study(directions=[])1136
1137with pytest.raises(ValueError):1138_ = create_study(direction="minimize", directions=["maximize"])1139
1140with pytest.raises(ValueError):1141_ = create_study(direction="minimize", directions=[])1142
1143
1144def test_create_study_with_direction_object() -> None:1145study = create_study(direction=StudyDirection.MAXIMIZE)1146assert study.direction == StudyDirection.MAXIMIZE1147
1148study = create_study(directions=[StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE])1149assert study.directions == [StudyDirection.MAXIMIZE, StudyDirection.MINIMIZE]1150
1151
1152@pytest.mark.parametrize("n_objectives", [2, 3])1153def test_optimize_with_multi_objectives(n_objectives: int) -> None:1154directions = ["minimize" for _ in range(n_objectives)]1155study = create_study(directions=directions)1156
1157def objective(trial: Trial) -> list[float]:1158return [trial.suggest_float("v{}".format(i), 0, 5) for i in range(n_objectives)]1159
1160study.optimize(objective, n_trials=10)1161
1162assert len(study.trials) == 101163
1164for trial in study.trials:1165assert trial.values1166assert len(trial.values) == n_objectives1167
1168
1169def test_best_trials() -> None:1170study = create_study(directions=["minimize", "maximize"])1171study.optimize(lambda t: [2, 2], n_trials=1)1172study.optimize(lambda t: [1, 1], n_trials=1)1173study.optimize(lambda t: [3, 1], n_trials=1)1174assert {tuple(t.values) for t in study.best_trials} == {(1, 1), (2, 2)}1175
1176
1177def test_wrong_n_objectives() -> None:1178n_objectives = 21179directions = ["minimize" for _ in range(n_objectives)]1180study = create_study(directions=directions)1181
1182def objective(trial: Trial) -> list[float]:1183return [trial.suggest_float("v{}".format(i), 0, 5) for i in range(n_objectives + 1)]1184
1185study.optimize(objective, n_trials=10)1186
1187for trial in study.trials:1188assert trial.state is TrialState.FAIL1189
1190
1191def test_ask() -> None:1192study = create_study()1193
1194trial = study.ask()1195assert isinstance(trial, Trial)1196
1197
1198def test_ask_enqueue_trial() -> None:1199study = create_study()1200
1201study.enqueue_trial({"x": 0.5}, user_attrs={"memo": "this is memo"})1202
1203trial = study.ask()1204assert trial.suggest_float("x", 0, 1) == 0.51205assert trial.user_attrs == {"memo": "this is memo"}1206
1207
1208def test_ask_fixed_search_space() -> None:1209fixed_distributions = {1210"x": distributions.FloatDistribution(0, 1),1211"y": distributions.CategoricalDistribution(["bacon", "spam"]),1212}1213
1214study = create_study()1215trial = study.ask(fixed_distributions=fixed_distributions)1216
1217params = trial.params1218assert len(trial.params) == 21219assert 0 <= params["x"] < 11220assert params["y"] in ["bacon", "spam"]1221
1222
1223# Deprecated distributions are internally converted to corresponding distributions.
1224@pytest.mark.filterwarnings("ignore::FutureWarning")1225def test_ask_distribution_conversion() -> None:1226fixed_distributions = {1227"ud": distributions.UniformDistribution(low=0, high=10),1228"dud": distributions.DiscreteUniformDistribution(low=0, high=10, q=2),1229"lud": distributions.LogUniformDistribution(low=1, high=10),1230"id": distributions.IntUniformDistribution(low=0, high=10),1231"idd": distributions.IntUniformDistribution(low=0, high=10, step=2),1232"ild": distributions.IntLogUniformDistribution(low=1, high=10),1233}1234
1235study = create_study()1236
1237with pytest.warns(1238FutureWarning,1239match="See https://github.com/optuna/optuna/issues/2941",1240) as record:1241trial = study.ask(fixed_distributions=fixed_distributions)1242assert len(record) == 61243
1244expected_distributions = {1245"ud": distributions.FloatDistribution(low=0, high=10, log=False, step=None),1246"dud": distributions.FloatDistribution(low=0, high=10, log=False, step=2),1247"lud": distributions.FloatDistribution(low=1, high=10, log=True, step=None),1248"id": distributions.IntDistribution(low=0, high=10, log=False, step=1),1249"idd": distributions.IntDistribution(low=0, high=10, log=False, step=2),1250"ild": distributions.IntDistribution(low=1, high=10, log=True, step=1),1251}1252
1253assert trial.distributions == expected_distributions1254
1255
1256# It confirms that ask doesn't convert non-deprecated distributions.
1257def test_ask_distribution_conversion_noop() -> None:1258fixed_distributions = {1259"ud": distributions.FloatDistribution(low=0, high=10, log=False, step=None),1260"dud": distributions.FloatDistribution(low=0, high=10, log=False, step=2),1261"lud": distributions.FloatDistribution(low=1, high=10, log=True, step=None),1262"id": distributions.IntDistribution(low=0, high=10, log=False, step=1),1263"idd": distributions.IntDistribution(low=0, high=10, log=False, step=2),1264"ild": distributions.IntDistribution(low=1, high=10, log=True, step=1),1265"cd": distributions.CategoricalDistribution(choices=["a", "b", "c"]),1266}1267
1268study = create_study()1269
1270trial = study.ask(fixed_distributions=fixed_distributions)1271
1272# Check fixed_distributions doesn't change.1273assert trial.distributions == fixed_distributions1274
1275
1276def test_tell() -> None:1277study = create_study()1278assert len(study.trials) == 01279
1280trial = study.ask()1281assert len(study.trials) == 11282assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 01283
1284study.tell(trial, 1.0)1285assert len(study.trials) == 11286assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 11287
1288study.tell(study.ask(), [1.0])1289assert len(study.trials) == 21290assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 21291
1292# `trial` could be int.1293study.tell(study.ask().number, 1.0)1294assert len(study.trials) == 31295assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 31296
1297# Inf is supported as values.1298study.tell(study.ask(), float("inf"))1299assert len(study.trials) == 41300assert len(study.get_trials(states=(TrialState.COMPLETE,))) == 41301
1302study.tell(study.ask(), state=TrialState.PRUNED)1303assert len(study.trials) == 51304assert len(study.get_trials(states=(TrialState.PRUNED,))) == 11305
1306study.tell(study.ask(), state=TrialState.FAIL)1307assert len(study.trials) == 61308assert len(study.get_trials(states=(TrialState.FAIL,))) == 11309
1310
1311def test_tell_pruned() -> None:1312study = create_study()1313
1314study.tell(study.ask(), state=TrialState.PRUNED)1315assert study.trials[-1].value is None1316assert study.trials[-1].state == TrialState.PRUNED1317
1318# Store the last intermediates as value.1319trial = study.ask()1320trial.report(2.0, step=1)1321study.tell(trial, state=TrialState.PRUNED)1322assert study.trials[-1].value == 2.01323assert study.trials[-1].state == TrialState.PRUNED1324
1325# Inf is also supported as a value.1326trial = study.ask()1327trial.report(float("inf"), step=1)1328study.tell(trial, state=TrialState.PRUNED)1329assert study.trials[-1].value == float("inf")1330assert study.trials[-1].state == TrialState.PRUNED1331
1332# NaN is not supported as a value.1333trial = study.ask()1334trial.report(float("nan"), step=1)1335study.tell(trial, state=TrialState.PRUNED)1336assert study.trials[-1].value is None1337assert study.trials[-1].state == TrialState.PRUNED1338
1339
1340def test_tell_automatically_fail() -> None:1341study = create_study()1342
1343# Check invalid values, e.g. str cannot be cast to float.1344with pytest.warns(UserWarning):1345study.tell(study.ask(), "a") # type: ignore1346assert len(study.trials) == 11347assert study.trials[-1].state == TrialState.FAIL1348assert study.trials[-1].values is None1349
1350# Check invalid values, e.g. `None` that cannot be cast to float.1351with pytest.warns(UserWarning):1352study.tell(study.ask(), None)1353assert len(study.trials) == 21354assert study.trials[-1].state == TrialState.FAIL1355assert study.trials[-1].values is None1356
1357# Check number of values.1358with pytest.warns(UserWarning):1359study.tell(study.ask(), [])1360assert len(study.trials) == 31361assert study.trials[-1].state == TrialState.FAIL1362assert study.trials[-1].values is None1363
1364# Check wrong number of values, e.g. two values for single direction.1365with pytest.warns(UserWarning):1366study.tell(study.ask(), [1.0, 2.0])1367assert len(study.trials) == 41368assert study.trials[-1].state == TrialState.FAIL1369assert study.trials[-1].values is None1370
1371# Both state and values are not specified.1372with pytest.warns(UserWarning):1373study.tell(study.ask())1374assert len(study.trials) == 51375assert study.trials[-1].state == TrialState.FAIL1376assert study.trials[-1].values is None1377
1378# Nan is not supported.1379with pytest.warns(UserWarning):1380study.tell(study.ask(), float("nan"))1381assert len(study.trials) == 61382assert study.trials[-1].state == TrialState.FAIL1383assert study.trials[-1].values is None1384
1385
1386def test_tell_multi_objective() -> None:1387study = create_study(directions=["minimize", "maximize"])1388study.tell(study.ask(), [1.0, 2.0])1389assert len(study.trials) == 11390
1391
1392def test_tell_multi_objective_automatically_fail() -> None:1393# Number of values doesn't match the length of directions.1394study = create_study(directions=["minimize", "maximize"])1395
1396with pytest.warns(UserWarning):1397study.tell(study.ask(), [])1398assert len(study.trials) == 11399assert study.trials[-1].state == TrialState.FAIL1400assert study.trials[-1].values is None1401
1402with pytest.warns(UserWarning):1403study.tell(study.ask(), [1.0])1404assert len(study.trials) == 21405assert study.trials[-1].state == TrialState.FAIL1406assert study.trials[-1].values is None1407
1408with pytest.warns(UserWarning):1409study.tell(study.ask(), [1.0, 2.0, 3.0])1410assert len(study.trials) == 31411assert study.trials[-1].state == TrialState.FAIL1412assert study.trials[-1].values is None1413
1414with pytest.warns(UserWarning):1415study.tell(study.ask(), [1.0, None]) # type: ignore1416assert len(study.trials) == 41417assert study.trials[-1].state == TrialState.FAIL1418assert study.trials[-1].values is None1419
1420with pytest.warns(UserWarning):1421study.tell(study.ask(), [None, None]) # type: ignore1422assert len(study.trials) == 51423assert study.trials[-1].state == TrialState.FAIL1424assert study.trials[-1].values is None1425
1426with pytest.warns(UserWarning):1427study.tell(study.ask(), 1.0)1428assert len(study.trials) == 61429assert study.trials[-1].state == TrialState.FAIL1430assert study.trials[-1].values is None1431
1432
1433def test_tell_invalid() -> None:1434study = create_study()1435
1436# Missing values for completions.1437with pytest.raises(ValueError):1438study.tell(study.ask(), state=TrialState.COMPLETE)1439
1440# Invalid values for completions.1441with pytest.raises(ValueError):1442study.tell(study.ask(), "a", state=TrialState.COMPLETE) # type: ignore1443
1444with pytest.raises(ValueError):1445study.tell(study.ask(), None, state=TrialState.COMPLETE)1446
1447with pytest.raises(ValueError):1448study.tell(study.ask(), [], state=TrialState.COMPLETE)1449
1450with pytest.raises(ValueError):1451study.tell(study.ask(), [1.0, 2.0], state=TrialState.COMPLETE)1452
1453with pytest.raises(ValueError):1454study.tell(study.ask(), float("nan"), state=TrialState.COMPLETE)1455
1456# `state` must be None or finished state.1457with pytest.raises(ValueError):1458study.tell(study.ask(), state=TrialState.RUNNING)1459
1460# `state` must be None or finished state.1461with pytest.raises(ValueError):1462study.tell(study.ask(), state=TrialState.WAITING)1463
1464# `value` must be None for `TrialState.PRUNED`.1465with pytest.raises(ValueError):1466study.tell(study.ask(), values=1, state=TrialState.PRUNED)1467
1468# `value` must be None for `TrialState.FAIL`.1469with pytest.raises(ValueError):1470study.tell(study.ask(), values=1, state=TrialState.FAIL)1471
1472# Trial that has not been asked for cannot be told.1473with pytest.raises(ValueError):1474study.tell(study.ask().number + 1, 1.0)1475
1476# Waiting trial cannot be told.1477with pytest.raises(ValueError):1478study.enqueue_trial({})1479study.tell(study.trials[-1].number, 1.0)1480
1481# It must be Trial or int for trial.1482with pytest.raises(TypeError):1483study.tell("1", 1.0) # type: ignore1484
1485
1486def test_tell_duplicate_tell() -> None:1487study = create_study()1488
1489trial = study.ask()1490study.tell(trial, 1.0)1491
1492# Should not panic when passthrough is enabled.1493study.tell(trial, 1.0, skip_if_finished=True)1494
1495with pytest.raises(ValueError):1496study.tell(trial, 1.0, skip_if_finished=False)1497
1498
1499@pytest.mark.parametrize("storage_mode", STORAGE_MODES)1500def test_enqueued_trial_datetime_start(storage_mode: str) -> None:1501with StorageSupplier(storage_mode) as storage:1502study = create_study(storage=storage)1503
1504def objective(trial: Trial) -> float:1505time.sleep(1)1506x = trial.suggest_int("x", -10, 10)1507return x1508
1509study.enqueue_trial(params={"x": 1})1510assert study.trials[0].datetime_start is None1511
1512study.optimize(objective, n_trials=1)1513assert study.trials[0].datetime_start is not None1514
1515
1516@pytest.mark.parametrize("storage_mode", STORAGE_MODES)1517def test_study_summary_datetime_start_calculation(storage_mode: str) -> None:1518with StorageSupplier(storage_mode) as storage:1519
1520def objective(trial: Trial) -> float:1521x = trial.suggest_int("x", -10, 10)1522return x1523
1524# StudySummary datetime_start tests.1525study = create_study(storage=storage)1526study.enqueue_trial(params={"x": 1})1527
1528# Study summary with only enqueued trials should have null datetime_start.1529summaries = get_all_study_summaries(study._storage, include_best_trial=True)1530assert summaries[0].datetime_start is None1531
1532# Study summary with completed trials should have nonnull datetime_start.1533study.optimize(objective, n_trials=1)1534study.enqueue_trial(params={"x": 1}, skip_if_exists=False)1535summaries = get_all_study_summaries(study._storage, include_best_trial=True)1536assert summaries[0].datetime_start is not None1537
1538
1539def _process_tell(study: Study, trial: Trial | int, values: float) -> None:1540study.tell(trial, values)1541
1542
1543def test_tell_from_another_process() -> None:1544pool = multiprocessing.Pool()1545
1546with StorageSupplier("sqlite") as storage:1547# Create a study and ask for a new trial.1548study = create_study(storage=storage)1549trial0 = study.ask()1550
1551# Test normal behaviour.1552pool.starmap(_process_tell, [(study, trial0, 1.2)])1553
1554assert len(study.trials) == 11555assert study.best_trial.state == TrialState.COMPLETE1556assert study.best_value == 1.21557
1558# Test study.tell using trial number.1559trial = study.ask()1560pool.starmap(_process_tell, [(study, trial.number, 1.5)])1561
1562assert len(study.trials) == 21563assert study.best_trial.state == TrialState.COMPLETE1564assert study.best_value == 1.21565
1566# Should fail because the trial0 is already finished.1567with pytest.raises(ValueError):1568pool.starmap(_process_tell, [(study, trial0, 1.2)])1569
1570
1571@pytest.mark.parametrize("storage_mode", STORAGE_MODES)1572def test_pop_waiting_trial_thread_safe(storage_mode: str) -> None:1573if "sqlite" == storage_mode or "cached_sqlite" == storage_mode:1574pytest.skip("study._pop_waiting_trial is not thread-safe on SQLite3")1575
1576num_enqueued = 101577with StorageSupplier(storage_mode) as storage:1578study = create_study(storage=storage)1579for i in range(num_enqueued):1580study.enqueue_trial({"i": i})1581
1582trial_id_set = set()1583with ThreadPoolExecutor(10) as pool:1584futures = []1585for i in range(num_enqueued):1586future = pool.submit(study._pop_waiting_trial_id)1587futures.append(future)1588
1589for future in as_completed(futures):1590trial_id_set.add(future.result())1591assert len(trial_id_set) == num_enqueued1592
1593
1594def test_set_metric_names() -> None:1595metric_names = ["v0", "v1"]1596study = create_study(directions=["minimize", "minimize"])1597study.set_metric_names(metric_names)1598
1599got_metric_names = study._storage.get_study_system_attrs(study._study_id).get(1600_SYSTEM_ATTR_METRIC_NAMES
1601)1602assert got_metric_names is not None1603assert metric_names == got_metric_names1604
1605
1606def test_set_metric_names_experimental_warning() -> None:1607study = create_study()1608with pytest.warns(ExperimentalWarning):1609study.set_metric_names(["v0"])1610
1611
1612def test_set_invalid_metric_names() -> None:1613metric_names = ["v0", "v1", "v2"]1614study = create_study(directions=["minimize", "minimize"])1615with pytest.raises(ValueError):1616study.set_metric_names(metric_names)1617
1618
1619def test_get_metric_names() -> None:1620study = create_study()1621assert study.metric_names is None1622study.set_metric_names(["v0"])1623assert study.metric_names == ["v0"]1624study.set_metric_names(["v1"])1625assert study.metric_names == ["v1"]1626