paddlenlp
190 строк · 7.4 Кб
1# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse16import os17import random18import time19from functools import partial20
21import numpy as np22import paddle23from data import convert_pointwise_example as convert_example24from data import create_dataloader25from model import PointwiseMatching26
27from paddlenlp.data import Pad, Stack, Tuple28from paddlenlp.datasets import load_dataset29from paddlenlp.transformers import AutoModel, AutoTokenizer, LinearDecayWithWarmup30
31
32def set_seed(seed):33"""sets random seed"""34random.seed(seed)35np.random.seed(seed)36paddle.seed(seed)37
38
39@paddle.no_grad()40def evaluate(model, criterion, metric, data_loader, phase="dev"):41"""42Given a dataset, it evals model and computes the metric.
43
44Args:
45model(obj:`paddle.nn.Layer`): A model to classify texts.
46data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
47criterion(obj:`paddle.nn.Layer`): It can compute the loss.
48metric(obj:`paddle.metric.Metric`): The evaluation metric.
49"""
50model.eval()51metric.reset()52losses = []53for batch in data_loader:54input_ids, token_type_ids, labels = batch55probs = model(input_ids=input_ids, token_type_ids=token_type_ids)56loss = criterion(probs, labels)57losses.append(loss.numpy())58correct = metric.compute(probs, labels)59metric.update(correct)60accu = metric.accumulate()61print("eval {} loss: {:.5}, accu: {:.5}".format(phase, np.mean(losses), accu))62model.train()63metric.reset()64
65
66def do_train(args):67paddle.set_device(args.device)68rank = paddle.distributed.get_rank()69if paddle.distributed.get_world_size() > 1:70paddle.distributed.init_parallel_env()71
72set_seed(args.seed)73
74train_ds, dev_ds = load_dataset("lcqmc", splits=["train", "dev"])75
76pretrained_model = AutoModel.from_pretrained("ernie-3.0-medium-zh")77tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-medium-zh")78
79trans_func = partial(convert_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length)80
81batchify_fn = lambda samples, fn=Tuple(82Pad(axis=0, pad_val=tokenizer.pad_token_id), # text_pair_input83Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # text_pair_segment84Stack(dtype="int64"), # label85): [data for data in fn(samples)]86
87train_data_loader = create_dataloader(88train_ds, mode="train", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func89)90
91dev_data_loader = create_dataloader(92dev_ds, mode="dev", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func93)94
95model = PointwiseMatching(pretrained_model)96
97if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):98state_dict = paddle.load(args.init_from_ckpt)99model.set_dict(state_dict)100
101model = paddle.DataParallel(model)102
103num_training_steps = len(train_data_loader) * args.epochs104
105lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, args.warmup_proportion)106
107# Generate parameter names needed to perform weight decay.108# All bias and LayerNorm parameters are excluded.109decay_params = [p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]110optimizer = paddle.optimizer.AdamW(111learning_rate=lr_scheduler,112parameters=model.parameters(),113weight_decay=args.weight_decay,114apply_decay_param_fun=lambda x: x in decay_params,115)116
117criterion = paddle.nn.loss.CrossEntropyLoss()118metric = paddle.metric.Accuracy()119
120global_step = 0121tic_train = time.time()122for epoch in range(1, args.epochs + 1):123for step, batch in enumerate(train_data_loader, start=1):124input_ids, token_type_ids, labels = batch125probs = model(input_ids=input_ids, token_type_ids=token_type_ids)126loss = criterion(probs, labels)127correct = metric.compute(probs, labels)128metric.update(correct)129acc = metric.accumulate()130
131global_step += 1132if global_step % 10 == 0 and rank == 0:133print(134"global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s"135% (global_step, epoch, step, loss, acc, 10 / (time.time() - tic_train))136)137tic_train = time.time()138loss.backward()139optimizer.step()140lr_scheduler.step()141optimizer.clear_grad()142
143if global_step % args.eval_step == 0 and rank == 0:144evaluate(model, criterion, metric, dev_data_loader)145save_dir = os.path.join(args.save_dir, "model")146tokenizer.save_pretrained(save_dir)147model_to_save = model._layers if isinstance(model, paddle.DataParallel) else model148paddle.save(model_to_save.state_dict(), os.path.join(save_dir, "model_state.pdparams"))149
150if global_step > args.max_steps:151return152
153
154if __name__ == "__main__":155parser = argparse.ArgumentParser()156parser.add_argument(157"--save_dir",158default="./checkpoint",159type=str,160help="The output directory where the model checkpoints will be written.",161)162parser.add_argument(163"--max_seq_length",164default=128,165type=int,166help="The maximum total input sequence length after tokenization. "167"Sequences longer than this will be truncated, sequences shorter will be padded.",168)169parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")170parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")171parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")172parser.add_argument("--epochs", default=3, type=int, help="Total number of training epochs to perform.")173parser.add_argument("--eval_step", default=100, type=int, help="Step interval for evaluation.")174parser.add_argument("--save_step", default=10000, type=int, help="Step interval for saving checkpoint.")175parser.add_argument(176"--warmup_proportion", default=0.0, type=float, help="Linear warmup proportion over the training process."177)178parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")179parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization.")180parser.add_argument(181"--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform."182)183parser.add_argument(184"--device",185choices=["cpu", "gpu", "npu", "xpu"],186default="gpu",187help="Select which device to train model, defaults to gpu.",188)189args = parser.parse_args()190do_train(args)191