OpenDelta

Форк
0
/
search_distributed.py 
127 строк · 5.6 Кб
1
import optuna
2
import argparse
3
import os
4
import shutil
5
import subprocess
6

7

8

9

10
if __name__=="__main__":
11
    parser = argparse.ArgumentParser()
12
    parser.add_argument("--delta_type")
13
    parser.add_argument("--dataset")
14
    parser.add_argument("--model_name")
15
    parser.add_argument("--study_name", type=str, default=None)
16
    parser.add_argument("--cuda_ids", nargs='+', help="list")
17
    parser.add_argument("--mode", type=str, default="run", help="select from 'run' and 'read' ")
18
    parser.add_argument("--continue_study", type=int, default=0)
19
    parser.add_argument("--substudy_prefix", type=str, default="")
20
    parser.add_argument("--main_file_name", type=str)
21
    parser.add_argument("--num_trials", type=int)
22
    parser.add_argument("--pathbase", type=str, default="")
23
    parser.add_argument("--pythonpath", type=str, default="python")
24
    parser.add_argument("--plm_path_base", type=str, default="", help="The path where we cache the plms. Must be empty string or dir that ends with /")
25
    parser.add_argument("--datasets_load_from_disk", action="store_true")
26
    parser.add_argument("--datasets_saved_path", type=str)
27
    parser.add_argument("--repeat_time", type=int, default=1)
28
    args = parser.parse_args()
29

30

31
    pardir = ".".join([args.delta_type, args.dataset, args.model_name])
32
    if args.study_name is None:
33
        args.study_name = pardir
34
    else:
35
        args.study_name += pardir
36

37
    setattr(args, "output_dir", f"{args.pathbase}/outputs_search/{args.study_name}")
38

39

40

41
    if args.mode == "run":
42
        if args.continue_study==1:
43
            print("Continue study!")
44
        else:
45
            print("Creat new study!")
46

47
        if not os.path.exists(f"{args.output_dir}"):
48
            os.mkdir(f"{args.output_dir}")
49
        else:
50
            if not args.continue_study:
51
                user_cmd = "yes" #input("Detected existing study, are you sure to create new by removing old? [Yes/No]")
52

53
                while user_cmd.lower() not in ["yes", "no"]:
54
                    print("Please input Yes/No")
55
                    user_cmd = input("Detected existing study, are you sure to create new by removing old? [Yes/No]")
56

57
                if user_cmd.lower() == "no":
58
                    exit()
59
                shutil.rmtree(f"{args.output_dir}")
60
                os.mkdir(f"{args.output_dir}")
61

62
        try:
63
            study = optuna.create_study(study_name=args.study_name, storage=f"sqlite:///{args.study_name}.db")
64
        except optuna.exceptions.DuplicatedStudyError:
65
            if not args.continue_study:
66
                optuna.delete_study(study_name=args.study_name, storage=f"sqlite:///{args.study_name}.db")
67
                study = optuna.create_study(study_name=args.study_name, storage=f"sqlite:///{args.study_name}.db")
68
            else:
69
                pass # no need to create study
70

71
        tot_chunk_num = len(args.cuda_ids)
72

73
        subprocesses = []
74
        for id, cudas in enumerate(args.cuda_ids):
75
            if id+1 < tot_chunk_num:
76
                sub_n_trials = args.num_trials//tot_chunk_num
77
            else:
78
                sub_n_trials = args.num_trials//tot_chunk_num + args.num_trials%tot_chunk_num
79

80
            command =  f"{args.pythonpath} search_single.py "
81
            command += f"--cuda_id {cudas} "
82
            command += f"--model_name {args.model_name} "
83
            command += f"--dataset {args.dataset} "
84
            command += f"--delta_type {args.delta_type} "
85
            command += f"--study_name {args.study_name} "
86
            command += f"--optuna_seed 10{id} "
87
            command += f"--main_file_name {args.main_file_name} "
88
            command += f"--num_trials {sub_n_trials} "
89
            command += f"--pythonpath {args.pythonpath} "
90
            command += f"--pathbase {args.pathbase} "
91
            command += f"--repeat_time {args.repeat_time} "
92
            command += f"--plm_path_base {args.plm_path_base} "
93
            command += f"--datasets_saved_path {args.datasets_saved_path} "
94
            if args.datasets_load_from_disk:
95
                command += f"--datasets_load_from_disk "
96
            command += f"> {args.output_dir}/{args.substudy_prefix}{id}.log 2>&1"
97
            p = subprocess.Popen(command, cwd=f"{args.pathbase}", shell=True)
98
            subprocesses.append(p)
99
            print("id {} on cuda:{}, pid {}".format(id, cudas, p.pid))
100
            print(command)
101
            print()
102

103
        print("Wait for subprocesses to complete")
104
        exit_codes = [p.wait() for p in subprocesses]
105
        print("All complete!")
106

107
    elif args.mode == 'read':
108
        study = optuna.load_study(study_name=args.study_name, storage=f"sqlite:///{args.study_name}.db")
109
        trial = study.best_trial
110
        finished = (len(study.trials) == args.num_trials)
111
        print("total num_trials: {}, {}".format(len(study.trials), "Finished!" if finished else "Not finished..." ))
112
        print("average acc {}".format(-trial.value))
113
        print("best config {}".format(trial.params))
114

115
        best_trial_dir = trial.user_attrs["trial_dir"]
116
        shutil.copyfile(f"{best_trial_dir}/this_configs.json", f"{args.output_dir}/best_config.json")
117

118
        plot_history = optuna.visualization.plot_optimization_history(study)
119
        plot_slice = optuna.visualization.plot_slice(study)
120
        plot_contour = optuna.visualization.plot_contour(study, params=['learning_rate', 'batch_size_base'])
121
        plot_contour2 = optuna.visualization.plot_contour(study, params=['learning_rate', 'warmup_steps'])
122

123

124
        plot_history.write_image(f"{args.output_dir}/history.png")
125
        plot_slice.write_image(f"{args.output_dir}/slice.png")
126
        plot_contour.write_image(f"{args.output_dir}/contour.png")
127
        plot_contour2.write_image(f"{args.output_dir}/contour2.png")
128

129

130

131

132

133

134

135

136

137

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.