openprompt

Форк
0
/
5.1_BMInf_CPM.py 
133 строки · 5.5 Кб
1
import sys
2
sys.path.append(".")
3
sys.path.append("..")
4

5
from openprompt.data_utils import InputExample
6
from openprompt.data_utils.ZH import ChnSentiCorp
7
from openprompt.data_utils.data_sampler import FewShotSampler
8
processor = ChnSentiCorp()
9
# TODO other chinese datasets are not fully adapted yet
10
trainset = processor.get_train_examples("datasets/ZH/ChnSentiCorp")
11
devset = processor.get_dev_examples("datasets/ZH/ChnSentiCorp")
12
# sampler  = FewShotSampler(num_examples_per_label=8, num_examples_per_label_dev=8, also_sample_dev=True)
13
# trainset, devset = sampler(trainset, devset)
14

15
import bminf.torch as bt
16
use_cpm_version = 2
17
if use_cpm_version == 1:
18
    from openprompt.plms.lm import LMTokenizerWrapper
19
    plm = bt.models.CPM1()
20
    tokenizer = plm.tokenizer
21
    WrapperClass = LMTokenizerWrapper
22
elif use_cpm_version == 2:
23
    from openprompt.plms.seq2seq import CPM2TokenizerWrapper
24
    plm = bt.models.CPM2()
25
    tokenizer = plm.tokenizer
26
    WrapperClass = CPM2TokenizerWrapper
27

28
from openprompt.prompts import SoftTemplate, MixedTemplate
29

30
mytemplate = SoftTemplate(
31
    model = plm,
32
    tokenizer = tokenizer,
33
    # text = '{"meta": "context", "shortenable": True} 上文中,{"meta": "entity"} 是一个{"mask"}。选项:{"meta": "options", "post_processing": lambda lis: ",".join([f"{i}:{choice}" for i, choice in enumerate(lis)])}',
34
    # text = '前提:{"meta": "premise", "shortenable": True} 假设: {"meta": "hypothesis", "shortenable": True} 问题:前提和假设是什么关系? 选项:{"meta": "options", "post_processing": lambda lis: ",".join([f"{i}:{choice}" for i, choice in enumerate(lis)])} 回答:{"mask"}',
35
    text = '文本:{"meta": "context", "shortenable": True} 问题:上述文本所表达的情感是积极的还是消极的? 回答:{"mask"}',
36
)
37

38
wrapped_example = mytemplate.wrap_one_example(trainset[0])
39
print("Wrapped Example:", wrapped_example)
40

41
# ## Define the verbalizer
42
# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:
43

44
from openprompt.prompts import ManualVerbalizer
45
import torch
46

47
# for example the verbalizer contains multiple label words in each class
48
label_words = processor.labels_mapped
49
myverbalizer = ManualVerbalizer(tokenizer, num_classes=len(label_words), label_words=label_words, prefix = '')
50
print("Verbalizer token id:", myverbalizer.label_words_ids.data)
51

52
from openprompt import PromptForClassification
53

54
use_cuda = True
55
prompt_model = PromptForClassification(plm=plm, template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
56
if use_cuda:
57
    prompt_model=  prompt_model.cuda()
58

59
# ## below is standard training
60

61
from openprompt import PromptDataLoader
62

63
train_dataloader = PromptDataLoader(dataset=trainset, template=mytemplate, tokenizer=tokenizer,
64
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=8,
65
    batch_size=16, shuffle=True, teacher_forcing=False, predict_eos_token=False,
66
    truncate_method="head")
67
# next(iter(train_dataloader))
68

69
validation_dataloader = PromptDataLoader(dataset=devset, template=mytemplate, tokenizer=tokenizer,
70
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=8,
71
    batch_size=16, shuffle=False, teacher_forcing=False, predict_eos_token=False,
72
    truncate_method="head")
73

74
from transformers import  AdamW, get_linear_schedule_with_warmup
75
loss_func = torch.nn.CrossEntropyLoss()
76

77
no_decay = ['bias', 'LayerNorm.weight']
78

79
print("names: ", [n for n, p in prompt_model.plm.named_parameters()])
80
# it's always good practice to set no decay to biase and LayerNorm parameters
81
optimizer_grouped_parameters1 = [
82
    {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
83
    {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
84
]
85

86
print("names: ", [n for n, p in prompt_model.template.named_parameters()])
87
# Using different optimizer for prompt parameters and model parameters
88
optimizer_grouped_parameters2 = [
89
    # {'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}
90
    {'params': [p for n,p in prompt_model.template.named_parameters()]}
91
]
92

93
optimizer1 = AdamW(optimizer_grouped_parameters1, lr=0)
94
optimizer2 = AdamW(optimizer_grouped_parameters2, lr=5e-1/1024)
95

96
for epoch in range(3):
97
    # ## train
98
    prompt_model.train()
99

100
    tot_loss = 0
101
    for step, inputs in enumerate(train_dataloader):
102
        if use_cuda:
103
            inputs = inputs.cuda()
104
        logits = prompt_model(inputs)
105
        labels = inputs['label']
106
        loss = loss_func(logits, labels)*1024
107
        loss.backward()
108
        # print(prompt_model.template.soft_embeds.grad)
109
        tot_loss += loss.item()
110
        optimizer1.step()
111
        optimizer1.zero_grad()
112
        optimizer2.step()
113
        optimizer2.zero_grad()
114
        print(f"epoch {epoch} - step {step}: ", loss.item(), tot_loss/(step+1))
115

116
    # ## evaluate
117

118
    prompt_model = prompt_model.eval()
119

120
    allpreds = []
121
    alllabels = []
122
    with torch.no_grad():
123
        for step, inputs in enumerate(validation_dataloader):
124
            if use_cuda:
125
                inputs = inputs.cuda()
126
            logits = prompt_model(inputs)
127
            labels = inputs['label']
128
            alllabels.extend(labels.cpu().tolist())
129
            allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
130
            print("step :", step)
131

132
    acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
133
    print("accuracy:", acc)
134

135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.