paddlenlp
166 строк · 5.7 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os
16import time
17from dataclasses import dataclass, field
18from typing import Optional
19
20import paddle
21from dataset import DataCollatorForErnieHealth, MedicalCorpus
22
23from paddlenlp.trainer import (
24PdArgumentParser,
25Trainer,
26TrainingArguments,
27get_last_checkpoint,
28)
29from paddlenlp.transformers import (
30ElectraConfig,
31ElectraTokenizer,
32ErnieHealthForTotalPretraining,
33)
34from paddlenlp.utils.log import logger
35
36MODEL_CLASSES = {
37"ernie-health": (ElectraConfig, ErnieHealthForTotalPretraining, ElectraTokenizer),
38}
39
40
41@dataclass
42class DataArguments:
43"""
44Arguments pertaining to what data we are going to input our model for training and evaluating.
45Using `PdArgumentParser` we can turn this class into argparse arguments to be able to
46specify them on the command line.
47"""
48
49input_dir: str = field(
50default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
51)
52max_seq_length: int = field(
53default=512,
54metadata={
55"help": "The maximum total input sequence length after tokenization. Sequences longer "
56"than this will be truncated, sequences shorter will be padded."
57},
58)
59masked_lm_prob: float = field(
60default=0.15,
61metadata={"help": "Mask token prob."},
62)
63
64
65@dataclass
66class ModelArguments:
67"""
68Arguments pertaining to which model/config/tokenizer we are going to pre-train from.
69"""
70
71model_type: Optional[str] = field(
72default="ernie-health", metadata={"help": "Only support for ernie-health pre-training for now."}
73)
74model_name_or_path: str = field(
75default="ernie-health-chinese",
76metadata={
77"help": "Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html"
78},
79)
80
81
82def main():
83parser = PdArgumentParser((ModelArguments, DataArguments, TrainingArguments))
84model_args, data_args, training_args = parser.parse_args_into_dataclasses()
85
86training_args.eval_iters = 10
87training_args.test_iters = training_args.eval_iters * 10
88# training_args.recompute = True
89
90# Log model and data config
91training_args.print_config(model_args, "Model")
92training_args.print_config(data_args, "Data")
93
94paddle.set_device(training_args.device)
95
96# Log on each process the small summary:
97logger.warning(
98f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
99+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
100)
101
102# Detecting last checkpoint.
103last_checkpoint = None
104if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
105last_checkpoint = get_last_checkpoint(training_args.output_dir)
106if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1:
107raise ValueError(
108f"Output directory ({training_args.output_dir}) already exists and is not empty. "
109"Use --overwrite_output_dir to overcome."
110)
111elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
112logger.info(
113f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
114"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
115)
116
117config_class, model_class, tokenizer_class = MODEL_CLASSES["ernie-health"]
118
119# Loads or initialize a model.
120tokenizer = tokenizer_class.from_pretrained(model_args.model_name_or_path)
121
122model_config = config_class()
123model = model_class(model_config)
124
125# Loads dataset.
126tic_load_data = time.time()
127logger.info("start load data : %s" % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
128
129train_dataset = MedicalCorpus(data_path=data_args.input_dir, tokenizer=tokenizer)
130logger.info("load data done, total : %s s" % (time.time() - tic_load_data))
131
132# Reads data and generates mini-batches.
133data_collator = DataCollatorForErnieHealth(
134tokenizer=tokenizer,
135max_seq_length=data_args.max_seq_length,
136mlm_prob=data_args.masked_lm_prob,
137return_dict=True,
138)
139
140trainer = Trainer(
141model=model,
142args=training_args,
143data_collator=data_collator,
144train_dataset=train_dataset if training_args.do_train else None,
145eval_dataset=None,
146tokenizer=tokenizer,
147)
148
149checkpoint = None
150if training_args.resume_from_checkpoint is not None:
151checkpoint = training_args.resume_from_checkpoint
152elif last_checkpoint is not None:
153checkpoint = last_checkpoint
154
155# Training
156if training_args.do_train:
157train_result = trainer.train(resume_from_checkpoint=checkpoint)
158metrics = train_result.metrics
159trainer.save_model()
160trainer.log_metrics("train", metrics)
161trainer.save_metrics("train", metrics)
162trainer.save_state()
163
164
165if __name__ == "__main__":
166main()
167