OpenDelta
113 строк · 4.4 Кб
1
2import os
3import argparse
4import random
5import json
6from examples_prompt.search_space import AllBackboneSearchSpace, AllDeltaSearchSpace, BaseSearchSpace, DatasetSearchSpace
7import optuna
8from functools import partial
9from optuna.samplers import TPESampler
10import shutil
11import time
12
13import subprocess
14
15
16def objective_singleseed(args, unicode, search_space_sample ):
17os.mkdir(f"{args.output_dir}/{unicode}")
18search_space_sample.update({"output_dir": f"{args.output_dir}/{unicode}"})
19
20
21with open(f"{args.output_dir}/{unicode}/this_configs.json", 'w') as fout:
22json.dump(search_space_sample, fout, indent=4,sort_keys=True)
23
24
25command = "CUDA_VISIBLE_DEVICES={} ".format(args.cuda_id)
26command += f"{args.pythonpath} {args.main_file_name} "
27command += f"{args.output_dir}/{unicode}/this_configs.json"
28command += f" >> {args.output_dir}/{unicode}/output.log 2>&1"
29
30
31print("======"*5+"\n"+command)
32p = subprocess.Popen(command, cwd=f"{args.pathbase}", shell=True)
33print(f"wait for subprocess \"{command}\" to complete")
34p.wait()
35
36# if status_code != 0:
37# with open(f"{args.output_dir}/{args.cuda_id}.log",'r') as flog:
38# lastlines = " ".join(flog.readlines()[-100:])
39# if "RuntimeError: CUDA out of memory." in lastlines:
40# time.sleep(600) # sleep ten minites and try again
41# shutil.rmtree(f"{args.output_dir}/{unicode}/")
42# return objective_singleseed(args, unicode, search_space_sample)
43# else:
44# raise RuntimeError("error in {}".format(unicode))
45
46
47
48with open(f"{args.output_dir}/{unicode}/results.json", 'r') as fret:
49results =json.load(fret)
50
51for filename in os.listdir(f"{args.output_dir}/{unicode}/"):
52if not filename.endswith("this_configs.json"):
53full_file_name = f"{args.output_dir}/{unicode}/{filename}"
54if os.path.isdir(full_file_name):
55shutil.rmtree(f"{args.output_dir}/{unicode}/{filename}")
56else:
57os.remove(full_file_name)
58
59results_all_test_datasets = []
60print("results:", results)
61for datasets in results['test']:
62results_all_test_datasets.append(results['test'][datasets]['test_average_metrics'])
63
64return sum(results_all_test_datasets)/len(results_all_test_datasets)#results['test']['average_metrics']
65
66
67
68def objective(trial, args=None):
69search_space_sample = {}
70search_space_sample.update(BaseSearchSpace().get_config(trial, args))
71search_space_sample.update(AllBackboneSearchSpace[args.model_name]().get_config(trial, args))
72search_space_sample.update(DatasetSearchSpace(args.dataset).get_config(trial, args))
73search_space_sample.update(AllDeltaSearchSpace[args.delta_type]().get_config(trial, args))
74results = []
75for seed in range(42, 42+args.repeat_time):
76search_space_sample.update({"seed": seed})
77unicode = random.randint(0, 100000000)
78while os.path.exists(f"{args.output_dir}/{unicode}"):
79unicode = unicode+1
80trial.set_user_attr("trial_dir", f"{args.output_dir}/{unicode}")
81res = objective_singleseed(args, unicode = unicode, search_space_sample=search_space_sample)
82results.append(res)
83ave_res = sum(results)/len(results)
84return -ave_res
85
86
87
88
89if __name__=="__main__":
90parser = argparse.ArgumentParser()
91parser.add_argument("--delta_type")
92parser.add_argument("--dataset")
93parser.add_argument("--model_name")
94parser.add_argument("--cuda_id", type=int)
95parser.add_argument("--main_file_name", type=str)
96parser.add_argument("--study_name")
97parser.add_argument("--num_trials", type=int)
98parser.add_argument("--repeat_time", type=int)
99parser.add_argument("--optuna_seed", type=int, default="the seed to sample suggest point")
100parser.add_argument("--pathbase", type=str, default="")
101parser.add_argument("--pythonpath", type=str, default="")
102parser.add_argument("--plm_path_base", type=str, default="")
103parser.add_argument("--datasets_load_from_disk", action="store_true")
104parser.add_argument("--datasets_saved_path", type=str)
105
106args = parser.parse_args()
107
108
109setattr(args, "output_dir", f"{args.pathbase}/outputs_search/{args.study_name}")
110
111study = optuna.load_study(study_name=args.study_name, storage=f'sqlite:///{args.study_name}.db', sampler=TPESampler(seed=args.optuna_seed))
112study.optimize(partial(objective, args=args), n_trials=args.num_trials)
113
114print("complete single!")
115
116
117
118