paddlenlp
130 строк · 4.4 Кб
1# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import paddle16import numpy as np17
18from paddlenlp.datasets import MapDataset19
20
21def create_dataloader(dataset, mode="train", batch_size=1, batchify_fn=None, trans_fn=None):22if trans_fn:23dataset = dataset.map(trans_fn)24
25shuffle = True if mode == "train" else False26if mode == "train":27batch_sampler = paddle.io.DistributedBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle)28else:29batch_sampler = paddle.io.BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle)30
31return paddle.io.DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=batchify_fn, return_list=True)32
33
34def read_text_pair(data_path):35"""Reads data."""36with open(data_path, "r", encoding="utf-8") as f:37for line in f:38data = line.rstrip().split("\t")39if len(data) != 2:40continue41yield {"query": data[0], "title": data[1]}42
43
44def convert_pointwise_example(example, tokenizer, max_seq_length=512, is_test=False):45
46query, title = example["query"], example["title"]47
48encoded_inputs = tokenizer(text=query, text_pair=title, max_seq_len=max_seq_length)49
50input_ids = encoded_inputs["input_ids"]51token_type_ids = encoded_inputs["token_type_ids"]52
53if not is_test:54label = np.array([example["label"]], dtype="int64")55return input_ids, token_type_ids, label56else:57return input_ids, token_type_ids58
59
60def convert_pairwise_example(example, tokenizer, max_seq_length=512, phase="train"):61
62if phase == "train":63query, pos_title, neg_title = example["query"], example["title"], example["neg_title"]64
65pos_inputs = tokenizer(text=query, text_pair=pos_title, max_seq_len=max_seq_length)66neg_inputs = tokenizer(text=query, text_pair=neg_title, max_seq_len=max_seq_length)67
68pos_input_ids = pos_inputs["input_ids"]69pos_token_type_ids = pos_inputs["token_type_ids"]70neg_input_ids = neg_inputs["input_ids"]71neg_token_type_ids = neg_inputs["token_type_ids"]72
73return (pos_input_ids, pos_token_type_ids, neg_input_ids, neg_token_type_ids)74
75else:76query, title = example["query"], example["title"]77
78inputs = tokenizer(text=query, text_pair=title, max_seq_len=max_seq_length)79
80input_ids = inputs["input_ids"]81token_type_ids = inputs["token_type_ids"]82if phase == "eval":83return input_ids, token_type_ids, example["label"]84elif phase == "predict":85return input_ids, token_type_ids86else:87raise ValueError("not supported phase:{}".format(phase))88
89
90def gen_pair(dataset, pool_size=100):91"""92Generate triplet randomly based on dataset
93
94Args:
95dataset: A `MapDataset` or `IterDataset` or a tuple of those.
96Each example is composed of 2 texts: example["query"], example["title"]
97pool_size: the number of example to sample negative example randomly
98
99Return:
100dataset: A `MapDataset` or `IterDataset` or a tuple of those.
101Each example is composed of 2 texts: example["query"], example["pos_title"]、example["neg_title"]
102"""
103
104if len(dataset) < pool_size:105pool_size = len(dataset)106
107new_examples = []108pool = []109tmp_examples = []110
111for example in dataset:112label = example["label"]113
114# Filter negative example115if label == 0:116continue117
118tmp_examples.append(example)119pool.append(example["title"])120
121if len(pool) >= pool_size:122np.random.shuffle(pool)123for idx, example in enumerate(tmp_examples):124example["neg_title"] = pool[idx]125new_examples.append(example)126tmp_examples = []127pool = []128else:129continue130return MapDataset(new_examples)131