paddlenlp

Форк
0
/
finetune_generation.py 
250 строк · 9.6 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import sys
17
from dataclasses import dataclass, field
18
from functools import partial
19

20
import paddle
21
from utils import (
22
    DataCollatorForSupervisedDataset,
23
    GPTTrainer,
24
    compute_metrics,
25
    convert_example,
26
)
27

28
from paddlenlp.datasets import load_dataset
29
from paddlenlp.peft import LoRAConfig, LoRAModel
30
from paddlenlp.trainer import (
31
    PdArgumentParser,
32
    TrainingArguments,
33
    get_last_checkpoint,
34
    set_seed,
35
)
36
from paddlenlp.transformers import (
37
    AutoTokenizer,
38
    GPTConfig,
39
    GPTForCausalLM,
40
    GPTForCausalLMPipe,
41
)
42
from paddlenlp.utils.log import logger
43

44
MODEL_CLASSES = {
45
    "gpt": (GPTConfig, GPTForCausalLM),
46
}
47

48

49
@dataclass
50
class DataArgument:
51
    task_name: str = field(default="squad", metadata={"help": "The name of task."})
52
    src_length: int = field(default=1024, metadata={"help": "The max length of source text."})
53
    tgt_length: int = field(default=142, metadata={"help": "The max length of target text."})
54
    generate_num: int = field(default=0, metadata={"help": "Save first k examples generation result in dev dataset"})
55

56

57
@dataclass
58
class ModelArgument:
59
    model_type: str = field(
60
        default="gpt-cn", metadata={"help": "Build-in pretrained model from the different model type."}
61
    )
62
    model_name_or_path: str = field(
63
        default="gpt-cpm-large-cn", metadata={"help": "Build-in pretrained model name or the path to local model."}
64
    )
65
    use_flash_attn: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
66
    enable_fuse_transformer: bool = field(
67
        default=False,
68
        metadata={"help": "gpt, enable_fuse_transformer"},
69
    )
70

71
    fuse_attention_qkv: bool = field(
72
        default=False,
73
        metadata={"help": "gpt, fuse_attention_qkv"},
74
    )
75
    eval_with_do_generation: bool = field(
76
        default=True, metadata={"help": "Evaluate with generation, instead for calc loss."}
77
    )
78
    lr_decay_ratio: float = field(default=0.1, metadata={"help": "The ratio for learning rate decrease"})
79
    # lora
80
    lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
81
    lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
82
    lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"})
83
    merge_weights: bool = field(
84
        default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
85
    )
86

87

88
def main():
89
    parser = PdArgumentParser((ModelArgument, DataArgument, TrainingArguments))
90
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
91
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
92
    else:
93
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
94
    # data_args.always_pad_to_max_length = False
95
    data_args.always_pad_to_max_length = training_args.pipeline_parallel_degree > 1
96
    setattr(training_args, "lr_decay_ratio", model_args.lr_decay_ratio)
97

98
    training_args.print_config(model_args, "Model")
99
    training_args.print_config(data_args, "Data")
100
    training_args.tgt_length = data_args.tgt_length
101
    paddle.set_device(training_args.device)
102

103
    set_seed(seed=training_args.seed)
104

105
    # Log on each process the small summary:
106
    logger.warning(
107
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
108
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
109
    )
110

111
    # Detecting last checkpoint.
112
    last_checkpoint = None
113
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
114
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
115
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1:
116
            raise ValueError(
117
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
118
                "Use --overwrite_output_dir to overcome."
119
            )
120
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
121
            logger.info(
122
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
123
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
124
            )
125

126
    # Set the dtype for loading model
127
    dtype = "float32"
128
    if training_args.fp16_opt_level == "O2":
129
        if training_args.fp16:
130
            dtype = "float16"
131
        if training_args.bf16:
132
            dtype = "bfloat16"
133

134
    config_class, model_class = MODEL_CLASSES[model_args.model_type]
135
    if training_args.pipeline_parallel_degree > 1:
136
        model_class = GPTForCausalLMPipe
137
    # Load the tokenizer
138
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
139
    tokenizer.padding_side = "left"
140

141
    # Load and set the pretrained configuration
142
    config = config_class.from_pretrained(model_args.model_name_or_path)
143
    config.enable_fuse_transformer = model_args.enable_fuse_transformer
144
    config.fuse_attention_qkv = model_args.fuse_attention_qkv
145
    config.use_flash_attn = model_args.use_flash_attn
146
    config.use_recompute = training_args.recompute
147

148
    config.tensor_parallel_degree = training_args.tensor_parallel_degree
149
    config.tensor_parallel_rank = training_args.tensor_parallel_rank
150
    config.ignore_index = tokenizer.pad_token_id
151

152
    model = model_class.from_pretrained(
153
        model_args.model_name_or_path,
154
        config=config,
155
        dtype=dtype,
156
    )
157
    if model_args.lora:
158
        if model_args.lora_path is None:
159
            target_modules = [
160
                ".*qkv_proj.*",
161
                ".*q_proj.*",
162
                ".*k_proj.*",
163
                ".*v_proj.*",
164
                ".*linear1.*",
165
                ".*linear2.*",
166
                ".*out_proj.*",
167
            ]
168
            lora_config = LoRAConfig(
169
                target_modules=target_modules,
170
                r=model_args.lora_rank,
171
                lora_alpha=2 * model_args.lora_rank,
172
                merge_weights=model_args.merge_weights,
173
                tensor_parallel_degree=training_args.tensor_parallel_degree,
174
                dtype=dtype,
175
            )
176
            model = LoRAModel(model, lora_config)
177
        else:
178
            model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)
179
        model.mark_only_lora_as_trainable()
180
        model.print_trainable_parameters()
181

182
    # Load the dataset.
183
    if training_args.do_train or training_args.do_eval:
184
        train_ds, dev_ds = load_dataset(data_args.task_name, splits=["train_v1", "dev_v1"])
185
        trans_func = partial(
186
            convert_example,
187
            tokenizer=tokenizer,
188
            max_source_length=data_args.src_length,
189
            max_target_length=data_args.tgt_length,
190
        )
191

192
    if training_args.do_train:
193
        train_ds = train_ds.map(partial(trans_func))
194
    if training_args.do_eval:
195
        is_test = model_args.eval_with_do_generation
196
        dev_ds = dev_ds.map(partial(trans_func, is_test=is_test))
197

198
    collate_fn = DataCollatorForSupervisedDataset(
199
        tokenizer, max_length=1024 if data_args.always_pad_to_max_length else 0
200
    )
201

202
    def compute_metrics_trainer(eval_preds, tokenizer):
203
        all_preds = []
204
        all_labels = []
205
        preds = eval_preds.predictions
206
        preds = [x[x != -100] for x in preds]
207
        all_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=False))
208
        labels = [x[x != -100] for x in eval_preds.label_ids]
209
        all_labels.extend(tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False))
210

211
        all_preds = [pred.strip() for pred in all_preds]
212
        all_labels = [label.strip() for label in all_labels]
213
        all_preds = [pred.strip("question:") for pred in all_preds]
214
        all_labels = [label.strip("question:") for label in all_labels]
215

216
        eval_result = compute_metrics(all_preds, all_labels)
217
        return eval_result
218

219
    compute_metrics_func = partial(
220
        compute_metrics_trainer,
221
        tokenizer=tokenizer,
222
    )
223

224
    trainer = GPTTrainer(
225
        model=model,
226
        args=training_args,
227
        train_dataset=train_ds if training_args.do_train else None,
228
        eval_dataset=dev_ds if training_args.do_eval else None,
229
        tokenizer=tokenizer,
230
        compute_metrics=compute_metrics_func
231
        if (model_args.eval_with_do_generation and training_args.do_eval)
232
        else None,
233
        do_generation=model_args.eval_with_do_generation,
234
        data_collator=collate_fn,
235
    )
236

237
    if training_args.do_train:
238
        train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
239
        trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
240
        trainer.log_metrics("train", train_result.metrics)
241
        trainer.save_metrics("train", train_result.metrics)
242
        trainer.save_state()
243

244
    if training_args.do_eval:
245
        eval_result = trainer.evaluate()
246
        trainer.log_metrics("test", eval_result)
247

248

249
if __name__ == "__main__":
250
    main()
251

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.