1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
8
# http://www.apache.org/licenses/LICENSE-2.0
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
28
Stacks the input data samples to construct the batch. The N input samples
29
must have the same shape/length and will be stacked to construct a batch.
31
axis (int, optional): The axis in the result data along which the input
32
data are stacked. Default: 0.
33
dtype (str|numpy.dtype, optional): The value type of the output. If it
34
is set to None, the type of input data is used. Default: None.
37
def __init__(self, axis=0, dtype=None):
41
def __call__(self, data):
43
Batchifies the input data by stacking.
45
data (list[numpy.ndarray]): The input data samples. It is a list.
46
Each element is a numpy.ndarray or list.
48
numpy.ndarray: Stacked batch data.
50
.. code-block:: python
51
from paddlenlp.data import Stack
55
result = Stack()([a, b, c])
62
data = np.stack(data, axis=self._axis).astype(self._dtype) if self._dtype else np.stack(data, axis=self._axis)
68
Pads the input data samples to the largest length at `axis`.
70
pad_val (float|int, optional): The padding value. Default: 0.
71
axis (int, optional): The axis to pad the arrays. The arrays will be
72
padded to the largest length at `axis`. For example, assume the
73
input arrays have shape (10, 8, 5), (6, 8, 5), (3, 8, 5) and the
74
axis is 0. Each input will be padded into (10, 8, 5) and then
75
stacked to form the final output, which has shape (3, 10, 8, 5).
77
ret_length (bool|numpy.dtype, optional): If it is bool, indicate whether
78
to return the valid length in the output, and the data type of
79
returned length is int32 if True. If it is numpy.dtype, indicate the
80
data type of returned length. Default: None.
81
dtype (numpy.dtype, optional): The value type of the output. If it is
82
set to None, the input data type is used. Default: None.
83
pad_right (bool, optional): Whether the padding direction is right-side.
84
If True, it indicates we pad to the right side, while False indicates
85
we pad to the left side. Default: True.
88
def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None, pad_right=True):
89
self._pad_val = pad_val
91
self._ret_length = ret_length
93
self._pad_right = pad_right
95
def __call__(self, data):
97
Batchifies the input data by padding. The input will be padded to the
98
largest dimension at `axis` and then stacked to form the final output.
99
In addition, the function will output the original dimensions at the
100
`axis` if `ret_length` is not None or False.
102
data (list[numpy.ndarray|list]): The input data samples. It is a
103
list. Each element is a numpy.ndarray or list.
105
numpy.ndarray|tuple[numpy.ndarray]: If `ret_length` is False, it
106
is a numpy.ndarray representing the padded batch data and the
107
shape is (N, …). Otherwise, it is a tuple, besides the padded batch
108
data, the tuple also includes a numpy.ndarray representing original
109
length at `axis` of all input samples, which shaped `(N,)`.
111
.. code-block:: python
112
from paddlenlp.data import Pad
116
result = Pad(pad_val=0)([a, b, c])
124
# return data itself for rare unexpected cases when 1-D array is passed to Pad
125
if not isinstance(data[0], list) and not isinstance(data[0], np.ndarray):
126
return np.asarray(data, dtype=self._dtype if self._dtype is not None else np.int64)
128
arrs = [np.asarray(ele) for ele in data]
129
original_length = [ele.shape[self._axis] for ele in arrs]
130
max_size = max(original_length)
131
ret_shape = list(arrs[0].shape)
132
ret_shape[self._axis] = max_size
133
ret_shape = (len(arrs),) + tuple(ret_shape)
135
shape=ret_shape, fill_value=self._pad_val, dtype=arrs[0].dtype if self._dtype is None else self._dtype
137
for i, arr in enumerate(arrs):
138
if arr.shape[self._axis] == max_size:
141
slices = [slice(None) for _ in range(arr.ndim)]
143
slices[self._axis] = slice(0, arr.shape[self._axis])
145
slices[self._axis] = slice(max_size - arr.shape[self._axis], max_size)
147
if slices[self._axis].start != slices[self._axis].stop:
148
slices = [slice(i, i + 1)] + slices
149
ret[tuple(slices)] = arr
151
return ret, np.asarray(original_length, dtype="int32") if self._ret_length else np.asarray(
152
original_length, self._ret_length
160
Wraps multiple batchify functions together. The input functions will be applied
161
to the corresponding input fields.
163
Each sample should be a list or tuple containing multiple fields. The i'th
164
batchify function stored in Tuple will be applied on the i'th field.
166
For example, when data sample is (nd_data, label), you can wrap two batchify
167
functions using `Tuple(DataBatchify, LabelBatchify)` to batchify nd_data and
168
label correspondingly.
170
fn (callable|list[callable]|tuple[callable]): The batchify functions to
171
wrap. It is a callable function or a list/tuple of callable functions.
172
args (tuple[callable]): The additional batchify functions to wrap.
175
def __init__(self, fn, *args):
176
if isinstance(fn, (list, tuple)):
177
assert len(args) == 0, (
178
"Input pattern not understood. The input of Tuple can be "
179
"Tuple(A, B, C) or Tuple([A, B, C]) or Tuple((A, B, C)). "
180
"Received fn=%s, args=%s" % (str(fn), str(args))
184
self._fn = (fn,) + args
185
for i, ele_fn in enumerate(self._fn):
186
assert callable(ele_fn), "Batchify functions must be callable! type(fn[%d]) = %s" % (i, str(type(ele_fn)))
188
def __call__(self, data):
190
Batchifies data samples by applying each function on the corresponding
191
data field, and each data field is produced by stacking the field data
194
data (list|tuple): The samples to batchfy. Each sample in list/tuple
195
should contain `N` fields.
197
tuple: A tuple composed of results from all including batchifying
200
.. code-block:: python
202
from paddlenlp.data import Stack, Pad, Tuple
208
batchify_fn = Tuple(Pad(pad_val=0), Stack())
209
ids, label = batchify_fn(data)
215
label: [[1], [0], [1]]
219
assert len(data[0]) == len(
221
), "The number of attributes in each data sample should contain" " {} elements".format(len(self._fn))
223
for i, ele_fn in enumerate(self._fn):
224
result = ele_fn([ele[i] for ele in data])
225
if isinstance(result, (tuple, list)):
234
Wraps multiple batchify functions together. The input functions will be
235
applied to the corresponding input fields.
237
Each sample should be a dict containing multiple fields. Each batchify
238
function with key stored in `Dict` will be applied on the field which has
241
For example, when data sample is {'tokens': tokens, 'labels': labels}, you
242
can wrap two batchify functions using
243
`Dict({'tokens': DataBatchify, 'labels': LabelBatchify})` to batchify tokens
244
and labels correspondingly.
246
fn (dict): The batchify functions to wrap. It is a dict, which values is
250
def __init__(self, fn):
251
assert isinstance(fn, (dict)), (
252
"Input pattern not understood. The input of Dict must be a dict with key of input column name and value of collate_fn "
253
"Received fn=%s" % (str(fn))
258
for col_name, ele_fn in self._fn.items():
259
assert callable(ele_fn), "Batchify functions must be callable! type(fn[%d]) = %s" % (
264
def __call__(self, data):
266
Batchifies data samples by applying each function on the corresponding
267
data field, and each data field is produced by stacking the field data
268
with the same key as batchify functions of all samples.
270
data (list[dict]|tuple[dict]): The samples to batchfy. Each sample
271
in list/tuple is a dict with `N` key-values.
274
tuple: A tuple composed of results from all including batchifying
278
.. code-block:: python
279
from paddlenlp.data import Stack, Pad, Dict
281
{'labels':[1], 'token_ids':[1, 2, 3, 4]},
282
{'labels':[0], 'token_ids':[5, 6, 7]},
283
{'labels':[1], 'token_ids':[8, 9]},
285
batchify_fn = Dict({'token_ids':Pad(pad_val=0), 'labels':Stack()})
286
ids, label = batchify_fn(data)
292
label: [[1], [0], [1]]
297
for col_name, ele_fn in self._fn.items():
298
result = ele_fn([ele[col_name] for ele in data])
299
if isinstance(result, (tuple, list)):