optuna
116 строк · 4.5 Кб
1import argparse2import os3import subprocess4
5
6def run(args: argparse.Namespace) -> None:7kurobako_cmd = os.path.join(args.path_to_kurobako, "kurobako")8subprocess.run(f"{kurobako_cmd} --version", shell=True)9
10os.makedirs(args.out_dir, exist_ok=True)11study_json_filename = os.path.join(args.out_dir, "studies.json")12solvers_filename = os.path.join(args.out_dir, "solvers.json")13problems_filename = os.path.join(args.out_dir, "problems.json")14
15# Ensure all files are empty.16for filename in [study_json_filename, solvers_filename, problems_filename]:17with open(filename, "w"):18pass19
20searchspace_datasets = [21"nasbench201 cifar10",22"nasbench201 cifar100",23"nasbench201 ImageNet16-120",24]25
26for searchspace_dataset in searchspace_datasets:27python_command = f"benchmarks/naslib/problem.py {searchspace_dataset}"28cmd = (29f"{kurobako_cmd} problem command python3 {python_command}"30f"| tee -a {problems_filename}"31)32subprocess.run(cmd, shell=True)33
34# Create solvers.35sampler_list = args.sampler_list.split()36sampler_kwargs_list = args.sampler_kwargs_list.split()37pruner_list = args.pruner_list.split()38pruner_kwargs_list = args.pruner_kwargs_list.split()39
40if len(sampler_list) != len(sampler_kwargs_list):41raise ValueError(42"The number of samplers does not match the given keyword arguments. \n"43f"sampler_list: {sampler_list}, sampler_kwargs_list: {sampler_kwargs_list}."44)45
46if len(pruner_list) != len(pruner_kwargs_list):47raise ValueError(48"The number of pruners does not match the given keyword arguments. \n"49f"pruner_list: {pruner_list}, pruner_kwargs_list: {pruner_kwargs_list}."50)51
52for i, (sampler, sampler_kwargs) in enumerate(zip(sampler_list, sampler_kwargs_list)):53sampler_name = sampler54if sampler_list.count(sampler) > 1:55sampler_name += f"_{sampler_list[:i].count(sampler)}"56for j, (pruner, pruner_kwargs) in enumerate(zip(pruner_list, pruner_kwargs_list)):57pruner_name = pruner58if pruner_list.count(pruner) > 1:59pruner_name += f"_{pruner_list[:j].count(pruner)}"60name = f"{args.name_prefix}_{sampler_name}_{pruner_name}"61cmd = (62f"{kurobako_cmd} solver --name {name} optuna --loglevel debug "63f"--sampler {sampler} --sampler-kwargs {sampler_kwargs} "64f"--pruner {pruner} --pruner-kwargs {pruner_kwargs} "65f"| tee -a {solvers_filename}"66)67subprocess.run(cmd, shell=True)68
69# Create study.70cmd = (71f"{kurobako_cmd} studies --budget {args.budget} "72f"--solvers $(cat {solvers_filename}) --problems $(cat {problems_filename}) "73f"--repeats {args.n_runs} --seed {args.seed} --concurrency {args.n_concurrency} "74f"> {study_json_filename}"75)76subprocess.run(cmd, shell=True)77
78result_filename = os.path.join(args.out_dir, "results.json")79cmd = (80f"cat {study_json_filename} | {kurobako_cmd} run --parallelism {args.n_jobs} -q "81f"> {result_filename}"82)83subprocess.run(cmd, shell=True)84
85# Report.86report_filename = os.path.join(args.out_dir, "report.md")87cmd = f"cat {result_filename} | {kurobako_cmd} report > {report_filename}"88subprocess.run(cmd, shell=True)89
90cmd = (91f"cat {result_filename} | {kurobako_cmd} plot curve --errorbar -o {args.out_dir} --xmin 10"92)93subprocess.run(cmd, shell=True)94
95
96if __name__ == "__main__":97parser = argparse.ArgumentParser()98parser.add_argument("--path-to-kurobako", type=str, default="")99parser.add_argument("--name-prefix", type=str, default="")100parser.add_argument("--budget", type=int, default=100)101parser.add_argument("--n-runs", type=int, default=10)102parser.add_argument("--n-jobs", type=int, default=10)103parser.add_argument("--n-concurrency", type=int, default=1)104parser.add_argument("--sampler-list", type=str, default="RandomSampler TPESampler")105parser.add_argument(106"--sampler-kwargs-list",107type=str,108default=r"{} {\"multivariate\":true\,\"constant_liar\":true}",109)110parser.add_argument("--pruner-list", type=str, default="NopPruner")111parser.add_argument("--pruner-kwargs-list", type=str, default="{}")112parser.add_argument("--seed", type=int, default=0)113parser.add_argument("--out-dir", type=str, default="out")114args = parser.parse_args()115
116run(args)117