1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
17
from dataclasses import dataclass, field
18
from functools import partial
22
DataCollatorForSupervisedDataset,
28
from paddlenlp.datasets import load_dataset
29
from paddlenlp.peft import LoRAConfig, LoRAModel
30
from paddlenlp.trainer import (
36
from paddlenlp.transformers import (
42
from paddlenlp.utils.log import logger
45
"gpt": (GPTConfig, GPTForCausalLM),
51
task_name: str = field(default="squad", metadata={"help": "The name of task."})
52
src_length: int = field(default=1024, metadata={"help": "The max length of source text."})
53
tgt_length: int = field(default=142, metadata={"help": "The max length of target text."})
54
generate_num: int = field(default=0, metadata={"help": "Save first k examples generation result in dev dataset"})
59
model_type: str = field(
60
default="gpt-cn", metadata={"help": "Build-in pretrained model from the different model type."}
62
model_name_or_path: str = field(
63
default="gpt-cpm-large-cn", metadata={"help": "Build-in pretrained model name or the path to local model."}
65
use_flash_attn: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
66
enable_fuse_transformer: bool = field(
68
metadata={"help": "gpt, enable_fuse_transformer"},
71
fuse_attention_qkv: bool = field(
73
metadata={"help": "gpt, fuse_attention_qkv"},
75
eval_with_do_generation: bool = field(
76
default=True, metadata={"help": "Evaluate with generation, instead for calc loss."}
78
lr_decay_ratio: float = field(default=0.1, metadata={"help": "The ratio for learning rate decrease"})
80
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
81
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
82
lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"})
83
merge_weights: bool = field(
84
default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
89
parser = PdArgumentParser((ModelArgument, DataArgument, TrainingArguments))
90
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
91
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
93
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
94
# data_args.always_pad_to_max_length = False
95
data_args.always_pad_to_max_length = training_args.pipeline_parallel_degree > 1
96
setattr(training_args, "lr_decay_ratio", model_args.lr_decay_ratio)
98
training_args.print_config(model_args, "Model")
99
training_args.print_config(data_args, "Data")
100
training_args.tgt_length = data_args.tgt_length
101
paddle.set_device(training_args.device)
103
set_seed(seed=training_args.seed)
105
# Log on each process the small summary:
107
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
108
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
111
# Detecting last checkpoint.
112
last_checkpoint = None
113
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
114
last_checkpoint = get_last_checkpoint(training_args.output_dir)
115
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1:
117
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
118
"Use --overwrite_output_dir to overcome."
120
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
122
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
123
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
126
# Set the dtype for loading model
128
if training_args.fp16_opt_level == "O2":
129
if training_args.fp16:
131
if training_args.bf16:
134
config_class, model_class = MODEL_CLASSES[model_args.model_type]
135
if training_args.pipeline_parallel_degree > 1:
136
model_class = GPTForCausalLMPipe
138
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
139
tokenizer.padding_side = "left"
141
# Load and set the pretrained configuration
142
config = config_class.from_pretrained(model_args.model_name_or_path)
143
config.enable_fuse_transformer = model_args.enable_fuse_transformer
144
config.fuse_attention_qkv = model_args.fuse_attention_qkv
145
config.use_flash_attn = model_args.use_flash_attn
146
config.use_recompute = training_args.recompute
148
config.tensor_parallel_degree = training_args.tensor_parallel_degree
149
config.tensor_parallel_rank = training_args.tensor_parallel_rank
150
config.ignore_index = tokenizer.pad_token_id
152
model = model_class.from_pretrained(
153
model_args.model_name_or_path,
158
if model_args.lora_path is None:
168
lora_config = LoRAConfig(
169
target_modules=target_modules,
170
r=model_args.lora_rank,
171
lora_alpha=2 * model_args.lora_rank,
172
merge_weights=model_args.merge_weights,
173
tensor_parallel_degree=training_args.tensor_parallel_degree,
176
model = LoRAModel(model, lora_config)
178
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)
179
model.mark_only_lora_as_trainable()
180
model.print_trainable_parameters()
183
if training_args.do_train or training_args.do_eval:
184
train_ds, dev_ds = load_dataset(data_args.task_name, splits=["train_v1", "dev_v1"])
185
trans_func = partial(
188
max_source_length=data_args.src_length,
189
max_target_length=data_args.tgt_length,
192
if training_args.do_train:
193
train_ds = train_ds.map(partial(trans_func))
194
if training_args.do_eval:
195
is_test = model_args.eval_with_do_generation
196
dev_ds = dev_ds.map(partial(trans_func, is_test=is_test))
198
collate_fn = DataCollatorForSupervisedDataset(
199
tokenizer, max_length=1024 if data_args.always_pad_to_max_length else 0
202
def compute_metrics_trainer(eval_preds, tokenizer):
205
preds = eval_preds.predictions
206
preds = [x[x != -100] for x in preds]
207
all_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=False))
208
labels = [x[x != -100] for x in eval_preds.label_ids]
209
all_labels.extend(tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False))
211
all_preds = [pred.strip() for pred in all_preds]
212
all_labels = [label.strip() for label in all_labels]
213
all_preds = [pred.strip("question:") for pred in all_preds]
214
all_labels = [label.strip("question:") for label in all_labels]
216
eval_result = compute_metrics(all_preds, all_labels)
219
compute_metrics_func = partial(
220
compute_metrics_trainer,
224
trainer = GPTTrainer(
227
train_dataset=train_ds if training_args.do_train else None,
228
eval_dataset=dev_ds if training_args.do_eval else None,
230
compute_metrics=compute_metrics_func
231
if (model_args.eval_with_do_generation and training_args.do_eval)
233
do_generation=model_args.eval_with_do_generation,
234
data_collator=collate_fn,
237
if training_args.do_train:
238
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
239
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
240
trainer.log_metrics("train", train_result.metrics)
241
trainer.save_metrics("train", train_result.metrics)
244
if training_args.do_eval:
245
eval_result = trainer.evaluate()
246
trainer.log_metrics("test", eval_result)
249
if __name__ == "__main__":