2
print out cmds for training and inference
7
from DPR.dpr.utils.tasks import task_map, train_cluster_map, test_cluster_map
18
return bs.join(textwrap.wrap(cmd,break_long_words=False,break_on_hyphens=False))
24
prompt_pool_dir = os.path.join(args.output_dir, 'prompt_pool')
25
random_sample_dir = os.path.join(args.output_dir, 'find_random')
26
scored_dir = os.path.join(args.output_dir, 'scored')
28
exp_name = f'train_{args.train_clusters}_test_{args.test_clusters}'
29
exp_path = os.path.join(args.output_dir, 'experiment', exp_name)
30
os.makedirs(exp_path, exist_ok=True)
32
random_port = random.randint(21966,25000)
34
if args.train_clusters is None:
35
clusters = list(train_cluster_map.keys())
37
clusters = args.train_clusters.split('+')
39
for cluster in tqdm(clusters):
40
for task in train_cluster_map[cluster]:
41
echo_cmd = f'echo "scoring {task} task of {cluster} cluster..."'
42
task_cls = task_map.cls_dic[task]()
43
prompt_pool_path = os.path.join(prompt_pool_dir, cluster, task+'_prompts.json')
44
random_sample_path = os.path.join(random_sample_dir, cluster, task+'_random_samples.json')
46
f'python find_random.py output_path=$PWD/{random_sample_path} \
47
task_name={task} +ds_size={args.ds_size} L={task_cls.finder_L} \
48
prompt_pool_path=$PWD/{prompt_pool_path} cache_dir=$PWD/{args.cache_dir}\
49
hydra.run.dir=$PWD/{exp_path}'
51
scored_train_path = os.path.join(scored_dir, cluster, task+'_scored_train.json')
52
scored_valid_path = os.path.join(scored_dir, cluster, task+'_scored_valid.json')
54
f'accelerate launch --multi_gpu --num_processes {args.gpus} --main_process_port {random_port} \
55
scorer.py example_file=$PWD/{random_sample_path} \
56
output_train_file=$PWD/{scored_train_path} \
57
output_valid_file=$PWD/{scored_valid_path} \
58
batch_size={task_cls.run_scorer_bsz} task_name={task} \
59
model_name={args.scr_model} \
60
prompt_pool_path=$PWD/{prompt_pool_path} cache_dir=$PWD/{args.cache_dir} \
61
hydra.run.dir=$PWD/{exp_path}'
63
train_cmd_list += [echo_cmd, find_random_cmd, run_scorer_cmd]
66
echo_cmd = f'echo "start training the retriever..."'
67
train_retriever_cmd = \
68
f'python DPR/train_dense_encoder.py train_datasets=[uprise_dataset] dev_datasets=[uprise_valid_dataset] \
69
train=biencoder_uprise output_dir=$PWD/{exp_path} \
70
datasets.train_clusters={args.train_clusters} \
71
datasets.train_file=$PWD/{scored_dir} \
72
datasets.valid_file=$PWD/{scored_dir} \
73
datasets.hard_neg=true datasets.multi_task={args.multi_task} \
74
datasets.top_k={args.retriever_top_k} train.hard_negatives={args.retriever_top_k} \
75
train.batch_size={args.retriever_bsz} \
76
train.num_train_epochs={args.retriever_epoch} \
77
datasets.prompt_pool_path=$PWD/{prompt_pool_dir} \
78
datasets.prompt_setup_type={args.retriever_prompt_setup} \
79
datasets.task_setup_type=q encoder.cache_dir=$PWD/{args.cache_dir}\
80
hydra.run.dir=$PWD/{exp_path}'
82
train_cmd_list += [echo_cmd, train_retriever_cmd]
85
train_cmd_list = [wrap(cmd) for cmd in train_cmd_list]
88
with open(f"{exp_path}/train.sh","w") as f:
89
f.write("\n\n".join(train_cmd_list))
90
print('saved training cmds to: ', f"{exp_path}/train.sh")
93
inference_cmd_list = []
95
echo_cmd = f'echo "encoding the whole prompt pool..."'
97
f"python DPR/generate_dense_embeddings.py model_file=$PWD/{exp_path}/dpr_biencoder.best_valid \
98
ctx_src=dpr_uprise shard_id=0 num_shards=1 \
99
out_file=$PWD/{exp_path}/dpr_enc_index \
100
ctx_sources.dpr_uprise.train_clusters={args.train_clusters} \
101
ctx_sources.dpr_uprise.prompt_pool_path=$PWD/{prompt_pool_dir} \
102
ctx_sources.dpr_uprise.prompt_setup_type={args.retriever_prompt_setup} \
103
encoder.cache_dir=$PWD/{args.cache_dir} \
104
hydra.run.dir=$PWD/{exp_path}"
106
inference_cmd_list += [echo_cmd, gen_emb_cmd]
108
def get_inference_cmd(num_prompts=3, retriever='uprise'):
110
assert retriever in [None, 'Random', 'Bm25', 'Sbert', 'Uprise']
111
random = True if retriever == "random" else False
113
echo_cmd = f'echo "running inference on {task} task of {cluster} cluster with {retriever} retriever..."'
114
pred_outpath = os.path.join(exp_path, f'preds_for_{cluster}', f'{task}_prompts{args.num_prompts}_retriever{retriever}_preds.json')
116
run_inference_cmd = \
117
f"accelerate launch --num_processes 1 --main_process_port {random_port} \
118
inference.py prompt_file=$PWD/{retrieve_prompts_outpath} \
120
output_file=$PWD/{pred_outpath} \
121
res_file=$PWD/{eval_res_outpath} \
122
batch_size={args.inference_bsz} \
123
train_clusters={args.train_clusters} \
124
model_name={args.inf_model} \
125
prompt_pool_path=$PWD/{prompt_pool_dir} \
126
num_prompts={num_prompts} \
127
random_sample={random} random_seed=42 \
128
cache_dir=$PWD/{args.cache_dir} \
129
hydra.run.dir=$PWD/{exp_path}"
131
return [echo_cmd, run_inference_cmd]
134
test_clusters = args.test_clusters.split('+')
135
for cluster in test_clusters:
136
eval_res_outpath = os.path.join(exp_path, f'eval_res_for_{cluster}.txt')
137
for task in test_cluster_map[cluster]:
138
echo_cmd = f'echo "uprise retrieves on {task} task of {cluster} cluster..."'
139
retrieve_prompts_outpath = os.path.join(exp_path, f'uprise_prompts_for_{cluster}', f'{task}_prompts.json')
140
retrieve_prompts_cmd = \
141
f'python DPR/dense_retriever.py model_file=$PWD/{exp_path}/dpr_biencoder.best_valid \
142
qa_dataset=qa_uprise ctx_datatsets=[dpr_uprise] \
143
encoded_ctx_files=["$PWD/{exp_path}/dpr_enc_index_*"]\
144
out_file=$PWD/{retrieve_prompts_outpath} \
145
datasets.qa_uprise.task_name={task} \
146
datasets.qa_uprise.task_setup_type=q \
147
datasets.qa_uprise.cache_dir=$PWD/{args.cache_dir} \
148
n_docs={args.num_prompts} \
149
ctx_sources.dpr_uprise.prompt_pool_path=$PWD/{prompt_pool_dir} \
150
ctx_sources.dpr_uprise.train_clusters={args.train_clusters} \
151
ctx_sources.dpr_uprise.prompt_setup_type={args.retriever_prompt_setup} \
152
encoder.cache_dir=$PWD/{args.cache_dir} \
153
hydra.run.dir={exp_path}'
154
inference_cmd_list += [echo_cmd, retrieve_prompts_cmd]
157
inference_cmd_list += get_inference_cmd(num_prompts=0, retriever=None)
160
inference_cmd_list += get_inference_cmd(num_prompts=args.num_prompts, retriever='Uprise')
163
if args.retrieve_random:
164
inference_cmd_list += get_inference_cmd(num_prompts=args.num_prompts, retriever='Random')
165
if args.retrieve_bm25:
166
echo_cmd = f'echo "bm25 retrieves on {task} task of {cluster} cluster..."'
167
retrieve_prompts_outpath = os.path.join(exp_path, f'bm25_prompts_for_{cluster}', f'{task}_prompts.json')
168
retrieve_bm25_prompts_cmd = \
169
f'python retrieve_bm25.py \
170
train_clusters={args.train_clusters} \
171
task_name={task} cache_dir=$PWD/{args.cache_dir} \
172
prompt_pool_path=$PWD/{prompt_pool_dir} \
173
out_file=$PWD/{retrieve_prompts_outpath} \
174
prompt_setup_type={args.retriever_prompt_setup} n_docs={args.num_prompts} \
175
hydra.run.dir=$PWD/{exp_path} '
176
inference_cmd_list += [echo_cmd, retrieve_bm25_prompts_cmd]
177
inference_cmd_list += get_inference_cmd(num_prompts=args.num_prompts, retriever='Bm25')
178
if args.retrieve_sbert:
179
echo_cmd = f'echo "sbert retrieves on {task} task of {cluster} cluster..."'
180
retrieve_prompts_outpath = os.path.join(exp_path, f'sbert_prompts_for_{cluster}', f'{task}_prompts.json')
181
retrieve_sbert_prompts_cmd = \
182
f'python retrieve_sbert.py \
183
train_clusters={args.train_clusters} \
184
task_name={task} cache_dir=$PWD/{args.cache_dir} \
185
prompt_pool_path=$PWD/{prompt_pool_dir} \
186
out_file=$PWD/{retrieve_prompts_outpath} \
187
prompt_setup_type={args.retriever_prompt_setup} n_docs={args.num_prompts} \
188
hydra.run.dir=$PWD/{exp_path} '
189
inference_cmd_list += [echo_cmd, retrieve_sbert_prompts_cmd]
190
inference_cmd_list += get_inference_cmd(num_prompts=args.num_prompts, retriever='Sbert')
192
inference_cmd_list = [wrap(cmd) for cmd in inference_cmd_list]
195
with open(f"{exp_path}/inference.sh","w") as f:
196
f.write("\n\n".join(inference_cmd_list))
197
print('saved inference cmds to: ', f"{exp_path}/inference.sh")
202
if __name__ == "__main__":
203
parser = argparse.ArgumentParser()
204
parser.add_argument('--output_dir',
205
type=str, help='Directory for saving all the intermediate and final outputs.',
207
parser.add_argument('--cache_dir',
208
type=str, help='Directory for caching the huggingface models and datasets.',
210
parser.add_argument('--gpus',
211
type=int, help='number of gpus to use',
215
parser.add_argument('--train_clusters',
217
help='a string concatenating task clusters for training, \
218
e.g., `nli+common_reason` means nli and common_reason task clusters \
219
all supoorted clusters are in DPR.dpr.utils.tasks.train_cluster_map \
220
clusters=`all supported clsuters` when the passed value is None',
222
parser.add_argument('--retriever_prompt_setup',
224
help='setup type of prompt, recommend setting as `qa` for cross-task training \
225
and `q` for task-specific training',
227
parser.add_argument('--ds_size',
229
help='number of maximum data examples sampled from each training dataset',
231
parser.add_argument('--scr_model',
233
help='Huggingface model for scoring data',
234
default="EleutherAI/gpt-neo-2.7B")
235
parser.add_argument("--multi_task",
237
help="True for multi-task and False for task-specific, \
238
the difference reflects on the sampling of negative prompts ONLY \
239
refer to `UpriseDataset` in `DPR/dpr/data/biencoder_data.py` for details")
240
parser.add_argument('--retriever_top_k',
242
help='number of k (hard) negatives for training the retriever',
244
parser.add_argument('--retriever_bsz',
246
help='sum of batch size of all gpus, NOT per gpu',
248
parser.add_argument('--retriever_epoch',
250
help='maximum training epoch, recommend setting as `3` when cross-task training, \
251
and `10` when task-specific training',
255
parser.add_argument('--inf_model',
257
help='Huggingface model for inference',
258
default="EleutherAI/gpt-neo-2.7B")
259
parser.add_argument('--test_clusters',
261
help='a string concatenating task clusters for training, \
262
e.g., `nli+common_reason` means nli and common_reason task clusters \
263
all supoorted clusters are in DPR.dpr.utils.tasks.test_cluster_map',
264
default="nli+common_reason")
265
parser.add_argument('--num_prompts',
267
help='maximum number of retrieved prompts to be concatenated before the task input',
269
parser.add_argument('--retrieve_random',
271
help='whether to random retrieve from our prompt pool, and run a baseline')
272
parser.add_argument('--retrieve_bm25',
274
help='whether to use bm25 retriever to retrieve from our prompt pool, and run a baseline')
275
parser.add_argument('--retrieve_sbert',
277
help='whether to use sbert to retrieve from our prompt pool, and run a baseline')
278
parser.add_argument('--inference_bsz',
280
help='sum of batch size of all gpus, NOT per gpu',
283
args = parser.parse_args()