optuna

Форк
0
/
run_naslib.py 
116 строк · 4.5 Кб
1
import argparse
2
import os
3
import subprocess
4

5

6
def run(args: argparse.Namespace) -> None:
7
    kurobako_cmd = os.path.join(args.path_to_kurobako, "kurobako")
8
    subprocess.run(f"{kurobako_cmd} --version", shell=True)
9

10
    os.makedirs(args.out_dir, exist_ok=True)
11
    study_json_filename = os.path.join(args.out_dir, "studies.json")
12
    solvers_filename = os.path.join(args.out_dir, "solvers.json")
13
    problems_filename = os.path.join(args.out_dir, "problems.json")
14

15
    # Ensure all files are empty.
16
    for filename in [study_json_filename, solvers_filename, problems_filename]:
17
        with open(filename, "w"):
18
            pass
19

20
    searchspace_datasets = [
21
        "nasbench201 cifar10",
22
        "nasbench201 cifar100",
23
        "nasbench201 ImageNet16-120",
24
    ]
25

26
    for searchspace_dataset in searchspace_datasets:
27
        python_command = f"benchmarks/naslib/problem.py {searchspace_dataset}"
28
        cmd = (
29
            f"{kurobako_cmd} problem command python3 {python_command}"
30
            f"| tee -a {problems_filename}"
31
        )
32
        subprocess.run(cmd, shell=True)
33

34
    # Create solvers.
35
    sampler_list = args.sampler_list.split()
36
    sampler_kwargs_list = args.sampler_kwargs_list.split()
37
    pruner_list = args.pruner_list.split()
38
    pruner_kwargs_list = args.pruner_kwargs_list.split()
39

40
    if len(sampler_list) != len(sampler_kwargs_list):
41
        raise ValueError(
42
            "The number of samplers does not match the given keyword arguments. \n"
43
            f"sampler_list: {sampler_list}, sampler_kwargs_list: {sampler_kwargs_list}."
44
        )
45

46
    if len(pruner_list) != len(pruner_kwargs_list):
47
        raise ValueError(
48
            "The number of pruners does not match the given keyword arguments. \n"
49
            f"pruner_list: {pruner_list}, pruner_kwargs_list: {pruner_kwargs_list}."
50
        )
51

52
    for i, (sampler, sampler_kwargs) in enumerate(zip(sampler_list, sampler_kwargs_list)):
53
        sampler_name = sampler
54
        if sampler_list.count(sampler) > 1:
55
            sampler_name += f"_{sampler_list[:i].count(sampler)}"
56
        for j, (pruner, pruner_kwargs) in enumerate(zip(pruner_list, pruner_kwargs_list)):
57
            pruner_name = pruner
58
            if pruner_list.count(pruner) > 1:
59
                pruner_name += f"_{pruner_list[:j].count(pruner)}"
60
            name = f"{args.name_prefix}_{sampler_name}_{pruner_name}"
61
            cmd = (
62
                f"{kurobako_cmd} solver --name {name} optuna --loglevel debug "
63
                f"--sampler {sampler} --sampler-kwargs {sampler_kwargs} "
64
                f"--pruner {pruner} --pruner-kwargs {pruner_kwargs} "
65
                f"| tee -a {solvers_filename}"
66
            )
67
            subprocess.run(cmd, shell=True)
68

69
    # Create study.
70
    cmd = (
71
        f"{kurobako_cmd} studies --budget {args.budget} "
72
        f"--solvers $(cat {solvers_filename}) --problems $(cat {problems_filename}) "
73
        f"--repeats {args.n_runs} --seed {args.seed} --concurrency {args.n_concurrency} "
74
        f"> {study_json_filename}"
75
    )
76
    subprocess.run(cmd, shell=True)
77

78
    result_filename = os.path.join(args.out_dir, "results.json")
79
    cmd = (
80
        f"cat {study_json_filename} | {kurobako_cmd} run --parallelism {args.n_jobs} -q "
81
        f"> {result_filename}"
82
    )
83
    subprocess.run(cmd, shell=True)
84

85
    # Report.
86
    report_filename = os.path.join(args.out_dir, "report.md")
87
    cmd = f"cat {result_filename} | {kurobako_cmd} report > {report_filename}"
88
    subprocess.run(cmd, shell=True)
89

90
    cmd = (
91
        f"cat {result_filename} | {kurobako_cmd} plot curve --errorbar -o {args.out_dir} --xmin 10"
92
    )
93
    subprocess.run(cmd, shell=True)
94

95

96
if __name__ == "__main__":
97
    parser = argparse.ArgumentParser()
98
    parser.add_argument("--path-to-kurobako", type=str, default="")
99
    parser.add_argument("--name-prefix", type=str, default="")
100
    parser.add_argument("--budget", type=int, default=100)
101
    parser.add_argument("--n-runs", type=int, default=10)
102
    parser.add_argument("--n-jobs", type=int, default=10)
103
    parser.add_argument("--n-concurrency", type=int, default=1)
104
    parser.add_argument("--sampler-list", type=str, default="RandomSampler TPESampler")
105
    parser.add_argument(
106
        "--sampler-kwargs-list",
107
        type=str,
108
        default=r"{} {\"multivariate\":true\,\"constant_liar\":true}",
109
    )
110
    parser.add_argument("--pruner-list", type=str, default="NopPruner")
111
    parser.add_argument("--pruner-kwargs-list", type=str, default="{}")
112
    parser.add_argument("--seed", type=int, default=0)
113
    parser.add_argument("--out-dir", type=str, default="out")
114
    args = parser.parse_args()
115

116
    run(args)
117

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.