OpenDelta

Форк
0
113 строк · 4.4 Кб
1

2
import os
3
import argparse
4
import random
5
import json
6
from examples_prompt.search_space import AllBackboneSearchSpace, AllDeltaSearchSpace, BaseSearchSpace, DatasetSearchSpace
7
import optuna
8
from functools import partial
9
from optuna.samplers import TPESampler
10
import shutil
11
import time
12

13
import subprocess
14

15

16
def objective_singleseed(args, unicode, search_space_sample  ):
17
    os.mkdir(f"{args.output_dir}/{unicode}")
18
    search_space_sample.update({"output_dir": f"{args.output_dir}/{unicode}"})
19

20

21
    with open(f"{args.output_dir}/{unicode}/this_configs.json", 'w') as fout:
22
        json.dump(search_space_sample, fout, indent=4,sort_keys=True)
23

24

25
    command = "CUDA_VISIBLE_DEVICES={} ".format(args.cuda_id)
26
    command += f"{args.pythonpath} {args.main_file_name} "
27
    command += f"{args.output_dir}/{unicode}/this_configs.json"
28
    command += f" >> {args.output_dir}/{unicode}/output.log 2>&1"
29

30

31
    print("======"*5+"\n"+command)
32
    p = subprocess.Popen(command, cwd=f"{args.pathbase}", shell=True)
33
    print(f"wait for subprocess \"{command}\" to complete")
34
    p.wait()
35

36
    # if status_code != 0:
37
    #     with open(f"{args.output_dir}/{args.cuda_id}.log",'r') as flog:
38
    #         lastlines = " ".join(flog.readlines()[-100:])
39
    #         if "RuntimeError: CUDA out of memory." in lastlines:
40
    #             time.sleep(600)  # sleep ten minites and try again
41
    #             shutil.rmtree(f"{args.output_dir}/{unicode}/")
42
    #             return objective_singleseed(args, unicode, search_space_sample)
43
    #         else:
44
    #             raise RuntimeError("error in {}".format(unicode))
45

46

47

48
    with open(f"{args.output_dir}/{unicode}/results.json", 'r') as fret:
49
        results =json.load(fret)
50

51
    for filename in os.listdir(f"{args.output_dir}/{unicode}/"):
52
        if not filename.endswith("this_configs.json"):
53
            full_file_name = f"{args.output_dir}/{unicode}/{filename}"
54
            if os.path.isdir(full_file_name):
55
                shutil.rmtree(f"{args.output_dir}/{unicode}/{filename}")
56
            else:
57
                os.remove(full_file_name)
58

59
    results_all_test_datasets = []
60
    print("results:", results)
61
    for datasets in results['test']:
62
        results_all_test_datasets.append(results['test'][datasets]['test_average_metrics'])
63

64
    return sum(results_all_test_datasets)/len(results_all_test_datasets)#results['test']['average_metrics']
65

66

67

68
def objective(trial, args=None):
69
    search_space_sample = {}
70
    search_space_sample.update(BaseSearchSpace().get_config(trial, args))
71
    search_space_sample.update(AllBackboneSearchSpace[args.model_name]().get_config(trial, args))
72
    search_space_sample.update(DatasetSearchSpace(args.dataset).get_config(trial, args))
73
    search_space_sample.update(AllDeltaSearchSpace[args.delta_type]().get_config(trial, args))
74
    results = []
75
    for seed in range(42, 42+args.repeat_time):
76
        search_space_sample.update({"seed": seed})
77
        unicode = random.randint(0, 100000000)
78
        while os.path.exists(f"{args.output_dir}/{unicode}"):
79
            unicode = unicode+1
80
        trial.set_user_attr("trial_dir", f"{args.output_dir}/{unicode}")
81
        res = objective_singleseed(args, unicode = unicode, search_space_sample=search_space_sample)
82
        results.append(res)
83
    ave_res = sum(results)/len(results)
84
    return -ave_res
85

86

87

88

89
if __name__=="__main__":
90
    parser = argparse.ArgumentParser()
91
    parser.add_argument("--delta_type")
92
    parser.add_argument("--dataset")
93
    parser.add_argument("--model_name")
94
    parser.add_argument("--cuda_id", type=int)
95
    parser.add_argument("--main_file_name", type=str)
96
    parser.add_argument("--study_name")
97
    parser.add_argument("--num_trials", type=int)
98
    parser.add_argument("--repeat_time", type=int)
99
    parser.add_argument("--optuna_seed", type=int, default="the seed to sample suggest point")
100
    parser.add_argument("--pathbase", type=str, default="")
101
    parser.add_argument("--pythonpath", type=str, default="")
102
    parser.add_argument("--plm_path_base", type=str, default="")
103
    parser.add_argument("--datasets_load_from_disk", action="store_true")
104
    parser.add_argument("--datasets_saved_path", type=str)
105

106
    args = parser.parse_args()
107

108

109
    setattr(args, "output_dir", f"{args.pathbase}/outputs_search/{args.study_name}")
110

111
    study = optuna.load_study(study_name=args.study_name, storage=f'sqlite:///{args.study_name}.db', sampler=TPESampler(seed=args.optuna_seed))
112
    study.optimize(partial(objective, args=args), n_trials=args.num_trials)
113

114
    print("complete single!")
115

116

117

118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.