paddlenlp

Форк
0
/
run_pretrain_trainer.py 
166 строк · 5.7 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import time
17
from dataclasses import dataclass, field
18
from typing import Optional
19

20
import paddle
21
from dataset import DataCollatorForErnieHealth, MedicalCorpus
22

23
from paddlenlp.trainer import (
24
    PdArgumentParser,
25
    Trainer,
26
    TrainingArguments,
27
    get_last_checkpoint,
28
)
29
from paddlenlp.transformers import (
30
    ElectraConfig,
31
    ElectraTokenizer,
32
    ErnieHealthForTotalPretraining,
33
)
34
from paddlenlp.utils.log import logger
35

36
MODEL_CLASSES = {
37
    "ernie-health": (ElectraConfig, ErnieHealthForTotalPretraining, ElectraTokenizer),
38
}
39

40

41
@dataclass
42
class DataArguments:
43
    """
44
    Arguments pertaining to what data we are going to input our model for training and evaluating.
45
    Using `PdArgumentParser` we can turn this class into argparse arguments to be able to
46
    specify them on the command line.
47
    """
48

49
    input_dir: str = field(
50
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
51
    )
52
    max_seq_length: int = field(
53
        default=512,
54
        metadata={
55
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
56
            "than this will be truncated, sequences shorter will be padded."
57
        },
58
    )
59
    masked_lm_prob: float = field(
60
        default=0.15,
61
        metadata={"help": "Mask token prob."},
62
    )
63

64

65
@dataclass
66
class ModelArguments:
67
    """
68
    Arguments pertaining to which model/config/tokenizer we are going to pre-train from.
69
    """
70

71
    model_type: Optional[str] = field(
72
        default="ernie-health", metadata={"help": "Only support for ernie-health pre-training for now."}
73
    )
74
    model_name_or_path: str = field(
75
        default="ernie-health-chinese",
76
        metadata={
77
            "help": "Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html"
78
        },
79
    )
80

81

82
def main():
83
    parser = PdArgumentParser((ModelArguments, DataArguments, TrainingArguments))
84
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
85

86
    training_args.eval_iters = 10
87
    training_args.test_iters = training_args.eval_iters * 10
88
    # training_args.recompute = True
89

90
    # Log model and data config
91
    training_args.print_config(model_args, "Model")
92
    training_args.print_config(data_args, "Data")
93

94
    paddle.set_device(training_args.device)
95

96
    # Log on each process the small summary:
97
    logger.warning(
98
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
99
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
100
    )
101

102
    # Detecting last checkpoint.
103
    last_checkpoint = None
104
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
105
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
106
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1:
107
            raise ValueError(
108
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
109
                "Use --overwrite_output_dir to overcome."
110
            )
111
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
112
            logger.info(
113
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
114
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
115
            )
116

117
    config_class, model_class, tokenizer_class = MODEL_CLASSES["ernie-health"]
118

119
    # Loads or initialize a model.
120
    tokenizer = tokenizer_class.from_pretrained(model_args.model_name_or_path)
121

122
    model_config = config_class()
123
    model = model_class(model_config)
124

125
    # Loads dataset.
126
    tic_load_data = time.time()
127
    logger.info("start load data : %s" % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
128

129
    train_dataset = MedicalCorpus(data_path=data_args.input_dir, tokenizer=tokenizer)
130
    logger.info("load data done, total : %s s" % (time.time() - tic_load_data))
131

132
    # Reads data and generates mini-batches.
133
    data_collator = DataCollatorForErnieHealth(
134
        tokenizer=tokenizer,
135
        max_seq_length=data_args.max_seq_length,
136
        mlm_prob=data_args.masked_lm_prob,
137
        return_dict=True,
138
    )
139

140
    trainer = Trainer(
141
        model=model,
142
        args=training_args,
143
        data_collator=data_collator,
144
        train_dataset=train_dataset if training_args.do_train else None,
145
        eval_dataset=None,
146
        tokenizer=tokenizer,
147
    )
148

149
    checkpoint = None
150
    if training_args.resume_from_checkpoint is not None:
151
        checkpoint = training_args.resume_from_checkpoint
152
    elif last_checkpoint is not None:
153
        checkpoint = last_checkpoint
154

155
    # Training
156
    if training_args.do_train:
157
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
158
        metrics = train_result.metrics
159
        trainer.save_model()
160
        trainer.log_metrics("train", metrics)
161
        trainer.save_metrics("train", metrics)
162
        trainer.save_state()
163

164

165
if __name__ == "__main__":
166
    main()
167

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.