OpenDelta

Форк
0
343 строки · 14.2 Кб
1
# coding=utf-8
2
# Copyright OpenDelta Team and THUNLP lab. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""
16
A unified runing scripts for most models to do down stream tasks in a
17
prompt learning fashion, i.e., No classification head, all tasks are casted
18
to mask prediction or span prediction tasks.
19

20
Processing relevant to different backbone models are stored in ../backbones/
21

22
Adding A few lines to integrate the Delta tuning methods.
23

24
You can also adapt this script on your own tasks.
25
"""
26

27
import os
28
import sys
29
os.environ['MKL_THREADING_LAYER'] = 'GNU'
30
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
31
os.environ["TOKENIZERS_PARALLELISM"] = "false"
32
sys.path.append(os.path.join(os.getcwd(), "../"))
33
sys.path.append(os.path.join(os.getcwd()))
34

35
import functools
36
import logging
37
import torch
38
import json
39
import numpy as np
40

41
import transformers
42
from transformers import (
43
    AutoConfig,
44
    AutoModelForMaskedLM,
45
    AutoModelForSeq2SeqLM,
46
    AutoTokenizer,
47
    DataCollatorForSeq2Seq,
48
    # HfArgumentParser,
49
    # MBartTokenizer,
50
    # default_data_collator,
51
    Trainer,
52
    Seq2SeqTrainer,
53
    set_seed,
54
)
55
from transformers.trainer_utils import is_main_process, get_last_checkpoint
56

57
from data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
58
from utils import read_json, save_json
59
from utils.args import ModelArguments, TrainingArguments, DataTrainingArguments, RemainArgHfArgumentParser, DeltaArguments
60

61

62
logger = logging.getLogger(__name__)
63

64

65
def main():
66
    parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, DeltaArguments))
67

68
    # You can provide a json file with contains the arguments and use the --argument some_arg to override or append to  the json file.
69
    json_file, cmd_args = (os.path.abspath(sys.argv[1]), sys.argv[2:]) if sys.argv[1].endswith(".json") else (None, sys.argv[1:])
70
    model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file_with_cmd_args(json_file=json_file, command_line_args=cmd_args)
71
    logger.warning("The following arguments not used! {}".format(remain_args))
72

73
    # # exit()
74
    # # Detecting last checkpoint.
75
    # last_checkpoint = None
76
    # if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
77
    #     last_checkpoint = get_last_checkpoint(training_args.output_dir)
78
    #     print("#### last_checkpoint ", last_checkpoint)
79
    #     if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
80
    #         '''
81
    #         raise ValueError(
82
    #             f"Output directory ({training_args.output_dir}) already exists and is not empty. "
83
    #             "Use --overwrite_output_dir to overcome."
84
    #         )
85
    #         '''
86
    #         pass
87
    #     elif last_checkpoint is not None:
88
    #         logger.info(
89
    #             f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
90
    #             "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
91
    #         )
92

93
    # Setup logging
94
    logging.basicConfig(
95
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
96
        datefmt="%m/%d/%Y %H:%M:%S",
97
        handlers=[logging.StreamHandler(sys.stdout)],
98
    )
99
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
100

101
    # Log on each process the small summary:
102
    logger.warning(
103
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
104
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
105
    )
106
    # Set the verbosity to info of the Transformers logger (on main process only):
107
    if is_main_process(training_args.local_rank):
108
        transformers.utils.logging.set_verbosity_info()
109
    # logger.info("Training/evaluation parameters %s", training_args, model_args, data_args, delta_args)
110
    logger.info("{}\n{}\n{}\n{}".format(training_args, model_args, data_args, delta_args))
111

112

113
    # Set seed before initializing model.
114
    set_seed(training_args.seed)
115

116

117

118
    if os.path.basename(model_args.model_name_or_path).startswith("t5"):
119
        from examples_prompt.backbones.t5 import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
120
        from examples_prompt.backbones.t5 import Trainer, DataCollator
121
    elif  os.path.basename(model_args.model_name_or_path).startswith("blenderbot"):
122
        from examples_prompt.backbones.blenderbot import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
123
        from examples_prompt.backbones.blenderbot import Trainer, DataCollator
124
    elif os.path.basename(model_args.model_name_or_path).startswith("roberta") \
125
        or os.path.basename(model_args.model_name_or_path).startswith("bert") \
126
          or os.path.basename(model_args.model_name_or_path).startswith("albert") :
127
        from examples_prompt.backbones.bert import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
128
        from examples_prompt.backbones.bert import Trainer, DataCollator
129
    elif os.path.basename(model_args.model_name_or_path).startswith("beit"):
130
        from examples_prompt.backbones.beit import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
131
        from examples_prompt.backbones.beit import Trainer, DataCollator
132
    elif os.path.basename(model_args.model_name_or_path).startswith("bart"):
133
        from examples_prompt.backbones.bart import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
134
        from examples_prompt.backbones.bart import Trainer, DataCollator
135
    elif os.path.basename(model_args.model_name_or_path).startswith("bigbird"):
136
        from examples_prompt.backbones.bigbird import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
137
        from examples_prompt.backbones.bigbird import Trainer, DataCollator
138
    elif os.path.basename(model_args.model_name_or_path).startswith("clip"):
139
        from examples_prompt.backbones.clip import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
140
        from examples_prompt.backbones.clip import Trainer, DataCollator
141

142

143

144
    config, tokenizer, model = get_backbone(model_args=model_args)
145

146
    # model parallelize
147
    if hasattr(training_args, "model_parallel") and training_args.model_parallel:
148
        logger.info('parallelize model!')
149
        model.parallelize()
150

151
    from bigmodelvis import Visualization
152
    Visualization(model).structure_graph()
153

154
    if delta_args.delta_type.lower() != "none":
155
        from opendelta.delta_models.adapter import AdapterConfig, AdapterModel
156
        delta_config = AdapterConfig.from_finetuned(finetuned_delta_path=delta_args.finetuned_delta_path)
157
        delta_model = AdapterModel.from_finetuned(finetuned_delta_path=delta_args.finetuned_delta_path,
158
                    delta_config=delta_config,
159
                    backbone_model=model,
160
                    force_download=delta_args.force_download,
161
                    cache_dir=delta_args.delta_cache_dir)
162
        # delta_model.freeze_module(set_state_dict = True)
163
        delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
164

165

166
    performance_metrics = {}
167

168

169

170

171
    non_empty_splits_names = []
172
    # if training_args.do_train:
173
    #     non_empty_splits_names.append("train")
174
    # if training_args.do_eval:
175
    #     non_empty_splits_names.append("eval")
176
    if training_args.do_test:
177
        non_empty_splits_names.append("test")
178
    splits = {}
179
    for split_name in ['test']:
180
        if split_name not in non_empty_splits_names:
181
            splits[split_name] = None
182
            continue
183

184
        task = AutoTask.get(data_args.task_name,
185
                            data_args.dataset_config_name,
186
                            data_args=data_args,
187
                            seed=data_args.data_sample_seed)
188

189
        dataset =  task.get(split=split_name,
190
                            split_validation_test=training_args.split_validation_test,
191
                            n_obs=data_args.max_train_samples)
192

193

194

195
        template, _verbalizer, tokenizer_wrapper = get_prompts(task, tokenizer, data_args)
196

197

198
        dataset = dataset.map(
199
                            functools.partial(preprocess_function,
200
                                            data_args=data_args,
201
                                            tokenizer=tokenizer,
202
                                            template=template,
203
                                            verbalizer=_verbalizer,
204
                                            tokenizer_wrapper=tokenizer_wrapper,
205
                                            split=split_name),
206
                            batched=False,
207
                            num_proc=data_args.preprocessing_num_workers,
208
                            remove_columns=get_remove_columns(list(dataset.features.keys())),
209
                            load_from_cache_file=not data_args.overwrite_cache,
210
                        )
211
        # from IPython import embed; embed()
212
        splits[split_name] = dataset
213
        if split_name == "test":
214
            eval_task = task
215
            verbalizer = _verbalizer
216

217

218

219
    trainer = Trainer(
220
        model=model,
221
        verbalizer=verbalizer,
222
        eval_task=eval_task,
223
        args=training_args,
224
        # train_dataset=splits['train'],
225
        # eval_dataset=splits['eval'],
226
        tokenizer=tokenizer,
227
        data_collator=DataCollator(tokenizer),
228
    )
229

230

231
    def save_training_config(config_file, output_dir):
232
        json_data = read_json(config_file)
233
        save_json(os.path.join(output_dir, "training_config.json"), json_data)
234

235

236
    # Saves training config.
237
    if trainer.is_world_process_zero():
238
        save_training_config(sys.argv[1], training_args.output_dir)
239

240
    # # Training
241
    # if training_args.do_train:
242
    #     checkpoint = None
243
    #     if training_args.resume_from_checkpoint is not None:
244
    #         checkpoint = training_args.resume_from_checkpoint
245
    #     elif last_checkpoint is not None:
246
    #         checkpoint = last_checkpoint
247

248
    #     if training_args.compute_time:
249
    #         torch.cuda.synchronize()  # wait for move to complete
250
    #         start = torch.cuda.Event(enable_timing=True)
251
    #         end = torch.cuda.Event(enable_timing=True)
252
    #         start.record()
253

254
    #     train_result = trainer.train(resume_from_checkpoint=checkpoint)
255

256
    #     if training_args.compute_time:
257
    #         end.record()
258
    #         torch.cuda.synchronize()  # wait for all_reduce to complete
259
    #         total_time = start.elapsed_time(end)/(1000*60)
260
    #         performance_metrics.update({"total_time in minutes ": total_time})
261

262
    #     trainer.save_model()  # Saves the tokenizer too for easy upload
263
    #     train_metrics = train_result.metrics
264
    #     max_train_samples = (
265
    #         data_args.max_train_samples if data_args.max_train_samples is not None else len(splits['train'])
266
    #     )
267
    #     train_metrics["train_samples"] = min(max_train_samples, len(splits['train']))
268
    #     trainer.log_metrics("train", train_metrics)
269
    #     trainer.save_metrics("train", train_metrics)
270
    #     trainer.save_state()
271

272
    # if torch.cuda.is_available() and training_args.compute_memory:
273
    #     peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2)/1000
274
    #     print(
275
    #         "Memory utilization",
276
    #         peak_memory,
277
    #         "GB"
278
    #     )
279
    #     performance_metrics.update({"peak_memory": peak_memory})
280
    # if training_args.compute_memory or training_args.compute_time:
281
    #     print("Efficiency Statistics {}".format(performance_metrics))
282
    #     trainer.save_metrics("performance", performance_metrics)
283

284
    # Evaluation
285
    all_results = {}
286

287
    # all_results['evaluate'] = {}
288

289
    # if training_args.do_eval:
290
    #     logger.info("*** Evaluate ***")
291

292
    #     metrics = trainer.evaluate(eval_dataset=splits['eval'],
293
    #     )
294
    #     trainer.log_metrics(f"{data_args.task_name}_eval", metrics)
295
    #     trainer.save_metrics(f"{data_args.task_name}_eval", metrics)
296
    #     all_results['evaluate'][data_args.task_name] = metrics
297

298
    # Test
299
    all_results['test'] = {}
300
    if training_args.do_test:
301
        logger.info("*** Test ***")
302
        metrics = trainer.evaluate(eval_dataset=splits['test'],
303
        metric_key_prefix="test"
304
        )
305
        trainer.log_metrics(f"{data_args.task_name}_test", metrics)
306
        trainer.save_metrics(f"{data_args.task_name}_test", metrics)
307
        all_results['test'][data_args.task_name] = metrics
308

309
    # from opendelta.utils.delta_hub import create_hub_repo_name
310
    # from opendelta.utils.delta_center import create_delta_center_args, create_repo_name
311

312
    # repo_name = create_hub_repo_name(root="DeltaHub",
313
    #                      dataset=data_args.task_name,
314
    #                      delta_type = delta_args.delta_type,
315
    #                      model_name_or_path= model_args.model_name_or_path)
316

317
    # center_args =
318
    # repo_name = create_repo_name(prefix="", center_args=center_args)
319
    # all_results['repo_name'] = repo_name
320

321

322
    # delta_model.save_finetuned(push_to_hf=training_args.push_to_hf,
323
    #                            push_to_dc=training_args.push_to_dc,
324
    #                            center_args={},
325
    #                            center_args_pool = {**vars(model_args), **vars(data_args), **vars(training_args), **vars(delta_args)},
326
    #                            delay_push=True,
327
    #                         )
328

329
    print(all_results)
330

331

332

333
    # with open(f"{training_args.output_dir}/results.json", 'w') as fout:
334
    #     string = json.dumps(all_results, indent=4,sort_keys=True)
335
    #     fout.write(string+"\n")
336

337
    return all_results
338

339

340

341

342
if __name__ == "__main__":
343
    result = main()
344

345

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.