lmops

Форк
0
/
get_cmds.py 
285 строк · 14.3 Кб
1
'''
2
print out cmds for training and inference
3
'''
4

5
import argparse
6
import os
7
from DPR.dpr.utils.tasks import task_map, train_cluster_map, test_cluster_map
8
import random
9
import textwrap
10
from tqdm import tqdm
11

12

13
def wrap(cmd): 
14
    '''
15
    wrap cmd
16
    '''
17
    bs = ' \\\n\t '
18
    return bs.join(textwrap.wrap(cmd,break_long_words=False,break_on_hyphens=False))     
19

20
def get_cmds(args):
21

22
    # ================================== Train Stage ===================================
23
    # 1. random sample prompts and score data
24
    prompt_pool_dir = os.path.join(args.output_dir, 'prompt_pool')
25
    random_sample_dir = os.path.join(args.output_dir, 'find_random')
26
    scored_dir = os.path.join(args.output_dir, 'scored')
27

28
    exp_name = f'train_{args.train_clusters}_test_{args.test_clusters}'
29
    exp_path = os.path.join(args.output_dir, 'experiment', exp_name)
30
    os.makedirs(exp_path, exist_ok=True)
31

32
    random_port = random.randint(21966,25000)
33

34
    if args.train_clusters is None:
35
        clusters = list(train_cluster_map.keys())
36
    else:
37
        clusters = args.train_clusters.split('+')
38
    train_cmd_list=[]
39
    for cluster in tqdm(clusters):
40
        for task in train_cluster_map[cluster]:
41
            echo_cmd = f'echo "scoring {task} task of {cluster} cluster..."'
42
            task_cls = task_map.cls_dic[task]()
43
            prompt_pool_path = os.path.join(prompt_pool_dir, cluster, task+'_prompts.json')
44
            random_sample_path = os.path.join(random_sample_dir, cluster, task+'_random_samples.json')
45
            find_random_cmd=\
46
                f'python find_random.py output_path=$PWD/{random_sample_path} \
47
                task_name={task} +ds_size={args.ds_size} L={task_cls.finder_L} \
48
                prompt_pool_path=$PWD/{prompt_pool_path} cache_dir=$PWD/{args.cache_dir}\
49
                hydra.run.dir=$PWD/{exp_path}'
50

51
            scored_train_path = os.path.join(scored_dir, cluster, task+'_scored_train.json')
52
            scored_valid_path = os.path.join(scored_dir, cluster, task+'_scored_valid.json')
53
            run_scorer_cmd = \
54
                f'accelerate launch --multi_gpu --num_processes {args.gpus} --main_process_port {random_port} \
55
                scorer.py example_file=$PWD/{random_sample_path} \
56
                output_train_file=$PWD/{scored_train_path} \
57
                output_valid_file=$PWD/{scored_valid_path} \
58
                batch_size={task_cls.run_scorer_bsz} task_name={task}  \
59
                model_name={args.scr_model} \
60
                prompt_pool_path=$PWD/{prompt_pool_path} cache_dir=$PWD/{args.cache_dir} \
61
                hydra.run.dir=$PWD/{exp_path}'
62
            
63
            train_cmd_list += [echo_cmd, find_random_cmd, run_scorer_cmd]
64

65
    # 2. train a retriever:
66
    echo_cmd = f'echo "start training the retriever..."'
67
    train_retriever_cmd = \
68
        f'python DPR/train_dense_encoder.py train_datasets=[uprise_dataset] dev_datasets=[uprise_valid_dataset] \
69
        train=biencoder_uprise output_dir=$PWD/{exp_path} \
70
        datasets.train_clusters={args.train_clusters} \
71
        datasets.train_file=$PWD/{scored_dir} \
72
        datasets.valid_file=$PWD/{scored_dir} \
73
        datasets.hard_neg=true datasets.multi_task={args.multi_task} \
74
        datasets.top_k={args.retriever_top_k} train.hard_negatives={args.retriever_top_k} \
75
        train.batch_size={args.retriever_bsz} \
76
        train.num_train_epochs={args.retriever_epoch} \
77
        datasets.prompt_pool_path=$PWD/{prompt_pool_dir} \
78
        datasets.prompt_setup_type={args.retriever_prompt_setup} \
79
        datasets.task_setup_type=q encoder.cache_dir=$PWD/{args.cache_dir}\
80
        hydra.run.dir=$PWD/{exp_path}'
81

82
    train_cmd_list += [echo_cmd, train_retriever_cmd]
83

84
    # write train cmds in train.sh
85
    train_cmd_list = [wrap(cmd) for cmd in train_cmd_list]
86

87
    # write run.sh
88
    with open(f"{exp_path}/train.sh","w") as f:
89
        f.write("\n\n".join(train_cmd_list))
90
    print('saved training cmds to: ', f"{exp_path}/train.sh")
91

92
    # ================================== Inference Stage ===================================
93
    inference_cmd_list = []
94
    # 1. encode the whole prompt pool, using prompt encoder of the trained retriever
95
    echo_cmd = f'echo "encoding the whole prompt pool..."'
96
    gen_emb_cmd = \
97
        f"python DPR/generate_dense_embeddings.py model_file=$PWD/{exp_path}/dpr_biencoder.best_valid \
98
        ctx_src=dpr_uprise shard_id=0 num_shards=1 \
99
        out_file=$PWD/{exp_path}/dpr_enc_index \
100
        ctx_sources.dpr_uprise.train_clusters={args.train_clusters} \
101
        ctx_sources.dpr_uprise.prompt_pool_path=$PWD/{prompt_pool_dir} \
102
        ctx_sources.dpr_uprise.prompt_setup_type={args.retriever_prompt_setup} \
103
        encoder.cache_dir=$PWD/{args.cache_dir} \
104
        hydra.run.dir=$PWD/{exp_path}"
105
    
106
    inference_cmd_list += [echo_cmd, gen_emb_cmd]
107

108
    def get_inference_cmd(num_prompts=3, retriever='uprise'):
109

110
        assert retriever in [None, 'Random', 'Bm25', 'Sbert', 'Uprise']
111
        random = True if retriever == "random" else False
112

113
        echo_cmd = f'echo "running inference on {task} task of {cluster} cluster with {retriever} retriever..."'
114
        pred_outpath = os.path.join(exp_path, f'preds_for_{cluster}', f'{task}_prompts{args.num_prompts}_retriever{retriever}_preds.json')
115

116
        run_inference_cmd = \
117
            f"accelerate launch --num_processes 1  --main_process_port {random_port} \
118
            inference.py prompt_file=$PWD/{retrieve_prompts_outpath} \
119
            task_name={task} \
120
            output_file=$PWD/{pred_outpath} \
121
            res_file=$PWD/{eval_res_outpath} \
122
            batch_size={args.inference_bsz} \
123
            train_clusters={args.train_clusters} \
124
            model_name={args.inf_model} \
125
            prompt_pool_path=$PWD/{prompt_pool_dir} \
126
            num_prompts={num_prompts} \
127
            random_sample={random} random_seed=42 \
128
            cache_dir=$PWD/{args.cache_dir} \
129
            hydra.run.dir=$PWD/{exp_path}"
130
        
131
        return [echo_cmd, run_inference_cmd]
132

133
    # 2. retrieve positive prompts from the prompt pool, for each task in the testing clusters:    
134
    test_clusters = args.test_clusters.split('+')
135
    for cluster in test_clusters:
136
        eval_res_outpath = os.path.join(exp_path, f'eval_res_for_{cluster}.txt')
137
        for task in test_cluster_map[cluster]:
138
            echo_cmd = f'echo "uprise retrieves on {task} task of {cluster} cluster..."'
139
            retrieve_prompts_outpath = os.path.join(exp_path, f'uprise_prompts_for_{cluster}', f'{task}_prompts.json')
140
            retrieve_prompts_cmd = \
141
                f'python DPR/dense_retriever.py model_file=$PWD/{exp_path}/dpr_biencoder.best_valid \
142
                qa_dataset=qa_uprise ctx_datatsets=[dpr_uprise] \
143
                encoded_ctx_files=["$PWD/{exp_path}/dpr_enc_index_*"]\
144
                out_file=$PWD/{retrieve_prompts_outpath} \
145
                datasets.qa_uprise.task_name={task} \
146
                datasets.qa_uprise.task_setup_type=q  \
147
                datasets.qa_uprise.cache_dir=$PWD/{args.cache_dir} \
148
                n_docs={args.num_prompts} \
149
                ctx_sources.dpr_uprise.prompt_pool_path=$PWD/{prompt_pool_dir} \
150
                ctx_sources.dpr_uprise.train_clusters={args.train_clusters} \
151
                ctx_sources.dpr_uprise.prompt_setup_type={args.retriever_prompt_setup} \
152
                encoder.cache_dir=$PWD/{args.cache_dir} \
153
                hydra.run.dir={exp_path}'
154
            inference_cmd_list += [echo_cmd, retrieve_prompts_cmd]
155

156
            # vanilla zero shot
157
            inference_cmd_list += get_inference_cmd(num_prompts=0, retriever=None)
158

159
            # uprise zero shot
160
            inference_cmd_list += get_inference_cmd(num_prompts=args.num_prompts, retriever='Uprise')
161

162
            # Ablations: replace uprise retriever with random, bm25 and sbert
163
            if args.retrieve_random:
164
                inference_cmd_list += get_inference_cmd(num_prompts=args.num_prompts, retriever='Random')
165
            if args.retrieve_bm25:
166
                echo_cmd = f'echo "bm25 retrieves on {task} task of {cluster} cluster..."'
167
                retrieve_prompts_outpath = os.path.join(exp_path, f'bm25_prompts_for_{cluster}', f'{task}_prompts.json')
168
                retrieve_bm25_prompts_cmd = \
169
                    f'python retrieve_bm25.py \
170
                    train_clusters={args.train_clusters} \
171
                    task_name={task} cache_dir=$PWD/{args.cache_dir} \
172
                    prompt_pool_path=$PWD/{prompt_pool_dir} \
173
                    out_file=$PWD/{retrieve_prompts_outpath} \
174
                    prompt_setup_type={args.retriever_prompt_setup} n_docs={args.num_prompts} \
175
                    hydra.run.dir=$PWD/{exp_path} '
176
                inference_cmd_list += [echo_cmd, retrieve_bm25_prompts_cmd]
177
                inference_cmd_list += get_inference_cmd(num_prompts=args.num_prompts, retriever='Bm25')
178
            if args.retrieve_sbert:
179
                echo_cmd = f'echo "sbert retrieves on {task} task of {cluster} cluster..."'
180
                retrieve_prompts_outpath = os.path.join(exp_path, f'sbert_prompts_for_{cluster}', f'{task}_prompts.json')
181
                retrieve_sbert_prompts_cmd = \
182
                    f'python retrieve_sbert.py \
183
                    train_clusters={args.train_clusters} \
184
                    task_name={task} cache_dir=$PWD/{args.cache_dir} \
185
                    prompt_pool_path=$PWD/{prompt_pool_dir} \
186
                    out_file=$PWD/{retrieve_prompts_outpath} \
187
                    prompt_setup_type={args.retriever_prompt_setup} n_docs={args.num_prompts} \
188
                    hydra.run.dir=$PWD/{exp_path} '
189
                inference_cmd_list += [echo_cmd, retrieve_sbert_prompts_cmd]
190
                inference_cmd_list += get_inference_cmd(num_prompts=args.num_prompts, retriever='Sbert')
191
                
192
    inference_cmd_list = [wrap(cmd) for cmd in inference_cmd_list]
193

194
    # write run.sh
195
    with open(f"{exp_path}/inference.sh","w") as f:
196
        f.write("\n\n".join(inference_cmd_list))
197
    print('saved inference cmds to: ', f"{exp_path}/inference.sh")
198

199
    return     
200

201
    
202
if __name__ == "__main__":
203
    parser = argparse.ArgumentParser()
204
    parser.add_argument('--output_dir', 
205
                        type=str, help='Directory for saving all the intermediate and final outputs.', 
206
                        default='my_data')
207
    parser.add_argument('--cache_dir', 
208
                        type=str, help='Directory for caching the huggingface models and datasets.', 
209
                        default='../cache')
210
    parser.add_argument('--gpus', 
211
                        type=int, help='number of gpus to use',
212
                        default=8)
213
    
214
    # training
215
    parser.add_argument('--train_clusters', 
216
                        type=str, 
217
                        help='a string concatenating task clusters for training, \
218
                            e.g., `nli+common_reason` means nli and common_reason task clusters \
219
                            all supoorted clusters are in DPR.dpr.utils.tasks.train_cluster_map \
220
                            clusters=`all supported clsuters` when the passed value is None',
221
                        default=None)
222
    parser.add_argument('--retriever_prompt_setup', 
223
                        type=str,
224
                        help='setup type of prompt, recommend setting as `qa` for cross-task training \
225
                            and `q` for task-specific training',
226
                        default="qa")
227
    parser.add_argument('--ds_size', 
228
                        type=int,
229
                        help='number of maximum data examples sampled from each training dataset',
230
                        default=10000)
231
    parser.add_argument('--scr_model', 
232
                        type=str,
233
                        help='Huggingface model for scoring data',
234
                        default="EleutherAI/gpt-neo-2.7B")
235
    parser.add_argument("--multi_task", 
236
                        action="store_true", 
237
                        help="True for multi-task and False for task-specific, \
238
                            the difference reflects on the sampling of negative prompts ONLY \
239
                            refer to `UpriseDataset` in `DPR/dpr/data/biencoder_data.py` for details")
240
    parser.add_argument('--retriever_top_k', 
241
                        type=int,
242
                        help='number of k (hard) negatives for training the retriever',
243
                        default=20)
244
    parser.add_argument('--retriever_bsz', 
245
                        type=int,
246
                        help='sum of batch size of all gpus, NOT per gpu',
247
                        default=16)
248
    parser.add_argument('--retriever_epoch', 
249
                        type=int,
250
                        help='maximum training epoch, recommend setting as `3` when cross-task training, \
251
                             and `10` when task-specific training',
252
                        default=3)
253
    
254
    # inference
255
    parser.add_argument('--inf_model', 
256
                        type=str,
257
                        help='Huggingface model for inference',
258
                        default="EleutherAI/gpt-neo-2.7B")
259
    parser.add_argument('--test_clusters', 
260
                        type=str, 
261
                        help='a string concatenating task clusters for training, \
262
                            e.g., `nli+common_reason` means nli and common_reason task clusters \
263
                            all supoorted clusters are in DPR.dpr.utils.tasks.test_cluster_map',
264
                        default="nli+common_reason")
265
    parser.add_argument('--num_prompts', 
266
                        type=int, 
267
                        help='maximum number of retrieved prompts to be concatenated before the task input',
268
                        default=3)
269
    parser.add_argument('--retrieve_random', 
270
                        action="store_true", 
271
                        help='whether to random retrieve from our prompt pool, and run a baseline')
272
    parser.add_argument('--retrieve_bm25', 
273
                        action="store_true", 
274
                        help='whether to use bm25 retriever to retrieve from our prompt pool, and run a baseline')
275
    parser.add_argument('--retrieve_sbert', 
276
                        action="store_true", 
277
                        help='whether to use sbert to retrieve from our prompt pool, and run a baseline')
278
    parser.add_argument('--inference_bsz', 
279
                        type=int,
280
                        help='sum of batch size of all gpus, NOT per gpu',
281
                        default=1)
282

283
    args = parser.parse_args()
284

285
    get_cmds(args)
286
    

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.