optuna
91 строка · 2.7 Кб
1from __future__ import annotations
2
3import sys
4from typing import Any
5
6from kurobako import problem
7from naslib.utils import get_dataset_api
8
9
10op_names = [
11"skip_connect",
12"none",
13"nor_conv_3x3",
14"nor_conv_1x1",
15"avg_pool_3x3",
16]
17edge_num = 4 * 3 // 2
18max_epoch = 199
19
20prune_start_epoch = 10
21prune_epoch_step = 10
22
23
24class NASLibProblemFactory(problem.ProblemFactory):
25def __init__(self, dataset: str) -> None:
26"""Creates ProblemFactory for NASBench201.
27
28Args:
29dataset:
30Accepts one of "cifar10", "cifar100" or "ImageNet16-120".
31"""
32self._dataset = dataset
33if dataset == "cifar10":
34# Set name used in dataset API.
35self._dataset = "cifar10-valid"
36self._dataset_api = get_dataset_api("nasbench201", dataset)
37
38def specification(self) -> problem.ProblemSpec:
39params = [
40problem.Var(f"x{i}", problem.CategoricalRange(op_names)) for i in range(edge_num)
41]
42return problem.ProblemSpec(
43name=f"NASBench201-{self._dataset}",
44params=params,
45values=[problem.Var("value")],
46steps=list(range(prune_start_epoch, max_epoch, prune_epoch_step)) + [max_epoch],
47)
48
49def create_problem(self, seed: int) -> problem.Problem:
50return NASLibProblem(self._dataset, self._dataset_api)
51
52
53class NASLibProblem(problem.Problem):
54def __init__(self, dataset: str, dataset_api: Any) -> None:
55super().__init__()
56self._dataset = dataset
57self._dataset_api = dataset_api
58
59def create_evaluator(self, params: list[float]) -> problem.Evaluator:
60ops = [op_names[int(x)] for x in params]
61arch_str = "|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|".format(*ops)
62return NASLibEvaluator(
63self._dataset_api["nb201_data"][arch_str][self._dataset]["eval_acc1es"]
64)
65
66
67class NASLibEvaluator(problem.Evaluator):
68def __init__(self, learning_curve: list[float]) -> None:
69self._current_step = 0
70self._lc = learning_curve
71
72def current_step(self) -> int:
73return self._current_step
74
75def evaluate(self, next_step: int) -> list[float]:
76self._current_step = next_step
77return [-self._lc[next_step]]
78
79
80if __name__ == "__main__":
81if len(sys.argv) < 1 + 2:
82print("Usage: python3 nas_bench_suite/problems.py <search_space> <dataset>")
83print("Example: python3 nas_bench_suite/problems.py nasbench201 cifar10")
84exit(1)
85
86search_space_name = sys.argv[1]
87# We currently do not support other benchmarks.
88assert search_space_name == "nasbench201"
89dataset = sys.argv[2]
90runner = problem.ProblemRunner(NASLibProblemFactory(dataset))
91runner.run()
92