openprompt
124 строки · 4.7 Кб
1
2from datasets import load_dataset3raw_dataset = load_dataset('super_glue', 'cb', cache_dir="../datasets/.cache/huggingface_datasets")4raw_dataset['train'][0]5
6from openprompt.data_utils import InputExample7
8dataset = {}9for split in ['train', 'validation', 'test']:10dataset[split] = []11for data in raw_dataset[split]:12input_example = InputExample(text_a = data['premise'], text_b = data['hypothesis'], label=int(data['label']), guid=data['idx'])13dataset[split].append(input_example)14print(dataset['train'][0])15
16from openprompt.plms import load_plm17
18plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")19
20
21# # Try more prompt!
22# You can use templates other than manual template, for example the mixedtemplate is a good place to start.
23# In MixedTemplate, you can use {"soft"} to denote a tunable template. More syntax and usage, please refer
24# to `How to write a template`
25from openprompt.prompts import MixedTemplate26
27mytemplate1 = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft": "Question:"} {"placeholder":"text_b"}? Is it correct? {"mask"}.')28
29mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft"} {"soft"} {"soft"} {"placeholder":"text_b"} {"soft"} {"mask"}.')30
31
32wrapped_example = mytemplate.wrap_one_example(dataset['train'][0])33print(wrapped_example)34
35wrapped_t5tokenizer = WrapperClass(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")36
37from openprompt import PromptDataLoader38
39train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,40tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,41batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,42truncate_method="head")43# next(iter(train_dataloader))
44
45# ## Define the verbalizer
46# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:
47
48from openprompt.prompts import ManualVerbalizer49import torch50
51# for example the verbalizer contains multiple label words in each class
52myverbalizer = ManualVerbalizer(tokenizer, num_classes=3,53label_words=[["yes"], ["no"], ["maybe"]])54
55print(myverbalizer.label_words_ids)56logits = torch.randn(2,len(tokenizer)) # creating a pseudo output from the plm57myverbalizer.process_logits(logits)58
59
60from openprompt import PromptForClassification61
62use_cuda = True63prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)64if use_cuda:65prompt_model= prompt_model.cuda()66
67# ## below is standard training
68
69
70from transformers import AdamW, get_linear_schedule_with_warmup71loss_func = torch.nn.CrossEntropyLoss()72
73no_decay = ['bias', 'LayerNorm.weight']74
75# it's always good practice to set no decay to biase and LayerNorm parameters
76optimizer_grouped_parameters1 = [77{'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},78{'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}79]
80
81# Using different optimizer for prompt parameters and model parameters
82optimizer_grouped_parameters2 = [83{'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}84]
85
86optimizer1 = AdamW(optimizer_grouped_parameters1, lr=1e-4)87optimizer2 = AdamW(optimizer_grouped_parameters2, lr=1e-3)88
89for epoch in range(10):90tot_loss = 091for step, inputs in enumerate(train_dataloader):92if use_cuda:93inputs = inputs.cuda()94logits = prompt_model(inputs)95labels = inputs['label']96loss = loss_func(logits, labels)97loss.backward()98tot_loss += loss.item()99optimizer1.step()100optimizer1.zero_grad()101optimizer2.step()102optimizer2.zero_grad()103print(tot_loss/(step+1))104
105# ## evaluate
106
107# %%
108validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,109tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,110batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,111truncate_method="head")112
113
114allpreds = []115alllabels = []116for step, inputs in enumerate(validation_dataloader):117if use_cuda:118inputs = inputs.cuda()119logits = prompt_model(inputs)120labels = inputs['label']121alllabels.extend(labels.cpu().tolist())122allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())123
124acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)125print(acc)126