openprompt

Форк
0
/
2.1_conditional_generation.py 
155 строк · 6.5 Кб
1

2
# # Conditional Generation with Prefix Tuning.
3
# In this tutorial, we do conditional generation with prefix tuning template.
4

5
# we use WebNLG as an example, as well. Note that the evaluation of generation result should be done
6
# by using the scripts provided by https://github.com/Yale-LILY/dart/tree/master/evaluation,
7
# Which we do not include in it.
8

9
import argparse
10
import torch
11

12
parser = argparse.ArgumentParser("")
13
parser.add_argument("--lr", type=float, default=5e-5)
14
parser.add_argument("--plm_eval_mode", action="store_true")
15
parser.add_argument("--model", type=str, default='t5')  # tested model are gpt2/t5
16
parser.add_argument("--model_name_or_path", default='t5-base')
17
args = parser.parse_args()
18
print(args)
19

20
from openprompt.data_utils.conditional_generation_dataset import WebNLGProcessor
21
dataset = {}
22
dataset['train'] = WebNLGProcessor().get_train_examples("./datasets/CondGen/webnlg_2017/")
23
dataset['validation'] = WebNLGProcessor().get_dev_examples("./datasets/CondGen/webnlg_2017/")
24
dataset['test'] = WebNLGProcessor().get_test_examples("./datasets/CondGen/webnlg_2017/")
25

26

27
# load a pretrained model, its tokenizer, its config, and its TokenzerWrapper by one function
28
from openprompt.plms import load_plm
29
plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path)
30

31
# Instantiating the PrefixTuning Template !
32
from openprompt.prompts.prefix_tuning_template import PrefixTuningTemplate
33
# we can use a plain text as the default setting
34
# i.e.
35
# mytemplate = PrefixTuningTemplate(model=plm, tokenizer=tokenizer)
36
# is equal to
37
# mytemplate = PrefixTuningTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"mask"}')
38
mytemplate = PrefixTuningTemplate(model=plm,  tokenizer=tokenizer, text=' {"placeholder":"text_a"} {"special": "<eos>"} {"mask"} ', using_decoder_past_key_values=False)
39

40
# To better understand how does the template wrap the example, we visualize one instance.
41
# You may observe that the example doesn't end with <|endoftext|> token. Don't worry, adding specific end-of-text token
42
# is a language-model-specific token. we will add it for you in the TokenizerWrapper once you pass `predict_eos_token=True`
43
wrapped_example = mytemplate.wrap_one_example(dataset['train'][0])
44
print(wrapped_example)
45

46

47
# Your can loop over the dataset by yourself by subsequently call mytemplate.wrap_one_example  and WrapperClass().tokenizer()
48
# but we have provide a PromptDataLoader for you.
49
from openprompt import PromptDataLoader
50
train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
51
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=256,
52
    batch_size=5,shuffle=True, teacher_forcing=True, predict_eos_token=True, # be sure to pass predict_eos_token=True if your template doesn't contain one, or you model may fail to stop generation.
53
    truncate_method="head")
54

55
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,
56
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=256,
57
    batch_size=5,shuffle=False, teacher_forcing=False, predict_eos_token=True,
58
    truncate_method="head")
59

60
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer,
61
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=256,
62
    batch_size=5,shuffle=False, teacher_forcing=False, predict_eos_token=True,
63
    truncate_method="head")
64

65
# load the pipeline model PromptForGeneration.
66
from openprompt import PromptForGeneration
67
use_cuda = True
68
prompt_model = PromptForGeneration(plm=plm,template=mytemplate, freeze_plm=True,tokenizer=tokenizer, plm_eval_mode=args.plm_eval_mode)
69
if use_cuda:
70
    prompt_model=  prompt_model.cuda()
71

72

73
from transformers import AdamW
74
# Follow PrefixTuning(https://github.com/XiangLi1999/PrefixTuning), we also fix the language model
75
# only include the template's parameters in training.
76

77
no_decay = ["bias", "LayerNorm.weight"]
78
optimizer_grouped_parameters = [
79
{
80
    "params": [p for n, p in mytemplate.named_parameters() if (not any(nd in n for nd in no_decay)) and p.requires_grad],
81
    "weight_decay": 0.0,
82
},
83
{
84
    "params": [p for n, p in mytemplate.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
85
    "weight_decay": 0.0,
86
},
87
]
88

89

90
optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=1e-8)
91

92
from transformers.optimization import get_linear_schedule_with_warmup
93

94
tot_step  = len(train_dataloader)*5
95
scheduler = get_linear_schedule_with_warmup(optimizer, 0, tot_step)
96

97
# We provide generation a generation metric, you can also define your own. Note that it's not directly comparable to WebNLG's scripts evaluation.
98
from openprompt.utils.metrics import generation_metric
99
# Define evaluate function
100
def evaluate(prompt_model, dataloader):
101
    generated_sentence = []
102
    groundtruth_sentence = []
103
    prompt_model.eval()
104

105
    for step, inputs in enumerate(dataloader):
106
        if use_cuda:
107
            inputs = inputs.cuda()
108
        _, output_sentence = prompt_model.generate(inputs, **generation_arguments)
109
        generated_sentence.extend(output_sentence)
110
        groundtruth_sentence.extend(inputs['tgt_text'])
111
    score = generation_metric(generated_sentence, groundtruth_sentence, "sentence_bleu")
112
    print("test_score", score, flush=True)
113
    return generated_sentence
114

115

116

117

118
generation_arguments = {
119
    "max_length": 512,
120
    "max_new_tokens": None,
121
    "min_length": 5,
122
    "temperature": 1.0,
123
    "do_sample": False,
124
    "top_k": 0,
125
    "top_p": 0.9,
126
    "repetition_penalty": 1.0,
127
    "num_beams": 5,
128
    "bad_words_ids": [[628], [198]]
129
}
130

131
# training and generation.
132
global_step = 0
133
tot_loss = 0
134
log_loss = 0
135
for epoch in range(5):
136
    prompt_model.train()
137
    for step, inputs in enumerate(train_dataloader):
138
        global_step +=1
139
        if use_cuda:
140
            inputs = inputs.cuda()
141
        loss = prompt_model(inputs)
142
        loss.backward()
143
        tot_loss += loss.item()
144
        torch.nn.utils.clip_grad_norm_(mytemplate.parameters(), 1.0)
145
        optimizer.step()
146
        scheduler.step()
147
        optimizer.zero_grad()
148
        if global_step %500 ==0:
149
            print("Epoch {}, global_step {} average loss: {} lr: {}".format(epoch, global_step, (tot_loss-log_loss)/500, scheduler.get_last_lr()[0]), flush=True)
150
            log_loss = tot_loss
151

152
generated_sentence = evaluate(prompt_model, test_dataloader)
153

154
with open(f"../../Generated_sentence_webnlg_gpt2_{args.plm_eval_mode}.txt",'w') as f:
155
    for i in generated_sentence:
156
        f.write(i+"\n")
157

158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.