optuna

Форк
0
/
optimize.py 
101 строка · 3.3 Кб
1
from __future__ import annotations
2

3
from typing import cast
4

5
import optuna
6
from optuna.samplers import BaseSampler
7
from optuna.samplers import CmaEsSampler
8
from optuna.samplers import NSGAIISampler
9
from optuna.samplers import RandomSampler
10
from optuna.samplers import TPESampler
11
from optuna.testing.storages import StorageSupplier
12

13

14
def parse_args(args: str) -> list[int | str]:
15
    ret: list[int | str] = []
16
    for arg in map(lambda s: s.strip(), args.split(",")):
17
        try:
18
            ret.append(int(arg))
19
        except ValueError:
20
            ret.append(arg)
21
    return ret
22

23

24
SAMPLER_MODES = [
25
    "random",
26
    "tpe",
27
    "cmaes",
28
]
29

30

31
def create_sampler(sampler_mode: str) -> BaseSampler:
32
    if sampler_mode == "random":
33
        return RandomSampler()
34
    elif sampler_mode == "tpe":
35
        return TPESampler()
36
    elif sampler_mode == "cmaes":
37
        return CmaEsSampler()
38
    elif sampler_mode == "nsgaii":
39
        return NSGAIISampler()
40
    else:
41
        assert False
42

43

44
class OptimizeSuite:
45
    def objective(self, trial: optuna.Trial) -> float:
46
        x = trial.suggest_float("x", -100, 100)
47
        y = trial.suggest_int("y", -100, 100)
48
        return x**2 + y**2
49

50
    def multi_objective(self, trial: optuna.Trial) -> tuple[float, float]:
51
        x = trial.suggest_float("x", -100, 100)
52
        y = trial.suggest_int("y", -100, 100)
53
        return (x**2 + y**2, (x - 2) ** 2 + (y - 2) ** 2)
54

55
    def optimize(
56
        self, storage_mode: str, sampler_mode: str, n_trials: int, objective_type: str
57
    ) -> None:
58
        with StorageSupplier(storage_mode) as storage:
59
            sampler = create_sampler(sampler_mode)
60
            if objective_type == "single":
61
                directions = ["minimize"]
62
            elif objective_type == "multi":
63
                directions = ["minimize", "minimize"]
64
            else:
65
                assert False
66
            study = optuna.create_study(storage=storage, sampler=sampler, directions=directions)
67
            if objective_type == "single":
68
                study.optimize(self.objective, n_trials=n_trials)
69
            elif objective_type == "multi":
70
                study.optimize(self.multi_objective, n_trials=n_trials)
71
            else:
72
                assert False
73

74
    def time_optimize(self, args: str) -> None:
75
        storage_mode, sampler_mode, n_trials, objective_type = parse_args(args)
76
        storage_mode = cast(str, storage_mode)
77
        sampler_mode = cast(str, sampler_mode)
78
        n_trials = cast(int, n_trials)
79
        objective_type = cast(str, objective_type)
80
        self.optimize(storage_mode, sampler_mode, n_trials, objective_type)
81

82
    params = (
83
        "inmemory, random, 1000, single",
84
        "inmemory, random, 10000, single",
85
        "inmemory, tpe, 1000, single",
86
        "inmemory, cmaes, 1000, single",
87
        "sqlite, random, 1000, single",
88
        "sqlite, tpe, 1000, single",
89
        "sqlite, cmaes, 1000, single",
90
        "journal, random, 1000, single",
91
        "journal, tpe, 1000, single",
92
        "journal, cmaes, 1000, single",
93
        "inmemory, tpe, 1000, multi",
94
        "inmemory, nsgaii, 1000, multi",
95
        "sqlite, tpe, 1000, multi",
96
        "sqlite, nsgaii, 1000, multi",
97
        "journal, tpe, 1000, multi",
98
        "journal, nsgaii, 1000, multi",
99
    )
100
    param_names = ["storage, sampler, n_trials, objective_type"]
101
    timeout = 600
102

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.