optuna
43 строки · 968.0 Байт
1import json
2import sys
3
4from kurobako import solver
5from kurobako.solver.optuna import OptunaSolverFactory
6import optuna
7
8
9optuna.logging.disable_default_handler()
10
11
12def create_study(seed: int) -> optuna.Study:
13# Avoid the fail by `flake8`.
14seed
15
16n_objectives = 2
17directions = ["minimize"] * n_objectives
18
19sampler_name = sys.argv[1]
20
21# Sampler.
22sampler_cls = getattr(
23optuna.samplers,
24sampler_name,
25getattr(optuna.integration, sampler_name, None),
26)
27if sampler_cls is None:
28raise ValueError("Unknown sampler: {}.".format(sampler_name))
29
30sampler_kwargs = json.loads(sys.argv[2])
31sampler = sampler_cls(**sampler_kwargs)
32
33return optuna.create_study(
34directions=directions,
35sampler=sampler,
36pruner=optuna.pruners.NopPruner(),
37)
38
39
40if __name__ == "__main__":
41factory = OptunaSolverFactory(create_study)
42runner = solver.SolverRunner(factory)
43runner.run()
44