optuna
155 строк · 6.1 Кб
1from typing import Callable
2from typing import Generator
3from typing import Optional
4from unittest import mock
5
6from _pytest.logging import LogCaptureFixture
7import pytest
8
9from optuna import create_study
10from optuna import logging
11from optuna import Trial
12from optuna import TrialPruned
13from optuna.study import _optimize
14from optuna.study._tell import _tell_with_warning
15from optuna.study._tell import STUDY_TELL_WARNING_KEY
16from optuna.testing.objectives import fail_objective
17from optuna.testing.storages import STORAGE_MODES
18from optuna.testing.storages import StorageSupplier
19from optuna.trial import TrialState
20
21
22@pytest.fixture(autouse=True)
23def logging_setup() -> Generator[None, None, None]:
24# We need to reconstruct our default handler to properly capture stderr.
25logging._reset_library_root_logger()
26logging.enable_default_handler()
27logging.set_verbosity(logging.INFO)
28logging.enable_propagation()
29
30yield
31
32# After testing, restore default propagation setting.
33logging.disable_propagation()
34
35
36@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
37def test_run_trial(storage_mode: str, caplog: LogCaptureFixture) -> None:
38with StorageSupplier(storage_mode) as storage:
39study = create_study(storage=storage)
40
41caplog.clear()
42frozen_trial = _optimize._run_trial(study, lambda _: 1.0, catch=())
43assert frozen_trial.state == TrialState.COMPLETE
44assert frozen_trial.value == 1.0
45assert "Trial 0 finished with value: 1.0 and parameters" in caplog.text
46
47caplog.clear()
48frozen_trial = _optimize._run_trial(study, lambda _: float("inf"), catch=())
49assert frozen_trial.state == TrialState.COMPLETE
50assert frozen_trial.value == float("inf")
51assert "Trial 1 finished with value: inf and parameters" in caplog.text
52
53caplog.clear()
54frozen_trial = _optimize._run_trial(study, lambda _: -float("inf"), catch=())
55assert frozen_trial.state == TrialState.COMPLETE
56assert frozen_trial.value == -float("inf")
57assert "Trial 2 finished with value: -inf and parameters" in caplog.text
58
59
60@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
61def test_run_trial_automatically_fail(storage_mode: str, caplog: LogCaptureFixture) -> None:
62with StorageSupplier(storage_mode) as storage:
63study = create_study(storage=storage)
64
65frozen_trial = _optimize._run_trial(study, lambda _: float("nan"), catch=())
66assert frozen_trial.state == TrialState.FAIL
67assert frozen_trial.value is None
68
69frozen_trial = _optimize._run_trial(study, lambda _: None, catch=()) # type: ignore[arg-type,return-value] # noqa: E501
70assert frozen_trial.state == TrialState.FAIL
71assert frozen_trial.value is None
72
73frozen_trial = _optimize._run_trial(study, lambda _: object(), catch=()) # type: ignore[arg-type,return-value] # noqa: E501
74assert frozen_trial.state == TrialState.FAIL
75assert frozen_trial.value is None
76
77frozen_trial = _optimize._run_trial(study, lambda _: [0, 1], catch=())
78assert frozen_trial.state == TrialState.FAIL
79assert frozen_trial.value is None
80
81
82@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
83def test_run_trial_pruned(storage_mode: str, caplog: LogCaptureFixture) -> None:
84def gen_func(intermediate: Optional[float] = None) -> Callable[[Trial], float]:
85def func(trial: Trial) -> float:
86if intermediate is not None:
87trial.report(step=1, value=intermediate)
88raise TrialPruned
89
90return func
91
92with StorageSupplier(storage_mode) as storage:
93study = create_study(storage=storage)
94
95caplog.clear()
96frozen_trial = _optimize._run_trial(study, gen_func(), catch=())
97assert frozen_trial.state == TrialState.PRUNED
98assert frozen_trial.value is None
99assert "Trial 0 pruned." in caplog.text
100
101caplog.clear()
102frozen_trial = _optimize._run_trial(study, gen_func(intermediate=1), catch=())
103assert frozen_trial.state == TrialState.PRUNED
104assert frozen_trial.value == 1
105assert "Trial 1 pruned." in caplog.text
106
107caplog.clear()
108frozen_trial = _optimize._run_trial(study, gen_func(intermediate=float("nan")), catch=())
109assert frozen_trial.state == TrialState.PRUNED
110assert frozen_trial.value is None
111assert "Trial 2 pruned." in caplog.text
112
113
114@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
115def test_run_trial_catch_exception(storage_mode: str) -> None:
116with StorageSupplier(storage_mode) as storage:
117study = create_study(storage=storage)
118frozen_trial = _optimize._run_trial(study, fail_objective, catch=(ValueError,))
119assert frozen_trial.state == TrialState.FAIL
120assert STUDY_TELL_WARNING_KEY not in frozen_trial.system_attrs
121
122
123@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
124def test_run_trial_exception(storage_mode: str) -> None:
125with StorageSupplier(storage_mode) as storage:
126study = create_study(storage=storage)
127with pytest.raises(ValueError):
128_optimize._run_trial(study, fail_objective, ())
129
130# Test trial with unacceptable exception.
131with StorageSupplier(storage_mode) as storage:
132study = create_study(storage=storage)
133with pytest.raises(ValueError):
134_optimize._run_trial(study, fail_objective, (ArithmeticError,))
135
136
137@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
138def test_run_trial_invoke_tell_with_suppressing_warning(storage_mode: str) -> None:
139def func_numerical(trial: Trial) -> float:
140return trial.suggest_float("v", 0, 10)
141
142with StorageSupplier(storage_mode) as storage:
143study = create_study(storage=storage)
144
145with mock.patch(
146"optuna.study._optimize._tell_with_warning", side_effect=_tell_with_warning
147) as mock_obj:
148_optimize._run_trial(study, func_numerical, ())
149mock_obj.assert_called_once_with(
150study=mock.ANY,
151trial=mock.ANY,
152value_or_values=mock.ANY,
153state=mock.ANY,
154suppress_warning=True,
155)
156