paddlenlp

Форк
0
/
train_classification.py 
263 строки · 10.3 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import distutils.util
17
import os
18
import random
19
import time
20
from functools import partial
21

22
import numpy as np
23
import paddle
24
import paddle.nn.functional as F
25
from paddle.metric import Accuracy
26
from utils import LinearDecayWithWarmup, convert_example, create_dataloader
27

28
from paddlenlp.data import Pad, Stack, Tuple
29
from paddlenlp.datasets import load_dataset
30
from paddlenlp.metrics import AccuracyAndF1, MultiLabelsMetric
31
from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer
32

33
METRIC_CLASSES = {
34
    "KUAKE-QIC": Accuracy,
35
    "KUAKE-QQR": Accuracy,
36
    "KUAKE-QTR": Accuracy,
37
    "CHIP-CTC": MultiLabelsMetric,
38
    "CHIP-STS": MultiLabelsMetric,
39
    "CHIP-CDN-2C": AccuracyAndF1,
40
}
41

42
parser = argparse.ArgumentParser()
43
parser.add_argument(
44
    "--dataset",
45
    choices=["KUAKE-QIC", "KUAKE-QQR", "KUAKE-QTR", "CHIP-STS", "CHIP-CTC", "CHIP-CDN-2C"],
46
    default="KUAKE-QIC",
47
    type=str,
48
    help="Dataset for sequence classfication tasks.",
49
)
50
parser.add_argument("--seed", default=1000, type=int, help="Random seed for initialization.")
51
parser.add_argument(
52
    "--device",
53
    choices=["cpu", "gpu", "xpu", "npu"],
54
    default="gpu",
55
    help="Select which device to train model, default to gpu.",
56
)
57
parser.add_argument("--epochs", default=3, type=int, help="Total number of training epochs.")
58
parser.add_argument(
59
    "--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override epochs."
60
)
61
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
62
parser.add_argument(
63
    "--learning_rate", default=6e-5, type=float, help="Learning rate for fine-tuning sequence classification task."
64
)
65
parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay of optimizer if we apply some.")
66
parser.add_argument(
67
    "--warmup_proportion",
68
    default=0.1,
69
    type=float,
70
    help="Linear warmup proportion of learning rate over the training process.",
71
)
72
parser.add_argument(
73
    "--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization."
74
)
75
parser.add_argument("--init_from_ckpt", default=None, type=str, help="The path of checkpoint to be loaded.")
76
parser.add_argument("--logging_steps", default=10, type=int, help="The interval steps to logging.")
77
parser.add_argument(
78
    "--save_dir",
79
    default="./checkpoint",
80
    type=str,
81
    help="The output directory where the model checkpoints will be written.",
82
)
83
parser.add_argument("--save_steps", default=100, type=int, help="The interval steps to save checkpoints.")
84
parser.add_argument("--valid_steps", default=100, type=int, help="The interval steps to evaluate model performance.")
85
parser.add_argument("--use_amp", default=False, type=distutils.util.strtobool, help="Enable mixed precision training.")
86
parser.add_argument("--scale_loss", default=128, type=float, help="The value of scale_loss for fp16.")
87

88
args = parser.parse_args()
89

90

91
def set_seed(seed):
92
    """set random seed"""
93
    random.seed(seed)
94
    np.random.seed(seed)
95
    paddle.seed(seed)
96

97

98
@paddle.no_grad()
99
def evaluate(model, criterion, metric, data_loader):
100
    """
101
    Given a dataset, it evals model and compute the metric.
102

103
    Args:
104
        model(obj:`paddle.nn.Layer`): A model to classify texts.
105
        dataloader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
106
        criterion(obj:`paddle.nn.Layer`): It can compute the loss.
107
        metric(obj:`paddle.metric.Metric`): The evaluation metric.
108
    """
109
    model.eval()
110
    metric.reset()
111
    losses = []
112
    for batch in data_loader:
113
        input_ids, token_type_ids, position_ids, labels = batch
114
        logits = model(input_ids, token_type_ids, position_ids)
115
        loss = criterion(logits, labels)
116
        losses.append(loss.numpy())
117
        correct = metric.compute(logits, labels)
118
        metric.update(correct)
119
    if isinstance(metric, Accuracy):
120
        metric_name = "accuracy"
121
        result = metric.accumulate()
122
    elif isinstance(metric, MultiLabelsMetric):
123
        metric_name = "macro f1"
124
        _, _, result = metric.accumulate("macro")
125
    else:
126
        metric_name = "micro f1"
127
        _, _, _, result, _ = metric.accumulate()
128

129
    print("eval loss: %.5f, %s: %.5f" % (np.mean(losses), metric_name, result))
130
    model.train()
131
    metric.reset()
132

133

134
def do_train():
135
    paddle.set_device(args.device)
136
    rank = paddle.distributed.get_rank()
137
    if paddle.distributed.get_world_size() > 1:
138
        paddle.distributed.init_parallel_env()
139

140
    set_seed(args.seed)
141

142
    train_ds, dev_ds = load_dataset("cblue", args.dataset, splits=["train", "dev"])
143

144
    model = ElectraForSequenceClassification.from_pretrained(
145
        "ernie-health-chinese", num_labels=len(train_ds.label_list)
146
    )
147
    tokenizer = ElectraTokenizer.from_pretrained("ernie-health-chinese")
148

149
    trans_func = partial(convert_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length)
150
    batchify_fn = lambda samples, fn=Tuple(  # noqa: E731
151
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  # input
152
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"),  # segment
153
        Pad(axis=0, pad_val=args.max_seq_length - 1, dtype="int64"),  # position
154
        Stack(dtype="int64"),
155
    ): [data for data in fn(samples)]
156
    train_data_loader = create_dataloader(
157
        train_ds, mode="train", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func
158
    )
159
    dev_data_loader = create_dataloader(
160
        dev_ds, mode="dev", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func
161
    )
162

163
    if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
164
        state_dict = paddle.load(args.init_from_ckpt)
165
        state_keys = {x: x.replace("discriminator.", "") for x in state_dict.keys() if "discriminator." in x}
166
        if len(state_keys) > 0:
167
            state_dict = {state_keys[k]: state_dict[k] for k in state_keys.keys()}
168
        model.set_dict(state_dict)
169
    if paddle.distributed.get_world_size() > 1:
170
        model = paddle.DataParallel(model)
171

172
    num_training_steps = args.max_steps if args.max_steps > 0 else len(train_data_loader) * args.epochs
173
    args.epochs = (num_training_steps - 1) // len(train_data_loader) + 1
174

175
    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, args.warmup_proportion)
176

177
    # Generate parameter names needed to perform weight decay.
178
    # All bias and LayerNorm parameters are excluded.
179
    decay_params = [p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
180

181
    optimizer = paddle.optimizer.AdamW(
182
        learning_rate=lr_scheduler,
183
        parameters=model.parameters(),
184
        weight_decay=args.weight_decay,
185
        apply_decay_param_fun=lambda x: x in decay_params,
186
    )
187

188
    criterion = paddle.nn.loss.CrossEntropyLoss()
189
    if METRIC_CLASSES[args.dataset] is Accuracy:
190
        metric = METRIC_CLASSES[args.dataset]()
191
        metric_name = "accuracy"
192
    elif METRIC_CLASSES[args.dataset] is MultiLabelsMetric:
193
        metric = METRIC_CLASSES[args.dataset](num_labels=len(train_ds.label_list))
194
        metric_name = "macro f1"
195
    else:
196
        metric = METRIC_CLASSES[args.dataset]()
197
        metric_name = "micro f1"
198
    if args.use_amp:
199
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
200
    global_step = 0
201
    tic_train = time.time()
202
    total_train_time = 0
203
    for epoch in range(1, args.epochs + 1):
204
        for step, batch in enumerate(train_data_loader, start=1):
205
            input_ids, token_type_ids, position_ids, labels = batch
206
            with paddle.amp.auto_cast(
207
                args.use_amp,
208
                custom_white_list=["layer_norm", "softmax", "gelu", "tanh"],
209
            ):
210
                logits = model(input_ids, token_type_ids, position_ids)
211
                loss = criterion(logits, labels)
212
            probs = F.softmax(logits, axis=1)
213
            correct = metric.compute(probs, labels)
214
            metric.update(correct)
215

216
            if isinstance(metric, Accuracy):
217
                result = metric.accumulate()
218
            elif isinstance(metric, MultiLabelsMetric):
219
                _, _, result = metric.accumulate("macro")
220
            else:
221
                _, _, _, result, _ = metric.accumulate()
222

223
            if args.use_amp:
224
                scaler.scale(loss).backward()
225
                scaler.minimize(optimizer, loss)
226
            else:
227
                loss.backward()
228
                optimizer.step()
229
            lr_scheduler.step()
230
            optimizer.clear_grad()
231

232
            global_step += 1
233
            if global_step % args.logging_steps == 0 and rank == 0:
234
                time_diff = time.time() - tic_train
235
                total_train_time += time_diff
236
                print(
237
                    "global step %d, epoch: %d, batch: %d, loss: %.5f, %s: %.5f, speed: %.2f step/s"
238
                    % (global_step, epoch, step, loss, metric_name, result, args.logging_steps / time_diff)
239
                )
240

241
            if global_step % args.valid_steps == 0 and rank == 0:
242
                evaluate(model, criterion, metric, dev_data_loader)
243

244
            if global_step % args.save_steps == 0 and rank == 0:
245
                save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
246
                if not os.path.exists(save_dir):
247
                    os.makedirs(save_dir)
248
                if paddle.distributed.get_world_size() > 1:
249
                    model._layers.save_pretrained(save_dir)
250
                else:
251
                    model.save_pretrained(save_dir)
252
                tokenizer.save_pretrained(save_dir)
253

254
            if global_step >= num_training_steps:
255
                return
256
            tic_train = time.time()
257

258
    if rank == 0 and total_train_time > 0:
259
        print("Speed: %.2f steps/s" % (global_step / total_train_time))
260

261

262
if __name__ == "__main__":
263
    do_train()
264

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.