Prompt-Transferability
188 строк · 7.9 Кб
1import os
2import sys
3import torch
4import logging
5import random
6import numpy as np
7
8from datasets import load_dataset, load_metric
9from transformers.trainer_utils import get_last_checkpoint
10from transformers import (
11set_seed,
12AutoTokenizer,
13Trainer,
14TrainingArguments,
15DataCollatorWithPadding,
16EvalPrediction,
17default_data_collator,
18)
19from openprompt.data_utils.utils import InputExample
20from openprompt import PromptDataLoader, PromptForClassification
21from openprompt.plms import load_plm
22from openprompt.prompts import SoftTemplate, ManualVerbalizer
23
24from prompt_hub import task_to_keys, get_model
25from prompt_hub.hub import PromptHub
26from prompt_hub.training_args import PromptTrainingArguments, RemainArgHfArgumentParser
27
28
29logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
30datefmt='%m/%d/%Y %H:%M:%S',
31level=logging.INFO)
32
33logger = logging.getLogger(__name__)
34
35
36def main():
37# See all possible arguments in src/transformers/args.py
38# or by passing the --help flag to this script.
39# We now keep distinct sets of args, for a cleaner separation of concerns.
40
41parser = RemainArgHfArgumentParser((PromptTrainingArguments))
42if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
43# If we pass only one argument to the script and it's the path to a json file,
44# let's parse it to get our arguments.
45json_file=os.path.abspath(sys.argv[1])
46args, _ = parser.parse_json_file(json_file, return_remaining_args=True) #args = arg_string, return_remaining_strings=True) #parse_json_file(json_file=os.path.abspath(sys.argv[1]))
47else:
48args = parser.parse_args_into_dataclasses()[0]
49
50set_seed(args.seed)
51
52# Dataset
53is_regression = args.dataset in ['stsb']
54# raw_dataset = load_dataset("glue", self.args.dataset)
55# train_dataset = [InputExample(guid=e['idx'], text_a=e['question'], text_b=e['sentence'], label=e['label']) for e in raw_dataset['train']]#[:100]
56# eval_dataset = [InputExample(guid=e['idx'], text_a=e['question'], text_b=e['sentence'], label=e['label']) for e in raw_dataset['validation']]#[:100]
57
58# Model
59# plm, tokenizer, model_config, tokenizer_wrapper_class = load_plm('roberta', args.backbone)
60# template = '{"soft": None, "duplicate": ' + str(args.prompt_len) + ', "same": True} {"mask"} {"placeholder": "text_a"} {"placeholder": "text_b"}'
61# template = SoftTemplate(model=plm, text=template, tokenizer=tokenizer, num_tokens=args.prompt_len) # initialize_from_vocab=args.init_from_vocab
62# verbalizer = ManualVerbalizer(tokenizer, classes=raw_dataset['train'].features['label'].names).from_file(f'verbalizer/{args.dataset}.txt', choice=0)
63# model = PromptForClassification(plm=plm, template=template, verbalizer=verbalizer, freeze_plm=True)
64
65metric = load_metric("prompt_hub/glue_metrics.py", args.dataset)
66
67
68def compute_metrics(p: EvalPrediction):
69preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
70preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
71result = metric.compute(predictions=preds, references=p.label_ids)
72result["combined_score"] = np.mean(list(result.values())).item()
73
74return result
75
76
77
78# Train
79trainer = PromptHub(
80args=args,
81compute_metrics=compute_metrics,
82)
83
84
85
86train_results = trainer.train_prompt(args.backbone, args.dataset)
87print(train_results)
88
89eval_results = trainer.eval_prompt(args.backbone, args.dataset)
90print(eval_results)
91
92cross_task_results = trainer.cross_task_eval(args.backbone, 'rotten_tomatoes')
93print(cross_task_results)
94
95trainer.cross_model_train(args.backbone, 'roberta-large', args.dataset)
96trainer.cross_task_eval(args.backbone, 'roberta-large', args.dataset)
97
98# Trainer
99# data_collator = DataCollatorWithPadding(tokenizer, max_length=args.max_source_length, pad_to_multiple_of=8)
100# model=model,
101# template=template,
102# verbalizer=verbalizer,
103# tokenizer_wrapper_class=tokenizer_wrapper_class,
104# train_dataset=train_dataset,
105# eval_dataset=eval_dataset,
106# tokenizer=tokenizer,
107# classes=raw_dataset['train'].features['label'].names
108# data_collator=data_collator
109
110# if args.do_train:
111# # Detecting last checkpoint.
112# last_checkpoint = None
113# if os.path.isdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
114# last_checkpoint = get_last_checkpoint(args.output_dir)
115# if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0:
116# raise ValueError(
117# f"Output directory ({args.output_dir}) already exists and is not empty. "
118# "Use --overwrite_output_dir to overcome."
119# )
120# elif last_checkpoint is not None and args.resume_from_checkpoint is None:
121# logger.info(
122# f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
123# "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
124# )
125# train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
126# metrics = train_result.metrics
127# # metrics["train_samples"] = min(max_train_samples, len(train_dataset))
128
129# trainer.save_model() # Saves the tokenizer too for easy upload
130
131# trainer.log_metrics("train", metrics)
132# trainer.save_metrics("train", metrics)
133# trainer.save_state()
134
135# results = {}
136# # Evaluation
137# if args.do_eval:
138# logger.info("*** Evaluate ***")
139
140# # Loop to handle MNLI double evaluation (matched, mis-matched)
141# tasks = [data_args.task_name]
142# eval_datasets = [eval_dataset]
143# if data_args.task_name == "mnli":
144# tasks.append("mnli-mm")
145# eval_datasets.append(raw_datasets["validation_mismatched"])
146
147# for eval_dataset, task in zip(eval_datasets, tasks):
148# metrics = trainer.evaluate(eval_dataset=eval_dataset)
149
150# max_eval_samples = (
151# data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
152# )
153# metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
154
155# trainer.log_metrics("eval", metrics)
156# trainer.save_metrics("eval", metrics)
157# results['eval'] = metrics
158
159# if args.do_predict:
160# logger.info("*** Predict ***")
161
162# # Loop to handle MNLI double evaluation (matched, mis-matched)
163# tasks = [data_args.task_name]
164# predict_datasets = [predict_dataset]
165# if data_args.task_name == "mnli":
166# tasks.append("mnli-mm")
167# predict_datasets.append(raw_datasets["test_mismatched"])
168
169# for predict_dataset, task in zip(predict_datasets, tasks):
170# # Removing the `label` columns because it contains -1 and Trainer won't like that.
171# predict_dataset = predict_dataset.remove_columns("label")
172# predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
173# predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
174
175# output_predict_file = os.path.join(args.output_dir, f"predict_results_{task}.txt")
176# if trainer.is_world_process_zero():
177# with open(output_predict_file, "w") as writer:
178# logger.info(f"***** Predict results {task} *****")
179# writer.write("index\tprediction\n")
180# for index, item in enumerate(predictions):
181# if is_regression:
182# writer.write(f"{index}\t{item:3.3f}\n")
183# else:
184# item = label_list[item]
185# writer.write(f"{index}\t{item}\n")
186
187if __name__ == "__main__":
188main()
189