openprompt

Форк
0
/
1.4_soft_template.py 
423 строки · 17.4 Кб
1
from tqdm import tqdm
2
from openprompt.data_utils import PROCESSORS
3
import torch
4
from openprompt.data_utils.utils import InputExample
5
import argparse
6
import numpy as np
7

8
from openprompt import PromptDataLoader
9
from openprompt.prompts import ManualVerbalizer
10
from openprompt.prompts import SoftTemplate
11
from openprompt import PromptForClassification
12
import time
13
import os
14

15

16
parser = argparse.ArgumentParser("")
17
parser.add_argument("--shot", type=int, default=-1)
18
parser.add_argument("--seed", type=int, default=144)
19
parser.add_argument("--plm_eval_mode", action="store_true", help="whether to turn off the dropout in the freezed model. Set to true to turn off.")
20
parser.add_argument("--tune_plm", action="store_true")
21
parser.add_argument("--model", type=str, default='t5-lm', help="We test both t5 and t5-lm in this scripts, the corresponding tokenizerwrapper will be automatically loaded.")
22
parser.add_argument("--model_name_or_path", default='../../plm_cache/t5-large-lm-adapt/')
23
parser.add_argument("--project_root", default="/mnt/sfs_turbo/hsd/OpenPrompt_official/OpenPrompt/", help="The project root in the file system, i.e. the absolute path of OpenPrompt")
24
parser.add_argument("--template_id", type=int)
25
parser.add_argument("--verbalizer_id", type=int)
26
parser.add_argument("--data_dir", type=str, default="/mnt/sfs_turbo/huggingface_datasets/") # sometimes, huggingface datasets can not be automatically downloaded due to network issue, please refer to 0_basic.py line 15 for solutions.
27
parser.add_argument("--dataset",type=str)
28
parser.add_argument("--result_file", type=str, default="../sfs_out/results.txt")
29
parser.add_argument("--max_steps", default=20000, type=int)
30
parser.add_argument("--prompt_lr", type=float, default=0.3)
31
parser.add_argument("--warmup_step_prompt", type=int, default=500)
32
parser.add_argument("--init_from_vocab", action="store_false")
33
parser.add_argument("--eval_every_steps", type=int, default=500)
34
parser.add_argument("--soft_token_num", type=int, default=20)
35
parser.add_argument("--optimizer", type=str, default="Adafactor")
36
args = parser.parse_args()
37

38
args.result_file = os.path.join(args.project_root, args.result_file)
39

40
content_write = "="*20+"\n"
41
content_write += f"dataset {args.dataset}\t"
42
content_write += f"temp {args.template_id}\t"
43
content_write += f"verb {args.verbalizer_id}\t"
44
content_write += f"model {args.model}\t"
45
content_write += f"seed {args.seed}\t"
46
content_write += f"shot {args.shot}\t"
47
content_write += f"plm_eval_mode {args.plm_eval_mode}\t"
48
content_write += f"init_from_vocab {args.init_from_vocab}\t"
49
content_write += f"eval_every_steps {args.eval_every_steps}\t"
50
content_write += f"prompt_lr {args.prompt_lr}\t"
51
content_write += f"optimizer {args.optimizer}\t"
52
content_write += f"warmup_step_prompt {args.warmup_step_prompt}\t"
53
content_write += f"soft_token_num {args.soft_token_num}\t"
54
content_write += "\n"
55

56
print(content_write)
57

58
import random
59
this_run_unicode = str(random.randint(0, 1e10))
60

61
from openprompt.utils.reproduciblity import set_seed
62
set_seed(args.seed)
63

64
# use lm-adapted version or t5-v1.1 checkpoint. Note that the original t5 checkpoint has been pretrained
65
# on part of GLUE dataset, thus should not be used.
66
from openprompt.plms.seq2seq import T5TokenizerWrapper, T5LMTokenizerWrapper
67
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
68
from openprompt.data_utils.data_sampler import FewShotSampler
69
from openprompt.plms import load_plm
70

71
plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path)
72
dataset = {}
73

74
# Below are multiple dataset examples, including few-shot ones.
75
if args.dataset == "boolq":
76
    Processor = PROCESSORS["super_glue.boolq"]
77
    dataset['train'] = Processor().get_train_examples(args.data_dir)
78
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
79
    dataset['test'] = Processor().get_test_examples(args.data_dir)
80
    class_labels =Processor().get_labels()
81
    scriptsbase = "SuperGLUE/BoolQ"
82
    scriptformat = "txt"
83
    max_seq_l = 480 # this should be specified according to the running GPU's capacity
84
    if args.tune_plm: # tune the entire plm will use more gpu-memories, thus we should use a smaller batch_size.
85
        batchsize_t = 4
86
        batchsize_e = 4
87
        gradient_accumulation_steps = 8
88
        model_parallelize = True # if multiple gpus are available, one can use model_parallelize
89
    else:
90
        batchsize_t = 8
91
        batchsize_e = 4
92
        gradient_accumulation_steps = 4
93
        model_parallelize = False
94
elif args.dataset == "multirc":
95
    Processor = PROCESSORS["super_glue.multirc"]
96
    dataset['train'] = Processor().get_train_examples(args.data_dir)
97
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
98
    dataset['test'] = Processor().get_test_examples(args.data_dir)
99
    class_labels =Processor().get_labels()
100
    scriptsbase = "SuperGLUE/MultiRC"
101
    scriptformat = "txt"
102
    max_seq_l = 480
103
    if args.tune_plm:
104
        batchsize_t = 4
105
        batchsize_e = 4
106
        gradient_accumulation_steps = 8
107
        model_parallelize = True
108
    else:
109
        batchsize_t = 8
110
        batchsize_e = 4
111
        gradient_accumulation_steps = 4
112
        model_parallelize = False
113
elif args.dataset == "rte":
114
    Processor = PROCESSORS["super_glue.rte"]
115
    dataset['train'] = Processor().get_train_examples(args.data_dir)
116
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
117
    dataset['test'] = Processor().get_test_examples(args.data_dir)
118
    class_labels =Processor().get_labels()
119
    scriptsbase = "SuperGLUE/RTE"
120
    scriptformat = "txt"
121
    max_seq_l = 480
122
    if args.tune_plm:
123
        batchsize_t = 4
124
        batchsize_e = 4
125
        gradient_accumulation_steps = 2
126
        model_parallelize = True
127
    else:
128
        batchsize_t = 8
129
        batchsize_e = 4
130
        gradient_accumulation_steps = 4
131
        model_parallelize = False
132
elif args.dataset == "cb":
133
    Processor = PROCESSORS["super_glue.cb"]
134
    dataset['train'] = Processor().get_train_examples(args.data_dir)
135
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
136
    dataset['test'] = Processor().get_test_examples(args.data_dir)
137
    class_labels =Processor().get_labels()
138
    scriptsbase = "SuperGLUE/CB"
139
    scriptformat = "txt"
140
    max_seq_l = 480
141
    if args.tune_plm:
142
        batchsize_t = 4
143
        batchsize_e = 4
144
        gradient_accumulation_steps = 8
145
        model_parallelize = True
146
    else:
147
        batchsize_t = 8
148
        batchsize_e = 4
149
        gradient_accumulation_steps = 4
150
        model_parallelize = False
151
elif args.dataset == "wic":
152
    Processor = PROCESSORS["super_glue.wic"]
153
    dataset['train'] = Processor().get_train_examples(args.data_dir)
154
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
155
    dataset['test'] = Processor().get_test_examples(args.data_dir)
156
    class_labels =Processor().get_labels()
157
    scriptsbase = "SuperGLUE/WiC"
158
    scriptformat = "txt"
159
    max_seq_l = 480
160
    if args.tune_plm:
161
        batchsize_t = 4
162
        batchsize_e = 4
163
        gradient_accumulation_steps = 8
164
        model_parallelize = True
165
    else:
166
        batchsize_t = 8
167
        batchsize_e = 4
168
        gradient_accumulation_steps = 4
169
        model_parallelize = False
170
elif args.dataset == "fewshot_boolq":
171
    Processor = PROCESSORS["super_glue.boolq"]
172
    dataset['train'] = Processor().get_train_examples(args.data_dir)
173
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
174
    dataset['test'] = Processor().get_test_examples(args.data_dir)
175
    class_labels =Processor().get_labels()
176
    scriptsbase = "SuperGLUE/BoolQ"
177
    scriptformat = "txt"
178
    sampler = FewShotSampler(num_examples_per_label=32)
179
    dataset['train']= sampler(dataset['train'], seed=args.seed)
180
    max_seq_l = 480
181
    if args.tune_plm:
182
        batchsize_t = 4
183
        batchsize_e = 4
184
        gradient_accumulation_steps = 8
185
        model_parallelize = True
186
    else:
187
        batchsize_t = 8
188
        batchsize_e = 4
189
        gradient_accumulation_steps = 4
190
        model_parallelize = False
191
elif args.dataset == "fewshot_multirc":
192
    Processor = PROCESSORS["super_glue.multirc"]
193
    dataset['train'] = Processor().get_train_examples(args.data_dir)
194
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
195
    dataset['test'] = Processor().get_test_examples(args.data_dir)
196
    class_labels =Processor().get_labels()
197
    scriptsbase = "SuperGLUE/MultiRC"
198
    scriptformat = "txt"
199
    sampler = FewShotSampler(num_examples_per_label=32)
200
    dataset['train']= sampler(dataset['train'], seed=args.seed)
201
    max_seq_l = 480
202
    if args.tune_plm:
203
        batchsize_t = 4
204
        batchsize_e = 4
205
        gradient_accumulation_steps = 8
206
        model_parallelize = True
207
    else:
208
        batchsize_t = 8
209
        batchsize_e = 4
210
        gradient_accumulation_steps = 4
211
        model_parallelize = False
212
elif args.dataset == "fewshot_wic":
213
    Processor = PROCESSORS["super_glue.wic"]
214
    dataset['train'] = Processor().get_train_examples(args.data_dir)
215
    dataset['validation'] = Processor().get_dev_examples(args.data_dir)
216
    dataset['test'] = Processor().get_test_examples(args.data_dir)
217
    class_labels =Processor().get_labels()
218
    scriptsbase = "SuperGLUE/WiC"
219
    scriptformat = "txt"
220
    sampler = FewShotSampler(num_examples_per_label=32)
221
    dataset['train']= sampler(dataset['train'], seed=args.seed)
222
    max_seq_l = 480
223
    if args.tune_plm:
224
        batchsize_t = 4
225
        batchsize_e = 4
226
        gradient_accumulation_steps = 8
227
        model_parallelize = True
228
    else:
229
        batchsize_t = 8
230
        batchsize_e = 4
231
        gradient_accumulation_steps = 4
232
        model_parallelize = False
233
else:
234
    raise NotImplementedError
235

236

237
# Now define the template and verbalizer.
238
# Note that soft template can be combined with hard template, by loading the hard template from file.
239
# For example, the template in soft_template.txt is {}
240
# The choice_id 1 is the hard template
241
mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, num_tokens=args.soft_token_num, initialize_from_vocab=args.init_from_vocab).from_file(f"scripts/{scriptsbase}/soft_template.txt", choice=args.template_id)
242
myverbalizer = ManualVerbalizer(tokenizer, classes=class_labels).from_file(f"scripts/{scriptsbase}/manual_verbalizer.{scriptformat}", choice=args.verbalizer_id)
243
wrapped_example = mytemplate.wrap_one_example(dataset['train'][0])
244
print(wrapped_example)
245

246

247
use_cuda = True
248
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=(not args.tune_plm), plm_eval_mode=args.plm_eval_mode)
249
if use_cuda:
250
    prompt_model=  prompt_model.cuda()
251

252
if model_parallelize:
253
    prompt_model.parallelize()
254

255

256
train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
257
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,
258
    batch_size=batchsize_t,shuffle=True, teacher_forcing=False, predict_eos_token=False,
259
    truncate_method="tail")
260

261
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,
262
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,
263
    batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,
264
    truncate_method="tail")
265

266
# zero-shot test
267
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer,
268
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3,
269
    batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,
270
    truncate_method="tail")
271

272
print("truncate rate: {}".format(test_dataloader.tokenizer_wrapper.truncate_rate), flush=True)
273

274
def evaluate(prompt_model, dataloader, desc):
275
    prompt_model.eval()
276
    allpreds = []
277
    alllabels = []
278

279
    for step, inputs in enumerate(dataloader):
280
        if use_cuda:
281
            inputs = inputs.cuda()
282
        logits = prompt_model(inputs)
283
        labels = inputs['label']
284
        alllabels.extend(labels.cpu().tolist())
285
        allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
286
    acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
287
    return acc
288

289
from transformers import  AdamW, get_linear_schedule_with_warmup,get_constant_schedule_with_warmup  # use AdamW is a standard practice for transformer
290
from transformers.optimization import Adafactor, AdafactorSchedule  # use Adafactor is the default setting for T5
291
loss_func = torch.nn.CrossEntropyLoss()
292

293
tot_step = args.max_steps
294

295

296
if args.tune_plm: # normally we freeze the model when using soft_template. However, we keep the option to tune plm
297
    no_decay = ['bias', 'LayerNorm.weight'] # it's always good practice to set no decay to biase and LayerNorm parameters
298
    optimizer_grouped_parameters1 = [
299
        {'params': [p for n, p in prompt_model.plm.named_parameters() if (not any(nd in n for nd in no_decay))], 'weight_decay': 0.01},
300
        {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
301
    ]
302
    optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5)
303
    scheduler1 = get_linear_schedule_with_warmup(
304
        optimizer1,
305
        num_warmup_steps=500, num_training_steps=tot_step)
306
else:
307
    optimizer1 = None
308
    scheduler1 = None
309

310

311
optimizer_grouped_parameters2 = [{'params': [p for name, p in prompt_model.template.named_parameters() if 'raw_embedding' not in name]}] # note that you have to remove the raw_embedding manually from the optimization
312
if args.optimizer.lower() == "adafactor":
313
    optimizer2 = Adafactor(optimizer_grouped_parameters2,
314
                            lr=args.prompt_lr,
315
                            relative_step=False,
316
                            scale_parameter=False,
317
                            warmup_init=False)  # when lr is 0.3, it is the same as the configuration of https://arxiv.org/abs/2104.08691
318
    scheduler2 = get_constant_schedule_with_warmup(optimizer2, num_warmup_steps=args.warmup_step_prompt) # when num_warmup_steps is 0, it is the same as the configuration of https://arxiv.org/abs/2104.08691
319
elif args.optimizer.lower() == "adamw":
320
    optimizer2 = AdamW(optimizer_grouped_parameters2, lr=args.prompt_lr) # usually lr = 0.5
321
    scheduler2 = get_linear_schedule_with_warmup(
322
                    optimizer2,
323
                    num_warmup_steps=args.warmup_step_prompt, num_training_steps=tot_step) # usually num_warmup_steps is 500
324

325

326
tot_loss = 0
327
log_loss = 0
328
best_val_acc = 0
329
glb_step = 0
330
actual_step = 0
331
leave_training = False
332

333
acc_traces = []
334
tot_train_time = 0
335
pbar_update_freq = 10
336
prompt_model.train()
337

338
pbar = tqdm(total=tot_step, desc="Train")
339
for epoch in range(1000000):
340
    print(f"Begin epoch {epoch}")
341
    for step, inputs in enumerate(train_dataloader):
342
        if use_cuda:
343
            inputs = inputs.cuda()
344
        tot_train_time -= time.time()
345
        logits = prompt_model(inputs)
346
        labels = inputs['label']
347
        loss = loss_func(logits, labels)
348
        loss.backward()
349
        tot_loss += loss.item()
350
        actual_step += 1
351

352
        if actual_step % gradient_accumulation_steps == 0:
353
            torch.nn.utils.clip_grad_norm_(prompt_model.parameters(), 1.0)
354
            glb_step += 1
355
            if glb_step % pbar_update_freq == 0:
356
                aveloss = (tot_loss - log_loss)/pbar_update_freq
357
                pbar.update(10)
358
                pbar.set_postfix({'loss': aveloss})
359
                log_loss = tot_loss
360

361

362
        if optimizer1 is not None:
363
            optimizer1.step()
364
            optimizer1.zero_grad()
365
        if scheduler1 is not None:
366
            scheduler1.step()
367
        if optimizer2 is not None:
368
            optimizer2.step()
369
            optimizer2.zero_grad()
370
        if scheduler2 is not None:
371
            scheduler2.step()
372

373
        tot_train_time += time.time()
374

375
        if actual_step % gradient_accumulation_steps == 0 and glb_step >0 and glb_step % args.eval_every_steps == 0:
376
            val_acc = evaluate(prompt_model, validation_dataloader, desc="Valid")
377
            if val_acc >= best_val_acc:
378
                torch.save(prompt_model.state_dict(),f"{args.project_root}/../ckpts/{this_run_unicode}.ckpt")
379
                best_val_acc = val_acc
380

381
            acc_traces.append(val_acc)
382
            print("Glb_step {}, val_acc {}, average time {}".format(glb_step, val_acc, tot_train_time/actual_step ), flush=True)
383
            prompt_model.train()
384

385
        if glb_step > args.max_steps:
386
            leave_training = True
387
            break
388

389
    if leave_training:
390
        break
391

392

393
# # super_glue test split can not be evaluated without submitting the results to their website. So we skip it here and keep them as comments.
394
#
395
# prompt_model.load_state_dict(torch.load(f"{args.project_root}/ckpts/{this_run_unicode}.ckpt"))
396
# prompt_model = prompt_model.cuda()
397
# test_acc = evaluate(prompt_model, test_dataloader, desc="Test")
398
# test_acc = evaluate(prompt_model, test_dataloader, desc="Test")
399

400
# a simple measure for the convergence speed.
401
thres99 = 0.99*best_val_acc
402
thres98 = 0.98*best_val_acc
403
thres100 = best_val_acc
404
step100=step98=step99=args.max_steps
405
for val_time, acc in enumerate(acc_traces):
406
    if acc>=thres98:
407
        step98 = min(val_time*args.eval_every_steps, step98)
408
        if acc>=thres99:
409
            step99 = min(val_time*args.eval_every_steps, step99)
410
            if acc>=thres100:
411
                step100 = min(val_time*args.eval_every_steps, step100)
412

413

414
content_write += f"BestValAcc:{best_val_acc}\tEndValAcc:{acc_traces[-1]}\tcritical_steps:{[step98,step99,step100]}\n"
415
content_write += "\n"
416

417
print(content_write)
418

419
with open(f"{args.result_file}", "a") as fout:
420
    fout.write(content_write)
421

422
import os
423
os.remove(f"../ckpts/{this_run_unicode}.ckpt")

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.