openprompt
133 строки · 5.5 Кб
1import sys2sys.path.append(".")3sys.path.append("..")4
5from openprompt.data_utils import InputExample6from openprompt.data_utils.ZH import ChnSentiCorp7from openprompt.data_utils.data_sampler import FewShotSampler8processor = ChnSentiCorp()9# TODO other chinese datasets are not fully adapted yet
10trainset = processor.get_train_examples("datasets/ZH/ChnSentiCorp")11devset = processor.get_dev_examples("datasets/ZH/ChnSentiCorp")12# sampler = FewShotSampler(num_examples_per_label=8, num_examples_per_label_dev=8, also_sample_dev=True)
13# trainset, devset = sampler(trainset, devset)
14
15import bminf.torch as bt16use_cpm_version = 217if use_cpm_version == 1:18from openprompt.plms.lm import LMTokenizerWrapper19plm = bt.models.CPM1()20tokenizer = plm.tokenizer21WrapperClass = LMTokenizerWrapper22elif use_cpm_version == 2:23from openprompt.plms.seq2seq import CPM2TokenizerWrapper24plm = bt.models.CPM2()25tokenizer = plm.tokenizer26WrapperClass = CPM2TokenizerWrapper27
28from openprompt.prompts import SoftTemplate, MixedTemplate29
30mytemplate = SoftTemplate(31model = plm,32tokenizer = tokenizer,33# text = '{"meta": "context", "shortenable": True} 上文中,{"meta": "entity"} 是一个{"mask"}。选项:{"meta": "options", "post_processing": lambda lis: ",".join([f"{i}:{choice}" for i, choice in enumerate(lis)])}',34# text = '前提:{"meta": "premise", "shortenable": True} 假设: {"meta": "hypothesis", "shortenable": True} 问题:前提和假设是什么关系? 选项:{"meta": "options", "post_processing": lambda lis: ",".join([f"{i}:{choice}" for i, choice in enumerate(lis)])} 回答:{"mask"}',35text = '文本:{"meta": "context", "shortenable": True} 问题:上述文本所表达的情感是积极的还是消极的? 回答:{"mask"}',36)
37
38wrapped_example = mytemplate.wrap_one_example(trainset[0])39print("Wrapped Example:", wrapped_example)40
41# ## Define the verbalizer
42# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:
43
44from openprompt.prompts import ManualVerbalizer45import torch46
47# for example the verbalizer contains multiple label words in each class
48label_words = processor.labels_mapped49myverbalizer = ManualVerbalizer(tokenizer, num_classes=len(label_words), label_words=label_words, prefix = '')50print("Verbalizer token id:", myverbalizer.label_words_ids.data)51
52from openprompt import PromptForClassification53
54use_cuda = True55prompt_model = PromptForClassification(plm=plm, template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)56if use_cuda:57prompt_model= prompt_model.cuda()58
59# ## below is standard training
60
61from openprompt import PromptDataLoader62
63train_dataloader = PromptDataLoader(dataset=trainset, template=mytemplate, tokenizer=tokenizer,64tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=8,65batch_size=16, shuffle=True, teacher_forcing=False, predict_eos_token=False,66truncate_method="head")67# next(iter(train_dataloader))
68
69validation_dataloader = PromptDataLoader(dataset=devset, template=mytemplate, tokenizer=tokenizer,70tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=8,71batch_size=16, shuffle=False, teacher_forcing=False, predict_eos_token=False,72truncate_method="head")73
74from transformers import AdamW, get_linear_schedule_with_warmup75loss_func = torch.nn.CrossEntropyLoss()76
77no_decay = ['bias', 'LayerNorm.weight']78
79print("names: ", [n for n, p in prompt_model.plm.named_parameters()])80# it's always good practice to set no decay to biase and LayerNorm parameters
81optimizer_grouped_parameters1 = [82{'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},83{'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}84]
85
86print("names: ", [n for n, p in prompt_model.template.named_parameters()])87# Using different optimizer for prompt parameters and model parameters
88optimizer_grouped_parameters2 = [89# {'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}90{'params': [p for n,p in prompt_model.template.named_parameters()]}91]
92
93optimizer1 = AdamW(optimizer_grouped_parameters1, lr=0)94optimizer2 = AdamW(optimizer_grouped_parameters2, lr=5e-1/1024)95
96for epoch in range(3):97# ## train98prompt_model.train()99
100tot_loss = 0101for step, inputs in enumerate(train_dataloader):102if use_cuda:103inputs = inputs.cuda()104logits = prompt_model(inputs)105labels = inputs['label']106loss = loss_func(logits, labels)*1024107loss.backward()108# print(prompt_model.template.soft_embeds.grad)109tot_loss += loss.item()110optimizer1.step()111optimizer1.zero_grad()112optimizer2.step()113optimizer2.zero_grad()114print(f"epoch {epoch} - step {step}: ", loss.item(), tot_loss/(step+1))115
116# ## evaluate117
118prompt_model = prompt_model.eval()119
120allpreds = []121alllabels = []122with torch.no_grad():123for step, inputs in enumerate(validation_dataloader):124if use_cuda:125inputs = inputs.cuda()126logits = prompt_model(inputs)127labels = inputs['label']128alllabels.extend(labels.cpu().tolist())129allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())130print("step :", step)131
132acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)133print("accuracy:", acc)134
135