openprompt

Форк
0
/
1.1_mixed_template.py 
124 строки · 4.7 Кб
1

2
from datasets import load_dataset
3
raw_dataset = load_dataset('super_glue', 'cb', cache_dir="../datasets/.cache/huggingface_datasets")
4
raw_dataset['train'][0]
5

6
from openprompt.data_utils import InputExample
7

8
dataset = {}
9
for split in ['train', 'validation', 'test']:
10
    dataset[split] = []
11
    for data in raw_dataset[split]:
12
        input_example = InputExample(text_a = data['premise'], text_b = data['hypothesis'], label=int(data['label']), guid=data['idx'])
13
        dataset[split].append(input_example)
14
print(dataset['train'][0])
15

16
from openprompt.plms import load_plm
17

18
plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")
19

20

21
# # Try more prompt!
22
# You can use templates other than manual template, for example the mixedtemplate is a good place to start.
23
# In MixedTemplate, you can use {"soft"} to denote a tunable template. More syntax and usage, please refer
24
# to `How to write a template`
25
from openprompt.prompts import MixedTemplate
26

27
mytemplate1 = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft": "Question:"} {"placeholder":"text_b"}? Is it correct? {"mask"}.')
28

29
mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft"} {"soft"} {"soft"} {"placeholder":"text_b"} {"soft"} {"mask"}.')
30

31

32
wrapped_example = mytemplate.wrap_one_example(dataset['train'][0])
33
print(wrapped_example)
34

35
wrapped_t5tokenizer = WrapperClass(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")
36

37
from openprompt import PromptDataLoader
38

39
train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
40
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
41
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
42
    truncate_method="head")
43
# next(iter(train_dataloader))
44

45
# ## Define the verbalizer
46
# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:
47

48
from openprompt.prompts import ManualVerbalizer
49
import torch
50

51
# for example the verbalizer contains multiple label words in each class
52
myverbalizer = ManualVerbalizer(tokenizer, num_classes=3,
53
                        label_words=[["yes"], ["no"], ["maybe"]])
54

55
print(myverbalizer.label_words_ids)
56
logits = torch.randn(2,len(tokenizer)) # creating a pseudo output from the plm
57
myverbalizer.process_logits(logits)
58

59

60
from openprompt import PromptForClassification
61

62
use_cuda = True
63
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
64
if use_cuda:
65
    prompt_model=  prompt_model.cuda()
66

67
# ## below is standard training
68

69

70
from transformers import  AdamW, get_linear_schedule_with_warmup
71
loss_func = torch.nn.CrossEntropyLoss()
72

73
no_decay = ['bias', 'LayerNorm.weight']
74

75
# it's always good practice to set no decay to biase and LayerNorm parameters
76
optimizer_grouped_parameters1 = [
77
    {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
78
    {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
79
]
80

81
# Using different optimizer for prompt parameters and model parameters
82
optimizer_grouped_parameters2 = [
83
    {'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}
84
]
85

86
optimizer1 = AdamW(optimizer_grouped_parameters1, lr=1e-4)
87
optimizer2 = AdamW(optimizer_grouped_parameters2, lr=1e-3)
88

89
for epoch in range(10):
90
    tot_loss = 0
91
    for step, inputs in enumerate(train_dataloader):
92
        if use_cuda:
93
            inputs = inputs.cuda()
94
        logits = prompt_model(inputs)
95
        labels = inputs['label']
96
        loss = loss_func(logits, labels)
97
        loss.backward()
98
        tot_loss += loss.item()
99
        optimizer1.step()
100
        optimizer1.zero_grad()
101
        optimizer2.step()
102
        optimizer2.zero_grad()
103
        print(tot_loss/(step+1))
104

105
# ## evaluate
106

107
# %%
108
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,
109
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
110
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
111
    truncate_method="head")
112

113

114
allpreds = []
115
alllabels = []
116
for step, inputs in enumerate(validation_dataloader):
117
    if use_cuda:
118
        inputs = inputs.cuda()
119
    logits = prompt_model(inputs)
120
    labels = inputs['label']
121
    alllabels.extend(labels.cpu().tolist())
122
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
123

124
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
125
print(acc)
126

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.