openprompt
442 строки · 18.8 Кб
1# There is a recent trend to unify all tasks (such as classification) into tasks generation.
2# In fact, unifying the tasks into text generation can be neatly conducted using prompt.
3# In OpenPrompt, we provide a GenerationVerbalizer for this utility.
4# Here we go!
5
6from openprompt.pipeline_base import PromptForGeneration
7from openprompt.prompts.generation_verbalizer import GenerationVerbalizer
8from tokenizers import PreTokenizedString
9from tqdm import tqdm
10from openprompt.data_utils import PROCESSORS
11import torch
12from openprompt.data_utils.utils import InputExample
13import argparse
14import numpy as np
15
16from openprompt import PromptDataLoader
17from openprompt.prompts import ManualVerbalizer
18from openprompt.prompts import SoftTemplate
19from openprompt import PromptForClassification
20import time
21import os
22import re
23from openprompt.utils.crossfit_metrics import evaluate as crossfit_evaluate
24
25
26parser = argparse.ArgumentParser("")
27parser.add_argument("--shot", type=int, default=-1)
28parser.add_argument("--seed", type=int, default=144)
29parser.add_argument("--plm_eval_mode", action="store_true", help="whether to turn off the dropout in the freezed model. Set to true to turn off.")
30parser.add_argument("--tune_plm", action="store_true", help="Whether to tune the plm, default to False")
31parser.add_argument("--model", type=str, default='t5-lm', help="We test both t5 and t5-lm in this scripts, the corresponding tokenizerwrapper will be automatically loaded.")
32parser.add_argument("--model_name_or_path", default='../../plm_cache/t5-large-lm-adapt/')
33parser.add_argument("--project_root", default="/home/hushengding/OpenPrompt_CameraReady/OpenPrompt/", help="The project root in the file system, i.e. the absolute path of OpenPrompt")
34parser.add_argument("--template_id", type=int, default=0)
35parser.add_argument("--verbalizer_id", type=int, default=0)
36parser.add_argument("--data_dir", type=str, default="../../huggingface_datasets/saved_to_disk/") # sometimes, huggingface datasets can not be automatically downloaded due to network issue, please refer to 0_basic.py line 15 for solutions.
37parser.add_argument("--dataset",type=str)
38parser.add_argument("--result_file", type=str, default="../results.txt")
39parser.add_argument("--max_steps", default=20000, type=int)
40parser.add_argument("--prompt_lr", type=float, default=0.3)
41parser.add_argument("--warmup_step_prompt", type=int, default=500)
42parser.add_argument("--init_from_vocab", action="store_false")
43parser.add_argument("--eval_every_steps", type=int, default=500)
44parser.add_argument("--soft_token_num", type=int, default=100)
45parser.add_argument("--optimizer", type=str, default="Adafactor")
46args = parser.parse_args()
47
48args.result_file = os.path.join(args.project_root, args.result_file)
49
50content_write = "="*20+"\n"
51content_write += f"dataset {args.dataset}\t"
52content_write += f"temp {args.template_id}\t"
53content_write += f"verb {args.verbalizer_id}\t"
54content_write += f"model {args.model}\t"
55content_write += f"seed {args.seed}\t"
56content_write += f"shot {args.shot}\t"
57content_write += f"plm_eval_mode {args.plm_eval_mode}\t"
58content_write += f"init_from_vocab {args.init_from_vocab}\t"
59content_write += f"eval_every_steps {args.eval_every_steps}\t"
60content_write += f"prompt_lr {args.prompt_lr}\t"
61content_write += f"optimizer {args.optimizer}\t"
62content_write += f"warmup_step_prompt {args.warmup_step_prompt}\t"
63content_write += f"soft_token_num {args.soft_token_num}\t"
64content_write += "\n"
65
66print(content_write)
67
68import random
69this_run_unicode = str(random.randint(0, 1e10))
70
71from openprompt.utils.reproduciblity import set_seed
72set_seed(args.seed)
73
74# use lm-adapted version or t5-v1.1 checkpoint. Note that the original t5 checkpoint has been pretrained
75# on part of GLUE dataset, thus should not be used.
76from openprompt.plms.seq2seq import T5TokenizerWrapper, T5LMTokenizerWrapper
77from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
78from openprompt.data_utils.data_sampler import FewShotSampler
79from openprompt.plms import load_plm
80
81plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path)
82dataset = {}
83
84# Below are multiple dataset examples, including few-shot ones.
85if args.dataset == "boolq":
86Processor = PROCESSORS["super_glue.boolq"]
87dataset['train'] = Processor().get_train_examples(args.data_dir)
88dataset['validation'] = Processor().get_dev_examples(args.data_dir)
89dataset['test'] = Processor().get_test_examples(args.data_dir)
90class_labels =Processor().get_labels()
91scriptsbase = "SuperGLUE/BoolQ"
92scriptformat = "txt"
93dataset_decoder_max_length = 10
94max_seq_l = 480 # this should be specified according to the running GPU's capacity
95if args.tune_plm: # tune the entire plm will use more gpu-memories, thus we should use a smaller batch_size.
96batchsize_t = 4
97batchsize_e = 4
98gradient_accumulation_steps = 8
99model_parallelize = True # if multiple gpus are available, one can use model_parallelize
100else:
101batchsize_t = 8
102batchsize_e = 4
103gradient_accumulation_steps = 4
104model_parallelize = False
105elif args.dataset == "multirc":
106Processor = PROCESSORS["super_glue.multirc"]
107dataset['train'] = Processor().get_train_examples(args.data_dir)
108dataset['validation'] = Processor().get_dev_examples(args.data_dir)
109dataset['test'] = Processor().get_test_examples(args.data_dir)
110class_labels =Processor().get_labels()
111scriptsbase = "SuperGLUE/MultiRC"
112scriptformat = "txt"
113dataset_decoder_max_length = 10
114max_seq_l = 480 # may be a bit less, but to keep a smaller training overhead, we use 480
115if args.tune_plm:
116batchsize_t = 4
117batchsize_e = 4
118gradient_accumulation_steps = 8
119model_parallelize = True
120else:
121batchsize_t = 8
122batchsize_e = 4
123gradient_accumulation_steps = 4
124model_parallelize = False
125elif args.dataset == "rte":
126Processor = PROCESSORS["super_glue.rte"]
127dataset['train'] = Processor().get_train_examples(args.data_dir)
128dataset['validation'] = Processor().get_dev_examples(args.data_dir)
129dataset['test'] = Processor().get_test_examples(args.data_dir)
130class_labels =Processor().get_labels()
131scriptsbase = "SuperGLUE/RTE"
132scriptformat = "txt"
133max_seq_l = 480
134dataset_decoder_max_length = 10
135if args.tune_plm:
136batchsize_t = 4
137batchsize_e = 4
138gradient_accumulation_steps = 2
139model_parallelize = True
140else:
141batchsize_t = 8
142batchsize_e = 4
143gradient_accumulation_steps = 4
144model_parallelize = False
145elif args.dataset == "cb":
146Processor = PROCESSORS["super_glue.cb"]
147dataset['train'] = Processor().get_train_examples(args.data_dir)
148dataset['validation'] = Processor().get_dev_examples(args.data_dir)
149dataset['test'] = Processor().get_test_examples(args.data_dir)
150class_labels =Processor().get_labels()
151scriptsbase = "SuperGLUE/CB"
152scriptformat = "txt"
153max_seq_l = 480
154dataset_decoder_max_length = 10
155if args.tune_plm:
156batchsize_t = 4
157batchsize_e = 4
158gradient_accumulation_steps = 8
159model_parallelize = True
160else:
161batchsize_t = 8
162batchsize_e = 4
163gradient_accumulation_steps = 4
164model_parallelize = False
165elif args.dataset == "wic":
166Processor = PROCESSORS["super_glue.wic"]
167dataset['train'] = Processor().get_train_examples(args.data_dir)
168dataset['validation'] = Processor().get_dev_examples(args.data_dir)
169dataset['test'] = Processor().get_test_examples(args.data_dir)
170class_labels =Processor().get_labels()
171scriptsbase = "SuperGLUE/WiC"
172scriptformat = "txt"
173max_seq_l = 480
174dataset_decoder_max_length = 10
175if args.tune_plm:
176batchsize_t = 4
177batchsize_e = 4
178gradient_accumulation_steps = 8
179model_parallelize = True
180else:
181batchsize_t = 8
182batchsize_e = 4
183gradient_accumulation_steps = 4
184model_parallelize = False
185elif args.dataset == "copa":
186Processor = PROCESSORS["super_glue.copa"]
187dataset['train'] = Processor().get_train_examples(args.data_dir)
188dataset['validation'] = Processor().get_dev_examples(args.data_dir)
189dataset['test'] = Processor().get_test_examples(args.data_dir)
190class_labels =Processor().get_labels()
191scriptsbase = "SuperGLUE/COPA"
192scriptformat = "txt"
193max_seq_l = 480
194dataset_decoder_max_length = 50
195if args.tune_plm:
196batchsize_t = 4
197batchsize_e = 4
198gradient_accumulation_steps = 8
199model_parallelize = True
200else:
201batchsize_t = 8
202batchsize_e = 4
203gradient_accumulation_steps = 4
204model_parallelize = False
205elif args.dataset == "wsc":
206Processor = PROCESSORS["super_glue.wsc"]
207dataset['train'] = Processor().get_train_examples(args.data_dir)
208dataset['validation'] = Processor().get_dev_examples(args.data_dir)
209dataset['test'] = Processor().get_test_examples(args.data_dir)
210class_labels =Processor().get_labels()
211scriptsbase = "SuperGLUE/WSC"
212scriptformat = "txt"
213max_seq_l = 480
214dataset_decoder_max_length = 10
215if args.tune_plm:
216batchsize_t = 4
217batchsize_e = 4
218gradient_accumulation_steps = 8
219model_parallelize = True
220else:
221batchsize_t = 8
222batchsize_e = 4
223gradient_accumulation_steps = 4
224model_parallelize = False
225elif args.dataset == "record":
226Processor = PROCESSORS["super_glue.record"]
227dataset['train'] = Processor().get_train_examples(args.data_dir)
228dataset['validation'] = Processor().get_dev_examples(args.data_dir)
229dataset['test'] = Processor().get_test_examples(args.data_dir)
230class_labels =Processor().get_labels()
231scriptsbase = "SuperGLUE/RECORD"
232scriptformat = "txt"
233max_seq_l = 480
234dataset_decoder_max_length = 20
235if args.tune_plm:
236batchsize_t = 4
237batchsize_e = 4
238gradient_accumulation_steps = 8
239model_parallelize = True
240else:
241batchsize_t = 8
242batchsize_e = 4
243gradient_accumulation_steps = 4
244model_parallelize = False
245else:
246raise NotImplementedError
247
248
249# Now define the template and verbalizer.
250# Note that soft template can be combined with hard template, by loading the hard template from file.
251# For example, the template in soft_template.txt is {}
252# The choice_id 1 is the hard template
253mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, num_tokens=args.soft_token_num, initialize_from_vocab=args.init_from_vocab).from_file(f"scripts/{scriptsbase}/soft_template.txt", choice=args.template_id)
254if os.path.exists(f"scripts/{scriptsbase}/generation_verbalizer.{scriptformat}"):
255myverbalizer = GenerationVerbalizer(tokenizer, classes=class_labels, is_rule=True).from_file(f"scripts/{scriptsbase}/generation_verbalizer.{scriptformat}", choice=args.verbalizer_id)
256else:
257myverbalizer = GenerationVerbalizer(tokenizer, classes=class_labels, is_rule=False).from_file(f"scripts/{scriptsbase}/manual_verbalizer.{scriptformat}", choice=args.verbalizer_id)
258
259
260use_cuda = True
261prompt_model = PromptForGeneration(plm=plm,template=mytemplate, freeze_plm=(not args.tune_plm), plm_eval_mode=args.plm_eval_mode)
262if use_cuda:
263prompt_model= prompt_model.cuda()
264
265if model_parallelize:
266prompt_model.parallelize()
267
268
269train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer, # be sure to add verbalizer
270tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=dataset_decoder_max_length, # be sure to use larger decoder_max_length for teacher forcing.
271batch_size=batchsize_t,shuffle=True, teacher_forcing=True, predict_eos_token=True, # be sure to use teacher_forcing and predict_eos_token=True
272truncate_method="tail")
273
274validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer,
275tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,
276batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False, # predict_eos_token=True or False are both ok
277truncate_method="tail")
278
279test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer,
280tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,
281batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,
282truncate_method="tail")
283
284print("truncate rate: {}".format(test_dataloader.tokenizer_wrapper.truncate_rate), flush=True)
285
286
287generation_arguments = {
288"max_length": dataset_decoder_max_length,
289}
290
291def evaluate(prompt_model, dataloader):
292predictions = []
293ground_truths = []
294
295for step, inputs in enumerate(dataloader):
296if use_cuda:
297inputs = inputs.cuda()
298_, output_sentence = prompt_model.generate(inputs, **generation_arguments, verbose=False)
299predictions.extend(output_sentence)
300ground_truths.extend(inputs['tgt_text'])
301assert len(predictions)==len(ground_truths), (len(predictions), len(ground_truths))
302predictions = [prediction.strip() for prediction in predictions]
303ground_truths = [ground_truth.strip() for ground_truth in ground_truths]
304# shown one example
305print(f"predictions {predictions[0]}, ground_truths {ground_truths[0]}")
306score = crossfit_evaluate(predictions, ground_truths, metric="ACC")
307return score
308
309
310from transformers import AdamW, get_linear_schedule_with_warmup,get_constant_schedule_with_warmup # use AdamW is a standard practice for transformer
311from transformers.optimization import Adafactor, AdafactorSchedule # use Adafactor is the default setting for T5
312loss_func = torch.nn.CrossEntropyLoss()
313
314tot_step = args.max_steps
315
316
317if args.tune_plm: # normally we freeze the model when using soft_template. However, we keep the option to tune plm
318no_decay = ['bias', 'LayerNorm.weight'] # it's always good practice to set no decay to biase and LayerNorm parameters
319optimizer_grouped_parameters1 = [
320{'params': [p for n, p in prompt_model.plm.named_parameters() if (not any(nd in n for nd in no_decay))], 'weight_decay': 0.01},
321{'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
322]
323optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5)
324scheduler1 = get_linear_schedule_with_warmup(
325optimizer1,
326num_warmup_steps=500, num_training_steps=tot_step)
327else:
328optimizer1 = None
329scheduler1 = None
330
331
332optimizer_grouped_parameters2 = [{'params': [p for name, p in prompt_model.template.named_parameters() if 'raw_embedding' not in name]}] # note that you have to remove the raw_embedding manually from the optimization
333if args.optimizer.lower() == "adafactor":
334optimizer2 = Adafactor(optimizer_grouped_parameters2,
335lr=args.prompt_lr,
336relative_step=False,
337scale_parameter=False,
338warmup_init=False) # when lr is 0.3, it is the same as the configuration of https://arxiv.org/abs/2104.08691
339scheduler2 = get_constant_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_step_prompt) # when num_warmup_steps is 0, it is the same as the configuration of https://arxiv.org/abs/2104.08691
340elif args.optimizer.lower() == "adamw":
341optimizer2 = AdamW(optimizer_grouped_parameters2, lr=args.prompt_lr) # usually lr = 0.5
342scheduler2 = get_linear_schedule_with_warmup(
343optimizer2,
344num_warmup_steps=args.warmup_step_prompt, num_training_steps=tot_step) # usually num_warmup_steps is 500
345
346
347tot_loss = 0
348log_loss = 0
349best_val_acc = 0
350glb_step = 0
351actual_step = 0
352leave_training = False
353
354acc_traces = []
355tot_train_time = 0
356pbar_update_freq = 10
357prompt_model.train()
358
359pbar = tqdm(total=tot_step, desc="Train")
360for epoch in range(1000000):
361print(f"Begin epoch {epoch}")
362for step, inputs in enumerate(train_dataloader):
363if use_cuda:
364inputs = inputs.cuda()
365tot_train_time -= time.time()
366loss = prompt_model(inputs)
367loss.backward()
368tot_loss += loss.item()
369actual_step += 1
370
371if actual_step % gradient_accumulation_steps == 0:
372torch.nn.utils.clip_grad_norm_(prompt_model.parameters(), 1.0)
373glb_step += 1
374if glb_step % pbar_update_freq == 0:
375aveloss = (tot_loss - log_loss)/pbar_update_freq
376pbar.update(10)
377pbar.set_postfix({'loss': aveloss})
378log_loss = tot_loss
379
380
381if optimizer1 is not None:
382optimizer1.step()
383optimizer1.zero_grad()
384if scheduler1 is not None:
385scheduler1.step()
386if optimizer2 is not None:
387optimizer2.step()
388optimizer2.zero_grad()
389if scheduler2 is not None:
390scheduler2.step()
391
392tot_train_time += time.time()
393
394if actual_step % gradient_accumulation_steps == 0 and glb_step >0 and glb_step % args.eval_every_steps == 0:
395val_acc = evaluate(prompt_model, validation_dataloader)
396if val_acc >= best_val_acc:
397torch.save(prompt_model.state_dict(),f"{args.project_root}/../ckpts/{this_run_unicode}.ckpt")
398best_val_acc = val_acc
399
400acc_traces.append(val_acc)
401print("Glb_step {}, val_acc {}, average time {}".format(glb_step, val_acc, tot_train_time/actual_step ), flush=True)
402prompt_model.train()
403
404if glb_step > args.max_steps:
405leave_training = True
406break
407
408if leave_training:
409break
410
411
412# # super_glue test split can not be evaluated without submitting the results to their website. So we skip it here and keep them as comments.
413#
414# prompt_model.load_state_dict(torch.load(f"{args.project_root}/ckpts/{this_run_unicode}.ckpt"))
415# prompt_model = prompt_model.cuda()
416# test_acc = evaluate(prompt_model, test_dataloader, desc="Test")
417# test_acc = evaluate(prompt_model, test_dataloader, desc="Test")
418
419# a simple measure for the convergence speed.
420thres99 = 0.99*best_val_acc
421thres98 = 0.98*best_val_acc
422thres100 = best_val_acc
423step100=step98=step99=args.max_steps
424for val_time, acc in enumerate(acc_traces):
425if acc>=thres98:
426step98 = min(val_time*args.eval_every_steps, step98)
427if acc>=thres99:
428step99 = min(val_time*args.eval_every_steps, step99)
429if acc>=thres100:
430step100 = min(val_time*args.eval_every_steps, step100)
431
432
433content_write += f"BestValAcc:{best_val_acc}\tEndValAcc:{acc_traces[-1]}\tcritical_steps:{[step98,step99,step100]}\n"
434content_write += "\n"
435
436print(content_write)
437
438with open(f"{args.result_file}", "a") as fout:
439fout.write(content_write)
440
441import os
442os.remove(f"../ckpts/{this_run_unicode}.ckpt")