OpenDelta
127 строк · 5.6 Кб
1import optuna
2import argparse
3import os
4import shutil
5import subprocess
6
7
8
9
10if __name__=="__main__":
11parser = argparse.ArgumentParser()
12parser.add_argument("--delta_type")
13parser.add_argument("--dataset")
14parser.add_argument("--model_name")
15parser.add_argument("--study_name", type=str, default=None)
16parser.add_argument("--cuda_ids", nargs='+', help="list")
17parser.add_argument("--mode", type=str, default="run", help="select from 'run' and 'read' ")
18parser.add_argument("--continue_study", type=int, default=0)
19parser.add_argument("--substudy_prefix", type=str, default="")
20parser.add_argument("--main_file_name", type=str)
21parser.add_argument("--num_trials", type=int)
22parser.add_argument("--pathbase", type=str, default="")
23parser.add_argument("--pythonpath", type=str, default="python")
24parser.add_argument("--plm_path_base", type=str, default="", help="The path where we cache the plms. Must be empty string or dir that ends with /")
25parser.add_argument("--datasets_load_from_disk", action="store_true")
26parser.add_argument("--datasets_saved_path", type=str)
27parser.add_argument("--repeat_time", type=int, default=1)
28args = parser.parse_args()
29
30
31pardir = ".".join([args.delta_type, args.dataset, args.model_name])
32if args.study_name is None:
33args.study_name = pardir
34else:
35args.study_name += pardir
36
37setattr(args, "output_dir", f"{args.pathbase}/outputs_search/{args.study_name}")
38
39
40
41if args.mode == "run":
42if args.continue_study==1:
43print("Continue study!")
44else:
45print("Creat new study!")
46
47if not os.path.exists(f"{args.output_dir}"):
48os.mkdir(f"{args.output_dir}")
49else:
50if not args.continue_study:
51user_cmd = "yes" #input("Detected existing study, are you sure to create new by removing old? [Yes/No]")
52
53while user_cmd.lower() not in ["yes", "no"]:
54print("Please input Yes/No")
55user_cmd = input("Detected existing study, are you sure to create new by removing old? [Yes/No]")
56
57if user_cmd.lower() == "no":
58exit()
59shutil.rmtree(f"{args.output_dir}")
60os.mkdir(f"{args.output_dir}")
61
62try:
63study = optuna.create_study(study_name=args.study_name, storage=f"sqlite:///{args.study_name}.db")
64except optuna.exceptions.DuplicatedStudyError:
65if not args.continue_study:
66optuna.delete_study(study_name=args.study_name, storage=f"sqlite:///{args.study_name}.db")
67study = optuna.create_study(study_name=args.study_name, storage=f"sqlite:///{args.study_name}.db")
68else:
69pass # no need to create study
70
71tot_chunk_num = len(args.cuda_ids)
72
73subprocesses = []
74for id, cudas in enumerate(args.cuda_ids):
75if id+1 < tot_chunk_num:
76sub_n_trials = args.num_trials//tot_chunk_num
77else:
78sub_n_trials = args.num_trials//tot_chunk_num + args.num_trials%tot_chunk_num
79
80command = f"{args.pythonpath} search_single.py "
81command += f"--cuda_id {cudas} "
82command += f"--model_name {args.model_name} "
83command += f"--dataset {args.dataset} "
84command += f"--delta_type {args.delta_type} "
85command += f"--study_name {args.study_name} "
86command += f"--optuna_seed 10{id} "
87command += f"--main_file_name {args.main_file_name} "
88command += f"--num_trials {sub_n_trials} "
89command += f"--pythonpath {args.pythonpath} "
90command += f"--pathbase {args.pathbase} "
91command += f"--repeat_time {args.repeat_time} "
92command += f"--plm_path_base {args.plm_path_base} "
93command += f"--datasets_saved_path {args.datasets_saved_path} "
94if args.datasets_load_from_disk:
95command += f"--datasets_load_from_disk "
96command += f"> {args.output_dir}/{args.substudy_prefix}{id}.log 2>&1"
97p = subprocess.Popen(command, cwd=f"{args.pathbase}", shell=True)
98subprocesses.append(p)
99print("id {} on cuda:{}, pid {}".format(id, cudas, p.pid))
100print(command)
101print()
102
103print("Wait for subprocesses to complete")
104exit_codes = [p.wait() for p in subprocesses]
105print("All complete!")
106
107elif args.mode == 'read':
108study = optuna.load_study(study_name=args.study_name, storage=f"sqlite:///{args.study_name}.db")
109trial = study.best_trial
110finished = (len(study.trials) == args.num_trials)
111print("total num_trials: {}, {}".format(len(study.trials), "Finished!" if finished else "Not finished..." ))
112print("average acc {}".format(-trial.value))
113print("best config {}".format(trial.params))
114
115best_trial_dir = trial.user_attrs["trial_dir"]
116shutil.copyfile(f"{best_trial_dir}/this_configs.json", f"{args.output_dir}/best_config.json")
117
118plot_history = optuna.visualization.plot_optimization_history(study)
119plot_slice = optuna.visualization.plot_slice(study)
120plot_contour = optuna.visualization.plot_contour(study, params=['learning_rate', 'batch_size_base'])
121plot_contour2 = optuna.visualization.plot_contour(study, params=['learning_rate', 'warmup_steps'])
122
123
124plot_history.write_image(f"{args.output_dir}/history.png")
125plot_slice.write_image(f"{args.output_dir}/slice.png")
126plot_contour.write_image(f"{args.output_dir}/contour.png")
127plot_contour2.write_image(f"{args.output_dir}/contour2.png")
128
129
130
131
132
133
134
135
136
137