paddlenlp
263 строки · 10.3 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse16import distutils.util17import os18import random19import time20from functools import partial21
22import numpy as np23import paddle24import paddle.nn.functional as F25from paddle.metric import Accuracy26from utils import LinearDecayWithWarmup, convert_example, create_dataloader27
28from paddlenlp.data import Pad, Stack, Tuple29from paddlenlp.datasets import load_dataset30from paddlenlp.metrics import AccuracyAndF1, MultiLabelsMetric31from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer32
33METRIC_CLASSES = {34"KUAKE-QIC": Accuracy,35"KUAKE-QQR": Accuracy,36"KUAKE-QTR": Accuracy,37"CHIP-CTC": MultiLabelsMetric,38"CHIP-STS": MultiLabelsMetric,39"CHIP-CDN-2C": AccuracyAndF1,40}
41
42parser = argparse.ArgumentParser()43parser.add_argument(44"--dataset",45choices=["KUAKE-QIC", "KUAKE-QQR", "KUAKE-QTR", "CHIP-STS", "CHIP-CTC", "CHIP-CDN-2C"],46default="KUAKE-QIC",47type=str,48help="Dataset for sequence classfication tasks.",49)
50parser.add_argument("--seed", default=1000, type=int, help="Random seed for initialization.")51parser.add_argument(52"--device",53choices=["cpu", "gpu", "xpu", "npu"],54default="gpu",55help="Select which device to train model, default to gpu.",56)
57parser.add_argument("--epochs", default=3, type=int, help="Total number of training epochs.")58parser.add_argument(59"--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override epochs."60)
61parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")62parser.add_argument(63"--learning_rate", default=6e-5, type=float, help="Learning rate for fine-tuning sequence classification task."64)
65parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay of optimizer if we apply some.")66parser.add_argument(67"--warmup_proportion",68default=0.1,69type=float,70help="Linear warmup proportion of learning rate over the training process.",71)
72parser.add_argument(73"--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization."74)
75parser.add_argument("--init_from_ckpt", default=None, type=str, help="The path of checkpoint to be loaded.")76parser.add_argument("--logging_steps", default=10, type=int, help="The interval steps to logging.")77parser.add_argument(78"--save_dir",79default="./checkpoint",80type=str,81help="The output directory where the model checkpoints will be written.",82)
83parser.add_argument("--save_steps", default=100, type=int, help="The interval steps to save checkpoints.")84parser.add_argument("--valid_steps", default=100, type=int, help="The interval steps to evaluate model performance.")85parser.add_argument("--use_amp", default=False, type=distutils.util.strtobool, help="Enable mixed precision training.")86parser.add_argument("--scale_loss", default=128, type=float, help="The value of scale_loss for fp16.")87
88args = parser.parse_args()89
90
91def set_seed(seed):92"""set random seed"""93random.seed(seed)94np.random.seed(seed)95paddle.seed(seed)96
97
98@paddle.no_grad()99def evaluate(model, criterion, metric, data_loader):100"""101Given a dataset, it evals model and compute the metric.
102
103Args:
104model(obj:`paddle.nn.Layer`): A model to classify texts.
105dataloader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
106criterion(obj:`paddle.nn.Layer`): It can compute the loss.
107metric(obj:`paddle.metric.Metric`): The evaluation metric.
108"""
109model.eval()110metric.reset()111losses = []112for batch in data_loader:113input_ids, token_type_ids, position_ids, labels = batch114logits = model(input_ids, token_type_ids, position_ids)115loss = criterion(logits, labels)116losses.append(loss.numpy())117correct = metric.compute(logits, labels)118metric.update(correct)119if isinstance(metric, Accuracy):120metric_name = "accuracy"121result = metric.accumulate()122elif isinstance(metric, MultiLabelsMetric):123metric_name = "macro f1"124_, _, result = metric.accumulate("macro")125else:126metric_name = "micro f1"127_, _, _, result, _ = metric.accumulate()128
129print("eval loss: %.5f, %s: %.5f" % (np.mean(losses), metric_name, result))130model.train()131metric.reset()132
133
134def do_train():135paddle.set_device(args.device)136rank = paddle.distributed.get_rank()137if paddle.distributed.get_world_size() > 1:138paddle.distributed.init_parallel_env()139
140set_seed(args.seed)141
142train_ds, dev_ds = load_dataset("cblue", args.dataset, splits=["train", "dev"])143
144model = ElectraForSequenceClassification.from_pretrained(145"ernie-health-chinese", num_labels=len(train_ds.label_list)146)147tokenizer = ElectraTokenizer.from_pretrained("ernie-health-chinese")148
149trans_func = partial(convert_example, tokenizer=tokenizer, max_seq_length=args.max_seq_length)150batchify_fn = lambda samples, fn=Tuple( # noqa: E731151Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input152Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"), # segment153Pad(axis=0, pad_val=args.max_seq_length - 1, dtype="int64"), # position154Stack(dtype="int64"),155): [data for data in fn(samples)]156train_data_loader = create_dataloader(157train_ds, mode="train", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func158)159dev_data_loader = create_dataloader(160dev_ds, mode="dev", batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func161)162
163if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):164state_dict = paddle.load(args.init_from_ckpt)165state_keys = {x: x.replace("discriminator.", "") for x in state_dict.keys() if "discriminator." in x}166if len(state_keys) > 0:167state_dict = {state_keys[k]: state_dict[k] for k in state_keys.keys()}168model.set_dict(state_dict)169if paddle.distributed.get_world_size() > 1:170model = paddle.DataParallel(model)171
172num_training_steps = args.max_steps if args.max_steps > 0 else len(train_data_loader) * args.epochs173args.epochs = (num_training_steps - 1) // len(train_data_loader) + 1174
175lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, args.warmup_proportion)176
177# Generate parameter names needed to perform weight decay.178# All bias and LayerNorm parameters are excluded.179decay_params = [p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]180
181optimizer = paddle.optimizer.AdamW(182learning_rate=lr_scheduler,183parameters=model.parameters(),184weight_decay=args.weight_decay,185apply_decay_param_fun=lambda x: x in decay_params,186)187
188criterion = paddle.nn.loss.CrossEntropyLoss()189if METRIC_CLASSES[args.dataset] is Accuracy:190metric = METRIC_CLASSES[args.dataset]()191metric_name = "accuracy"192elif METRIC_CLASSES[args.dataset] is MultiLabelsMetric:193metric = METRIC_CLASSES[args.dataset](num_labels=len(train_ds.label_list))194metric_name = "macro f1"195else:196metric = METRIC_CLASSES[args.dataset]()197metric_name = "micro f1"198if args.use_amp:199scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)200global_step = 0201tic_train = time.time()202total_train_time = 0203for epoch in range(1, args.epochs + 1):204for step, batch in enumerate(train_data_loader, start=1):205input_ids, token_type_ids, position_ids, labels = batch206with paddle.amp.auto_cast(207args.use_amp,208custom_white_list=["layer_norm", "softmax", "gelu", "tanh"],209):210logits = model(input_ids, token_type_ids, position_ids)211loss = criterion(logits, labels)212probs = F.softmax(logits, axis=1)213correct = metric.compute(probs, labels)214metric.update(correct)215
216if isinstance(metric, Accuracy):217result = metric.accumulate()218elif isinstance(metric, MultiLabelsMetric):219_, _, result = metric.accumulate("macro")220else:221_, _, _, result, _ = metric.accumulate()222
223if args.use_amp:224scaler.scale(loss).backward()225scaler.minimize(optimizer, loss)226else:227loss.backward()228optimizer.step()229lr_scheduler.step()230optimizer.clear_grad()231
232global_step += 1233if global_step % args.logging_steps == 0 and rank == 0:234time_diff = time.time() - tic_train235total_train_time += time_diff236print(237"global step %d, epoch: %d, batch: %d, loss: %.5f, %s: %.5f, speed: %.2f step/s"238% (global_step, epoch, step, loss, metric_name, result, args.logging_steps / time_diff)239)240
241if global_step % args.valid_steps == 0 and rank == 0:242evaluate(model, criterion, metric, dev_data_loader)243
244if global_step % args.save_steps == 0 and rank == 0:245save_dir = os.path.join(args.save_dir, "model_%d" % global_step)246if not os.path.exists(save_dir):247os.makedirs(save_dir)248if paddle.distributed.get_world_size() > 1:249model._layers.save_pretrained(save_dir)250else:251model.save_pretrained(save_dir)252tokenizer.save_pretrained(save_dir)253
254if global_step >= num_training_steps:255return256tic_train = time.time()257
258if rank == 0 and total_train_time > 0:259print("Speed: %.2f steps/s" % (global_step / total_train_time))260
261
262if __name__ == "__main__":263do_train()264