optuna

Форк
0
/
test_optimize.py 
155 строк · 6.1 Кб
1
from typing import Callable
2
from typing import Generator
3
from typing import Optional
4
from unittest import mock
5

6
from _pytest.logging import LogCaptureFixture
7
import pytest
8

9
from optuna import create_study
10
from optuna import logging
11
from optuna import Trial
12
from optuna import TrialPruned
13
from optuna.study import _optimize
14
from optuna.study._tell import _tell_with_warning
15
from optuna.study._tell import STUDY_TELL_WARNING_KEY
16
from optuna.testing.objectives import fail_objective
17
from optuna.testing.storages import STORAGE_MODES
18
from optuna.testing.storages import StorageSupplier
19
from optuna.trial import TrialState
20

21

22
@pytest.fixture(autouse=True)
23
def logging_setup() -> Generator[None, None, None]:
24
    # We need to reconstruct our default handler to properly capture stderr.
25
    logging._reset_library_root_logger()
26
    logging.enable_default_handler()
27
    logging.set_verbosity(logging.INFO)
28
    logging.enable_propagation()
29

30
    yield
31

32
    # After testing, restore default propagation setting.
33
    logging.disable_propagation()
34

35

36
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
37
def test_run_trial(storage_mode: str, caplog: LogCaptureFixture) -> None:
38
    with StorageSupplier(storage_mode) as storage:
39
        study = create_study(storage=storage)
40

41
        caplog.clear()
42
        frozen_trial = _optimize._run_trial(study, lambda _: 1.0, catch=())
43
        assert frozen_trial.state == TrialState.COMPLETE
44
        assert frozen_trial.value == 1.0
45
        assert "Trial 0 finished with value: 1.0 and parameters" in caplog.text
46

47
        caplog.clear()
48
        frozen_trial = _optimize._run_trial(study, lambda _: float("inf"), catch=())
49
        assert frozen_trial.state == TrialState.COMPLETE
50
        assert frozen_trial.value == float("inf")
51
        assert "Trial 1 finished with value: inf and parameters" in caplog.text
52

53
        caplog.clear()
54
        frozen_trial = _optimize._run_trial(study, lambda _: -float("inf"), catch=())
55
        assert frozen_trial.state == TrialState.COMPLETE
56
        assert frozen_trial.value == -float("inf")
57
        assert "Trial 2 finished with value: -inf and parameters" in caplog.text
58

59

60
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
61
def test_run_trial_automatically_fail(storage_mode: str, caplog: LogCaptureFixture) -> None:
62
    with StorageSupplier(storage_mode) as storage:
63
        study = create_study(storage=storage)
64

65
        frozen_trial = _optimize._run_trial(study, lambda _: float("nan"), catch=())
66
        assert frozen_trial.state == TrialState.FAIL
67
        assert frozen_trial.value is None
68

69
        frozen_trial = _optimize._run_trial(study, lambda _: None, catch=())  # type: ignore[arg-type,return-value] # noqa: E501
70
        assert frozen_trial.state == TrialState.FAIL
71
        assert frozen_trial.value is None
72

73
        frozen_trial = _optimize._run_trial(study, lambda _: object(), catch=())  # type: ignore[arg-type,return-value] # noqa: E501
74
        assert frozen_trial.state == TrialState.FAIL
75
        assert frozen_trial.value is None
76

77
        frozen_trial = _optimize._run_trial(study, lambda _: [0, 1], catch=())
78
        assert frozen_trial.state == TrialState.FAIL
79
        assert frozen_trial.value is None
80

81

82
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
83
def test_run_trial_pruned(storage_mode: str, caplog: LogCaptureFixture) -> None:
84
    def gen_func(intermediate: Optional[float] = None) -> Callable[[Trial], float]:
85
        def func(trial: Trial) -> float:
86
            if intermediate is not None:
87
                trial.report(step=1, value=intermediate)
88
            raise TrialPruned
89

90
        return func
91

92
    with StorageSupplier(storage_mode) as storage:
93
        study = create_study(storage=storage)
94

95
        caplog.clear()
96
        frozen_trial = _optimize._run_trial(study, gen_func(), catch=())
97
        assert frozen_trial.state == TrialState.PRUNED
98
        assert frozen_trial.value is None
99
        assert "Trial 0 pruned." in caplog.text
100

101
        caplog.clear()
102
        frozen_trial = _optimize._run_trial(study, gen_func(intermediate=1), catch=())
103
        assert frozen_trial.state == TrialState.PRUNED
104
        assert frozen_trial.value == 1
105
        assert "Trial 1 pruned." in caplog.text
106

107
        caplog.clear()
108
        frozen_trial = _optimize._run_trial(study, gen_func(intermediate=float("nan")), catch=())
109
        assert frozen_trial.state == TrialState.PRUNED
110
        assert frozen_trial.value is None
111
        assert "Trial 2 pruned." in caplog.text
112

113

114
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
115
def test_run_trial_catch_exception(storage_mode: str) -> None:
116
    with StorageSupplier(storage_mode) as storage:
117
        study = create_study(storage=storage)
118
        frozen_trial = _optimize._run_trial(study, fail_objective, catch=(ValueError,))
119
        assert frozen_trial.state == TrialState.FAIL
120
        assert STUDY_TELL_WARNING_KEY not in frozen_trial.system_attrs
121

122

123
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
124
def test_run_trial_exception(storage_mode: str) -> None:
125
    with StorageSupplier(storage_mode) as storage:
126
        study = create_study(storage=storage)
127
        with pytest.raises(ValueError):
128
            _optimize._run_trial(study, fail_objective, ())
129

130
    # Test trial with unacceptable exception.
131
    with StorageSupplier(storage_mode) as storage:
132
        study = create_study(storage=storage)
133
        with pytest.raises(ValueError):
134
            _optimize._run_trial(study, fail_objective, (ArithmeticError,))
135

136

137
@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
138
def test_run_trial_invoke_tell_with_suppressing_warning(storage_mode: str) -> None:
139
    def func_numerical(trial: Trial) -> float:
140
        return trial.suggest_float("v", 0, 10)
141

142
    with StorageSupplier(storage_mode) as storage:
143
        study = create_study(storage=storage)
144

145
        with mock.patch(
146
            "optuna.study._optimize._tell_with_warning", side_effect=_tell_with_warning
147
        ) as mock_obj:
148
            _optimize._run_trial(study, func_numerical, ())
149
            mock_obj.assert_called_once_with(
150
                study=mock.ANY,
151
                trial=mock.ANY,
152
                value_or_values=mock.ANY,
153
                state=mock.ANY,
154
                suppress_warning=True,
155
            )
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.