paddlenlp

Форк
0
130 строк · 4.4 Кб
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import paddle
16
import numpy as np
17

18
from paddlenlp.datasets import MapDataset
19

20

21
def create_dataloader(dataset, mode="train", batch_size=1, batchify_fn=None, trans_fn=None):
22
    if trans_fn:
23
        dataset = dataset.map(trans_fn)
24

25
    shuffle = True if mode == "train" else False
26
    if mode == "train":
27
        batch_sampler = paddle.io.DistributedBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle)
28
    else:
29
        batch_sampler = paddle.io.BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle)
30

31
    return paddle.io.DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=batchify_fn, return_list=True)
32

33

34
def read_text_pair(data_path):
35
    """Reads data."""
36
    with open(data_path, "r", encoding="utf-8") as f:
37
        for line in f:
38
            data = line.rstrip().split("\t")
39
            if len(data) != 2:
40
                continue
41
            yield {"query": data[0], "title": data[1]}
42

43

44
def convert_pointwise_example(example, tokenizer, max_seq_length=512, is_test=False):
45

46
    query, title = example["query"], example["title"]
47

48
    encoded_inputs = tokenizer(text=query, text_pair=title, max_seq_len=max_seq_length)
49

50
    input_ids = encoded_inputs["input_ids"]
51
    token_type_ids = encoded_inputs["token_type_ids"]
52

53
    if not is_test:
54
        label = np.array([example["label"]], dtype="int64")
55
        return input_ids, token_type_ids, label
56
    else:
57
        return input_ids, token_type_ids
58

59

60
def convert_pairwise_example(example, tokenizer, max_seq_length=512, phase="train"):
61

62
    if phase == "train":
63
        query, pos_title, neg_title = example["query"], example["title"], example["neg_title"]
64

65
        pos_inputs = tokenizer(text=query, text_pair=pos_title, max_seq_len=max_seq_length)
66
        neg_inputs = tokenizer(text=query, text_pair=neg_title, max_seq_len=max_seq_length)
67

68
        pos_input_ids = pos_inputs["input_ids"]
69
        pos_token_type_ids = pos_inputs["token_type_ids"]
70
        neg_input_ids = neg_inputs["input_ids"]
71
        neg_token_type_ids = neg_inputs["token_type_ids"]
72

73
        return (pos_input_ids, pos_token_type_ids, neg_input_ids, neg_token_type_ids)
74

75
    else:
76
        query, title = example["query"], example["title"]
77

78
        inputs = tokenizer(text=query, text_pair=title, max_seq_len=max_seq_length)
79

80
        input_ids = inputs["input_ids"]
81
        token_type_ids = inputs["token_type_ids"]
82
        if phase == "eval":
83
            return input_ids, token_type_ids, example["label"]
84
        elif phase == "predict":
85
            return input_ids, token_type_ids
86
        else:
87
            raise ValueError("not supported phase:{}".format(phase))
88

89

90
def gen_pair(dataset, pool_size=100):
91
    """
92
    Generate triplet randomly based on dataset
93

94
    Args:
95
        dataset: A `MapDataset` or `IterDataset` or a tuple of those.
96
            Each example is composed of 2 texts: example["query"], example["title"]
97
        pool_size: the number of example to sample negative example randomly
98

99
    Return:
100
        dataset: A `MapDataset` or `IterDataset` or a tuple of those.
101
        Each example is composed of 2 texts: example["query"], example["pos_title"]、example["neg_title"]
102
    """
103

104
    if len(dataset) < pool_size:
105
        pool_size = len(dataset)
106

107
    new_examples = []
108
    pool = []
109
    tmp_examples = []
110

111
    for example in dataset:
112
        label = example["label"]
113

114
        # Filter negative example
115
        if label == 0:
116
            continue
117

118
        tmp_examples.append(example)
119
        pool.append(example["title"])
120

121
        if len(pool) >= pool_size:
122
            np.random.shuffle(pool)
123
            for idx, example in enumerate(tmp_examples):
124
                example["neg_title"] = pool[idx]
125
                new_examples.append(example)
126
            tmp_examples = []
127
            pool = []
128
        else:
129
            continue
130
    return MapDataset(new_examples)
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.