Prompt-Transferability

Форк
0
1
import os
2
import sys
3
import torch
4
import logging
5
import random
6
import numpy as np
7

8
from datasets import load_dataset, load_metric
9
from transformers.trainer_utils import get_last_checkpoint
10
from transformers import (
11
    set_seed,
12
    AutoTokenizer,
13
    Trainer,
14
    TrainingArguments,
15
    DataCollatorWithPadding,
16
    EvalPrediction,
17
    default_data_collator,
18
)
19
from openprompt.data_utils.utils import InputExample
20
from openprompt import PromptDataLoader, PromptForClassification
21
from openprompt.plms import load_plm
22
from openprompt.prompts import SoftTemplate, ManualVerbalizer
23

24
from prompt_hub import task_to_keys, get_model
25
from prompt_hub.hub import PromptHub
26
from prompt_hub.training_args import PromptTrainingArguments, RemainArgHfArgumentParser
27

28

29
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
30
                    datefmt='%m/%d/%Y %H:%M:%S',
31
                    level=logging.INFO)
32

33
logger = logging.getLogger(__name__)
34

35

36
def main():
37
    # See all possible arguments in src/transformers/args.py
38
    # or by passing the --help flag to this script.
39
    # We now keep distinct sets of args, for a cleaner separation of concerns.
40

41
    parser = RemainArgHfArgumentParser((PromptTrainingArguments))
42
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
43
        # If we pass only one argument to the script and it's the path to a json file,
44
        # let's parse it to get our arguments.
45
        json_file=os.path.abspath(sys.argv[1])
46
        args, _ = parser.parse_json_file(json_file, return_remaining_args=True) #args = arg_string, return_remaining_strings=True) #parse_json_file(json_file=os.path.abspath(sys.argv[1]))
47
    else:
48
        args = parser.parse_args_into_dataclasses()[0]
49

50
    set_seed(args.seed)
51

52
    # Dataset
53
    is_regression = args.dataset in ['stsb']
54
    # raw_dataset = load_dataset("glue", self.args.dataset)
55
    # train_dataset = [InputExample(guid=e['idx'], text_a=e['question'], text_b=e['sentence'], label=e['label']) for e in raw_dataset['train']]#[:100]
56
    # eval_dataset = [InputExample(guid=e['idx'], text_a=e['question'], text_b=e['sentence'], label=e['label']) for e in raw_dataset['validation']]#[:100]
57

58
    # Model
59
    # plm, tokenizer, model_config, tokenizer_wrapper_class = load_plm('roberta', args.backbone)
60
    # template = '{"soft": None, "duplicate": ' + str(args.prompt_len) + ', "same": True} {"mask"} {"placeholder": "text_a"} {"placeholder": "text_b"}'
61
    # template = SoftTemplate(model=plm, text=template, tokenizer=tokenizer, num_tokens=args.prompt_len) # initialize_from_vocab=args.init_from_vocab
62
    # verbalizer = ManualVerbalizer(tokenizer, classes=raw_dataset['train'].features['label'].names).from_file(f'verbalizer/{args.dataset}.txt', choice=0)
63
    # model = PromptForClassification(plm=plm, template=template, verbalizer=verbalizer, freeze_plm=True)
64

65
    metric = load_metric("prompt_hub/glue_metrics.py", args.dataset)
66

67

68
    def compute_metrics(p: EvalPrediction):
69
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
70
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
71
        result = metric.compute(predictions=preds, references=p.label_ids)
72
        result["combined_score"] = np.mean(list(result.values())).item()
73

74
        return result
75

76

77

78
    # Train
79
    trainer = PromptHub(
80
        args=args,
81
        compute_metrics=compute_metrics,
82
    )
83

84

85

86
    train_results = trainer.train_prompt(args.backbone, args.dataset)
87
    print(train_results)
88

89
    eval_results = trainer.eval_prompt(args.backbone, args.dataset)
90
    print(eval_results)
91

92
    cross_task_results = trainer.cross_task_eval(args.backbone, 'rotten_tomatoes')
93
    print(cross_task_results)
94

95
    trainer.cross_model_train(args.backbone, 'roberta-large', args.dataset)
96
    trainer.cross_task_eval(args.backbone, 'roberta-large', args.dataset)
97

98
    # Trainer
99
    # data_collator = DataCollatorWithPadding(tokenizer, max_length=args.max_source_length, pad_to_multiple_of=8)
100
    # model=model,
101
    # template=template,
102
    # verbalizer=verbalizer,
103
    # tokenizer_wrapper_class=tokenizer_wrapper_class,
104
    # train_dataset=train_dataset,
105
    # eval_dataset=eval_dataset,
106
    # tokenizer=tokenizer,
107
    # classes=raw_dataset['train'].features['label'].names
108
    # data_collator=data_collator
109

110
    # if args.do_train:
111
    #     # Detecting last checkpoint.
112
    #     last_checkpoint = None
113
    #     if os.path.isdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
114
    #         last_checkpoint = get_last_checkpoint(args.output_dir)
115
    #         if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0:
116
    #             raise ValueError(
117
    #                 f"Output directory ({args.output_dir}) already exists and is not empty. "
118
    #                 "Use --overwrite_output_dir to overcome."
119
    #             )
120
    #         elif last_checkpoint is not None and args.resume_from_checkpoint is None:
121
    #             logger.info(
122
    #                 f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
123
    #                 "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
124
    #             )
125
    #     train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
126
    #     metrics = train_result.metrics
127
    #     # metrics["train_samples"] = min(max_train_samples, len(train_dataset))
128

129
    #     trainer.save_model()  # Saves the tokenizer too for easy upload
130

131
    #     trainer.log_metrics("train", metrics)
132
    #     trainer.save_metrics("train", metrics)
133
    #     trainer.save_state()
134

135
    # results = {}
136
    # # Evaluation
137
    # if args.do_eval:
138
    #     logger.info("*** Evaluate ***")
139

140
    #     # Loop to handle MNLI double evaluation (matched, mis-matched)
141
    #     tasks = [data_args.task_name]
142
    #     eval_datasets = [eval_dataset]
143
    #     if data_args.task_name == "mnli":
144
    #         tasks.append("mnli-mm")
145
    #         eval_datasets.append(raw_datasets["validation_mismatched"])
146

147
    #     for eval_dataset, task in zip(eval_datasets, tasks):
148
    #         metrics = trainer.evaluate(eval_dataset=eval_dataset)
149

150
    #         max_eval_samples = (
151
    #             data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
152
    #         )
153
    #         metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
154

155
    #         trainer.log_metrics("eval", metrics)
156
    #         trainer.save_metrics("eval", metrics)
157
    #     results['eval'] = metrics
158

159
    # if args.do_predict:
160
    #     logger.info("*** Predict ***")
161

162
    #     # Loop to handle MNLI double evaluation (matched, mis-matched)
163
    #     tasks = [data_args.task_name]
164
    #     predict_datasets = [predict_dataset]
165
    #     if data_args.task_name == "mnli":
166
    #         tasks.append("mnli-mm")
167
    #         predict_datasets.append(raw_datasets["test_mismatched"])
168

169
    #     for predict_dataset, task in zip(predict_datasets, tasks):
170
    #         # Removing the `label` columns because it contains -1 and Trainer won't like that.
171
    #         predict_dataset = predict_dataset.remove_columns("label")
172
    #         predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
173
    #         predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
174

175
    #         output_predict_file = os.path.join(args.output_dir, f"predict_results_{task}.txt")
176
    #         if trainer.is_world_process_zero():
177
    #             with open(output_predict_file, "w") as writer:
178
    #                 logger.info(f"***** Predict results {task} *****")
179
    #                 writer.write("index\tprediction\n")
180
    #                 for index, item in enumerate(predictions):
181
    #                     if is_regression:
182
    #                         writer.write(f"{index}\t{item:3.3f}\n")
183
    #                     else:
184
    #                         item = label_list[item]
185
    #                         writer.write(f"{index}\t{item}\n")
186

187
if __name__ == "__main__":
188
    main()
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.