openprompt

Форк
0
/
4.1_all_tasks_are_generation.py 
442 строки · 18.8 Кб
1
# There is a recent trend to unify all tasks (such as classification) into tasks generation.
2
# In fact, unifying the tasks into text generation can be neatly conducted using prompt.
3
# In OpenPrompt, we provide a GenerationVerbalizer for this utility.
4
# Here we go!
5

6
from openprompt.pipeline_base import PromptForGeneration
7
from openprompt.prompts.generation_verbalizer import GenerationVerbalizer
8
from tokenizers import PreTokenizedString
9
from tqdm import tqdm
10
from openprompt.data_utils import PROCESSORS
11
import torch
12
from openprompt.data_utils.utils import InputExample
13
import argparse
14
import numpy as np
15

16
from openprompt import PromptDataLoader
17
from openprompt.prompts import ManualVerbalizer
18
from openprompt.prompts import SoftTemplate
19
from openprompt import PromptForClassification
20
import time
21
import os
22
import re
23
from openprompt.utils.crossfit_metrics import evaluate as crossfit_evaluate
24

25

26
parser = argparse.ArgumentParser("")
27
parser.add_argument("--shot", type=int, default=-1)
28
parser.add_argument("--seed", type=int, default=144)
29
parser.add_argument("--plm_eval_mode", action="store_true", help="whether to turn off the dropout in the freezed model. Set to true to turn off.")
30
parser.add_argument("--tune_plm", action="store_true", help="Whether to tune the plm, default to False")
31
parser.add_argument("--model", type=str, default='t5-lm', help="We test both t5 and t5-lm in this scripts, the corresponding tokenizerwrapper will be automatically loaded.")
32
parser.add_argument("--model_name_or_path", default='../../plm_cache/t5-large-lm-adapt/')
33
parser.add_argument("--project_root", default="/home/hushengding/OpenPrompt_CameraReady/OpenPrompt/", help="The project root in the file system, i.e. the absolute path of OpenPrompt")
34
parser.add_argument("--template_id", type=int, default=0)
35
parser.add_argument("--verbalizer_id", type=int, default=0)
36
parser.add_argument("--data_dir", type=str, default="../../huggingface_datasets/saved_to_disk/") # sometimes, huggingface datasets can not be automatically downloaded due to network issue, please refer to 0_basic.py line 15 for solutions.
37
parser.add_argument("--dataset",type=str)
38
parser.add_argument("--result_file", type=str, default="../results.txt")
39
parser.add_argument("--max_steps", default=20000, type=int)
40
parser.add_argument("--prompt_lr", type=float, default=0.3)
41
parser.add_argument("--warmup_step_prompt", type=int, default=500)
42
parser.add_argument("--init_from_vocab", action="store_false")
43
parser.add_argument("--eval_every_steps", type=int, default=500)
44
parser.add_argument("--soft_token_num", type=int, default=100)
45
parser.add_argument("--optimizer", type=str, default="Adafactor")
46
args = parser.parse_args()
47

48
args.result_file = os.path.join(args.project_root, args.result_file)
49

50
content_write = "="*20+"\n"
51
content_write += f"dataset {args.dataset}\t"
52
content_write += f"temp {args.template_id}\t"
53
content_write += f"verb {args.verbalizer_id}\t"
54
content_write += f"model {args.model}\t"
55
content_write += f"seed {args.seed}\t"
56
content_write += f"shot {args.shot}\t"
57
content_write += f"plm_eval_mode {args.plm_eval_mode}\t"
58
content_write += f"init_from_vocab {args.init_from_vocab}\t"
59
content_write += f"eval_every_steps {args.eval_every_steps}\t"
60
content_write += f"prompt_lr {args.prompt_lr}\t"
61
content_write += f"optimizer {args.optimizer}\t"
62
content_write += f"warmup_step_prompt {args.warmup_step_prompt}\t"
63
content_write += f"soft_token_num {args.soft_token_num}\t"
64
content_write += "\n"
65

66
print(content_write)
67

68
import random
69
this_run_unicode = str(random.randint(0, 1e10))
70

71
from openprompt.utils.reproduciblity import set_seed
72
set_seed(args.seed)
73

74
# use lm-adapted version or t5-v1.1 checkpoint. Note that the original t5 checkpoint has been pretrained
75
# on part of GLUE dataset, thus should not be used.
76
from openprompt.plms.seq2seq import T5TokenizerWrapper, T5LMTokenizerWrapper
77
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
78
from openprompt.data_utils.data_sampler import FewShotSampler
79
from openprompt.plms import load_plm
80

81
plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path)
82
dataset = {}
83

84
# Below are multiple dataset examples, including few-shot ones.
85
if args.dataset == "boolq":
86
    Processor = PROCESSORS["super_glue.boolq"]
87
    dataset['train'] = Processor().get_train_examples(args.data_dir)
88
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
89
    dataset['test'] = Processor().get_test_examples(args.data_dir)
90
    class_labels =Processor().get_labels()
91
    scriptsbase = "SuperGLUE/BoolQ"
92
    scriptformat = "txt"
93
    dataset_decoder_max_length = 10
94
    max_seq_l = 480 # this should be specified according to the running GPU's capacity
95
    if args.tune_plm: # tune the entire plm will use more gpu-memories, thus we should use a smaller batch_size.
96
        batchsize_t = 4
97
        batchsize_e = 4
98
        gradient_accumulation_steps = 8
99
        model_parallelize = True # if multiple gpus are available, one can use model_parallelize
100
    else:
101
        batchsize_t = 8
102
        batchsize_e = 4
103
        gradient_accumulation_steps = 4
104
        model_parallelize = False
105
elif args.dataset == "multirc":
106
    Processor = PROCESSORS["super_glue.multirc"]
107
    dataset['train'] = Processor().get_train_examples(args.data_dir)
108
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
109
    dataset['test'] = Processor().get_test_examples(args.data_dir)
110
    class_labels =Processor().get_labels()
111
    scriptsbase = "SuperGLUE/MultiRC"
112
    scriptformat = "txt"
113
    dataset_decoder_max_length = 10
114
    max_seq_l = 480 # may be a bit less, but to keep a smaller training overhead, we use 480
115
    if args.tune_plm:
116
        batchsize_t = 4
117
        batchsize_e = 4
118
        gradient_accumulation_steps = 8
119
        model_parallelize = True
120
    else:
121
        batchsize_t = 8
122
        batchsize_e = 4
123
        gradient_accumulation_steps = 4
124
        model_parallelize = False
125
elif args.dataset == "rte":
126
    Processor = PROCESSORS["super_glue.rte"]
127
    dataset['train'] = Processor().get_train_examples(args.data_dir)
128
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
129
    dataset['test'] = Processor().get_test_examples(args.data_dir)
130
    class_labels =Processor().get_labels()
131
    scriptsbase = "SuperGLUE/RTE"
132
    scriptformat = "txt"
133
    max_seq_l = 480
134
    dataset_decoder_max_length = 10
135
    if args.tune_plm:
136
        batchsize_t = 4
137
        batchsize_e = 4
138
        gradient_accumulation_steps = 2
139
        model_parallelize = True
140
    else:
141
        batchsize_t = 8
142
        batchsize_e = 4
143
        gradient_accumulation_steps = 4
144
        model_parallelize = False
145
elif args.dataset == "cb":
146
    Processor = PROCESSORS["super_glue.cb"]
147
    dataset['train'] = Processor().get_train_examples(args.data_dir)
148
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
149
    dataset['test'] = Processor().get_test_examples(args.data_dir)
150
    class_labels =Processor().get_labels()
151
    scriptsbase = "SuperGLUE/CB"
152
    scriptformat = "txt"
153
    max_seq_l = 480
154
    dataset_decoder_max_length = 10
155
    if args.tune_plm:
156
        batchsize_t = 4
157
        batchsize_e = 4
158
        gradient_accumulation_steps = 8
159
        model_parallelize = True
160
    else:
161
        batchsize_t = 8
162
        batchsize_e = 4
163
        gradient_accumulation_steps = 4
164
        model_parallelize = False
165
elif args.dataset == "wic":
166
    Processor = PROCESSORS["super_glue.wic"]
167
    dataset['train'] = Processor().get_train_examples(args.data_dir)
168
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
169
    dataset['test'] = Processor().get_test_examples(args.data_dir)
170
    class_labels =Processor().get_labels()
171
    scriptsbase = "SuperGLUE/WiC"
172
    scriptformat = "txt"
173
    max_seq_l = 480
174
    dataset_decoder_max_length = 10
175
    if args.tune_plm:
176
        batchsize_t = 4
177
        batchsize_e = 4
178
        gradient_accumulation_steps = 8
179
        model_parallelize = True
180
    else:
181
        batchsize_t = 8
182
        batchsize_e = 4
183
        gradient_accumulation_steps = 4
184
        model_parallelize = False
185
elif args.dataset == "copa":
186
    Processor = PROCESSORS["super_glue.copa"]
187
    dataset['train'] = Processor().get_train_examples(args.data_dir)
188
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
189
    dataset['test'] = Processor().get_test_examples(args.data_dir)
190
    class_labels =Processor().get_labels()
191
    scriptsbase = "SuperGLUE/COPA"
192
    scriptformat = "txt"
193
    max_seq_l = 480
194
    dataset_decoder_max_length = 50
195
    if args.tune_plm:
196
        batchsize_t = 4
197
        batchsize_e = 4
198
        gradient_accumulation_steps = 8
199
        model_parallelize = True
200
    else:
201
        batchsize_t = 8
202
        batchsize_e = 4
203
        gradient_accumulation_steps = 4
204
        model_parallelize = False
205
elif args.dataset == "wsc":
206
    Processor = PROCESSORS["super_glue.wsc"]
207
    dataset['train'] = Processor().get_train_examples(args.data_dir)
208
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
209
    dataset['test'] = Processor().get_test_examples(args.data_dir)
210
    class_labels =Processor().get_labels()
211
    scriptsbase = "SuperGLUE/WSC"
212
    scriptformat = "txt"
213
    max_seq_l = 480
214
    dataset_decoder_max_length = 10
215
    if args.tune_plm:
216
        batchsize_t = 4
217
        batchsize_e = 4
218
        gradient_accumulation_steps = 8
219
        model_parallelize = True
220
    else:
221
        batchsize_t = 8
222
        batchsize_e = 4
223
        gradient_accumulation_steps = 4
224
        model_parallelize = False
225
elif args.dataset == "record":
226
    Processor = PROCESSORS["super_glue.record"]
227
    dataset['train'] = Processor().get_train_examples(args.data_dir)
228
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
229
    dataset['test'] = Processor().get_test_examples(args.data_dir)
230
    class_labels =Processor().get_labels()
231
    scriptsbase = "SuperGLUE/RECORD"
232
    scriptformat = "txt"
233
    max_seq_l = 480
234
    dataset_decoder_max_length = 20
235
    if args.tune_plm:
236
        batchsize_t = 4
237
        batchsize_e = 4
238
        gradient_accumulation_steps = 8
239
        model_parallelize = True
240
    else:
241
        batchsize_t = 8
242
        batchsize_e = 4
243
        gradient_accumulation_steps = 4
244
        model_parallelize = False
245
else:
246
    raise NotImplementedError
247

248

249
# Now define the template and verbalizer.
250
# Note that soft template can be combined with hard template, by loading the hard template from file.
251
# For example, the template in soft_template.txt is {}
252
# The choice_id 1 is the hard template
253
mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, num_tokens=args.soft_token_num, initialize_from_vocab=args.init_from_vocab).from_file(f"scripts/{scriptsbase}/soft_template.txt", choice=args.template_id)
254
if os.path.exists(f"scripts/{scriptsbase}/generation_verbalizer.{scriptformat}"):
255
    myverbalizer = GenerationVerbalizer(tokenizer, classes=class_labels, is_rule=True).from_file(f"scripts/{scriptsbase}/generation_verbalizer.{scriptformat}", choice=args.verbalizer_id)
256
else:
257
    myverbalizer = GenerationVerbalizer(tokenizer, classes=class_labels, is_rule=False).from_file(f"scripts/{scriptsbase}/manual_verbalizer.{scriptformat}", choice=args.verbalizer_id)
258

259

260
use_cuda = True
261
prompt_model = PromptForGeneration(plm=plm,template=mytemplate, freeze_plm=(not args.tune_plm), plm_eval_mode=args.plm_eval_mode)
262
if use_cuda:
263
    prompt_model=  prompt_model.cuda()
264

265
if model_parallelize:
266
    prompt_model.parallelize()
267

268

269
train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer, # be sure to add verbalizer
270
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=dataset_decoder_max_length,  # be sure to use larger decoder_max_length for teacher forcing.
271
    batch_size=batchsize_t,shuffle=True, teacher_forcing=True, predict_eos_token=True,  # be sure to use teacher_forcing and predict_eos_token=True
272
    truncate_method="tail")
273

274
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer,
275
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,
276
    batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False, # predict_eos_token=True or False are both ok
277
    truncate_method="tail")
278

279
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, verbalizer=myverbalizer, tokenizer=tokenizer,
280
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,
281
    batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,
282
    truncate_method="tail")
283

284
print("truncate rate: {}".format(test_dataloader.tokenizer_wrapper.truncate_rate), flush=True)
285

286

287
generation_arguments = {
288
    "max_length": dataset_decoder_max_length,
289
}
290

291
def evaluate(prompt_model, dataloader):
292
    predictions = []
293
    ground_truths = []
294

295
    for step, inputs in enumerate(dataloader):
296
        if use_cuda:
297
            inputs = inputs.cuda()
298
        _, output_sentence = prompt_model.generate(inputs, **generation_arguments, verbose=False)
299
        predictions.extend(output_sentence)
300
        ground_truths.extend(inputs['tgt_text'])
301
    assert len(predictions)==len(ground_truths), (len(predictions), len(ground_truths))
302
    predictions = [prediction.strip() for prediction in predictions]
303
    ground_truths = [ground_truth.strip() for ground_truth in ground_truths]
304
    # shown one example
305
    print(f"predictions {predictions[0]}, ground_truths {ground_truths[0]}")
306
    score =  crossfit_evaluate(predictions, ground_truths, metric="ACC")
307
    return score
308

309

310
from transformers import  AdamW, get_linear_schedule_with_warmup,get_constant_schedule_with_warmup  # use AdamW is a standard practice for transformer
311
from transformers.optimization import Adafactor, AdafactorSchedule  # use Adafactor is the default setting for T5
312
loss_func = torch.nn.CrossEntropyLoss()
313

314
tot_step = args.max_steps
315

316

317
if args.tune_plm: # normally we freeze the model when using soft_template. However, we keep the option to tune plm
318
    no_decay = ['bias', 'LayerNorm.weight'] # it's always good practice to set no decay to biase and LayerNorm parameters
319
    optimizer_grouped_parameters1 = [
320
        {'params': [p for n, p in prompt_model.plm.named_parameters() if (not any(nd in n for nd in no_decay))], 'weight_decay': 0.01},
321
        {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
322
    ]
323
    optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5)
324
    scheduler1 = get_linear_schedule_with_warmup(
325
        optimizer1,
326
        num_warmup_steps=500, num_training_steps=tot_step)
327
else:
328
    optimizer1 = None
329
    scheduler1 = None
330

331

332
optimizer_grouped_parameters2 = [{'params': [p for name, p in prompt_model.template.named_parameters() if 'raw_embedding' not in name]}] # note that you have to remove the raw_embedding manually from the optimization
333
if args.optimizer.lower() == "adafactor":
334
    optimizer2 = Adafactor(optimizer_grouped_parameters2,
335
                            lr=args.prompt_lr,
336
                            relative_step=False,
337
                            scale_parameter=False,
338
                            warmup_init=False)  # when lr is 0.3, it is the same as the configuration of https://arxiv.org/abs/2104.08691
339
    scheduler2 = get_constant_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_step_prompt) # when num_warmup_steps is 0, it is the same as the configuration of https://arxiv.org/abs/2104.08691
340
elif args.optimizer.lower() == "adamw":
341
    optimizer2 = AdamW(optimizer_grouped_parameters2, lr=args.prompt_lr) # usually lr = 0.5
342
    scheduler2 = get_linear_schedule_with_warmup(
343
                    optimizer2,
344
                    num_warmup_steps=args.warmup_step_prompt, num_training_steps=tot_step) # usually num_warmup_steps is 500
345

346

347
tot_loss = 0
348
log_loss = 0
349
best_val_acc = 0
350
glb_step = 0
351
actual_step = 0
352
leave_training = False
353

354
acc_traces = []
355
tot_train_time = 0
356
pbar_update_freq = 10
357
prompt_model.train()
358

359
pbar = tqdm(total=tot_step, desc="Train")
360
for epoch in range(1000000):
361
    print(f"Begin epoch {epoch}")
362
    for step, inputs in enumerate(train_dataloader):
363
        if use_cuda:
364
            inputs = inputs.cuda()
365
        tot_train_time -= time.time()
366
        loss = prompt_model(inputs)
367
        loss.backward()
368
        tot_loss += loss.item()
369
        actual_step += 1
370

371
        if actual_step % gradient_accumulation_steps == 0:
372
            torch.nn.utils.clip_grad_norm_(prompt_model.parameters(), 1.0)
373
            glb_step += 1
374
            if glb_step % pbar_update_freq == 0:
375
                aveloss = (tot_loss - log_loss)/pbar_update_freq
376
                pbar.update(10)
377
                pbar.set_postfix({'loss': aveloss})
378
                log_loss = tot_loss
379

380

381
                if optimizer1 is not None:
382
                    optimizer1.step()
383
                    optimizer1.zero_grad()
384
                if scheduler1 is not None:
385
                    scheduler1.step()
386
                if optimizer2 is not None:
387
                    optimizer2.step()
388
                    optimizer2.zero_grad()
389
                if scheduler2 is not None:
390
                    scheduler2.step()
391

392
        tot_train_time += time.time()
393

394
        if actual_step % gradient_accumulation_steps == 0 and glb_step >0 and glb_step % args.eval_every_steps == 0:
395
            val_acc = evaluate(prompt_model, validation_dataloader)
396
            if val_acc >= best_val_acc:
397
                torch.save(prompt_model.state_dict(),f"{args.project_root}/../ckpts/{this_run_unicode}.ckpt")
398
                best_val_acc = val_acc
399

400
            acc_traces.append(val_acc)
401
            print("Glb_step {}, val_acc {}, average time {}".format(glb_step, val_acc, tot_train_time/actual_step ), flush=True)
402
            prompt_model.train()
403

404
        if glb_step > args.max_steps:
405
            leave_training = True
406
            break
407

408
    if leave_training:
409
        break
410

411

412
# # super_glue test split can not be evaluated without submitting the results to their website. So we skip it here and keep them as comments.
413
#
414
# prompt_model.load_state_dict(torch.load(f"{args.project_root}/ckpts/{this_run_unicode}.ckpt"))
415
# prompt_model = prompt_model.cuda()
416
# test_acc = evaluate(prompt_model, test_dataloader, desc="Test")
417
# test_acc = evaluate(prompt_model, test_dataloader, desc="Test")
418

419
# a simple measure for the convergence speed.
420
thres99 = 0.99*best_val_acc
421
thres98 = 0.98*best_val_acc
422
thres100 = best_val_acc
423
step100=step98=step99=args.max_steps
424
for val_time, acc in enumerate(acc_traces):
425
    if acc>=thres98:
426
        step98 = min(val_time*args.eval_every_steps, step98)
427
        if acc>=thres99:
428
            step99 = min(val_time*args.eval_every_steps, step99)
429
            if acc>=thres100:
430
                step100 = min(val_time*args.eval_every_steps, step100)
431

432

433
content_write += f"BestValAcc:{best_val_acc}\tEndValAcc:{acc_traces[-1]}\tcritical_steps:{[step98,step99,step100]}\n"
434
content_write += "\n"
435

436
print(content_write)
437

438
with open(f"{args.result_file}", "a") as fout:
439
    fout.write(content_write)
440

441
import os
442
os.remove(f"../ckpts/{this_run_unicode}.ckpt")

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.