paddlenlp

Форк
0
190 строк · 7.4 Кб
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import os
17
import random
18
import time
19
from functools import partial
20

21
import numpy as np
22
import paddle
23
from data import convert_pointwise_example as convert_example
24
from data import create_dataloader
25
from model import PointwiseMatching
26

27
from paddlenlp.data import Pad, Stack, Tuple
28
from paddlenlp.datasets import load_dataset
29
from paddlenlp.transformers import AutoModel, AutoTokenizer, LinearDecayWithWarmup
30

31

32
def set_seed(seed):
33
    """sets random seed"""
34
    random.seed(seed)
35
    np.random.seed(seed)
36
    paddle.seed(seed)
37

38

39
@paddle.no_grad()
40
def evaluate(model, criterion, metric, data_loader, phase="dev"):
41
    """
42
    Given a dataset, it evals model and computes the metric.
43

44
    Args:
45
        model(obj:`paddle.nn.Layer`): A model to classify texts.
46
        data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
47
        criterion(obj:`paddle.nn.Layer`): It can compute the loss.
48
        metric(obj:`paddle.metric.Metric`): The evaluation metric.
49
    """
50
    model.eval()
51
    metric.reset()
52
    losses = []
53
    for batch in data_loader:
54
        input_ids, token_type_ids, labels = batch
55
        probs = model(input_ids=input_ids, token_type_ids=token_type_ids)
56
        loss = criterion(probs, labels)
57
        losses.append(loss.numpy())
58
        correct = metric.compute(probs, labels)
59
        metric.update(correct)
60
        accu = metric.accumulate()
61
    print("eval {} loss: {:.5}, accu: {:.5}".format(phase, np.mean(losses), accu))
62
    model.train()
63
    metric.reset()
64

65

66
def do_train(args):
67
    paddle.set_device(args.device)
68
    rank = paddle.distributed.get_rank()
69
    if paddle.distributed.get_world_size() > 1:
70
        paddle.distributed.init_parallel_env()
71

72
    set_seed(args.seed)
73

74
    train_ds, dev_ds = load_dataset("lcqmc", splits=["train", "dev"])
75

76
    pretrained_model = AutoModel.from_pretrained("ernie-3.0-medium-zh")
77
    tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-medium-zh")
78

79
    trans_func = partial(convert_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length)
80

81
    batchify_fn = lambda samples, fn=Tuple(
82
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # text_pair_input
83
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # text_pair_segment
84
        Stack(dtype="int64"),  # label
85
    ): [data for data in fn(samples)]
86

87
    train_data_loader = create_dataloader(
88
        train_ds, mode="train", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func
89
    )
90

91
    dev_data_loader = create_dataloader(
92
        dev_ds, mode="dev", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func
93
    )
94

95
    model = PointwiseMatching(pretrained_model)
96

97
    if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
98
        state_dict = paddle.load(args.init_from_ckpt)
99
        model.set_dict(state_dict)
100

101
    model = paddle.DataParallel(model)
102

103
    num_training_steps = len(train_data_loader) * args.epochs
104

105
    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, args.warmup_proportion)
106

107
    # Generate parameter names needed to perform weight decay.
108
    # All bias and LayerNorm parameters are excluded.
109
    decay_params = [p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
110
    optimizer = paddle.optimizer.AdamW(
111
        learning_rate=lr_scheduler,
112
        parameters=model.parameters(),
113
        weight_decay=args.weight_decay,
114
        apply_decay_param_fun=lambda x: x in decay_params,
115
    )
116

117
    criterion = paddle.nn.loss.CrossEntropyLoss()
118
    metric = paddle.metric.Accuracy()
119

120
    global_step = 0
121
    tic_train = time.time()
122
    for epoch in range(1, args.epochs + 1):
123
        for step, batch in enumerate(train_data_loader, start=1):
124
            input_ids, token_type_ids, labels = batch
125
            probs = model(input_ids=input_ids, token_type_ids=token_type_ids)
126
            loss = criterion(probs, labels)
127
            correct = metric.compute(probs, labels)
128
            metric.update(correct)
129
            acc = metric.accumulate()
130

131
            global_step += 1
132
            if global_step % 10 == 0 and rank == 0:
133
                print(
134
                    "global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s"
135
                    % (global_step, epoch, step, loss, acc, 10 / (time.time() - tic_train))
136
                )
137
                tic_train = time.time()
138
            loss.backward()
139
            optimizer.step()
140
            lr_scheduler.step()
141
            optimizer.clear_grad()
142

143
            if global_step % args.eval_step == 0 and rank == 0:
144
                evaluate(model, criterion, metric, dev_data_loader)
145
                save_dir = os.path.join(args.save_dir, "model")
146
                tokenizer.save_pretrained(save_dir)
147
                model_to_save = model._layers if isinstance(model, paddle.DataParallel) else model
148
                paddle.save(model_to_save.state_dict(), os.path.join(save_dir, "model_state.pdparams"))
149

150
            if global_step > args.max_steps:
151
                return
152

153

154
if __name__ == "__main__":
155
    parser = argparse.ArgumentParser()
156
    parser.add_argument(
157
        "--save_dir",
158
        default="./checkpoint",
159
        type=str,
160
        help="The output directory where the model checkpoints will be written.",
161
    )
162
    parser.add_argument(
163
        "--max_seq_length",
164
        default=128,
165
        type=int,
166
        help="The maximum total input sequence length after tokenization. "
167
        "Sequences longer than this will be truncated, sequences shorter will be padded.",
168
    )
169
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
170
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
171
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
172
    parser.add_argument("--epochs", default=3, type=int, help="Total number of training epochs to perform.")
173
    parser.add_argument("--eval_step", default=100, type=int, help="Step interval for evaluation.")
174
    parser.add_argument("--save_step", default=10000, type=int, help="Step interval for saving checkpoint.")
175
    parser.add_argument(
176
        "--warmup_proportion", default=0.0, type=float, help="Linear warmup proportion over the training process."
177
    )
178
    parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
179
    parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization.")
180
    parser.add_argument(
181
        "--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform."
182
    )
183
    parser.add_argument(
184
        "--device",
185
        choices=["cpu", "gpu", "npu", "xpu"],
186
        default="gpu",
187
        help="Select which device to train model, defaults to gpu.",
188
    )
189
    args = parser.parse_args()
190
    do_train(args)
191

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.