OpenDelta
343 строки · 14.2 Кб
1# coding=utf-8
2# Copyright OpenDelta Team and THUNLP lab. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16A unified runing scripts for most models to do down stream tasks in a
17prompt learning fashion, i.e., No classification head, all tasks are casted
18to mask prediction or span prediction tasks.
19
20Processing relevant to different backbone models are stored in ../backbones/
21
22Adding A few lines to integrate the Delta tuning methods.
23
24You can also adapt this script on your own tasks.
25"""
26
27import os
28import sys
29os.environ['MKL_THREADING_LAYER'] = 'GNU'
30os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
31os.environ["TOKENIZERS_PARALLELISM"] = "false"
32sys.path.append(os.path.join(os.getcwd(), "../"))
33sys.path.append(os.path.join(os.getcwd()))
34
35import functools
36import logging
37import torch
38import json
39import numpy as np
40
41import transformers
42from transformers import (
43AutoConfig,
44AutoModelForMaskedLM,
45AutoModelForSeq2SeqLM,
46AutoTokenizer,
47DataCollatorForSeq2Seq,
48# HfArgumentParser,
49# MBartTokenizer,
50# default_data_collator,
51Trainer,
52Seq2SeqTrainer,
53set_seed,
54)
55from transformers.trainer_utils import is_main_process, get_last_checkpoint
56
57from data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
58from utils import read_json, save_json
59from utils.args import ModelArguments, TrainingArguments, DataTrainingArguments, RemainArgHfArgumentParser, DeltaArguments
60
61
62logger = logging.getLogger(__name__)
63
64
65def main():
66parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, DeltaArguments))
67
68# You can provide a json file with contains the arguments and use the --argument some_arg to override or append to the json file.
69json_file, cmd_args = (os.path.abspath(sys.argv[1]), sys.argv[2:]) if sys.argv[1].endswith(".json") else (None, sys.argv[1:])
70model_args, data_args, training_args, delta_args, remain_args = parser.parse_json_file_with_cmd_args(json_file=json_file, command_line_args=cmd_args)
71logger.warning("The following arguments not used! {}".format(remain_args))
72
73# # exit()
74# # Detecting last checkpoint.
75# last_checkpoint = None
76# if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
77# last_checkpoint = get_last_checkpoint(training_args.output_dir)
78# print("#### last_checkpoint ", last_checkpoint)
79# if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
80# '''
81# raise ValueError(
82# f"Output directory ({training_args.output_dir}) already exists and is not empty. "
83# "Use --overwrite_output_dir to overcome."
84# )
85# '''
86# pass
87# elif last_checkpoint is not None:
88# logger.info(
89# f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
90# "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
91# )
92
93# Setup logging
94logging.basicConfig(
95format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
96datefmt="%m/%d/%Y %H:%M:%S",
97handlers=[logging.StreamHandler(sys.stdout)],
98)
99logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
100
101# Log on each process the small summary:
102logger.warning(
103f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
104+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
105)
106# Set the verbosity to info of the Transformers logger (on main process only):
107if is_main_process(training_args.local_rank):
108transformers.utils.logging.set_verbosity_info()
109# logger.info("Training/evaluation parameters %s", training_args, model_args, data_args, delta_args)
110logger.info("{}\n{}\n{}\n{}".format(training_args, model_args, data_args, delta_args))
111
112
113# Set seed before initializing model.
114set_seed(training_args.seed)
115
116
117
118if os.path.basename(model_args.model_name_or_path).startswith("t5"):
119from examples_prompt.backbones.t5 import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
120from examples_prompt.backbones.t5 import Trainer, DataCollator
121elif os.path.basename(model_args.model_name_or_path).startswith("blenderbot"):
122from examples_prompt.backbones.blenderbot import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
123from examples_prompt.backbones.blenderbot import Trainer, DataCollator
124elif os.path.basename(model_args.model_name_or_path).startswith("roberta") \
125or os.path.basename(model_args.model_name_or_path).startswith("bert") \
126or os.path.basename(model_args.model_name_or_path).startswith("albert") :
127from examples_prompt.backbones.bert import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
128from examples_prompt.backbones.bert import Trainer, DataCollator
129elif os.path.basename(model_args.model_name_or_path).startswith("beit"):
130from examples_prompt.backbones.beit import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
131from examples_prompt.backbones.beit import Trainer, DataCollator
132elif os.path.basename(model_args.model_name_or_path).startswith("bart"):
133from examples_prompt.backbones.bart import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
134from examples_prompt.backbones.bart import Trainer, DataCollator
135elif os.path.basename(model_args.model_name_or_path).startswith("bigbird"):
136from examples_prompt.backbones.bigbird import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
137from examples_prompt.backbones.bigbird import Trainer, DataCollator
138elif os.path.basename(model_args.model_name_or_path).startswith("clip"):
139from examples_prompt.backbones.clip import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
140from examples_prompt.backbones.clip import Trainer, DataCollator
141
142
143
144config, tokenizer, model = get_backbone(model_args=model_args)
145
146# model parallelize
147if hasattr(training_args, "model_parallel") and training_args.model_parallel:
148logger.info('parallelize model!')
149model.parallelize()
150
151from bigmodelvis import Visualization
152Visualization(model).structure_graph()
153
154if delta_args.delta_type.lower() != "none":
155from opendelta.delta_models.adapter import AdapterConfig, AdapterModel
156delta_config = AdapterConfig.from_finetuned(finetuned_delta_path=delta_args.finetuned_delta_path)
157delta_model = AdapterModel.from_finetuned(finetuned_delta_path=delta_args.finetuned_delta_path,
158delta_config=delta_config,
159backbone_model=model,
160force_download=delta_args.force_download,
161cache_dir=delta_args.delta_cache_dir)
162# delta_model.freeze_module(set_state_dict = True)
163delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
164
165
166performance_metrics = {}
167
168
169
170
171non_empty_splits_names = []
172# if training_args.do_train:
173# non_empty_splits_names.append("train")
174# if training_args.do_eval:
175# non_empty_splits_names.append("eval")
176if training_args.do_test:
177non_empty_splits_names.append("test")
178splits = {}
179for split_name in ['test']:
180if split_name not in non_empty_splits_names:
181splits[split_name] = None
182continue
183
184task = AutoTask.get(data_args.task_name,
185data_args.dataset_config_name,
186data_args=data_args,
187seed=data_args.data_sample_seed)
188
189dataset = task.get(split=split_name,
190split_validation_test=training_args.split_validation_test,
191n_obs=data_args.max_train_samples)
192
193
194
195template, _verbalizer, tokenizer_wrapper = get_prompts(task, tokenizer, data_args)
196
197
198dataset = dataset.map(
199functools.partial(preprocess_function,
200data_args=data_args,
201tokenizer=tokenizer,
202template=template,
203verbalizer=_verbalizer,
204tokenizer_wrapper=tokenizer_wrapper,
205split=split_name),
206batched=False,
207num_proc=data_args.preprocessing_num_workers,
208remove_columns=get_remove_columns(list(dataset.features.keys())),
209load_from_cache_file=not data_args.overwrite_cache,
210)
211# from IPython import embed; embed()
212splits[split_name] = dataset
213if split_name == "test":
214eval_task = task
215verbalizer = _verbalizer
216
217
218
219trainer = Trainer(
220model=model,
221verbalizer=verbalizer,
222eval_task=eval_task,
223args=training_args,
224# train_dataset=splits['train'],
225# eval_dataset=splits['eval'],
226tokenizer=tokenizer,
227data_collator=DataCollator(tokenizer),
228)
229
230
231def save_training_config(config_file, output_dir):
232json_data = read_json(config_file)
233save_json(os.path.join(output_dir, "training_config.json"), json_data)
234
235
236# Saves training config.
237if trainer.is_world_process_zero():
238save_training_config(sys.argv[1], training_args.output_dir)
239
240# # Training
241# if training_args.do_train:
242# checkpoint = None
243# if training_args.resume_from_checkpoint is not None:
244# checkpoint = training_args.resume_from_checkpoint
245# elif last_checkpoint is not None:
246# checkpoint = last_checkpoint
247
248# if training_args.compute_time:
249# torch.cuda.synchronize() # wait for move to complete
250# start = torch.cuda.Event(enable_timing=True)
251# end = torch.cuda.Event(enable_timing=True)
252# start.record()
253
254# train_result = trainer.train(resume_from_checkpoint=checkpoint)
255
256# if training_args.compute_time:
257# end.record()
258# torch.cuda.synchronize() # wait for all_reduce to complete
259# total_time = start.elapsed_time(end)/(1000*60)
260# performance_metrics.update({"total_time in minutes ": total_time})
261
262# trainer.save_model() # Saves the tokenizer too for easy upload
263# train_metrics = train_result.metrics
264# max_train_samples = (
265# data_args.max_train_samples if data_args.max_train_samples is not None else len(splits['train'])
266# )
267# train_metrics["train_samples"] = min(max_train_samples, len(splits['train']))
268# trainer.log_metrics("train", train_metrics)
269# trainer.save_metrics("train", train_metrics)
270# trainer.save_state()
271
272# if torch.cuda.is_available() and training_args.compute_memory:
273# peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2)/1000
274# print(
275# "Memory utilization",
276# peak_memory,
277# "GB"
278# )
279# performance_metrics.update({"peak_memory": peak_memory})
280# if training_args.compute_memory or training_args.compute_time:
281# print("Efficiency Statistics {}".format(performance_metrics))
282# trainer.save_metrics("performance", performance_metrics)
283
284# Evaluation
285all_results = {}
286
287# all_results['evaluate'] = {}
288
289# if training_args.do_eval:
290# logger.info("*** Evaluate ***")
291
292# metrics = trainer.evaluate(eval_dataset=splits['eval'],
293# )
294# trainer.log_metrics(f"{data_args.task_name}_eval", metrics)
295# trainer.save_metrics(f"{data_args.task_name}_eval", metrics)
296# all_results['evaluate'][data_args.task_name] = metrics
297
298# Test
299all_results['test'] = {}
300if training_args.do_test:
301logger.info("*** Test ***")
302metrics = trainer.evaluate(eval_dataset=splits['test'],
303metric_key_prefix="test"
304)
305trainer.log_metrics(f"{data_args.task_name}_test", metrics)
306trainer.save_metrics(f"{data_args.task_name}_test", metrics)
307all_results['test'][data_args.task_name] = metrics
308
309# from opendelta.utils.delta_hub import create_hub_repo_name
310# from opendelta.utils.delta_center import create_delta_center_args, create_repo_name
311
312# repo_name = create_hub_repo_name(root="DeltaHub",
313# dataset=data_args.task_name,
314# delta_type = delta_args.delta_type,
315# model_name_or_path= model_args.model_name_or_path)
316
317# center_args =
318# repo_name = create_repo_name(prefix="", center_args=center_args)
319# all_results['repo_name'] = repo_name
320
321
322# delta_model.save_finetuned(push_to_hf=training_args.push_to_hf,
323# push_to_dc=training_args.push_to_dc,
324# center_args={},
325# center_args_pool = {**vars(model_args), **vars(data_args), **vars(training_args), **vars(delta_args)},
326# delay_push=True,
327# )
328
329print(all_results)
330
331
332
333# with open(f"{training_args.output_dir}/results.json", 'w') as fout:
334# string = json.dumps(all_results, indent=4,sort_keys=True)
335# fout.write(string+"\n")
336
337return all_results
338
339
340
341
342if __name__ == "__main__":
343result = main()
344
345