OpenDelta

Форк
0
/
1_with_openprompt.py 
156 строк · 7.4 Кб
1
"""
2
This tutorial is a copy of OpenPrompt's tutorial/1.1_mixed_template.py
3
The only modification is in lines 98 to 102
4

5
1. OpenPrompt provides pre-processing of data, such as prompt template formatting
6
2. OpenPrompt pre-process the model input, such as: prompt soft embedding
7
3. OpenDelta modify the backbone model, such as: Adapter, Lora, Compactor, etc.
8
4. OpenPrompt post-process the model output, such as: extract logits at <mask> position, apply prompt verbalizer
9
"""
10

11
# load dataset
12
from datasets import load_dataset
13
from datasets import load_from_disk
14
# raw_dataset = load_dataset('super_glue', 'cb', cache_dir="../datasets/.cache/huggingface_datasets")
15
raw_dataset = load_from_disk("/home/hx/huggingface_datasets/saved_to_disk/super_glue.cb")
16
# Note that if you are running this scripts inside a GPU cluster, there are chances are you are not able to connect to huggingface website directly. 
17
# In this case, we recommend you to run `raw_dataset = load_dataset(...)` on some machine that have internet connections. 
18
# Then use `raw_dataset.save_to_disk(path)` method to save to local path.
19
# Thirdly upload the saved content into the machiine in cluster. 
20
# Then use `load_from_disk` method to load the dataset. 
21

22
from openprompt.data_utils import InputExample
23

24
dataset = {}
25
for split in ['train', 'validation', 'test']:
26
    dataset[split] = []
27
    for data in raw_dataset[split]:
28
        input_example = InputExample(text_a = data['premise'], text_b = data['hypothesis'], label=int(data['label']), guid=data['idx'])
29
        dataset[split].append(input_example)
30
print(dataset['train'][0])
31

32
# You can load the plm related things provided by openprompt simply by calling:
33
from openprompt.plms import load_plm
34
plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")
35

36
# Constructing Template
37
# A template can be constructed from the yaml config, but it can also be constructed by directly passing arguments.
38
from openprompt.prompts import MixedTemplate
39
template_text = '{"placeholder":"text_a"} {"soft"} {"soft"} {"soft"} {"placeholder":"text_b"}? {"soft"} {"soft"} {"soft"} {"mask"}.'
40
mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer, text=template_text)
41

42
# To better understand how does the template wrap the example, we visualize one instance.
43

44
wrapped_example = mytemplate.wrap_one_example(dataset['train'][0]) 
45
print(wrapped_example)
46

47
# Now, the wrapped example is ready to be pass into the tokenizer, hence producing the input for language models.
48
# You can use the tokenizer to tokenize the input by yourself, but we recommend using our wrapped tokenizer, which is a wrapped tokenizer tailed for InputExample. 
49
# The wrapper has been given if you use our `load_plm` function, otherwise, you should choose the suitable wrapper based on
50
# the configuration in `openprompt.plms.__init__.py`.
51
# Note that when t5 is used for classification, we only need to pass <pad> <extra_id_0> <eos> to decoder.
52
# The loss is calcaluted at <extra_id_0>. Thus passing decoder_max_length=3 saves the space
53
wrapped_t5tokenizer = WrapperClass(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")
54
# or
55
from openprompt.plms import T5TokenizerWrapper
56
wrapped_t5tokenizer= T5TokenizerWrapper(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")
57

58
# You can see what a tokenized example looks like by 
59
tokenized_example = wrapped_t5tokenizer.tokenize_one_example(wrapped_example, teacher_forcing=False)
60
print(tokenized_example)
61
print(tokenizer.convert_ids_to_tokens(tokenized_example['input_ids']))
62
print(tokenizer.convert_ids_to_tokens(tokenized_example['decoder_input_ids']))
63

64
# Now it's time to convert the whole dataset into the input format!
65
# Simply loop over the dataset to achieve it!
66

67
model_inputs = {}
68
for split in ['train', 'validation', 'test']:
69
    model_inputs[split] = []
70
    for sample in dataset[split]:
71
        tokenized_example = wrapped_t5tokenizer.tokenize_one_example(mytemplate.wrap_one_example(sample), teacher_forcing=False)
72
        model_inputs[split].append(tokenized_example)
73

74

75
# We provide a `PromptDataLoader` class to help you do all the above matters and wrap them into an `torch.DataLoader` style iterator.
76
from openprompt import PromptDataLoader
77

78
train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer, 
79
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3, 
80
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
81
    truncate_method="head")
82

83

84
# Define the verbalizer
85
# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:
86

87
from openprompt.prompts import ManualVerbalizer
88
import torch
89

90
# for example the verbalizer contains multiple label words in each class
91
myverbalizer = ManualVerbalizer(tokenizer, num_classes=3, label_words=[["yes"], ["no"], ["maybe"]])
92

93
print("label_words_ids", myverbalizer.label_words_ids)
94

95
# Although you can manually combine the plm, template, verbalizer together, we provide a pipeline 
96
# model which take the batched data from the PromptDataLoader and produce a class-wise logits
97

98
from opendelta import LoraModel
99
# delta_model = LoraModel(backbone_model=plm, modified_modules=[])
100
delta_model = LoraModel(backbone_model=plm, modified_modules=["SelfAttention.q", "SelfAttention.v"])
101
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
102
delta_model.log()
103

104
from openprompt import PromptForClassification
105

106
use_cuda = True
107
prompt_model = PromptForClassification(plm=plm, template=mytemplate, verbalizer=myverbalizer)
108
if use_cuda:
109
    prompt_model = prompt_model.cuda()
110

111
# Now the training is standard
112
from transformers import  AdamW, get_linear_schedule_with_warmup
113
loss_func = torch.nn.CrossEntropyLoss()
114
no_decay = ['bias', 'LayerNorm.weight']
115
# it's always good practice to set no decay to biase and LayerNorm parameters
116
optimizer_grouped_parameters = [
117
    {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
118
    {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
119
]
120
print([n for n, p in prompt_model.named_parameters()])
121

122
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)
123

124
for epoch in range(30):
125
    tot_loss = 0 
126
    for step, inputs in enumerate(train_dataloader):
127
        if use_cuda:
128
            inputs = inputs.cuda()
129
        logits = prompt_model(inputs)
130
        labels = inputs['label']
131
        loss = loss_func(logits, labels)
132
        loss.backward()
133
        tot_loss += loss.item()
134
        optimizer.step()
135
        optimizer.zero_grad()
136
        if step %100 ==1:
137
            print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
138
    
139
# Evaluate
140
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer, 
141
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3, 
142
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
143
    truncate_method="head")
144

145
allpreds = []
146
alllabels = []
147
for step, inputs in enumerate(validation_dataloader):
148
    if use_cuda:
149
        inputs = inputs.cuda()
150
    logits = prompt_model(inputs)
151
    labels = inputs['label']
152
    alllabels.extend(labels.cpu().tolist())
153
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
154

155
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
156
print(acc)

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.