optuna

Форк
0
/
problem.py 
91 строка · 2.7 Кб
1
from __future__ import annotations
2

3
import sys
4
from typing import Any
5

6
from kurobako import problem
7
from naslib.utils import get_dataset_api
8

9

10
op_names = [
11
    "skip_connect",
12
    "none",
13
    "nor_conv_3x3",
14
    "nor_conv_1x1",
15
    "avg_pool_3x3",
16
]
17
edge_num = 4 * 3 // 2
18
max_epoch = 199
19

20
prune_start_epoch = 10
21
prune_epoch_step = 10
22

23

24
class NASLibProblemFactory(problem.ProblemFactory):
25
    def __init__(self, dataset: str) -> None:
26
        """Creates ProblemFactory for NASBench201.
27

28
        Args:
29
            dataset:
30
                Accepts one of "cifar10", "cifar100" or "ImageNet16-120".
31
        """
32
        self._dataset = dataset
33
        if dataset == "cifar10":
34
            # Set name used in dataset API.
35
            self._dataset = "cifar10-valid"
36
        self._dataset_api = get_dataset_api("nasbench201", dataset)
37

38
    def specification(self) -> problem.ProblemSpec:
39
        params = [
40
            problem.Var(f"x{i}", problem.CategoricalRange(op_names)) for i in range(edge_num)
41
        ]
42
        return problem.ProblemSpec(
43
            name=f"NASBench201-{self._dataset}",
44
            params=params,
45
            values=[problem.Var("value")],
46
            steps=list(range(prune_start_epoch, max_epoch, prune_epoch_step)) + [max_epoch],
47
        )
48

49
    def create_problem(self, seed: int) -> problem.Problem:
50
        return NASLibProblem(self._dataset, self._dataset_api)
51

52

53
class NASLibProblem(problem.Problem):
54
    def __init__(self, dataset: str, dataset_api: Any) -> None:
55
        super().__init__()
56
        self._dataset = dataset
57
        self._dataset_api = dataset_api
58

59
    def create_evaluator(self, params: list[float]) -> problem.Evaluator:
60
        ops = [op_names[int(x)] for x in params]
61
        arch_str = "|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|".format(*ops)
62
        return NASLibEvaluator(
63
            self._dataset_api["nb201_data"][arch_str][self._dataset]["eval_acc1es"]
64
        )
65

66

67
class NASLibEvaluator(problem.Evaluator):
68
    def __init__(self, learning_curve: list[float]) -> None:
69
        self._current_step = 0
70
        self._lc = learning_curve
71

72
    def current_step(self) -> int:
73
        return self._current_step
74

75
    def evaluate(self, next_step: int) -> list[float]:
76
        self._current_step = next_step
77
        return [-self._lc[next_step]]
78

79

80
if __name__ == "__main__":
81
    if len(sys.argv) < 1 + 2:
82
        print("Usage: python3 nas_bench_suite/problems.py <search_space> <dataset>")
83
        print("Example: python3 nas_bench_suite/problems.py nasbench201 cifar10")
84
        exit(1)
85

86
    search_space_name = sys.argv[1]
87
    # We currently do not support other benchmarks.
88
    assert search_space_name == "nasbench201"
89
    dataset = sys.argv[2]
90
    runner = problem.ProblemRunner(NASLibProblemFactory(dataset))
91
    runner.run()
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.