openprompt
423 строки · 17.4 Кб
1from tqdm import tqdm2from openprompt.data_utils import PROCESSORS3import torch4from openprompt.data_utils.utils import InputExample5import argparse6import numpy as np7
8from openprompt import PromptDataLoader9from openprompt.prompts import ManualVerbalizer10from openprompt.prompts import SoftTemplate11from openprompt import PromptForClassification12import time13import os14
15
16parser = argparse.ArgumentParser("")17parser.add_argument("--shot", type=int, default=-1)18parser.add_argument("--seed", type=int, default=144)19parser.add_argument("--plm_eval_mode", action="store_true", help="whether to turn off the dropout in the freezed model. Set to true to turn off.")20parser.add_argument("--tune_plm", action="store_true")21parser.add_argument("--model", type=str, default='t5-lm', help="We test both t5 and t5-lm in this scripts, the corresponding tokenizerwrapper will be automatically loaded.")22parser.add_argument("--model_name_or_path", default='../../plm_cache/t5-large-lm-adapt/')23parser.add_argument("--project_root", default="/mnt/sfs_turbo/hsd/OpenPrompt_official/OpenPrompt/", help="The project root in the file system, i.e. the absolute path of OpenPrompt")24parser.add_argument("--template_id", type=int)25parser.add_argument("--verbalizer_id", type=int)26parser.add_argument("--data_dir", type=str, default="/mnt/sfs_turbo/huggingface_datasets/") # sometimes, huggingface datasets can not be automatically downloaded due to network issue, please refer to 0_basic.py line 15 for solutions.27parser.add_argument("--dataset",type=str)28parser.add_argument("--result_file", type=str, default="../sfs_out/results.txt")29parser.add_argument("--max_steps", default=20000, type=int)30parser.add_argument("--prompt_lr", type=float, default=0.3)31parser.add_argument("--warmup_step_prompt", type=int, default=500)32parser.add_argument("--init_from_vocab", action="store_false")33parser.add_argument("--eval_every_steps", type=int, default=500)34parser.add_argument("--soft_token_num", type=int, default=20)35parser.add_argument("--optimizer", type=str, default="Adafactor")36args = parser.parse_args()37
38args.result_file = os.path.join(args.project_root, args.result_file)39
40content_write = "="*20+"\n"41content_write += f"dataset {args.dataset}\t"42content_write += f"temp {args.template_id}\t"43content_write += f"verb {args.verbalizer_id}\t"44content_write += f"model {args.model}\t"45content_write += f"seed {args.seed}\t"46content_write += f"shot {args.shot}\t"47content_write += f"plm_eval_mode {args.plm_eval_mode}\t"48content_write += f"init_from_vocab {args.init_from_vocab}\t"49content_write += f"eval_every_steps {args.eval_every_steps}\t"50content_write += f"prompt_lr {args.prompt_lr}\t"51content_write += f"optimizer {args.optimizer}\t"52content_write += f"warmup_step_prompt {args.warmup_step_prompt}\t"53content_write += f"soft_token_num {args.soft_token_num}\t"54content_write += "\n"55
56print(content_write)57
58import random59this_run_unicode = str(random.randint(0, 1e10))60
61from openprompt.utils.reproduciblity import set_seed62set_seed(args.seed)63
64# use lm-adapted version or t5-v1.1 checkpoint. Note that the original t5 checkpoint has been pretrained
65# on part of GLUE dataset, thus should not be used.
66from openprompt.plms.seq2seq import T5TokenizerWrapper, T5LMTokenizerWrapper67from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration68from openprompt.data_utils.data_sampler import FewShotSampler69from openprompt.plms import load_plm70
71plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path)72dataset = {}73
74# Below are multiple dataset examples, including few-shot ones.
75if args.dataset == "boolq":76Processor = PROCESSORS["super_glue.boolq"]77dataset['train'] = Processor().get_train_examples(args.data_dir)78dataset['validation'] = Processor().get_dev_examples(args.data_dir)79dataset['test'] = Processor().get_test_examples(args.data_dir)80class_labels =Processor().get_labels()81scriptsbase = "SuperGLUE/BoolQ"82scriptformat = "txt"83max_seq_l = 480 # this should be specified according to the running GPU's capacity84if args.tune_plm: # tune the entire plm will use more gpu-memories, thus we should use a smaller batch_size.85batchsize_t = 486batchsize_e = 487gradient_accumulation_steps = 888model_parallelize = True # if multiple gpus are available, one can use model_parallelize89else:90batchsize_t = 891batchsize_e = 492gradient_accumulation_steps = 493model_parallelize = False94elif args.dataset == "multirc":95Processor = PROCESSORS["super_glue.multirc"]96dataset['train'] = Processor().get_train_examples(args.data_dir)97dataset['validation'] = Processor().get_dev_examples(args.data_dir)98dataset['test'] = Processor().get_test_examples(args.data_dir)99class_labels =Processor().get_labels()100scriptsbase = "SuperGLUE/MultiRC"101scriptformat = "txt"102max_seq_l = 480103if args.tune_plm:104batchsize_t = 4105batchsize_e = 4106gradient_accumulation_steps = 8107model_parallelize = True108else:109batchsize_t = 8110batchsize_e = 4111gradient_accumulation_steps = 4112model_parallelize = False113elif args.dataset == "rte":114Processor = PROCESSORS["super_glue.rte"]115dataset['train'] = Processor().get_train_examples(args.data_dir)116dataset['validation'] = Processor().get_dev_examples(args.data_dir)117dataset['test'] = Processor().get_test_examples(args.data_dir)118class_labels =Processor().get_labels()119scriptsbase = "SuperGLUE/RTE"120scriptformat = "txt"121max_seq_l = 480122if args.tune_plm:123batchsize_t = 4124batchsize_e = 4125gradient_accumulation_steps = 2126model_parallelize = True127else:128batchsize_t = 8129batchsize_e = 4130gradient_accumulation_steps = 4131model_parallelize = False132elif args.dataset == "cb":133Processor = PROCESSORS["super_glue.cb"]134dataset['train'] = Processor().get_train_examples(args.data_dir)135dataset['validation'] = Processor().get_dev_examples(args.data_dir)136dataset['test'] = Processor().get_test_examples(args.data_dir)137class_labels =Processor().get_labels()138scriptsbase = "SuperGLUE/CB"139scriptformat = "txt"140max_seq_l = 480141if args.tune_plm:142batchsize_t = 4143batchsize_e = 4144gradient_accumulation_steps = 8145model_parallelize = True146else:147batchsize_t = 8148batchsize_e = 4149gradient_accumulation_steps = 4150model_parallelize = False151elif args.dataset == "wic":152Processor = PROCESSORS["super_glue.wic"]153dataset['train'] = Processor().get_train_examples(args.data_dir)154dataset['validation'] = Processor().get_dev_examples(args.data_dir)155dataset['test'] = Processor().get_test_examples(args.data_dir)156class_labels =Processor().get_labels()157scriptsbase = "SuperGLUE/WiC"158scriptformat = "txt"159max_seq_l = 480160if args.tune_plm:161batchsize_t = 4162batchsize_e = 4163gradient_accumulation_steps = 8164model_parallelize = True165else:166batchsize_t = 8167batchsize_e = 4168gradient_accumulation_steps = 4169model_parallelize = False170elif args.dataset == "fewshot_boolq":171Processor = PROCESSORS["super_glue.boolq"]172dataset['train'] = Processor().get_train_examples(args.data_dir)173dataset['validation'] = Processor().get_dev_examples(args.data_dir)174dataset['test'] = Processor().get_test_examples(args.data_dir)175class_labels =Processor().get_labels()176scriptsbase = "SuperGLUE/BoolQ"177scriptformat = "txt"178sampler = FewShotSampler(num_examples_per_label=32)179dataset['train']= sampler(dataset['train'], seed=args.seed)180max_seq_l = 480181if args.tune_plm:182batchsize_t = 4183batchsize_e = 4184gradient_accumulation_steps = 8185model_parallelize = True186else:187batchsize_t = 8188batchsize_e = 4189gradient_accumulation_steps = 4190model_parallelize = False191elif args.dataset == "fewshot_multirc":192Processor = PROCESSORS["super_glue.multirc"]193dataset['train'] = Processor().get_train_examples(args.data_dir)194dataset['validation'] = Processor().get_dev_examples(args.data_dir)195dataset['test'] = Processor().get_test_examples(args.data_dir)196class_labels =Processor().get_labels()197scriptsbase = "SuperGLUE/MultiRC"198scriptformat = "txt"199sampler = FewShotSampler(num_examples_per_label=32)200dataset['train']= sampler(dataset['train'], seed=args.seed)201max_seq_l = 480202if args.tune_plm:203batchsize_t = 4204batchsize_e = 4205gradient_accumulation_steps = 8206model_parallelize = True207else:208batchsize_t = 8209batchsize_e = 4210gradient_accumulation_steps = 4211model_parallelize = False212elif args.dataset == "fewshot_wic":213Processor = PROCESSORS["super_glue.wic"]214dataset['train'] = Processor().get_train_examples(args.data_dir)215dataset['validation'] = Processor().get_dev_examples(args.data_dir)216dataset['test'] = Processor().get_test_examples(args.data_dir)217class_labels =Processor().get_labels()218scriptsbase = "SuperGLUE/WiC"219scriptformat = "txt"220sampler = FewShotSampler(num_examples_per_label=32)221dataset['train']= sampler(dataset['train'], seed=args.seed)222max_seq_l = 480223if args.tune_plm:224batchsize_t = 4225batchsize_e = 4226gradient_accumulation_steps = 8227model_parallelize = True228else:229batchsize_t = 8230batchsize_e = 4231gradient_accumulation_steps = 4232model_parallelize = False233else:234raise NotImplementedError235
236
237# Now define the template and verbalizer.
238# Note that soft template can be combined with hard template, by loading the hard template from file.
239# For example, the template in soft_template.txt is {}
240# The choice_id 1 is the hard template
241mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, num_tokens=args.soft_token_num, initialize_from_vocab=args.init_from_vocab).from_file(f"scripts/{scriptsbase}/soft_template.txt", choice=args.template_id)242myverbalizer = ManualVerbalizer(tokenizer, classes=class_labels).from_file(f"scripts/{scriptsbase}/manual_verbalizer.{scriptformat}", choice=args.verbalizer_id)243wrapped_example = mytemplate.wrap_one_example(dataset['train'][0])244print(wrapped_example)245
246
247use_cuda = True248prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=(not args.tune_plm), plm_eval_mode=args.plm_eval_mode)249if use_cuda:250prompt_model= prompt_model.cuda()251
252if model_parallelize:253prompt_model.parallelize()254
255
256train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,257tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,258batch_size=batchsize_t,shuffle=True, teacher_forcing=False, predict_eos_token=False,259truncate_method="tail")260
261validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,262tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,263batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,264truncate_method="tail")265
266# zero-shot test
267test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer,268tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,269batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,270truncate_method="tail")271
272print("truncate rate: {}".format(test_dataloader.tokenizer_wrapper.truncate_rate), flush=True)273
274def evaluate(prompt_model, dataloader, desc):275prompt_model.eval()276allpreds = []277alllabels = []278
279for step, inputs in enumerate(dataloader):280if use_cuda:281inputs = inputs.cuda()282logits = prompt_model(inputs)283labels = inputs['label']284alllabels.extend(labels.cpu().tolist())285allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())286acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)287return acc288
289from transformers import AdamW, get_linear_schedule_with_warmup,get_constant_schedule_with_warmup # use AdamW is a standard practice for transformer290from transformers.optimization import Adafactor, AdafactorSchedule # use Adafactor is the default setting for T5291loss_func = torch.nn.CrossEntropyLoss()292
293tot_step = args.max_steps294
295
296if args.tune_plm: # normally we freeze the model when using soft_template. However, we keep the option to tune plm297no_decay = ['bias', 'LayerNorm.weight'] # it's always good practice to set no decay to biase and LayerNorm parameters298optimizer_grouped_parameters1 = [299{'params': [p for n, p in prompt_model.plm.named_parameters() if (not any(nd in n for nd in no_decay))], 'weight_decay': 0.01},300{'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}301]302optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5)303scheduler1 = get_linear_schedule_with_warmup(304optimizer1,305num_warmup_steps=500, num_training_steps=tot_step)306else:307optimizer1 = None308scheduler1 = None309
310
311optimizer_grouped_parameters2 = [{'params': [p for name, p in prompt_model.template.named_parameters() if 'raw_embedding' not in name]}] # note that you have to remove the raw_embedding manually from the optimization312if args.optimizer.lower() == "adafactor":313optimizer2 = Adafactor(optimizer_grouped_parameters2,314lr=args.prompt_lr,315relative_step=False,316scale_parameter=False,317warmup_init=False) # when lr is 0.3, it is the same as the configuration of https://arxiv.org/abs/2104.08691318scheduler2 = get_constant_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_step_prompt) # when num_warmup_steps is 0, it is the same as the configuration of https://arxiv.org/abs/2104.08691319elif args.optimizer.lower() == "adamw":320optimizer2 = AdamW(optimizer_grouped_parameters2, lr=args.prompt_lr) # usually lr = 0.5321scheduler2 = get_linear_schedule_with_warmup(322optimizer2,323num_warmup_steps=args.warmup_step_prompt, num_training_steps=tot_step) # usually num_warmup_steps is 500324
325
326tot_loss = 0327log_loss = 0328best_val_acc = 0329glb_step = 0330actual_step = 0331leave_training = False332
333acc_traces = []334tot_train_time = 0335pbar_update_freq = 10336prompt_model.train()337
338pbar = tqdm(total=tot_step, desc="Train")339for epoch in range(1000000):340print(f"Begin epoch {epoch}")341for step, inputs in enumerate(train_dataloader):342if use_cuda:343inputs = inputs.cuda()344tot_train_time -= time.time()345logits = prompt_model(inputs)346labels = inputs['label']347loss = loss_func(logits, labels)348loss.backward()349tot_loss += loss.item()350actual_step += 1351
352if actual_step % gradient_accumulation_steps == 0:353torch.nn.utils.clip_grad_norm_(prompt_model.parameters(), 1.0)354glb_step += 1355if glb_step % pbar_update_freq == 0:356aveloss = (tot_loss - log_loss)/pbar_update_freq357pbar.update(10)358pbar.set_postfix({'loss': aveloss})359log_loss = tot_loss360
361
362if optimizer1 is not None:363optimizer1.step()364optimizer1.zero_grad()365if scheduler1 is not None:366scheduler1.step()367if optimizer2 is not None:368optimizer2.step()369optimizer2.zero_grad()370if scheduler2 is not None:371scheduler2.step()372
373tot_train_time += time.time()374
375if actual_step % gradient_accumulation_steps == 0 and glb_step >0 and glb_step % args.eval_every_steps == 0:376val_acc = evaluate(prompt_model, validation_dataloader, desc="Valid")377if val_acc >= best_val_acc:378torch.save(prompt_model.state_dict(),f"{args.project_root}/../ckpts/{this_run_unicode}.ckpt")379best_val_acc = val_acc380
381acc_traces.append(val_acc)382print("Glb_step {}, val_acc {}, average time {}".format(glb_step, val_acc, tot_train_time/actual_step ), flush=True)383prompt_model.train()384
385if glb_step > args.max_steps:386leave_training = True387break388
389if leave_training:390break391
392
393# # super_glue test split can not be evaluated without submitting the results to their website. So we skip it here and keep them as comments.
394#
395# prompt_model.load_state_dict(torch.load(f"{args.project_root}/ckpts/{this_run_unicode}.ckpt"))
396# prompt_model = prompt_model.cuda()
397# test_acc = evaluate(prompt_model, test_dataloader, desc="Test")
398# test_acc = evaluate(prompt_model, test_dataloader, desc="Test")
399
400# a simple measure for the convergence speed.
401thres99 = 0.99*best_val_acc402thres98 = 0.98*best_val_acc403thres100 = best_val_acc404step100=step98=step99=args.max_steps405for val_time, acc in enumerate(acc_traces):406if acc>=thres98:407step98 = min(val_time*args.eval_every_steps, step98)408if acc>=thres99:409step99 = min(val_time*args.eval_every_steps, step99)410if acc>=thres100:411step100 = min(val_time*args.eval_every_steps, step100)412
413
414content_write += f"BestValAcc:{best_val_acc}\tEndValAcc:{acc_traces[-1]}\tcritical_steps:{[step98,step99,step100]}\n"415content_write += "\n"416
417print(content_write)418
419with open(f"{args.result_file}", "a") as fout:420fout.write(content_write)421
422import os423os.remove(f"../ckpts/{this_run_unicode}.ckpt")