paddlenlp

Форк
0
303 строки · 11.8 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3

4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import numpy as np
17

18
__all__ = [
19
    "Stack",
20
    "Pad",
21
    "Tuple",
22
    "Dict",
23
]
24

25

26
class Stack(object):
27
    """
28
    Stacks the input data samples to construct the batch. The N input samples
29
    must have the same shape/length and will be stacked to construct a batch.
30
    Args:
31
        axis (int, optional): The axis in the result data along which the input
32
            data are stacked. Default: 0.
33
        dtype (str|numpy.dtype, optional): The value type of the output. If it
34
            is set to None, the type of input data is used. Default: None.
35
    """
36

37
    def __init__(self, axis=0, dtype=None):
38
        self._axis = axis
39
        self._dtype = dtype
40

41
    def __call__(self, data):
42
        """
43
        Batchifies the input data by stacking.
44
        Args:
45
            data (list[numpy.ndarray]): The input data samples. It is a list.
46
                Each element is a numpy.ndarray or list.
47
        Returns:
48
            numpy.ndarray: Stacked batch data.
49
        Example:
50
            .. code-block:: python
51
                from paddlenlp.data import Stack
52
                a = [1, 2, 3, 4]
53
                b = [3, 4, 5, 6]
54
                c = [5, 6, 7, 8]
55
                result = Stack()([a, b, c])
56
                '''
57
                [[1, 2, 3, 4],
58
                 [3, 4, 5, 6],
59
                 [5, 6, 7, 8]]
60
                '''
61
        """
62
        data = np.stack(data, axis=self._axis).astype(self._dtype) if self._dtype else np.stack(data, axis=self._axis)
63
        return data
64

65

66
class Pad(object):
67
    """
68
    Pads the input data samples to the largest length at `axis`.
69
    Args:
70
        pad_val (float|int, optional): The padding value. Default: 0.
71
        axis (int, optional): The axis to pad the arrays. The arrays will be
72
            padded to the largest length at `axis`. For example, assume the
73
            input arrays have shape (10, 8, 5), (6, 8, 5), (3, 8, 5) and the
74
            axis is 0. Each input will be padded into (10, 8, 5) and then
75
            stacked to form the final output, which has shape (3, 10, 8, 5).
76
            Default: 0.
77
        ret_length (bool|numpy.dtype, optional): If it is bool, indicate whether
78
            to return the valid length in the output, and the data type of
79
            returned length is int32 if True. If it is numpy.dtype, indicate the
80
            data type of returned length. Default: None.
81
        dtype (numpy.dtype, optional): The value type of the output. If it is
82
            set to None, the input data type is used. Default: None.
83
        pad_right (bool, optional): Whether the padding direction is right-side.
84
            If True, it indicates we pad to the right side, while False indicates
85
            we pad to the left side. Default: True.
86
    """
87

88
    def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None, pad_right=True):
89
        self._pad_val = pad_val
90
        self._axis = axis
91
        self._ret_length = ret_length
92
        self._dtype = dtype
93
        self._pad_right = pad_right
94

95
    def __call__(self, data):
96
        """
97
        Batchifies the input data by padding. The input will be padded to the
98
        largest dimension at `axis` and then stacked to form the final output.
99
        In addition, the function will output the original dimensions at the
100
        `axis` if `ret_length` is not None or False.
101
        Args:
102
            data (list[numpy.ndarray|list]): The input data samples. It is a
103
                list. Each element is a numpy.ndarray or list.
104
        Returns:
105
            numpy.ndarray|tuple[numpy.ndarray]: If `ret_length` is False, it
106
            is a numpy.ndarray representing the padded batch data and the
107
            shape is (N, …). Otherwise, it is a tuple, besides the padded batch
108
            data, the tuple also includes a numpy.ndarray representing original
109
            length at `axis` of all input samples, which shaped `(N,)`.
110
        Example:
111
            .. code-block:: python
112
                from paddlenlp.data import Pad
113
                a = [1, 2, 3, 4]
114
                b = [5, 6, 7]
115
                c = [8, 9]
116
                result = Pad(pad_val=0)([a, b, c])
117
                '''
118
                [[1, 2, 3, 4],
119
                 [5, 6, 7, 0],
120
                 [8, 9, 0, 0]]
121
                '''
122
        """
123

124
        # return data itself for rare unexpected cases when 1-D array is passed to Pad
125
        if not isinstance(data[0], list) and not isinstance(data[0], np.ndarray):
126
            return np.asarray(data, dtype=self._dtype if self._dtype is not None else np.int64)
127

128
        arrs = [np.asarray(ele) for ele in data]
129
        original_length = [ele.shape[self._axis] for ele in arrs]
130
        max_size = max(original_length)
131
        ret_shape = list(arrs[0].shape)
132
        ret_shape[self._axis] = max_size
133
        ret_shape = (len(arrs),) + tuple(ret_shape)
134
        ret = np.full(
135
            shape=ret_shape, fill_value=self._pad_val, dtype=arrs[0].dtype if self._dtype is None else self._dtype
136
        )
137
        for i, arr in enumerate(arrs):
138
            if arr.shape[self._axis] == max_size:
139
                ret[i] = arr
140
            else:
141
                slices = [slice(None) for _ in range(arr.ndim)]
142
                if self._pad_right:
143
                    slices[self._axis] = slice(0, arr.shape[self._axis])
144
                else:
145
                    slices[self._axis] = slice(max_size - arr.shape[self._axis], max_size)
146

147
                if slices[self._axis].start != slices[self._axis].stop:
148
                    slices = [slice(i, i + 1)] + slices
149
                    ret[tuple(slices)] = arr
150
        if self._ret_length:
151
            return ret, np.asarray(original_length, dtype="int32") if self._ret_length else np.asarray(
152
                original_length, self._ret_length
153
            )
154
        else:
155
            return ret
156

157

158
class Tuple(object):
159
    """
160
    Wraps multiple batchify functions together. The input functions will be applied
161
    to the corresponding input fields.
162

163
    Each sample should be a list or tuple containing multiple fields. The i'th
164
    batchify function stored in Tuple will be applied on the i'th field.
165

166
    For example, when data sample is (nd_data, label), you can wrap two batchify
167
    functions using `Tuple(DataBatchify, LabelBatchify)` to batchify nd_data and
168
    label correspondingly.
169
    Args:
170
        fn (callable|list[callable]|tuple[callable]): The batchify functions to
171
            wrap. It is a callable function or a list/tuple of callable functions.
172
        args (tuple[callable]): The additional batchify functions to wrap.
173
    """
174

175
    def __init__(self, fn, *args):
176
        if isinstance(fn, (list, tuple)):
177
            assert len(args) == 0, (
178
                "Input pattern not understood. The input of Tuple can be "
179
                "Tuple(A, B, C) or Tuple([A, B, C]) or Tuple((A, B, C)). "
180
                "Received fn=%s, args=%s" % (str(fn), str(args))
181
            )
182
            self._fn = fn
183
        else:
184
            self._fn = (fn,) + args
185
        for i, ele_fn in enumerate(self._fn):
186
            assert callable(ele_fn), "Batchify functions must be callable! type(fn[%d]) = %s" % (i, str(type(ele_fn)))
187

188
    def __call__(self, data):
189
        """
190
        Batchifies data samples by applying each function on the corresponding
191
        data field, and each data field is produced by stacking the field data
192
        of samples.
193
        Args:
194
            data (list|tuple): The samples to batchfy. Each sample in list/tuple
195
                should contain `N` fields.
196
        Returns:
197
            tuple: A tuple composed of results from all including batchifying
198
            functions.
199
        Example:
200
            .. code-block:: python
201

202
                from paddlenlp.data import Stack, Pad, Tuple
203
                data = [
204
                        [[1, 2, 3, 4], [1]],
205
                        [[5, 6, 7], [0]],
206
                        [[8, 9], [1]],
207
                       ]
208
                batchify_fn = Tuple(Pad(pad_val=0), Stack())
209
                ids, label = batchify_fn(data)
210
                '''
211
                ids:
212
                [[1, 2, 3, 4],
213
                [5, 6, 7, 0],
214
                [8, 9, 0, 0]]
215
                label: [[1], [0], [1]]
216
                '''
217
        """
218

219
        assert len(data[0]) == len(
220
            self._fn
221
        ), "The number of attributes in each data sample should contain" " {} elements".format(len(self._fn))
222
        ret = []
223
        for i, ele_fn in enumerate(self._fn):
224
            result = ele_fn([ele[i] for ele in data])
225
            if isinstance(result, (tuple, list)):
226
                ret.extend(result)
227
            else:
228
                ret.append(result)
229
        return tuple(ret)
230

231

232
class Dict(object):
233
    """
234
    Wraps multiple batchify functions together. The input functions will be
235
    applied to the corresponding input fields.
236

237
    Each sample should be a dict containing multiple fields. Each batchify
238
    function with key stored in `Dict` will be applied on the field which has
239
    the same key.
240

241
    For example, when data sample is {'tokens': tokens, 'labels': labels}, you
242
    can wrap two batchify functions using
243
    `Dict({'tokens': DataBatchify, 'labels': LabelBatchify})` to batchify tokens
244
    and labels correspondingly.
245
    Args:
246
        fn (dict): The batchify functions to wrap. It is a dict, which values is
247
            callable functions.
248
    """
249

250
    def __init__(self, fn):
251
        assert isinstance(fn, (dict)), (
252
            "Input pattern not understood. The input of Dict must be a dict with key of input column name and value of collate_fn "
253
            "Received fn=%s" % (str(fn))
254
        )
255

256
        self._fn = fn
257

258
        for col_name, ele_fn in self._fn.items():
259
            assert callable(ele_fn), "Batchify functions must be callable! type(fn[%d]) = %s" % (
260
                col_name,
261
                str(type(ele_fn)),
262
            )
263

264
    def __call__(self, data):
265
        """
266
        Batchifies data samples by applying each function on the corresponding
267
        data field, and each data field is produced by stacking the field data
268
        with the same key as batchify functions of all samples.
269
        Args:
270
            data (list[dict]|tuple[dict]): The samples to batchfy. Each sample
271
                in list/tuple is a dict with `N` key-values.
272

273
        Returns:
274
            tuple: A tuple composed of results from all including batchifying
275
            functions.
276

277
        Example:
278
            .. code-block:: python
279
                from paddlenlp.data import Stack, Pad, Dict
280
                data = [
281
                        {'labels':[1], 'token_ids':[1, 2, 3, 4]},
282
                        {'labels':[0], 'token_ids':[5, 6, 7]},
283
                        {'labels':[1], 'token_ids':[8, 9]},
284
                       ]
285
                batchify_fn = Dict({'token_ids':Pad(pad_val=0), 'labels':Stack()})
286
                ids, label = batchify_fn(data)
287
                '''
288
                ids:
289
                [[1, 2, 3, 4],
290
                [5, 6, 7, 0],
291
                [8, 9, 0, 0]]
292
                label: [[1], [0], [1]]
293
                '''
294
        """
295

296
        ret = []
297
        for col_name, ele_fn in self._fn.items():
298
            result = ele_fn([ele[col_name] for ele in data])
299
            if isinstance(result, (tuple, list)):
300
                ret.extend(result)
301
            else:
302
                ret.append(result)
303
        return tuple(ret)
304

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.