optuna
163 строки · 7.0 Кб
1from __future__ import annotations
2
3import numpy as np
4import pytest
5
6from optuna.study import StudyDirection
7from optuna.study._multi_objective import _dominates
8from optuna.study._multi_objective import _fast_non_dominated_sort
9from optuna.trial import create_trial
10from optuna.trial import TrialState
11
12
13@pytest.mark.parametrize(
14("v1", "v2"), [(-1, 1), (-float("inf"), 0), (0, float("inf")), (-float("inf"), float("inf"))]
15)
16def test_dominates_1d_not_equal(v1: float, v2: float) -> None:
17t1 = create_trial(values=[v1])
18t2 = create_trial(values=[v2])
19
20assert _dominates(t1, t2, [StudyDirection.MINIMIZE])
21assert not _dominates(t2, t1, [StudyDirection.MINIMIZE])
22
23assert _dominates(t2, t1, [StudyDirection.MAXIMIZE])
24assert not _dominates(t1, t2, [StudyDirection.MAXIMIZE])
25
26
27@pytest.mark.parametrize("v", [0, -float("inf"), float("inf")])
28@pytest.mark.parametrize("direction", [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE])
29def test_dominates_1d_equal(v: float, direction: StudyDirection) -> None:
30assert not _dominates(create_trial(values=[v]), create_trial(values=[v]), [direction])
31
32
33def test_dominates_2d() -> None:
34directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
35
36# Check all pairs of trials consisting of these values, i.e.,
37# [-inf, -inf], [-inf, -1], [-inf, 1], [-inf, inf], [-1, -inf], ...
38# These values should be specified in ascending order.
39vals = [-float("inf"), -1, 1, float("inf")]
40
41# The following table illustrates an example of dominance relations.
42# "d" cells in the table dominates the "t" cell in (MINIMIZE, MAXIMIZE) setting.
43#
44# value1
45# ╔═════╤═════╤═════╤═════╤═════╗
46# ║ │ -∞ │ -1 │ 1 │ ∞ ║
47# ╟─────┼─────┼─────┼─────┼─────╢
48# ║ -∞ │ │ │ d │ d ║
49# ╟─────┼─────┼─────┼─────┼─────╢
50# ║ -1 │ │ │ d │ d ║
51# value0 ╟─────┼─────┼─────┼─────┼─────╢
52# ║ 1 │ │ │ t │ d ║
53# ╟─────┼─────┼─────┼─────┼─────╢
54# ║ ∞ │ │ │ │ ║
55# ╚═════╧═════╧═════╧═════╧═════╝
56#
57# In the following code, we check that for each position of "t" cell, the relation
58# above holds.
59
60# Generate the set of all possible indices.
61all_indices = set((i, j) for i in range(len(vals)) for j in range(len(vals)))
62for t_i, t_j in all_indices:
63# Generate the set of all indices that dominates the current index.
64dominating_indices = set(
65(d_i, d_j) for d_i in range(t_i + 1) for d_j in range(t_j, len(vals))
66)
67dominating_indices -= {(t_i, t_j)}
68
69for d_i, d_j in dominating_indices:
70trial1 = create_trial(values=[vals[t_i], vals[t_j]])
71trial2 = create_trial(values=[vals[d_i], vals[d_j]])
72assert _dominates(trial2, trial1, directions)
73
74for d_i, d_j in all_indices - dominating_indices:
75trial1 = create_trial(values=[vals[t_i], vals[t_j]])
76trial2 = create_trial(values=[vals[d_i], vals[d_j]])
77assert not _dominates(trial2, trial1, directions)
78
79
80def test_dominates_invalid() -> None:
81directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
82
83# The numbers of objectives for `t1` and `t2` don't match.
84t1 = create_trial(values=[1]) # One objective.
85t2 = create_trial(values=[1, 2]) # Two objectives.
86with pytest.raises(ValueError):
87_dominates(t1, t2, directions)
88
89# The numbers of objectives and directions don't match.
90t1 = create_trial(values=[1]) # One objective.
91t2 = create_trial(values=[1]) # One objective.
92with pytest.raises(ValueError):
93_dominates(t1, t2, directions)
94
95
96@pytest.mark.parametrize("t1_state", [TrialState.FAIL, TrialState.WAITING, TrialState.PRUNED])
97@pytest.mark.parametrize("t2_state", [TrialState.FAIL, TrialState.WAITING, TrialState.PRUNED])
98def test_dominates_incomplete_vs_incomplete(t1_state: TrialState, t2_state: TrialState) -> None:
99directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
100
101t1 = create_trial(values=None, state=t1_state)
102t2 = create_trial(values=None, state=t2_state)
103
104assert not _dominates(t2, t1, list(directions))
105assert not _dominates(t1, t2, list(directions))
106
107
108@pytest.mark.parametrize("t1_state", [TrialState.FAIL, TrialState.WAITING, TrialState.PRUNED])
109def test_dominates_complete_vs_incomplete(t1_state: TrialState) -> None:
110directions = [StudyDirection.MINIMIZE, StudyDirection.MAXIMIZE]
111
112t1 = create_trial(values=None, state=t1_state)
113t2 = create_trial(values=[1, 1], state=TrialState.COMPLETE)
114
115assert _dominates(t2, t1, list(directions))
116assert not _dominates(t1, t2, list(directions))
117
118
119@pytest.mark.parametrize(
120("trial_values", "trial_ranks"),
121[
122([[10], [20], [20], [30]], [0, 1, 1, 2]), # Single objective
123([[10, 30], [10, 10], [20, 20], [30, 10], [15, 15]], [1, 0, 2, 1, 1]), # Two objectives
124(
125[[5, 5, 4], [5, 5, 5], [9, 9, 0], [5, 7, 5], [0, 0, 9], [0, 9, 9]],
126[0, 1, 0, 2, 0, 1],
127), # Three objectives
128(
129[[-5, -5, -4], [-5, -5, 5], [-9, -9, 0], [5, 7, 5], [0, 0, -9], [0, -9, 9]],
130[0, 1, 0, 2, 0, 1],
131), # Negative values are included.
132(
133[[1, 1], [1, float("inf")], [float("inf"), 1], [float("inf"), float("inf")]],
134[0, 1, 1, 2],
135), # +infs are included.
136(
137[[1, 1], [1, -float("inf")], [-float("inf"), 1], [-float("inf"), -float("inf")]],
138[2, 1, 1, 0],
139), # -infs are included.
140(
141[[1, 1], [1, 1], [1, 2], [2, 1], [0, 1.5], [1.5, 0], [0, 1.5]],
142[0, 0, 1, 1, 0, 0, 0],
143), # Two objectives with duplicate values are included.
144(
145[[1, 1], [1, 1], [1, 2], [2, 1], [1, 1], [0, 1.5], [0, 1.5]],
146[0, 0, 1, 1, 0, 0, 0],
147), # Two objectives with duplicate values are included.
148(
149[[1, 1, 1], [1, 1, 1], [1, 1, 2], [1, 2, 1], [2, 1, 1], [0, 1.5, 1.5], [0, 1.5, 1.5]],
150[0, 0, 1, 1, 1, 0, 0],
151), # Three objectives with duplicate values are included.
152],
153)
154def test_fast_non_dominated_sort(trial_values: list[float], trial_ranks: list[int]) -> None:
155ranks = list(_fast_non_dominated_sort(np.array(trial_values)))
156assert np.array_equal(ranks, trial_ranks)
157
158
159def test_fast_non_dominated_sort_invalid() -> None:
160with pytest.raises(ValueError):
161_fast_non_dominated_sort(
162np.array([[1.0, 2.0], [3.0, 4.0]]), penalty=np.array([1.0, 2.0, 3.0])
163)
164