optuna

Форк
0
/
test_multi_objective.py 
163 строки · 7.0 Кб
1
from __future__ import annotations
2

3
import numpy as np
4
import pytest
5

6
from optuna.study import StudyDirection
7
from optuna.study._multi_objective import _dominates
8
from optuna.study._multi_objective import _fast_non_dominated_sort
9
from optuna.trial import create_trial
10
from optuna.trial import TrialState
11

12

13
@pytest.mark.parametrize(
14
    ("v1", "v2"), [(-1, 1), (-float("inf"), 0), (0, float("inf")), (-float("inf"), float("inf"))]
15
)
16
def test_dominates_1d_not_equal(v1: float, v2: float) -> None:
17
    t1 = create_trial(values=[v1])
18
    t2 = create_trial(values=[v2])
19

20
    assert _dominates(t1, t2, [StudyDirection.MINIMIZE])
21
    assert not _dominates(t2, t1, [StudyDirection.MINIMIZE])
22

23
    assert _dominates(t2, t1, [StudyDirection.MAXIMIZE])
24
    assert not _dominates(t1, t2, [StudyDirection.MAXIMIZE])
25

26

27
@pytest.mark.parametrize("v", [0, -float("inf"), float("inf")])
28
@pytest.mark.parametrize("direction", [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE])
29
def test_dominates_1d_equal(v: float, direction: StudyDirection) -> None:
30
    assert not _dominates(create_trial(values=[v]), create_trial(values=[v]), [direction])
31

32

33
def test_dominates_2d() -> None:
34
    directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
35

36
    # Check all pairs of trials consisting of these values, i.e.,
37
    # [-inf, -inf], [-inf, -1], [-inf, 1], [-inf, inf], [-1, -inf], ...
38
    # These values should be specified in ascending order.
39
    vals = [-float("inf"), -1, 1, float("inf")]
40

41
    # The following table illustrates an example of dominance relations.
42
    # "d" cells in the table dominates the "t" cell in (MINIMIZE, MAXIMIZE) setting.
43
    #
44
    #                        value1
45
    #        ╔═════╤═════╤═════╤═════╤═════╗
46
    #        ║     │ -∞  │ -1  │  1  │  ∞  ║
47
    #        ╟─────┼─────┼─────┼─────┼─────╢
48
    #        ║ -∞  │     │     │  d  │  d  ║
49
    #        ╟─────┼─────┼─────┼─────┼─────╢
50
    #        ║ -1  │     │     │  d  │  d  ║
51
    # value0 ╟─────┼─────┼─────┼─────┼─────╢
52
    #        ║  1  │     │     │  t  │  d  ║
53
    #        ╟─────┼─────┼─────┼─────┼─────╢
54
    #        ║  ∞  │     │     │     │     ║
55
    #        ╚═════╧═════╧═════╧═════╧═════╝
56
    #
57
    # In the following code, we check that for each position of "t" cell, the relation
58
    # above holds.
59

60
    # Generate the set of all possible indices.
61
    all_indices = set((i, j) for i in range(len(vals)) for j in range(len(vals)))
62
    for t_i, t_j in all_indices:
63
        # Generate the set of all indices that dominates the current index.
64
        dominating_indices = set(
65
            (d_i, d_j) for d_i in range(t_i + 1) for d_j in range(t_j, len(vals))
66
        )
67
        dominating_indices -= {(t_i, t_j)}
68

69
        for d_i, d_j in dominating_indices:
70
            trial1 = create_trial(values=[vals[t_i], vals[t_j]])
71
            trial2 = create_trial(values=[vals[d_i], vals[d_j]])
72
            assert _dominates(trial2, trial1, directions)
73

74
        for d_i, d_j in all_indices - dominating_indices:
75
            trial1 = create_trial(values=[vals[t_i], vals[t_j]])
76
            trial2 = create_trial(values=[vals[d_i], vals[d_j]])
77
            assert not _dominates(trial2, trial1, directions)
78

79

80
def test_dominates_invalid() -> None:
81
    directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
82

83
    # The numbers of objectives for `t1` and `t2` don't match.
84
    t1 = create_trial(values=[1])  # One objective.
85
    t2 = create_trial(values=[1, 2])  # Two objectives.
86
    with pytest.raises(ValueError):
87
        _dominates(t1, t2, directions)
88

89
    # The numbers of objectives and directions don't match.
90
    t1 = create_trial(values=[1])  # One objective.
91
    t2 = create_trial(values=[1])  # One objective.
92
    with pytest.raises(ValueError):
93
        _dominates(t1, t2, directions)
94

95

96
@pytest.mark.parametrize("t1_state", [TrialState.FAIL, TrialState.WAITING, TrialState.PRUNED])
97
@pytest.mark.parametrize("t2_state", [TrialState.FAIL, TrialState.WAITING, TrialState.PRUNED])
98
def test_dominates_incomplete_vs_incomplete(t1_state: TrialState, t2_state: TrialState) -> None:
99
    directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
100

101
    t1 = create_trial(values=None, state=t1_state)
102
    t2 = create_trial(values=None, state=t2_state)
103

104
    assert not _dominates(t2, t1, list(directions))
105
    assert not _dominates(t1, t2, list(directions))
106

107

108
@pytest.mark.parametrize("t1_state", [TrialState.FAIL, TrialState.WAITING, TrialState.PRUNED])
109
def test_dominates_complete_vs_incomplete(t1_state: TrialState) -> None:
110
    directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
111

112
    t1 = create_trial(values=None, state=t1_state)
113
    t2 = create_trial(values=[1, 1], state=TrialState.COMPLETE)
114

115
    assert _dominates(t2, t1, list(directions))
116
    assert not _dominates(t1, t2, list(directions))
117

118

119
@pytest.mark.parametrize(
120
    ("trial_values", "trial_ranks"),
121
    [
122
        ([[10], [20], [20], [30]], [0, 1, 1, 2]),  # Single objective
123
        ([[10, 30], [10, 10], [20, 20], [30, 10], [15, 15]], [1, 0, 2, 1, 1]),  # Two objectives
124
        (
125
            [[5, 5, 4], [5, 5, 5], [9, 9, 0], [5, 7, 5], [0, 0, 9], [0, 9, 9]],
126
            [0, 1, 0, 2, 0, 1],
127
        ),  # Three objectives
128
        (
129
            [[-5, -5, -4], [-5, -5, 5], [-9, -9, 0], [5, 7, 5], [0, 0, -9], [0, -9, 9]],
130
            [0, 1, 0, 2, 0, 1],
131
        ),  # Negative values are included.
132
        (
133
            [[1, 1], [1, float("inf")], [float("inf"), 1], [float("inf"), float("inf")]],
134
            [0, 1, 1, 2],
135
        ),  # +infs are included.
136
        (
137
            [[1, 1], [1, -float("inf")], [-float("inf"), 1], [-float("inf"), -float("inf")]],
138
            [2, 1, 1, 0],
139
        ),  # -infs are included.
140
        (
141
            [[1, 1], [1, 1], [1, 2], [2, 1], [0, 1.5], [1.5, 0], [0, 1.5]],
142
            [0, 0, 1, 1, 0, 0, 0],
143
        ),  # Two objectives with duplicate values are included.
144
        (
145
            [[1, 1], [1, 1], [1, 2], [2, 1], [1, 1], [0, 1.5], [0, 1.5]],
146
            [0, 0, 1, 1, 0, 0, 0],
147
        ),  # Two objectives with duplicate values are included.
148
        (
149
            [[1, 1, 1], [1, 1, 1], [1, 1, 2], [1, 2, 1], [2, 1, 1], [0, 1.5, 1.5], [0, 1.5, 1.5]],
150
            [0, 0, 1, 1, 1, 0, 0],
151
        ),  # Three objectives with duplicate values are included.
152
    ],
153
)
154
def test_fast_non_dominated_sort(trial_values: list[float], trial_ranks: list[int]) -> None:
155
    ranks = list(_fast_non_dominated_sort(np.array(trial_values)))
156
    assert np.array_equal(ranks, trial_ranks)
157

158

159
def test_fast_non_dominated_sort_invalid() -> None:
160
    with pytest.raises(ValueError):
161
        _fast_non_dominated_sort(
162
            np.array([[1.0, 2.0], [3.0, 4.0]]), penalty=np.array([1.0, 2.0, 3.0])
163
        )
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.