10
from functools import partial
11
from pathlib import Path
12
from unittest import TestCase
13
from unittest.mock import MagicMock, patch
16
import numpy.testing as npt
20
from absl.testing import parameterized
21
from fsspec.core import strip_protocol
22
from packaging import version
24
import datasets.arrow_dataset
25
from datasets import concatenate_datasets, interleave_datasets, load_from_disk
26
from datasets.arrow_dataset import Dataset, transmit_format, update_metadata_with_features
27
from datasets.dataset_dict import DatasetDict
28
from datasets.features import (
37
TranslationVariableLanguages,
40
from datasets.info import DatasetInfo
41
from datasets.iterable_dataset import IterableDataset
42
from datasets.splits import NamedSplit
43
from datasets.table import ConcatenationTable, InMemoryTable, MemoryMappedTable
44
from datasets.tasks import (
45
AutomaticSpeechRecognition,
47
QuestionAnsweringExtractive,
51
from datasets.utils.logging import INFO, get_logger
52
from datasets.utils.py_utils import temp_seed
55
assert_arrow_memory_doesnt_increase,
56
assert_arrow_memory_increases,
57
require_dill_gt_0_3_2,
66
set_current_working_directory_to_temp_dir,
70
class PickableMagicMock(MagicMock):
76
def __getstate__(self):
77
raise pickle.PicklingError()
80
def picklable_map_function(x):
81
return {"id": int(x["filename"].split("_")[-1])}
84
def picklable_map_function_with_indices(x, i):
88
def picklable_map_function_with_rank(x, r):
92
def picklable_map_function_with_indices_and_rank(x, i, r):
93
return {"id": i, "rank": r}
96
def picklable_filter_function(x):
97
return int(x["filename"].split("_")[-1]) < 10
100
def picklable_filter_function_with_rank(x, r):
104
def assert_arrow_metadata_are_synced_with_dataset_features(dataset: Dataset):
105
assert dataset.data.schema.metadata is not None
106
assert b"huggingface" in dataset.data.schema.metadata
107
metadata = json.loads(dataset.data.schema.metadata[b"huggingface"].decode())
108
assert "info" in metadata
109
features = DatasetInfo.from_dict(metadata["info"]).features
110
assert features is not None
111
assert features == dataset.features
112
assert features == Features.from_arrow_schema(dataset.data.schema)
113
assert list(features) == dataset.data.column_names
114
assert list(features) == list(dataset.features)
117
IN_MEMORY_PARAMETERS = [
118
{"testcase_name": name, "in_memory": im} for im, name in [(True, "in_memory"), (False, "on_disk")]
122
@parameterized.named_parameters(IN_MEMORY_PARAMETERS)
123
class BaseDatasetTest(TestCase):
124
@pytest.fixture(autouse=True)
125
def inject_fixtures(self, caplog, set_sqlalchemy_silence_uber_warning):
126
self._caplog = caplog
128
def _create_dummy_dataset(
129
self, in_memory: bool, tmp_dir: str, multiple_columns=False, array_features=False, nested_features=False
131
assert int(multiple_columns) + int(array_features) + int(nested_features) < 2
133
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"], "col_3": [False, True, False, True]}
134
dset = Dataset.from_dict(data)
137
"col_1": [[[True, False], [False, True]]] * 4,
138
"col_2": [[[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]] * 4,
139
"col_3": [[3, 2, 1, 0]] * 4,
143
"col_1": Array2D(shape=(2, 2), dtype="bool"),
144
"col_2": Array3D(shape=(2, 2, 2), dtype="string"),
145
"col_3": Sequence(feature=Value("int64")),
148
dset = Dataset.from_dict(data, features=features)
149
elif nested_features:
150
data = {"nested": [{"a": i, "x": i * 10, "c": i * 100} for i in range(1, 11)]}
151
features = Features({"nested": {"a": Value("int64"), "x": Value("int64"), "c": Value("int64")}})
152
dset = Dataset.from_dict(data, features=features)
154
dset = Dataset.from_dict({"filename": ["my_name-train" + "_" + str(x) for x in np.arange(30).tolist()]})
156
dset = self._to(in_memory, tmp_dir, dset)
159
def _to(self, in_memory, tmp_dir, *datasets):
161
datasets = [dataset.map(keep_in_memory=True) for dataset in datasets]
164
while os.path.isfile(os.path.join(tmp_dir, f"dataset{start}.arrow")):
167
dataset.map(cache_file_name=os.path.join(tmp_dir, f"dataset{start + i}.arrow"))
168
for i, dataset in enumerate(datasets)
170
return datasets if len(datasets) > 1 else datasets[0]
172
def test_dummy_dataset(self, in_memory):
173
with tempfile.TemporaryDirectory() as tmp_dir:
174
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
175
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
176
self.assertEqual(dset[0]["filename"], "my_name-train_0")
177
self.assertEqual(dset["filename"][0], "my_name-train_0")
179
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
180
self.assertDictEqual(
182
Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")}),
184
self.assertEqual(dset[0]["col_1"], 3)
185
self.assertEqual(dset["col_1"][0], 3)
187
with tempfile.TemporaryDirectory() as tmp_dir:
188
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
189
self.assertDictEqual(
193
"col_1": Array2D(shape=(2, 2), dtype="bool"),
194
"col_2": Array3D(shape=(2, 2, 2), dtype="string"),
195
"col_3": Sequence(feature=Value("int64")),
199
self.assertEqual(dset[0]["col_2"], [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]])
200
self.assertEqual(dset["col_2"][0], [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]])
202
def test_dataset_getitem(self, in_memory):
203
with tempfile.TemporaryDirectory() as tmp_dir:
204
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
205
self.assertEqual(dset[0]["filename"], "my_name-train_0")
206
self.assertEqual(dset["filename"][0], "my_name-train_0")
208
self.assertEqual(dset[-1]["filename"], "my_name-train_29")
209
self.assertEqual(dset["filename"][-1], "my_name-train_29")
211
self.assertListEqual(dset[:2]["filename"], ["my_name-train_0", "my_name-train_1"])
212
self.assertListEqual(dset["filename"][:2], ["my_name-train_0", "my_name-train_1"])
214
self.assertEqual(dset[:-1]["filename"][-1], "my_name-train_28")
215
self.assertEqual(dset["filename"][:-1][-1], "my_name-train_28")
217
self.assertListEqual(dset[[0, -1]]["filename"], ["my_name-train_0", "my_name-train_29"])
218
self.assertListEqual(dset[range(0, -2, -1)]["filename"], ["my_name-train_0", "my_name-train_29"])
219
self.assertListEqual(dset[np.array([0, -1])]["filename"], ["my_name-train_0", "my_name-train_29"])
220
self.assertListEqual(dset[pd.Series([0, -1])]["filename"], ["my_name-train_0", "my_name-train_29"])
222
with dset.select(range(2)) as dset_subset:
223
self.assertListEqual(dset_subset[-1:]["filename"], ["my_name-train_1"])
224
self.assertListEqual(dset_subset["filename"][-1:], ["my_name-train_1"])
226
def test_dummy_dataset_deepcopy(self, in_memory):
227
with tempfile.TemporaryDirectory() as tmp_dir:
228
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
229
with assert_arrow_memory_doesnt_increase():
230
dset2 = copy.deepcopy(dset)
232
self.assertEqual(len(dset2), 10)
233
self.assertDictEqual(dset2.features, Features({"filename": Value("string")}))
234
self.assertEqual(dset2[0]["filename"], "my_name-train_0")
235
self.assertEqual(dset2["filename"][0], "my_name-train_0")
238
def test_dummy_dataset_pickle(self, in_memory):
239
with tempfile.TemporaryDirectory() as tmp_dir:
240
tmp_file = os.path.join(tmp_dir, "dset.pt")
242
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(0, 10, 2)) as dset:
243
with open(tmp_file, "wb") as f:
246
with open(tmp_file, "rb") as f:
247
with pickle.load(f) as dset:
248
self.assertEqual(len(dset), 5)
249
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
250
self.assertEqual(dset[0]["filename"], "my_name-train_0")
251
self.assertEqual(dset["filename"][0], "my_name-train_0")
253
with self._create_dummy_dataset(in_memory, tmp_dir).select(
254
range(0, 10, 2), indices_cache_file_name=os.path.join(tmp_dir, "ind.arrow")
257
dset._data.table = Unpicklable()
258
dset._indices.table = Unpicklable()
259
with open(tmp_file, "wb") as f:
262
with open(tmp_file, "rb") as f:
263
with pickle.load(f) as dset:
264
self.assertEqual(len(dset), 5)
265
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
266
self.assertEqual(dset[0]["filename"], "my_name-train_0")
267
self.assertEqual(dset["filename"][0], "my_name-train_0")
269
def test_dummy_dataset_serialize(self, in_memory):
270
with tempfile.TemporaryDirectory() as tmp_dir:
271
with set_current_working_directory_to_temp_dir():
272
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
273
dataset_path = "my_dataset"
274
dset.save_to_disk(dataset_path)
276
with Dataset.load_from_disk(dataset_path) as dset:
277
self.assertEqual(len(dset), 10)
278
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
279
self.assertEqual(dset[0]["filename"], "my_name-train_0")
280
self.assertEqual(dset["filename"][0], "my_name-train_0")
281
expected = dset.to_dict()
283
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
284
dataset_path = os.path.join(tmp_dir, "my_dataset")
285
dset.save_to_disk(dataset_path)
287
with Dataset.load_from_disk(dataset_path) as dset:
288
self.assertEqual(len(dset), 10)
289
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
290
self.assertEqual(dset[0]["filename"], "my_name-train_0")
291
self.assertEqual(dset["filename"][0], "my_name-train_0")
293
with self._create_dummy_dataset(in_memory, tmp_dir).select(
294
range(10), indices_cache_file_name=os.path.join(tmp_dir, "ind.arrow")
296
with assert_arrow_memory_doesnt_increase():
297
dset.save_to_disk(dataset_path)
299
with Dataset.load_from_disk(dataset_path) as dset:
300
self.assertEqual(len(dset), 10)
301
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
302
self.assertEqual(dset[0]["filename"], "my_name-train_0")
303
self.assertEqual(dset["filename"][0], "my_name-train_0")
305
with self._create_dummy_dataset(in_memory, tmp_dir, nested_features=True) as dset:
306
with assert_arrow_memory_doesnt_increase():
307
dset.save_to_disk(dataset_path)
309
with Dataset.load_from_disk(dataset_path) as dset:
310
self.assertEqual(len(dset), 10)
311
self.assertDictEqual(
313
Features({"nested": {"a": Value("int64"), "x": Value("int64"), "c": Value("int64")}}),
315
self.assertDictEqual(dset[0]["nested"], {"a": 1, "c": 100, "x": 10})
316
self.assertDictEqual(dset["nested"][0], {"a": 1, "c": 100, "x": 10})
318
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
319
with assert_arrow_memory_doesnt_increase():
320
dset.save_to_disk(dataset_path, num_shards=4)
322
with Dataset.load_from_disk(dataset_path) as dset:
323
self.assertEqual(len(dset), 10)
324
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
325
self.assertDictEqual(dset.to_dict(), expected)
326
self.assertEqual(len(dset.cache_files), 4)
328
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
329
with assert_arrow_memory_doesnt_increase():
330
dset.save_to_disk(dataset_path, num_proc=2)
332
with Dataset.load_from_disk(dataset_path) as dset:
333
self.assertEqual(len(dset), 10)
334
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
335
self.assertDictEqual(dset.to_dict(), expected)
336
self.assertEqual(len(dset.cache_files), 2)
338
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
339
with assert_arrow_memory_doesnt_increase():
340
dset.save_to_disk(dataset_path, num_shards=7, num_proc=2)
342
with Dataset.load_from_disk(dataset_path) as dset:
343
self.assertEqual(len(dset), 10)
344
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
345
self.assertDictEqual(dset.to_dict(), expected)
346
self.assertEqual(len(dset.cache_files), 7)
348
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
349
with assert_arrow_memory_doesnt_increase():
350
max_shard_size = dset._estimate_nbytes() // 2 + 1
351
dset.save_to_disk(dataset_path, max_shard_size=max_shard_size)
353
with Dataset.load_from_disk(dataset_path) as dset:
354
self.assertEqual(len(dset), 10)
355
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
356
self.assertDictEqual(dset.to_dict(), expected)
357
self.assertEqual(len(dset.cache_files), 2)
359
def test_dummy_dataset_load_from_disk(self, in_memory):
360
with tempfile.TemporaryDirectory() as tmp_dir:
361
with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
362
dataset_path = os.path.join(tmp_dir, "my_dataset")
363
dset.save_to_disk(dataset_path)
365
with load_from_disk(dataset_path) as dset:
366
self.assertEqual(len(dset), 10)
367
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
368
self.assertEqual(dset[0]["filename"], "my_name-train_0")
369
self.assertEqual(dset["filename"][0], "my_name-train_0")
371
def test_restore_saved_format(self, in_memory):
372
with tempfile.TemporaryDirectory() as tmp_dir:
373
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
374
dset.set_format(type="numpy", columns=["col_1"], output_all_columns=True)
375
dataset_path = os.path.join(tmp_dir, "my_dataset")
376
dset.save_to_disk(dataset_path)
378
with load_from_disk(dataset_path) as loaded_dset:
379
self.assertEqual(dset.format, loaded_dset.format)
381
def test_set_format_numpy_multiple_columns(self, in_memory):
382
with tempfile.TemporaryDirectory() as tmp_dir:
383
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
384
fingerprint = dset._fingerprint
385
dset.set_format(type="numpy", columns=["col_1"])
386
self.assertEqual(len(dset[0]), 1)
387
self.assertIsInstance(dset[0]["col_1"], np.int64)
388
self.assertEqual(dset[0]["col_1"].item(), 3)
389
self.assertIsInstance(dset["col_1"], np.ndarray)
390
self.assertListEqual(list(dset["col_1"].shape), [4])
391
np.testing.assert_array_equal(dset["col_1"], np.array([3, 2, 1, 0]))
392
self.assertNotEqual(dset._fingerprint, fingerprint)
395
with dset.formatted_as(type="numpy", columns=["col_1"]):
396
self.assertEqual(len(dset[0]), 1)
397
self.assertIsInstance(dset[0]["col_1"], np.int64)
398
self.assertEqual(dset[0]["col_1"].item(), 3)
399
self.assertIsInstance(dset["col_1"], np.ndarray)
400
self.assertListEqual(list(dset["col_1"].shape), [4])
401
np.testing.assert_array_equal(dset["col_1"], np.array([3, 2, 1, 0]))
403
self.assertEqual(dset.format["type"], None)
404
self.assertEqual(dset.format["format_kwargs"], {})
405
self.assertEqual(dset.format["columns"], dset.column_names)
406
self.assertEqual(dset.format["output_all_columns"], False)
408
dset.set_format(type="numpy", columns=["col_1"], output_all_columns=True)
409
self.assertEqual(len(dset[0]), 3)
410
self.assertIsInstance(dset[0]["col_2"], str)
411
self.assertEqual(dset[0]["col_2"], "a")
413
dset.set_format(type="numpy", columns=["col_1", "col_2"])
414
self.assertEqual(len(dset[0]), 2)
415
self.assertIsInstance(dset[0]["col_2"], np.str_)
416
self.assertEqual(dset[0]["col_2"].item(), "a")
419
def test_set_format_torch(self, in_memory):
422
with tempfile.TemporaryDirectory() as tmp_dir:
423
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
424
dset.set_format(type="torch", columns=["col_1"])
425
self.assertEqual(len(dset[0]), 1)
426
self.assertIsInstance(dset[0]["col_1"], torch.Tensor)
427
self.assertIsInstance(dset["col_1"], torch.Tensor)
428
self.assertListEqual(list(dset[0]["col_1"].shape), [])
429
self.assertEqual(dset[0]["col_1"].item(), 3)
431
dset.set_format(type="torch", columns=["col_1"], output_all_columns=True)
432
self.assertEqual(len(dset[0]), 3)
433
self.assertIsInstance(dset[0]["col_2"], str)
434
self.assertEqual(dset[0]["col_2"], "a")
436
dset.set_format(type="torch")
437
self.assertEqual(len(dset[0]), 3)
438
self.assertIsInstance(dset[0]["col_1"], torch.Tensor)
439
self.assertIsInstance(dset["col_1"], torch.Tensor)
440
self.assertListEqual(list(dset[0]["col_1"].shape), [])
441
self.assertEqual(dset[0]["col_1"].item(), 3)
442
self.assertIsInstance(dset[0]["col_2"], str)
443
self.assertEqual(dset[0]["col_2"], "a")
444
self.assertIsInstance(dset[0]["col_3"], torch.Tensor)
445
self.assertIsInstance(dset["col_3"], torch.Tensor)
446
self.assertListEqual(list(dset[0]["col_3"].shape), [])
449
def test_set_format_tf(self, in_memory):
450
import tensorflow as tf
452
with tempfile.TemporaryDirectory() as tmp_dir:
453
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
454
dset.set_format(type="tensorflow", columns=["col_1"])
455
self.assertEqual(len(dset[0]), 1)
456
self.assertIsInstance(dset[0]["col_1"], tf.Tensor)
457
self.assertListEqual(list(dset[0]["col_1"].shape), [])
458
self.assertEqual(dset[0]["col_1"].numpy().item(), 3)
460
dset.set_format(type="tensorflow", columns=["col_1"], output_all_columns=True)
461
self.assertEqual(len(dset[0]), 3)
462
self.assertIsInstance(dset[0]["col_2"], str)
463
self.assertEqual(dset[0]["col_2"], "a")
465
dset.set_format(type="tensorflow", columns=["col_1", "col_2"])
466
self.assertEqual(len(dset[0]), 2)
467
self.assertEqual(dset[0]["col_2"].numpy().decode("utf-8"), "a")
469
def test_set_format_pandas(self, in_memory):
470
with tempfile.TemporaryDirectory() as tmp_dir:
471
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
472
dset.set_format(type="pandas", columns=["col_1"])
473
self.assertEqual(len(dset[0].columns), 1)
474
self.assertIsInstance(dset[0], pd.DataFrame)
475
self.assertListEqual(list(dset[0].shape), [1, 1])
476
self.assertEqual(dset[0]["col_1"].item(), 3)
478
dset.set_format(type="pandas", columns=["col_1", "col_2"])
479
self.assertEqual(len(dset[0].columns), 2)
480
self.assertEqual(dset[0]["col_2"].item(), "a")
482
def test_set_transform(self, in_memory):
483
def transform(batch):
484
return {k: [str(i).upper() for i in v] for k, v in batch.items()}
486
with tempfile.TemporaryDirectory() as tmp_dir:
487
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
488
dset.set_transform(transform=transform, columns=["col_1"])
489
self.assertEqual(dset.format["type"], "custom")
490
self.assertEqual(len(dset[0].keys()), 1)
491
self.assertEqual(dset[0]["col_1"], "3")
492
self.assertEqual(dset[:2]["col_1"], ["3", "2"])
493
self.assertEqual(dset["col_1"][:2], ["3", "2"])
495
prev_format = dset.format
496
dset.set_format(**dset.format)
497
self.assertEqual(prev_format, dset.format)
499
dset.set_transform(transform=transform, columns=["col_1", "col_2"])
500
self.assertEqual(len(dset[0].keys()), 2)
501
self.assertEqual(dset[0]["col_2"], "A")
503
def test_transmit_format(self, in_memory):
504
with tempfile.TemporaryDirectory() as tmp_dir:
505
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
506
transform = datasets.arrow_dataset.transmit_format(lambda x: x)
508
self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
509
dset.set_format(**dset.format)
510
self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
512
dset.set_format(columns=["col_1"])
513
self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
514
dset.set_format(columns=["col_1", "col_2"])
515
self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
516
dset.set_format("numpy", columns=["col_1", "col_2"])
517
self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
519
def test_cast(self, in_memory):
520
with tempfile.TemporaryDirectory() as tmp_dir:
521
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
522
features = dset.features
523
features["col_1"] = Value("float64")
524
features = Features({k: features[k] for k in list(features)[::-1]})
525
fingerprint = dset._fingerprint
527
with dset.cast(features) as casted_dset:
528
self.assertEqual(casted_dset.num_columns, 3)
529
self.assertEqual(casted_dset.features["col_1"], Value("float64"))
530
self.assertIsInstance(casted_dset[0]["col_1"], float)
531
self.assertNotEqual(casted_dset._fingerprint, fingerprint)
532
self.assertNotEqual(casted_dset, dset)
533
assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
535
def test_class_encode_column(self, in_memory):
536
with tempfile.TemporaryDirectory() as tmp_dir:
537
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
538
with self.assertRaises(ValueError):
539
dset.class_encode_column(column="does not exist")
541
with dset.class_encode_column("col_1") as casted_dset:
542
self.assertIsInstance(casted_dset.features["col_1"], ClassLabel)
543
self.assertListEqual(casted_dset.features["col_1"].names, ["0", "1", "2", "3"])
544
self.assertListEqual(casted_dset["col_1"], [3, 2, 1, 0])
545
self.assertNotEqual(casted_dset._fingerprint, dset._fingerprint)
546
self.assertNotEqual(casted_dset, dset)
547
assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
549
with dset.class_encode_column("col_2") as casted_dset:
550
self.assertIsInstance(casted_dset.features["col_2"], ClassLabel)
551
self.assertListEqual(casted_dset.features["col_2"].names, ["a", "b", "c", "d"])
552
self.assertListEqual(casted_dset["col_2"], [0, 1, 2, 3])
553
self.assertNotEqual(casted_dset._fingerprint, dset._fingerprint)
554
self.assertNotEqual(casted_dset, dset)
555
assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
557
with dset.class_encode_column("col_3") as casted_dset:
558
self.assertIsInstance(casted_dset.features["col_3"], ClassLabel)
559
self.assertListEqual(casted_dset.features["col_3"].names, ["False", "True"])
560
self.assertListEqual(casted_dset["col_3"], [0, 1, 0, 1])
561
self.assertNotEqual(casted_dset._fingerprint, dset._fingerprint)
562
self.assertNotEqual(casted_dset, dset)
563
assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
566
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
567
for column in dset.column_names:
568
with self.assertRaises(ValueError):
569
dset.class_encode_column(column)
571
def test_remove_columns(self, in_memory):
572
with tempfile.TemporaryDirectory() as tmp_dir:
573
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
574
fingerprint = dset._fingerprint
575
with dset.remove_columns(column_names="col_1") as new_dset:
576
self.assertEqual(new_dset.num_columns, 2)
577
self.assertListEqual(list(new_dset.column_names), ["col_2", "col_3"])
578
self.assertNotEqual(new_dset._fingerprint, fingerprint)
579
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
581
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
582
with dset.remove_columns(column_names=["col_1", "col_2", "col_3"]) as new_dset:
583
self.assertEqual(new_dset.num_columns, 0)
584
self.assertNotEqual(new_dset._fingerprint, fingerprint)
585
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
587
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
588
dset._format_columns = ["col_1", "col_2", "col_3"]
589
with dset.remove_columns(column_names=["col_1"]) as new_dset:
590
self.assertListEqual(new_dset._format_columns, ["col_2", "col_3"])
591
self.assertEqual(new_dset.num_columns, 2)
592
self.assertListEqual(list(new_dset.column_names), ["col_2", "col_3"])
593
self.assertNotEqual(new_dset._fingerprint, fingerprint)
594
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
596
def test_rename_column(self, in_memory):
597
with tempfile.TemporaryDirectory() as tmp_dir:
598
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
599
fingerprint = dset._fingerprint
600
with dset.rename_column(original_column_name="col_1", new_column_name="new_name") as new_dset:
601
self.assertEqual(new_dset.num_columns, 3)
602
self.assertListEqual(list(new_dset.column_names), ["new_name", "col_2", "col_3"])
603
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
604
self.assertNotEqual(new_dset._fingerprint, fingerprint)
605
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
607
def test_rename_columns(self, in_memory):
608
with tempfile.TemporaryDirectory() as tmp_dir:
609
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
610
fingerprint = dset._fingerprint
611
with dset.rename_columns({"col_1": "new_name"}) as new_dset:
612
self.assertEqual(new_dset.num_columns, 3)
613
self.assertListEqual(list(new_dset.column_names), ["new_name", "col_2", "col_3"])
614
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
615
self.assertNotEqual(new_dset._fingerprint, fingerprint)
617
with dset.rename_columns({"col_1": "new_name", "col_2": "new_name2"}) as new_dset:
618
self.assertEqual(new_dset.num_columns, 3)
619
self.assertListEqual(list(new_dset.column_names), ["new_name", "new_name2", "col_3"])
620
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
621
self.assertNotEqual(new_dset._fingerprint, fingerprint)
624
with self.assertRaises(ValueError):
625
dset.rename_columns({"not_there": "new_name"})
628
with self.assertRaises(ValueError):
629
dset.rename_columns({"col_1": ""})
632
with self.assertRaises(ValueError):
633
dset.rename_columns({"col_1": "new_name", "col_2": "new_name"})
635
def test_select_columns(self, in_memory):
636
with tempfile.TemporaryDirectory() as tmp_dir:
637
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
638
fingerprint = dset._fingerprint
639
with dset.select_columns(column_names=[]) as new_dset:
640
self.assertEqual(new_dset.num_columns, 0)
641
self.assertListEqual(list(new_dset.column_names), [])
642
self.assertNotEqual(new_dset._fingerprint, fingerprint)
643
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
645
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
646
fingerprint = dset._fingerprint
647
with dset.select_columns(column_names="col_1") as new_dset:
648
self.assertEqual(new_dset.num_columns, 1)
649
self.assertListEqual(list(new_dset.column_names), ["col_1"])
650
self.assertNotEqual(new_dset._fingerprint, fingerprint)
651
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
653
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
654
with dset.select_columns(column_names=["col_1", "col_2", "col_3"]) as new_dset:
655
self.assertEqual(new_dset.num_columns, 3)
656
self.assertListEqual(list(new_dset.column_names), ["col_1", "col_2", "col_3"])
657
self.assertNotEqual(new_dset._fingerprint, fingerprint)
658
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
660
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
661
with dset.select_columns(column_names=["col_3", "col_2", "col_1"]) as new_dset:
662
self.assertEqual(new_dset.num_columns, 3)
663
self.assertListEqual(list(new_dset.column_names), ["col_3", "col_2", "col_1"])
664
self.assertNotEqual(new_dset._fingerprint, fingerprint)
665
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
667
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
668
dset._format_columns = ["col_1", "col_2", "col_3"]
669
with dset.select_columns(column_names=["col_1"]) as new_dset:
670
self.assertListEqual(new_dset._format_columns, ["col_1"])
671
self.assertEqual(new_dset.num_columns, 1)
672
self.assertListEqual(list(new_dset.column_names), ["col_1"])
673
self.assertNotEqual(new_dset._fingerprint, fingerprint)
674
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
676
def test_concatenate(self, in_memory):
677
data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
678
info1 = DatasetInfo(description="Dataset1")
679
info2 = DatasetInfo(description="Dataset2")
680
with tempfile.TemporaryDirectory() as tmp_dir:
681
dset1, dset2, dset3 = (
682
Dataset.from_dict(data1, info=info1),
683
Dataset.from_dict(data2, info=info2),
684
Dataset.from_dict(data3),
686
dset1, dset2, dset3 = self._to(in_memory, tmp_dir, dset1, dset2, dset3)
688
with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
689
self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 2))
690
self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
691
self.assertListEqual(dset_concat["id"], [0, 1, 2, 3, 4, 5, 6, 7])
692
self.assertEqual(len(dset_concat.cache_files), 0 if in_memory else 3)
693
self.assertEqual(dset_concat.info.description, "Dataset1\n\nDataset2")
694
del dset1, dset2, dset3
696
def test_concatenate_formatted(self, in_memory):
697
data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
698
info1 = DatasetInfo(description="Dataset1")
699
info2 = DatasetInfo(description="Dataset2")
700
with tempfile.TemporaryDirectory() as tmp_dir:
701
dset1, dset2, dset3 = (
702
Dataset.from_dict(data1, info=info1),
703
Dataset.from_dict(data2, info=info2),
704
Dataset.from_dict(data3),
706
dset1, dset2, dset3 = self._to(in_memory, tmp_dir, dset1, dset2, dset3)
708
dset1.set_format("numpy")
709
with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
710
self.assertEqual(dset_concat.format["type"], None)
711
dset2.set_format("numpy")
712
dset3.set_format("numpy")
713
with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
714
self.assertEqual(dset_concat.format["type"], "numpy")
715
del dset1, dset2, dset3
717
def test_concatenate_with_indices(self, in_memory):
718
data1, data2, data3 = {"id": [0, 1, 2] * 2}, {"id": [3, 4, 5] * 2}, {"id": [6, 7, 8]}
719
info1 = DatasetInfo(description="Dataset1")
720
info2 = DatasetInfo(description="Dataset2")
721
with tempfile.TemporaryDirectory() as tmp_dir:
722
dset1, dset2, dset3 = (
723
Dataset.from_dict(data1, info=info1),
724
Dataset.from_dict(data2, info=info2),
725
Dataset.from_dict(data3),
727
dset1, dset2, dset3 = self._to(in_memory, tmp_dir, dset1, dset2, dset3)
728
dset1, dset2, dset3 = dset1.select([2, 1, 0]), dset2.select([2, 1, 0]), dset3
730
with concatenate_datasets([dset3, dset2, dset1]) as dset_concat:
731
self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 3))
732
self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
733
self.assertListEqual(dset_concat["id"], [6, 7, 8, 5, 4, 3, 2, 1, 0])
739
self.assertEqual(len(dset_concat.cache_files), 0 if in_memory else 3)
740
self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1")
742
dset1 = dset1.rename_columns({"id": "id1"})
743
dset2 = dset2.rename_columns({"id": "id2"})
744
dset3 = dset3.rename_columns({"id": "id3"})
745
with concatenate_datasets([dset1, dset2, dset3], axis=1) as dset_concat:
746
self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 3))
747
self.assertEqual(len(dset_concat), len(dset1))
748
self.assertListEqual(dset_concat["id1"], [2, 1, 0])
749
self.assertListEqual(dset_concat["id2"], [5, 4, 3])
750
self.assertListEqual(dset_concat["id3"], [6, 7, 8])
756
self.assertEqual(len(dset_concat.cache_files), 0 if in_memory else 3)
757
self.assertIsNone(dset_concat._indices)
758
self.assertEqual(dset_concat.info.description, "Dataset1\n\nDataset2")
760
with concatenate_datasets([dset1], axis=1) as dset_concat:
761
self.assertEqual(len(dset_concat), len(dset1))
762
self.assertListEqual(dset_concat["id1"], [2, 1, 0])
768
self.assertEqual(len(dset_concat.cache_files), 0 if in_memory else 1)
769
self.assertTrue(dset_concat._indices == dset1._indices)
770
self.assertEqual(dset_concat.info.description, "Dataset1")
771
del dset1, dset2, dset3
773
def test_concatenate_with_indices_from_disk(self, in_memory):
774
data1, data2, data3 = {"id": [0, 1, 2] * 2}, {"id": [3, 4, 5] * 2}, {"id": [6, 7]}
775
info1 = DatasetInfo(description="Dataset1")
776
info2 = DatasetInfo(description="Dataset2")
777
with tempfile.TemporaryDirectory() as tmp_dir:
778
dset1, dset2, dset3 = (
779
Dataset.from_dict(data1, info=info1),
780
Dataset.from_dict(data2, info=info2),
781
Dataset.from_dict(data3),
783
dset1, dset2, dset3 = self._to(in_memory, tmp_dir, dset1, dset2, dset3)
784
dset1, dset2, dset3 = (
785
dset1.select([2, 1, 0], indices_cache_file_name=os.path.join(tmp_dir, "i1.arrow")),
786
dset2.select([2, 1, 0], indices_cache_file_name=os.path.join(tmp_dir, "i2.arrow")),
787
dset3.select([1, 0], indices_cache_file_name=os.path.join(tmp_dir, "i3.arrow")),
790
with concatenate_datasets([dset3, dset2, dset1]) as dset_concat:
791
self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 2))
792
self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
793
self.assertListEqual(dset_concat["id"], [7, 6, 5, 4, 3, 2, 1, 0])
800
self.assertEqual(len(dset_concat.cache_files), 1 if in_memory else 3 + 1)
801
self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1")
802
del dset1, dset2, dset3
804
def test_concatenate_pickle(self, in_memory):
805
data1, data2, data3 = {"id": [0, 1, 2] * 2}, {"id": [3, 4, 5] * 2}, {"id": [6, 7], "foo": ["bar", "bar"]}
806
info1 = DatasetInfo(description="Dataset1")
807
info2 = DatasetInfo(description="Dataset2")
808
with tempfile.TemporaryDirectory() as tmp_dir:
809
dset1, dset2, dset3 = (
810
Dataset.from_dict(data1, info=info1),
811
Dataset.from_dict(data2, info=info2),
812
Dataset.from_dict(data3),
815
dset1, dset2 = self._to(in_memory, tmp_dir, dset1, dset2)
816
dset3 = self._to(not in_memory, tmp_dir, dset3)
817
dset1, dset2, dset3 = (
820
keep_in_memory=in_memory,
821
indices_cache_file_name=os.path.join(tmp_dir, "i1.arrow") if not in_memory else None,
825
keep_in_memory=in_memory,
826
indices_cache_file_name=os.path.join(tmp_dir, "i2.arrow") if not in_memory else None,
830
keep_in_memory=in_memory,
831
indices_cache_file_name=os.path.join(tmp_dir, "i3.arrow") if not in_memory else None,
835
dset3 = dset3.rename_column("foo", "new_foo")
836
dset3 = dset3.remove_columns("new_foo")
838
dset3._data.table = Unpicklable()
840
dset1._data.table, dset2._data.table = Unpicklable(), Unpicklable()
841
dset1, dset2, dset3 = (pickle.loads(pickle.dumps(d)) for d in (dset1, dset2, dset3))
842
with concatenate_datasets([dset3, dset2, dset1]) as dset_concat:
844
dset_concat._data.table = Unpicklable()
845
with pickle.loads(pickle.dumps(dset_concat)) as dset_concat:
846
self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 2))
847
self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
848
self.assertListEqual(dset_concat["id"], [7, 6, 5, 4, 3, 2, 1, 0])
851
self.assertEqual(len(dset_concat.cache_files), 1 if in_memory else 2 + 1)
852
self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1")
853
del dset1, dset2, dset3
855
def test_flatten(self, in_memory):
856
with tempfile.TemporaryDirectory() as tmp_dir:
857
with Dataset.from_dict(
858
{"a": [{"b": {"c": ["text"]}}] * 10, "foo": [1] * 10},
859
features=Features({"a": {"b": Sequence({"c": Value("string")})}, "foo": Value("int64")}),
861
with self._to(in_memory, tmp_dir, dset) as dset:
862
fingerprint = dset._fingerprint
863
with dset.flatten() as dset:
864
self.assertListEqual(sorted(dset.column_names), ["a.b.c", "foo"])
865
self.assertListEqual(sorted(dset.features.keys()), ["a.b.c", "foo"])
866
self.assertDictEqual(
867
dset.features, Features({"a.b.c": Sequence(Value("string")), "foo": Value("int64")})
869
self.assertNotEqual(dset._fingerprint, fingerprint)
870
assert_arrow_metadata_are_synced_with_dataset_features(dset)
872
with tempfile.TemporaryDirectory() as tmp_dir:
873
with Dataset.from_dict(
874
{"a": [{"en": "Thank you", "fr": "Merci"}] * 10, "foo": [1] * 10},
875
features=Features({"a": Translation(languages=["en", "fr"]), "foo": Value("int64")}),
877
with self._to(in_memory, tmp_dir, dset) as dset:
878
fingerprint = dset._fingerprint
879
with dset.flatten() as dset:
880
self.assertListEqual(sorted(dset.column_names), ["a.en", "a.fr", "foo"])
881
self.assertListEqual(sorted(dset.features.keys()), ["a.en", "a.fr", "foo"])
882
self.assertDictEqual(
884
Features({"a.en": Value("string"), "a.fr": Value("string"), "foo": Value("int64")}),
886
self.assertNotEqual(dset._fingerprint, fingerprint)
887
assert_arrow_metadata_are_synced_with_dataset_features(dset)
889
with tempfile.TemporaryDirectory() as tmp_dir:
890
with Dataset.from_dict(
891
{"a": [{"en": "the cat", "fr": ["le chat", "la chatte"], "de": "die katze"}] * 10, "foo": [1] * 10},
893
{"a": TranslationVariableLanguages(languages=["en", "fr", "de"]), "foo": Value("int64")}
896
with self._to(in_memory, tmp_dir, dset) as dset:
897
fingerprint = dset._fingerprint
898
with dset.flatten() as dset:
899
self.assertListEqual(sorted(dset.column_names), ["a.language", "a.translation", "foo"])
900
self.assertListEqual(sorted(dset.features.keys()), ["a.language", "a.translation", "foo"])
901
self.assertDictEqual(
905
"a.language": Sequence(Value("string")),
906
"a.translation": Sequence(Value("string")),
907
"foo": Value("int64"),
911
self.assertNotEqual(dset._fingerprint, fingerprint)
912
assert_arrow_metadata_are_synced_with_dataset_features(dset)
915
def test_flatten_complex_image(self, in_memory):
917
with tempfile.TemporaryDirectory() as tmp_dir:
918
with Dataset.from_dict(
919
{"a": [np.arange(4 * 4 * 3, dtype=np.uint8).reshape(4, 4, 3)] * 10, "foo": [1] * 10},
920
features=Features({"a": Image(), "foo": Value("int64")}),
922
with self._to(in_memory, tmp_dir, dset) as dset:
923
fingerprint = dset._fingerprint
924
with dset.flatten() as dset:
925
self.assertListEqual(sorted(dset.column_names), ["a", "foo"])
926
self.assertListEqual(sorted(dset.features.keys()), ["a", "foo"])
927
self.assertDictEqual(dset.features, Features({"a": Image(), "foo": Value("int64")}))
928
self.assertNotEqual(dset._fingerprint, fingerprint)
929
assert_arrow_metadata_are_synced_with_dataset_features(dset)
932
with tempfile.TemporaryDirectory() as tmp_dir:
933
with Dataset.from_dict(
934
{"a": [{"b": np.arange(4 * 4 * 3, dtype=np.uint8).reshape(4, 4, 3)}] * 10, "foo": [1] * 10},
935
features=Features({"a": {"b": Image()}, "foo": Value("int64")}),
937
with self._to(in_memory, tmp_dir, dset) as dset:
938
fingerprint = dset._fingerprint
939
with dset.flatten() as dset:
940
self.assertListEqual(sorted(dset.column_names), ["a.b", "foo"])
941
self.assertListEqual(sorted(dset.features.keys()), ["a.b", "foo"])
942
self.assertDictEqual(dset.features, Features({"a.b": Image(), "foo": Value("int64")}))
943
self.assertNotEqual(dset._fingerprint, fingerprint)
944
assert_arrow_metadata_are_synced_with_dataset_features(dset)
947
with tempfile.TemporaryDirectory() as tmp_dir:
948
with Dataset.from_dict(
949
{"a": [np.arange(4 * 4 * 3, dtype=np.uint8).reshape(4, 4, 3)] * 10, "foo": [1] * 10},
950
features=Features({"a": Image(decode=False), "foo": Value("int64")}),
952
with self._to(in_memory, tmp_dir, dset) as dset:
953
fingerprint = dset._fingerprint
954
with dset.flatten() as dset:
955
self.assertListEqual(sorted(dset.column_names), ["a.bytes", "a.path", "foo"])
956
self.assertListEqual(sorted(dset.features.keys()), ["a.bytes", "a.path", "foo"])
957
self.assertDictEqual(
959
Features({"a.bytes": Value("binary"), "a.path": Value("string"), "foo": Value("int64")}),
961
self.assertNotEqual(dset._fingerprint, fingerprint)
962
assert_arrow_metadata_are_synced_with_dataset_features(dset)
965
with tempfile.TemporaryDirectory() as tmp_dir:
966
with Dataset.from_dict(
967
{"a": [{"b": np.arange(4 * 4 * 3, dtype=np.uint8).reshape(4, 4, 3)}] * 10, "foo": [1] * 10},
968
features=Features({"a": {"b": Image(decode=False)}, "foo": Value("int64")}),
970
with self._to(in_memory, tmp_dir, dset) as dset:
971
fingerprint = dset._fingerprint
972
with dset.flatten() as dset:
973
self.assertListEqual(sorted(dset.column_names), ["a.b.bytes", "a.b.path", "foo"])
974
self.assertListEqual(sorted(dset.features.keys()), ["a.b.bytes", "a.b.path", "foo"])
975
self.assertDictEqual(
978
{"a.b.bytes": Value("binary"), "a.b.path": Value("string"), "foo": Value("int64")}
981
self.assertNotEqual(dset._fingerprint, fingerprint)
982
assert_arrow_metadata_are_synced_with_dataset_features(dset)
984
def test_map(self, in_memory):
986
with tempfile.TemporaryDirectory() as tmp_dir:
987
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
988
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
989
fingerprint = dset._fingerprint
991
lambda x: {"name": x["filename"][:-2], "id": int(x["filename"].split("_")[-1])}
993
self.assertEqual(len(dset_test), 30)
994
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
995
self.assertDictEqual(
997
Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
999
self.assertListEqual(dset_test["id"], list(range(30)))
1000
self.assertNotEqual(dset_test._fingerprint, fingerprint)
1001
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1004
with tempfile.TemporaryDirectory() as tmp_dir:
1005
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1006
fingerprint = dset._fingerprint
1007
with dset.map(lambda x: None) as dset_test:
1008
self.assertEqual(len(dset_test), 30)
1009
self.assertEqual(dset_test._fingerprint, fingerprint)
1010
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1013
with tempfile.TemporaryDirectory() as tmp_dir:
1014
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1016
lambda x, i: {"name": x["filename"][:-2], "id": i}, with_indices=True
1017
) as dset_test_with_indices:
1018
self.assertEqual(len(dset_test_with_indices), 30)
1019
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1020
self.assertDictEqual(
1021
dset_test_with_indices.features,
1022
Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
1024
self.assertListEqual(dset_test_with_indices["id"], list(range(30)))
1025
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices)
1028
with tempfile.TemporaryDirectory() as tmp_dir:
1029
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1033
raise KeyboardInterrupt()
1034
return {"name": x["filename"][:-2], "id": i}
1036
tmp_file = os.path.join(tmp_dir, "test.arrow")
1042
cache_file_name=tmp_file,
1043
writer_batch_size=2,
1045
self.assertFalse(os.path.exists(tmp_file))
1047
lambda x, i: {"name": x["filename"][:-2], "id": i},
1049
cache_file_name=tmp_file,
1050
writer_batch_size=2,
1051
) as dset_test_with_indices:
1052
self.assertTrue(os.path.exists(tmp_file))
1053
self.assertEqual(len(dset_test_with_indices), 30)
1054
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1055
self.assertDictEqual(
1056
dset_test_with_indices.features,
1057
Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
1059
self.assertListEqual(dset_test_with_indices["id"], list(range(30)))
1060
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices)
1063
with tempfile.TemporaryDirectory() as tmp_dir:
1064
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
1065
dset.set_format("numpy", columns=["col_1"])
1066
with dset.map(lambda x: {"col_1_plus_one": x["col_1"] + 1}) as dset_test:
1067
self.assertEqual(len(dset_test), 4)
1068
self.assertEqual(dset_test.format["type"], "numpy")
1069
self.assertIsInstance(dset_test["col_1"], np.ndarray)
1070
self.assertIsInstance(dset_test["col_1_plus_one"], np.ndarray)
1071
self.assertListEqual(sorted(dset_test[0].keys()), ["col_1", "col_1_plus_one"])
1072
self.assertListEqual(sorted(dset_test.column_names), ["col_1", "col_1_plus_one", "col_2", "col_3"])
1073
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1075
def test_map_multiprocessing(self, in_memory):
1076
with tempfile.TemporaryDirectory() as tmp_dir:
1077
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1078
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1079
fingerprint = dset._fingerprint
1080
with dset.map(picklable_map_function, num_proc=2) as dset_test:
1081
self.assertEqual(len(dset_test), 30)
1082
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1083
self.assertDictEqual(
1085
Features({"filename": Value("string"), "id": Value("int64")}),
1087
self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 2)
1089
self.assertIn("_of_00002.arrow", dset_test.cache_files[0]["filename"])
1090
self.assertListEqual(dset_test["id"], list(range(30)))
1091
self.assertNotEqual(dset_test._fingerprint, fingerprint)
1092
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1094
with tempfile.TemporaryDirectory() as tmp_dir:
1095
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1096
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1097
fingerprint = dset._fingerprint
1098
with dset.select([0, 1], keep_in_memory=True).map(picklable_map_function, num_proc=10) as dset_test:
1099
self.assertEqual(len(dset_test), 2)
1100
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1101
self.assertDictEqual(
1103
Features({"filename": Value("string"), "id": Value("int64")}),
1105
self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 2)
1106
self.assertListEqual(dset_test["id"], list(range(2)))
1107
self.assertNotEqual(dset_test._fingerprint, fingerprint)
1108
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1110
with tempfile.TemporaryDirectory() as tmp_dir:
1111
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1112
fingerprint = dset._fingerprint
1113
with dset.map(picklable_map_function_with_indices, num_proc=3, with_indices=True) as dset_test:
1114
self.assertEqual(len(dset_test), 30)
1115
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1116
self.assertDictEqual(
1118
Features({"filename": Value("string"), "id": Value("int64")}),
1120
self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 3)
1121
self.assertListEqual(dset_test["id"], list(range(30)))
1122
self.assertNotEqual(dset_test._fingerprint, fingerprint)
1123
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1125
with tempfile.TemporaryDirectory() as tmp_dir:
1126
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1127
fingerprint = dset._fingerprint
1128
with dset.map(picklable_map_function_with_rank, num_proc=3, with_rank=True) as dset_test:
1129
self.assertEqual(len(dset_test), 30)
1130
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1131
self.assertDictEqual(
1133
Features({"filename": Value("string"), "rank": Value("int64")}),
1135
self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 3)
1136
self.assertListEqual(dset_test["rank"], [0] * 10 + [1] * 10 + [2] * 10)
1137
self.assertNotEqual(dset_test._fingerprint, fingerprint)
1138
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1140
with tempfile.TemporaryDirectory() as tmp_dir:
1141
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1142
fingerprint = dset._fingerprint
1144
picklable_map_function_with_indices_and_rank, num_proc=3, with_indices=True, with_rank=True
1146
self.assertEqual(len(dset_test), 30)
1147
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1148
self.assertDictEqual(
1150
Features({"filename": Value("string"), "id": Value("int64"), "rank": Value("int64")}),
1152
self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 3)
1153
self.assertListEqual(dset_test["id"], list(range(30)))
1154
self.assertListEqual(dset_test["rank"], [0] * 10 + [1] * 10 + [2] * 10)
1155
self.assertNotEqual(dset_test._fingerprint, fingerprint)
1156
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1158
with tempfile.TemporaryDirectory() as tmp_dir:
1159
new_fingerprint = "foobar"
1160
invalid_new_fingerprint = "foobar/hey"
1161
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1162
fingerprint = dset._fingerprint
1164
ValueError, dset.map, picklable_map_function, num_proc=2, new_fingerprint=invalid_new_fingerprint
1166
with dset.map(picklable_map_function, num_proc=2, new_fingerprint=new_fingerprint) as dset_test:
1167
self.assertEqual(len(dset_test), 30)
1168
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1169
self.assertDictEqual(
1171
Features({"filename": Value("string"), "id": Value("int64")}),
1173
self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 2)
1174
self.assertListEqual(dset_test["id"], list(range(30)))
1175
self.assertNotEqual(dset_test._fingerprint, fingerprint)
1176
self.assertEqual(dset_test._fingerprint, new_fingerprint)
1177
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1178
file_names = sorted(Path(cache_file["filename"]).name for cache_file in dset_test.cache_files)
1179
for i, file_name in enumerate(file_names):
1180
self.assertIn(new_fingerprint + f"_{i:05d}", file_name)
1182
with tempfile.TemporaryDirectory() as tmp_dir:
1183
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1184
fingerprint = dset._fingerprint
1185
with dset.map(lambda x: {"id": int(x["filename"].split("_")[-1])}, num_proc=2) as dset_test:
1186
self.assertEqual(len(dset_test), 30)
1187
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1188
self.assertDictEqual(
1190
Features({"filename": Value("string"), "id": Value("int64")}),
1192
self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 2)
1193
self.assertListEqual(dset_test["id"], list(range(30)))
1194
self.assertNotEqual(dset_test._fingerprint, fingerprint)
1195
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1197
def test_map_new_features(self, in_memory):
1198
with tempfile.TemporaryDirectory() as tmp_dir:
1199
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1200
features = Features({"filename": Value("string"), "label": ClassLabel(names=["positive", "negative"])})
1202
lambda x, i: {"label": i % 2}, with_indices=True, features=features
1203
) as dset_test_with_indices:
1204
self.assertEqual(len(dset_test_with_indices), 30)
1205
self.assertDictEqual(
1206
dset_test_with_indices.features,
1209
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices)
1211
def test_map_batched(self, in_memory):
1212
def map_batched(example):
1213
return {"filename_new": [x + "_extension" for x in example["filename"]]}
1215
with tempfile.TemporaryDirectory() as tmp_dir:
1216
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1217
with dset.map(map_batched, batched=True) as dset_test_batched:
1218
self.assertEqual(len(dset_test_batched), 30)
1219
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1220
self.assertDictEqual(
1221
dset_test_batched.features,
1222
Features({"filename": Value("string"), "filename_new": Value("string")}),
1224
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_batched)
1227
with tempfile.TemporaryDirectory() as tmp_dir:
1228
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1231
map_batched, batched=True, batch_size=batch_size, drop_last_batch=True
1232
) as dset_test_batched:
1233
self.assertEqual(len(dset_test_batched), 30 // batch_size * batch_size)
1234
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1235
self.assertDictEqual(
1236
dset_test_batched.features,
1237
Features({"filename": Value("string"), "filename_new": Value("string")}),
1239
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_batched)
1241
with tempfile.TemporaryDirectory() as tmp_dir:
1242
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1243
with dset.formatted_as("numpy", columns=["filename"]):
1244
with dset.map(map_batched, batched=True) as dset_test_batched:
1245
self.assertEqual(len(dset_test_batched), 30)
1246
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1247
self.assertDictEqual(
1248
dset_test_batched.features,
1249
Features({"filename": Value("string"), "filename_new": Value("string")}),
1251
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_batched)
1253
def map_batched_with_indices(example, idx):
1254
return {"filename_new": [x + "_extension_" + str(idx) for x in example["filename"]]}
1256
with tempfile.TemporaryDirectory() as tmp_dir:
1257
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1259
map_batched_with_indices, batched=True, with_indices=True
1260
) as dset_test_with_indices_batched:
1261
self.assertEqual(len(dset_test_with_indices_batched), 30)
1262
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1263
self.assertDictEqual(
1264
dset_test_with_indices_batched.features,
1265
Features({"filename": Value("string"), "filename_new": Value("string")}),
1267
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices_batched)
1270
def map_batched_modifying_inputs_inplace(example):
1271
result = {"filename_new": [x + "_extension" for x in example["filename"]]}
1272
del example["filename"]
1275
with tempfile.TemporaryDirectory() as tmp_dir:
1276
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1278
map_batched_modifying_inputs_inplace, batched=True, remove_columns="filename"
1279
) as dset_test_modifying_inputs_inplace:
1280
self.assertEqual(len(dset_test_modifying_inputs_inplace), 30)
1281
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1282
self.assertDictEqual(
1283
dset_test_modifying_inputs_inplace.features,
1284
Features({"filename_new": Value("string")}),
1286
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_modifying_inputs_inplace)
1288
def test_map_nested(self, in_memory):
1289
with tempfile.TemporaryDirectory() as tmp_dir:
1290
with Dataset.from_dict({"field": ["a", "b"]}) as dset:
1291
with self._to(in_memory, tmp_dir, dset) as dset:
1292
with dset.map(lambda example: {"otherfield": {"capital": example["field"].capitalize()}}) as dset:
1293
with dset.map(lambda example: {"otherfield": {"append_x": example["field"] + "x"}}) as dset:
1294
self.assertEqual(dset[0], {"field": "a", "otherfield": {"append_x": "ax"}})
1296
def test_map_return_example_as_dict_value(self, in_memory):
1297
with tempfile.TemporaryDirectory() as tmp_dir:
1298
with Dataset.from_dict({"en": ["aa", "bb"], "fr": ["cc", "dd"]}) as dset:
1299
with self._to(in_memory, tmp_dir, dset) as dset:
1300
with dset.map(lambda example: {"translation": example}) as dset:
1301
self.assertEqual(dset[0], {"en": "aa", "fr": "cc", "translation": {"en": "aa", "fr": "cc"}})
1303
def test_map_fn_kwargs(self, in_memory):
1304
with tempfile.TemporaryDirectory() as tmp_dir:
1305
with Dataset.from_dict({"id": range(10)}) as dset:
1306
with self._to(in_memory, tmp_dir, dset) as dset:
1307
fn_kwargs = {"offset": 3}
1309
lambda example, offset: {"id+offset": example["id"] + offset}, fn_kwargs=fn_kwargs
1311
assert mapped_dset["id+offset"] == list(range(3, 13))
1313
lambda id, offset: {"id+offset": id + offset}, fn_kwargs=fn_kwargs, input_columns="id"
1315
assert mapped_dset["id+offset"] == list(range(3, 13))
1317
lambda id, i, offset: {"id+offset": i + offset},
1318
fn_kwargs=fn_kwargs,
1322
assert mapped_dset["id+offset"] == list(range(3, 13))
1324
def test_map_caching(self, in_memory):
1325
with tempfile.TemporaryDirectory() as tmp_dir:
1326
self._caplog.clear()
1327
with self._caplog.at_level(INFO, logger=get_logger().name):
1328
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1330
"datasets.arrow_dataset.Dataset._map_single",
1331
autospec=Dataset._map_single,
1332
side_effect=Dataset._map_single,
1333
) as mock_map_single:
1334
with dset.map(lambda x: {"foo": "bar"}) as dset_test1:
1335
dset_test1_data_files = list(dset_test1.cache_files)
1336
self.assertEqual(mock_map_single.call_count, 1)
1337
with dset.map(lambda x: {"foo": "bar"}) as dset_test2:
1338
self.assertEqual(dset_test1_data_files, dset_test2.cache_files)
1339
self.assertEqual(len(dset_test2.cache_files), 1 - int(in_memory))
1340
self.assertTrue(("Loading cached processed dataset" in self._caplog.text) ^ in_memory)
1341
self.assertEqual(mock_map_single.call_count, 2 if in_memory else 1)
1343
with tempfile.TemporaryDirectory() as tmp_dir:
1344
self._caplog.clear()
1345
with self._caplog.at_level(INFO, logger=get_logger().name):
1346
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1347
with dset.map(lambda x: {"foo": "bar"}) as dset_test1:
1348
dset_test1_data_files = list(dset_test1.cache_files)
1349
with dset.map(lambda x: {"foo": "bar"}, load_from_cache_file=False) as dset_test2:
1350
self.assertEqual(dset_test1_data_files, dset_test2.cache_files)
1351
self.assertEqual(len(dset_test2.cache_files), 1 - int(in_memory))
1352
self.assertNotIn("Loading cached processed dataset", self._caplog.text)
1354
with tempfile.TemporaryDirectory() as tmp_dir:
1355
self._caplog.clear()
1356
with self._caplog.at_level(INFO, logger=get_logger().name):
1357
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1359
"datasets.arrow_dataset.Pool",
1360
new_callable=PickableMagicMock,
1361
side_effect=datasets.arrow_dataset.Pool,
1363
with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1:
1364
dset_test1_data_files = list(dset_test1.cache_files)
1365
self.assertEqual(mock_pool.call_count, 1)
1366
with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test2:
1367
self.assertEqual(dset_test1_data_files, dset_test2.cache_files)
1369
(len(re.findall("Loading cached processed dataset", self._caplog.text)) == 1)
1372
self.assertEqual(mock_pool.call_count, 2 if in_memory else 1)
1374
with tempfile.TemporaryDirectory() as tmp_dir:
1375
self._caplog.clear()
1376
with self._caplog.at_level(INFO, logger=get_logger().name):
1377
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1378
with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1:
1379
dset_test1_data_files = list(dset_test1.cache_files)
1380
with dset.map(lambda x: {"foo": "bar"}, num_proc=2, load_from_cache_file=False) as dset_test2:
1381
self.assertEqual(dset_test1_data_files, dset_test2.cache_files)
1382
self.assertEqual(len(dset_test2.cache_files), (1 - int(in_memory)) * 2)
1383
self.assertNotIn("Loading cached processed dataset", self._caplog.text)
1387
self._caplog.clear()
1388
with tempfile.TemporaryDirectory() as tmp_dir:
1389
with self._caplog.at_level(INFO, logger=get_logger().name):
1390
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1391
datasets.disable_caching()
1392
with dset.map(lambda x: {"foo": "bar"}) as dset_test1:
1393
with dset.map(lambda x: {"foo": "bar"}) as dset_test2:
1394
self.assertNotEqual(dset_test1.cache_files, dset_test2.cache_files)
1395
self.assertEqual(len(dset_test1.cache_files), 1)
1396
self.assertEqual(len(dset_test2.cache_files), 1)
1397
self.assertNotIn("Loading cached processed dataset", self._caplog.text)
1400
Path(tempfile.gettempdir()),
1401
Path(dset_test1.cache_files[0]["filename"]).parents,
1404
Path(tempfile.gettempdir()),
1405
Path(dset_test2.cache_files[0]["filename"]).parents,
1408
datasets.enable_caching()
1410
def test_map_return_pa_table(self, in_memory):
1411
def func_return_single_row_pa_table(x):
1412
return pa.table({"id": [0], "text": ["a"]})
1414
with tempfile.TemporaryDirectory() as tmp_dir:
1415
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1416
with dset.map(func_return_single_row_pa_table) as dset_test:
1417
self.assertEqual(len(dset_test), 30)
1418
self.assertDictEqual(
1420
Features({"id": Value("int64"), "text": Value("string")}),
1422
self.assertEqual(dset_test[0]["id"], 0)
1423
self.assertEqual(dset_test[0]["text"], "a")
1426
def func_return_single_row_pa_table_batched(x):
1427
batch_size = len(x[next(iter(x))])
1428
return pa.table({"id": [0] * batch_size, "text": ["a"] * batch_size})
1430
with tempfile.TemporaryDirectory() as tmp_dir:
1431
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1432
with dset.map(func_return_single_row_pa_table_batched, batched=True) as dset_test:
1433
self.assertEqual(len(dset_test), 30)
1434
self.assertDictEqual(
1436
Features({"id": Value("int64"), "text": Value("string")}),
1438
self.assertEqual(dset_test[0]["id"], 0)
1439
self.assertEqual(dset_test[0]["text"], "a")
1442
def func_return_multi_row_pa_table(x):
1443
return pa.table({"id": [0, 1], "text": ["a", "b"]})
1445
with tempfile.TemporaryDirectory() as tmp_dir:
1446
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1447
self.assertRaises(ValueError, dset.map, func_return_multi_row_pa_table)
1450
def func_return_table_from_expression(t):
1451
import pyarrow.dataset as pds
1453
return pds.dataset(t).to_table(
1454
columns={"new_column": pds.field("")._call("ascii_capitalize", [pds.field("filename")])}
1457
with tempfile.TemporaryDirectory() as tmp_dir:
1458
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1459
with dset.with_format("arrow").map(func_return_table_from_expression, batched=True) as dset_test:
1460
self.assertEqual(len(dset_test), 30)
1461
self.assertDictEqual(
1463
Features({"new_column": Value("string")}),
1465
self.assertEqual(dset_test.with_format(None)[0]["new_column"], dset[0]["filename"].capitalize())
1467
def test_map_return_pd_dataframe(self, in_memory):
1468
def func_return_single_row_pd_dataframe(x):
1469
return pd.DataFrame({"id": [0], "text": ["a"]})
1471
with tempfile.TemporaryDirectory() as tmp_dir:
1472
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1473
with dset.map(func_return_single_row_pd_dataframe) as dset_test:
1474
self.assertEqual(len(dset_test), 30)
1475
self.assertDictEqual(
1477
Features({"id": Value("int64"), "text": Value("string")}),
1479
self.assertEqual(dset_test[0]["id"], 0)
1480
self.assertEqual(dset_test[0]["text"], "a")
1483
def func_return_single_row_pd_dataframe_batched(x):
1484
batch_size = len(x[next(iter(x))])
1485
return pd.DataFrame({"id": [0] * batch_size, "text": ["a"] * batch_size})
1487
with tempfile.TemporaryDirectory() as tmp_dir:
1488
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1489
with dset.map(func_return_single_row_pd_dataframe_batched, batched=True) as dset_test:
1490
self.assertEqual(len(dset_test), 30)
1491
self.assertDictEqual(
1493
Features({"id": Value("int64"), "text": Value("string")}),
1495
self.assertEqual(dset_test[0]["id"], 0)
1496
self.assertEqual(dset_test[0]["text"], "a")
1499
def func_return_multi_row_pd_dataframe(x):
1500
return pd.DataFrame({"id": [0, 1], "text": ["a", "b"]})
1502
with tempfile.TemporaryDirectory() as tmp_dir:
1503
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1504
self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)
1507
def test_map_torch(self, in_memory):
1511
return {"tensor": torch.tensor([1.0, 2, 3])}
1513
with tempfile.TemporaryDirectory() as tmp_dir:
1514
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1515
with dset.map(func) as dset_test:
1516
self.assertEqual(len(dset_test), 30)
1517
self.assertDictEqual(
1519
Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
1521
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1524
def test_map_tf(self, in_memory):
1525
import tensorflow as tf
1528
return {"tensor": tf.constant([1.0, 2, 3])}
1530
with tempfile.TemporaryDirectory() as tmp_dir:
1531
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1532
with dset.map(func) as dset_test:
1533
self.assertEqual(len(dset_test), 30)
1534
self.assertDictEqual(
1536
Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
1538
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1541
def test_map_jax(self, in_memory):
1542
import jax.numpy as jnp
1545
return {"tensor": jnp.asarray([1.0, 2, 3])}
1547
with tempfile.TemporaryDirectory() as tmp_dir:
1548
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1549
with dset.map(func) as dset_test:
1550
self.assertEqual(len(dset_test), 30)
1551
self.assertDictEqual(
1553
Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
1555
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1557
def test_map_numpy(self, in_memory):
1559
return {"tensor": np.array([1.0, 2, 3])}
1561
with tempfile.TemporaryDirectory() as tmp_dir:
1562
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1563
with dset.map(func) as dset_test:
1564
self.assertEqual(len(dset_test), 30)
1565
self.assertDictEqual(
1567
Features({"filename": Value("string"), "tensor": Sequence(Value("float64"))}),
1569
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1572
def test_map_tensor_batched(self, in_memory):
1576
return {"tensor": torch.tensor([[1.0, 2, 3]] * len(batch["filename"]))}
1578
with tempfile.TemporaryDirectory() as tmp_dir:
1579
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1580
with dset.map(func, batched=True) as dset_test:
1581
self.assertEqual(len(dset_test), 30)
1582
self.assertDictEqual(
1584
Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
1586
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1588
def test_map_input_columns(self, in_memory):
1589
with tempfile.TemporaryDirectory() as tmp_dir:
1590
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
1591
with dset.map(lambda col_1: {"label": col_1 % 2}, input_columns="col_1") as mapped_dset:
1592
self.assertEqual(mapped_dset[0].keys(), {"col_1", "col_2", "col_3", "label"})
1594
mapped_dset.features,
1597
"col_1": Value("int64"),
1598
"col_2": Value("string"),
1599
"col_3": Value("bool"),
1600
"label": Value("int64"),
1605
def test_map_remove_columns(self, in_memory):
1606
with tempfile.TemporaryDirectory() as tmp_dir:
1607
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1608
with dset.map(lambda x, i: {"name": x["filename"][:-2], "id": i}, with_indices=True) as dset:
1609
self.assertTrue("id" in dset[0])
1610
self.assertDictEqual(
1612
Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
1614
assert_arrow_metadata_are_synced_with_dataset_features(dset)
1615
with dset.map(lambda x: x, remove_columns=["id"]) as mapped_dset:
1616
self.assertTrue("id" not in mapped_dset[0])
1617
self.assertDictEqual(
1618
mapped_dset.features, Features({"filename": Value("string"), "name": Value("string")})
1620
assert_arrow_metadata_are_synced_with_dataset_features(mapped_dset)
1621
with mapped_dset.with_format("numpy", columns=mapped_dset.column_names) as mapped_dset:
1622
with mapped_dset.map(
1623
lambda x: {"name": 1}, remove_columns=mapped_dset.column_names
1625
self.assertTrue("filename" not in mapped_dset[0])
1626
self.assertTrue("name" in mapped_dset[0])
1627
self.assertDictEqual(mapped_dset.features, Features({"name": Value(dtype="int64")}))
1628
assert_arrow_metadata_are_synced_with_dataset_features(mapped_dset)
1630
columns_names = dset.column_names
1631
with dset.select([]) as empty_dset:
1632
self.assertEqual(len(empty_dset), 0)
1633
with empty_dset.map(lambda x: {}, remove_columns=columns_names[0]) as mapped_dset:
1634
self.assertListEqual(columns_names[1:], mapped_dset.column_names)
1635
assert_arrow_metadata_are_synced_with_dataset_features(mapped_dset)
1637
def test_map_stateful_callable(self, in_memory):
1641
class ExampleCounter:
1642
def __init__(self, batched=False):
1643
self.batched = batched
1647
def __call__(self, example):
1649
self.cnt += len(example)
1653
with tempfile.TemporaryDirectory() as tmp_dir:
1654
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1655
ex_cnt = ExampleCounter()
1657
self.assertEqual(ex_cnt.cnt, len(dset))
1659
ex_cnt = ExampleCounter(batched=True)
1661
self.assertEqual(ex_cnt.cnt, len(dset))
1663
@require_not_windows
1664
def test_map_crash_subprocess(self, in_memory):
1671
os.kill(os.getpid(), 9)
1674
with tempfile.TemporaryDirectory() as tmp_dir:
1675
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1676
with pytest.raises(RuntimeError) as excinfo:
1677
dset.map(do_crash, num_proc=2)
1678
assert str(excinfo.value) == (
1679
"One of the subprocesses has abruptly died during map operation."
1680
"To debug the error, disable multiprocessing."
1683
def test_filter(self, in_memory):
1686
with tempfile.TemporaryDirectory() as tmp_dir:
1687
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1688
fingerprint = dset._fingerprint
1689
with dset.filter(lambda x, i: i < 5, with_indices=True) as dset_filter_first_five:
1690
self.assertEqual(len(dset_filter_first_five), 5)
1691
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1692
self.assertDictEqual(dset_filter_first_five.features, Features({"filename": Value("string")}))
1693
self.assertNotEqual(dset_filter_first_five._fingerprint, fingerprint)
1696
with tempfile.TemporaryDirectory() as tmp_dir:
1697
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1698
dset.set_format("numpy")
1699
fingerprint = dset._fingerprint
1700
with dset.filter(lambda x: (int(x["filename"][-1]) % 2 == 0)) as dset_filter_even_num:
1701
self.assertEqual(len(dset_filter_even_num), 15)
1702
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1703
self.assertDictEqual(dset_filter_even_num.features, Features({"filename": Value("string")}))
1704
self.assertNotEqual(dset_filter_even_num._fingerprint, fingerprint)
1705
self.assertEqual(dset_filter_even_num.format["type"], "numpy")
1707
def test_filter_with_indices_mapping(self, in_memory):
1708
with tempfile.TemporaryDirectory() as tmp_dir:
1709
dset = Dataset.from_dict({"col": [0, 1, 2]})
1710
with self._to(in_memory, tmp_dir, dset) as dset:
1711
with dset.filter(lambda x: x["col"] > 0) as dset:
1712
self.assertListEqual(dset["col"], [1, 2])
1713
with dset.filter(lambda x: x["col"] < 2) as dset:
1714
self.assertListEqual(dset["col"], [1])
1716
def test_filter_empty(self, in_memory):
1717
with tempfile.TemporaryDirectory() as tmp_dir:
1718
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1719
self.assertIsNone(dset._indices, None)
1721
tmp_file = os.path.join(tmp_dir, "test.arrow")
1722
with dset.filter(lambda _: False, cache_file_name=tmp_file) as dset:
1723
self.assertEqual(len(dset), 0)
1724
self.assertIsNotNone(dset._indices, None)
1726
tmp_file_2 = os.path.join(tmp_dir, "test_2.arrow")
1727
with dset.filter(lambda _: False, cache_file_name=tmp_file_2) as dset2:
1728
self.assertEqual(len(dset2), 0)
1729
self.assertEqual(dset._indices, dset2._indices)
1731
def test_filter_batched(self, in_memory):
1732
with tempfile.TemporaryDirectory() as tmp_dir:
1733
dset = Dataset.from_dict({"col": [0, 1, 2]})
1734
with self._to(in_memory, tmp_dir, dset) as dset:
1735
with dset.filter(lambda x: [i > 0 for i in x["col"]], batched=True) as dset:
1736
self.assertListEqual(dset["col"], [1, 2])
1737
with dset.filter(lambda x: [i < 2 for i in x["col"]], batched=True) as dset:
1738
self.assertListEqual(dset["col"], [1])
1740
def test_filter_input_columns(self, in_memory):
1741
with tempfile.TemporaryDirectory() as tmp_dir:
1742
dset = Dataset.from_dict({"col_1": [0, 1, 2], "col_2": ["a", "b", "c"]})
1743
with self._to(in_memory, tmp_dir, dset) as dset:
1744
with dset.filter(lambda x: x > 0, input_columns=["col_1"]) as filtered_dset:
1745
self.assertListEqual(filtered_dset.column_names, dset.column_names)
1746
self.assertListEqual(filtered_dset["col_1"], [1, 2])
1747
self.assertListEqual(filtered_dset["col_2"], ["b", "c"])
1749
def test_filter_fn_kwargs(self, in_memory):
1750
with tempfile.TemporaryDirectory() as tmp_dir:
1751
with Dataset.from_dict({"id": range(10)}) as dset:
1752
with self._to(in_memory, tmp_dir, dset) as dset:
1753
fn_kwargs = {"max_offset": 3}
1755
lambda example, max_offset: example["id"] < max_offset, fn_kwargs=fn_kwargs
1757
assert len(filtered_dset) == 3
1759
lambda id, max_offset: id < max_offset, fn_kwargs=fn_kwargs, input_columns="id"
1761
assert len(filtered_dset) == 3
1763
lambda id, i, max_offset: i < max_offset,
1764
fn_kwargs=fn_kwargs,
1768
assert len(filtered_dset) == 3
1770
def test_filter_multiprocessing(self, in_memory):
1771
with tempfile.TemporaryDirectory() as tmp_dir:
1772
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1773
fingerprint = dset._fingerprint
1774
with dset.filter(picklable_filter_function, num_proc=2) as dset_filter_first_ten:
1775
self.assertEqual(len(dset_filter_first_ten), 10)
1776
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1777
self.assertDictEqual(dset_filter_first_ten.features, Features({"filename": Value("string")}))
1778
self.assertEqual(len(dset_filter_first_ten.cache_files), 0 if in_memory else 2)
1779
self.assertNotEqual(dset_filter_first_ten._fingerprint, fingerprint)
1781
with tempfile.TemporaryDirectory() as tmp_dir:
1782
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1783
fingerprint = dset._fingerprint
1785
picklable_filter_function_with_rank, num_proc=2, with_rank=True
1786
) as dset_filter_first_rank:
1787
self.assertEqual(len(dset_filter_first_rank), min(len(dset) // 2, len(dset)))
1788
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1789
self.assertDictEqual(dset_filter_first_rank.features, Features({"filename": Value("string")}))
1790
self.assertEqual(len(dset_filter_first_rank.cache_files), 0 if in_memory else 2)
1791
self.assertNotEqual(dset_filter_first_rank._fingerprint, fingerprint)
1793
def test_filter_caching(self, in_memory):
1794
with tempfile.TemporaryDirectory() as tmp_dir:
1795
self._caplog.clear()
1796
with self._caplog.at_level(INFO, logger=get_logger().name):
1797
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1798
with dset.filter(lambda x, i: i < 5, with_indices=True) as dset_filter_first_five1:
1799
dset_test1_data_files = list(dset_filter_first_five1.cache_files)
1800
with dset.filter(lambda x, i: i < 5, with_indices=True) as dset_filter_first_five2:
1801
self.assertEqual(dset_test1_data_files, dset_filter_first_five2.cache_files)
1802
self.assertEqual(len(dset_filter_first_five2.cache_files), 0 if in_memory else 2)
1803
self.assertTrue(("Loading cached processed dataset" in self._caplog.text) ^ in_memory)
1805
def test_keep_features_after_transform_specified(self, in_memory):
1806
features = Features(
1807
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1810
def invert_labels(x):
1811
return {"labels": [(1 - label) for label in x["labels"]]}
1813
with tempfile.TemporaryDirectory() as tmp_dir:
1814
with Dataset.from_dict(
1815
{"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1817
with self._to(in_memory, tmp_dir, dset) as dset:
1818
with dset.map(invert_labels, features=features) as inverted_dset:
1819
self.assertEqual(inverted_dset.features.type, features.type)
1820
self.assertDictEqual(inverted_dset.features, features)
1821
assert_arrow_metadata_are_synced_with_dataset_features(inverted_dset)
1823
def test_keep_features_after_transform_unspecified(self, in_memory):
1824
features = Features(
1825
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1828
def invert_labels(x):
1829
return {"labels": [(1 - label) for label in x["labels"]]}
1831
with tempfile.TemporaryDirectory() as tmp_dir:
1832
with Dataset.from_dict(
1833
{"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1835
with self._to(in_memory, tmp_dir, dset) as dset:
1836
with dset.map(invert_labels) as inverted_dset:
1837
self.assertEqual(inverted_dset.features.type, features.type)
1838
self.assertDictEqual(inverted_dset.features, features)
1839
assert_arrow_metadata_are_synced_with_dataset_features(inverted_dset)
1841
def test_keep_features_after_transform_to_file(self, in_memory):
1842
features = Features(
1843
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1846
def invert_labels(x):
1847
return {"labels": [(1 - label) for label in x["labels"]]}
1849
with tempfile.TemporaryDirectory() as tmp_dir:
1850
with Dataset.from_dict(
1851
{"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1853
with self._to(in_memory, tmp_dir, dset) as dset:
1854
tmp_file = os.path.join(tmp_dir, "test.arrow")
1855
dset.map(invert_labels, cache_file_name=tmp_file)
1856
with Dataset.from_file(tmp_file) as inverted_dset:
1857
self.assertEqual(inverted_dset.features.type, features.type)
1858
self.assertDictEqual(inverted_dset.features, features)
1860
def test_keep_features_after_transform_to_memory(self, in_memory):
1861
features = Features(
1862
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1865
def invert_labels(x):
1866
return {"labels": [(1 - label) for label in x["labels"]]}
1868
with tempfile.TemporaryDirectory() as tmp_dir:
1869
with Dataset.from_dict(
1870
{"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1872
with self._to(in_memory, tmp_dir, dset) as dset:
1873
with dset.map(invert_labels, keep_in_memory=True) as inverted_dset:
1874
self.assertEqual(inverted_dset.features.type, features.type)
1875
self.assertDictEqual(inverted_dset.features, features)
1877
def test_keep_features_after_loading_from_cache(self, in_memory):
1878
features = Features(
1879
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1882
def invert_labels(x):
1883
return {"labels": [(1 - label) for label in x["labels"]]}
1885
with tempfile.TemporaryDirectory() as tmp_dir:
1886
with Dataset.from_dict(
1887
{"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1889
with self._to(in_memory, tmp_dir, dset) as dset:
1890
tmp_file1 = os.path.join(tmp_dir, "test1.arrow")
1891
tmp_file2 = os.path.join(tmp_dir, "test2.arrow")
1893
inverted_dset = dset.map(invert_labels, cache_file_name=tmp_file1)
1894
inverted_dset = dset.map(invert_labels, cache_file_name=tmp_file2)
1895
self.assertGreater(len(inverted_dset.cache_files), 0)
1896
self.assertEqual(inverted_dset.features.type, features.type)
1897
self.assertDictEqual(inverted_dset.features, features)
1900
def test_keep_features_with_new_features(self, in_memory):
1901
features = Features(
1902
{"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1905
def invert_labels(x):
1906
return {"labels": [(1 - label) for label in x["labels"]], "labels2": x["labels"]}
1908
expected_features = Features(
1910
"tokens": Sequence(Value("string")),
1911
"labels": Sequence(ClassLabel(names=["negative", "positive"])),
1912
"labels2": Sequence(Value("int64")),
1916
with tempfile.TemporaryDirectory() as tmp_dir:
1917
with Dataset.from_dict(
1918
{"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1920
with self._to(in_memory, tmp_dir, dset) as dset:
1921
with dset.map(invert_labels) as inverted_dset:
1922
self.assertEqual(inverted_dset.features.type, expected_features.type)
1923
self.assertDictEqual(inverted_dset.features, expected_features)
1924
assert_arrow_metadata_are_synced_with_dataset_features(inverted_dset)
1926
def test_select(self, in_memory):
1927
with tempfile.TemporaryDirectory() as tmp_dir:
1928
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1930
indices = list(range(0, len(dset), 2))
1931
tmp_file = os.path.join(tmp_dir, "test.arrow")
1932
fingerprint = dset._fingerprint
1933
with dset.select(indices, indices_cache_file_name=tmp_file) as dset_select_even:
1934
self.assertIsNotNone(dset_select_even._indices)
1935
self.assertTrue(os.path.exists(tmp_file))
1936
self.assertEqual(len(dset_select_even), 15)
1937
for row in dset_select_even:
1938
self.assertEqual(int(row["filename"][-1]) % 2, 0)
1939
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1940
self.assertDictEqual(dset_select_even.features, Features({"filename": Value("string")}))
1941
self.assertNotEqual(dset_select_even._fingerprint, fingerprint)
1943
with tempfile.TemporaryDirectory() as tmp_dir:
1944
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1945
indices = list(range(0, len(dset)))
1946
with dset.select(indices) as dset_select_all:
1949
self.assertIsNone(dset_select_all._indices)
1950
self.assertEqual(len(dset_select_all), len(dset))
1951
self.assertListEqual(list(dset_select_all), list(dset))
1952
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1953
self.assertDictEqual(dset_select_all.features, Features({"filename": Value("string")}))
1954
self.assertNotEqual(dset_select_all._fingerprint, fingerprint)
1955
indices = range(0, len(dset))
1956
with dset.select(indices) as dset_select_all:
1958
self.assertIsNone(dset_select_all._indices)
1959
self.assertEqual(len(dset_select_all), len(dset))
1960
self.assertListEqual(list(dset_select_all), list(dset))
1961
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1962
self.assertDictEqual(dset_select_all.features, Features({"filename": Value("string")}))
1963
self.assertNotEqual(dset_select_all._fingerprint, fingerprint)
1965
with tempfile.TemporaryDirectory() as tmp_dir:
1966
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1967
bad_indices = list(range(5))
1968
bad_indices[-1] = len(dset) + 10
1969
tmp_file = os.path.join(tmp_dir, "test.arrow")
1973
indices=bad_indices,
1974
indices_cache_file_name=tmp_file,
1975
writer_batch_size=2,
1977
self.assertFalse(os.path.exists(tmp_file))
1979
with tempfile.TemporaryDirectory() as tmp_dir:
1980
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1981
indices = iter(range(len(dset)))
1982
with dset.select(indices) as dset_select_all:
1984
self.assertIsNone(dset_select_all._indices)
1985
self.assertEqual(len(dset_select_all), len(dset))
1986
indices = reversed(range(len(dset)))
1987
tmp_file = os.path.join(tmp_dir, "test.arrow")
1988
with dset.select(indices, indices_cache_file_name=tmp_file) as dset_select_all:
1990
self.assertIsNotNone(dset_select_all._indices)
1991
self.assertEqual(len(dset_select_all), len(dset))
1993
with tempfile.TemporaryDirectory() as tmp_dir:
1994
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1995
bad_indices = list(range(5))
1996
bad_indices[3] = "foo"
1997
tmp_file = os.path.join(tmp_dir, "test.arrow")
2001
indices=bad_indices,
2002
indices_cache_file_name=tmp_file,
2003
writer_batch_size=2,
2005
self.assertFalse(os.path.exists(tmp_file))
2006
dset.set_format("numpy")
2009
indices_cache_file_name=tmp_file,
2010
writer_batch_size=2,
2011
) as dset_select_five:
2012
self.assertIsNone(dset_select_five._indices)
2013
self.assertEqual(len(dset_select_five), 5)
2014
self.assertEqual(dset_select_five.format["type"], "numpy")
2015
for i, row in enumerate(dset_select_five):
2016
self.assertEqual(int(row["filename"][-1]), i)
2017
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2018
self.assertDictEqual(dset_select_five.features, Features({"filename": Value("string")}))
2020
def test_select_then_map(self, in_memory):
2021
with tempfile.TemporaryDirectory() as tmp_dir:
2022
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2023
with dset.select([0]) as d1:
2024
with d1.map(lambda x: {"id": int(x["filename"].split("_")[-1])}) as d1:
2025
self.assertEqual(d1[0]["id"], 0)
2026
with dset.select([1]) as d2:
2027
with d2.map(lambda x: {"id": int(x["filename"].split("_")[-1])}) as d2:
2028
self.assertEqual(d2[0]["id"], 1)
2030
with tempfile.TemporaryDirectory() as tmp_dir:
2031
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2032
with dset.select([0], indices_cache_file_name=os.path.join(tmp_dir, "i1.arrow")) as d1:
2033
with d1.map(lambda x: {"id": int(x["filename"].split("_")[-1])}) as d1:
2034
self.assertEqual(d1[0]["id"], 0)
2035
with dset.select([1], indices_cache_file_name=os.path.join(tmp_dir, "i2.arrow")) as d2:
2036
with d2.map(lambda x: {"id": int(x["filename"].split("_")[-1])}) as d2:
2037
self.assertEqual(d2[0]["id"], 1)
2039
def test_pickle_after_many_transforms_on_disk(self, in_memory):
2040
with tempfile.TemporaryDirectory() as tmp_dir:
2041
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2042
self.assertEqual(len(dset.cache_files), 0 if in_memory else 1)
2043
with dset.rename_column("filename", "file") as dset:
2044
self.assertListEqual(dset.column_names, ["file"])
2045
with dset.select(range(5)) as dset:
2046
self.assertEqual(len(dset), 5)
2047
with dset.map(lambda x: {"id": int(x["file"][-1])}) as dset:
2048
self.assertListEqual(sorted(dset.column_names), ["file", "id"])
2049
with dset.rename_column("id", "number") as dset:
2050
self.assertListEqual(sorted(dset.column_names), ["file", "number"])
2051
with dset.select([1, 0]) as dset:
2052
self.assertEqual(dset[0]["file"], "my_name-train_1")
2053
self.assertEqual(dset[0]["number"], 1)
2055
self.assertEqual(dset._indices["indices"].to_pylist(), [1, 0])
2058
("rename_columns", (["file", "number"],), {}),
2062
dset._data.table = Unpicklable()
2064
pickled = pickle.dumps(dset)
2065
with pickle.loads(pickled) as loaded:
2066
self.assertEqual(loaded[0]["file"], "my_name-train_1")
2067
self.assertEqual(loaded[0]["number"], 1)
2069
def test_shuffle(self, in_memory):
2070
with tempfile.TemporaryDirectory() as tmp_dir:
2071
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2072
tmp_file = os.path.join(tmp_dir, "test.arrow")
2073
fingerprint = dset._fingerprint
2075
with dset.shuffle(seed=1234, keep_in_memory=True) as dset_shuffled:
2076
self.assertEqual(len(dset_shuffled), 30)
2077
self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28")
2078
self.assertEqual(dset_shuffled[2]["filename"], "my_name-train_10")
2079
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2080
self.assertDictEqual(dset_shuffled.features, Features({"filename": Value("string")}))
2081
self.assertNotEqual(dset_shuffled._fingerprint, fingerprint)
2083
with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled:
2084
self.assertEqual(len(dset_shuffled), 30)
2085
self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28")
2086
self.assertEqual(dset_shuffled[2]["filename"], "my_name-train_10")
2087
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2088
self.assertDictEqual(dset_shuffled.features, Features({"filename": Value("string")}))
2089
self.assertNotEqual(dset_shuffled._fingerprint, fingerprint)
2092
tmp_file = os.path.join(tmp_dir, "test_2.arrow")
2093
with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled_2:
2094
self.assertListEqual(dset_shuffled["filename"], dset_shuffled_2["filename"])
2097
with temp_seed(42), dset.shuffle() as d1:
2098
with temp_seed(42), dset.shuffle() as d2, dset.shuffle() as d3:
2099
self.assertListEqual(d1["filename"], d2["filename"])
2100
self.assertEqual(d1._fingerprint, d2._fingerprint)
2101
self.assertNotEqual(d3["filename"], d2["filename"])
2102
self.assertNotEqual(d3._fingerprint, d2._fingerprint)
2104
def test_sort(self, in_memory):
2105
with tempfile.TemporaryDirectory() as tmp_dir:
2107
with self._create_dummy_dataset(in_memory=in_memory, tmp_dir=tmp_dir) as dset:
2109
tmp_file = os.path.join(tmp_dir, "test.arrow")
2110
with dset.select(range(10), indices_cache_file_name=tmp_file) as dset:
2111
tmp_file = os.path.join(tmp_dir, "test_2.arrow")
2112
with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset:
2113
self.assertEqual(len(dset), 10)
2114
self.assertEqual(dset[0]["filename"], "my_name-train_8")
2115
self.assertEqual(dset[1]["filename"], "my_name-train_9")
2117
tmp_file = os.path.join(tmp_dir, "test_3.arrow")
2118
fingerprint = dset._fingerprint
2119
with dset.sort("filename", indices_cache_file_name=tmp_file) as dset_sorted:
2120
for i, row in enumerate(dset_sorted):
2121
self.assertEqual(int(row["filename"][-1]), i)
2122
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2123
self.assertDictEqual(dset_sorted.features, Features({"filename": Value("string")}))
2124
self.assertNotEqual(dset_sorted._fingerprint, fingerprint)
2126
tmp_file = os.path.join(tmp_dir, "test_4.arrow")
2127
fingerprint = dset._fingerprint
2128
with dset.sort("filename", indices_cache_file_name=tmp_file, reverse=True) as dset_sorted:
2129
for i, row in enumerate(dset_sorted):
2130
self.assertEqual(int(row["filename"][-1]), len(dset_sorted) - 1 - i)
2131
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2132
self.assertDictEqual(dset_sorted.features, Features({"filename": Value("string")}))
2133
self.assertNotEqual(dset_sorted._fingerprint, fingerprint)
2135
dset.set_format("numpy")
2136
with dset.sort("filename") as dset_sorted_formatted:
2137
self.assertEqual(dset_sorted_formatted.format["type"], "numpy")
2139
with self._create_dummy_dataset(in_memory=in_memory, tmp_dir=tmp_dir, multiple_columns=True) as dset:
2140
tmp_file = os.path.join(tmp_dir, "test_5.arrow")
2141
fingerprint = dset._fingerprint
2143
with pytest.raises(ValueError):
2144
dset.sort(["col_1", "col_2", "col_3"], reverse=[False])
2145
with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset:
2147
with dset.sort(["col_1", "col_2", "col_3"], reverse=[False, True, False]) as dset_sorted:
2148
for i, row in enumerate(dset_sorted):
2149
self.assertEqual(row["col_1"], i)
2150
self.assertDictEqual(
2154
"col_1": Value("int64"),
2155
"col_2": Value("string"),
2156
"col_3": Value("bool"),
2160
self.assertDictEqual(
2161
dset_sorted.features,
2164
"col_1": Value("int64"),
2165
"col_2": Value("string"),
2166
"col_3": Value("bool"),
2170
self.assertNotEqual(dset_sorted._fingerprint, fingerprint)
2172
with dset.sort(["col_1", "col_2", "col_3"], reverse=[True, False, True]) as dset_sorted:
2173
for i, row in enumerate(dset_sorted):
2174
self.assertEqual(row["col_1"], len(dset_sorted) - 1 - i)
2175
self.assertDictEqual(
2179
"col_1": Value("int64"),
2180
"col_2": Value("string"),
2181
"col_3": Value("bool"),
2185
self.assertDictEqual(
2186
dset_sorted.features,
2189
"col_1": Value("int64"),
2190
"col_2": Value("string"),
2191
"col_3": Value("bool"),
2195
self.assertNotEqual(dset_sorted._fingerprint, fingerprint)
2197
dset.set_format("numpy")
2199
["col_1", "col_2", "col_3"], reverse=[False, True, False]
2200
) as dset_sorted_formatted:
2201
self.assertEqual(dset_sorted_formatted.format["type"], "numpy")
2204
def test_export(self, in_memory):
2205
with tempfile.TemporaryDirectory() as tmp_dir:
2206
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2208
tfrecord_path = os.path.join(tmp_dir, "test.tfrecord")
2212
"question": f"Question {i}",
2213
"answers": {"text": [f"Answer {i}-0", f"Answer {i}-1"], "answer_start": [0, 1]},
2216
remove_columns=["filename"],
2217
) as formatted_dset:
2218
with formatted_dset.flatten() as formatted_dset:
2219
formatted_dset.set_format("numpy")
2220
formatted_dset.export(filename=tfrecord_path, format="tfrecord")
2223
import tensorflow as tf
2225
tf_dset = tf.data.TFRecordDataset([tfrecord_path])
2226
feature_description = {
2227
"id": tf.io.FixedLenFeature([], tf.int64),
2228
"question": tf.io.FixedLenFeature([], tf.string),
2229
"answers.text": tf.io.VarLenFeature(tf.string),
2230
"answers.answer_start": tf.io.VarLenFeature(tf.int64),
2232
tf_parsed_dset = tf_dset.map(
2233
lambda example_proto: tf.io.parse_single_example(example_proto, feature_description)
2236
for i, ex in enumerate(tf_parsed_dset):
2237
self.assertEqual(ex.keys(), formatted_dset[i].keys())
2239
self.assertEqual(i, len(formatted_dset) - 1)
2241
def test_to_csv(self, in_memory):
2242
with tempfile.TemporaryDirectory() as tmp_dir:
2244
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2245
file_path = os.path.join(tmp_dir, "test_path.csv")
2246
bytes_written = dset.to_csv(path_or_buf=file_path)
2248
self.assertTrue(os.path.isfile(file_path))
2249
self.assertEqual(bytes_written, os.path.getsize(file_path))
2250
csv_dset = pd.read_csv(file_path)
2252
self.assertEqual(csv_dset.shape, dset.shape)
2253
self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
2256
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2257
file_path = os.path.join(tmp_dir, "test_buffer.csv")
2258
with open(file_path, "wb+") as buffer:
2259
bytes_written = dset.to_csv(path_or_buf=buffer)
2261
self.assertTrue(os.path.isfile(file_path))
2262
self.assertEqual(bytes_written, os.path.getsize(file_path))
2263
csv_dset = pd.read_csv(file_path)
2265
self.assertEqual(csv_dset.shape, dset.shape)
2266
self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
2269
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2270
dset = dset.select(range(0, len(dset), 2)).shuffle()
2271
file_path = os.path.join(tmp_dir, "test_path.csv")
2272
bytes_written = dset.to_csv(path_or_buf=file_path)
2274
self.assertTrue(os.path.isfile(file_path))
2275
self.assertEqual(bytes_written, os.path.getsize(file_path))
2276
csv_dset = pd.read_csv(file_path)
2278
self.assertEqual(csv_dset.shape, dset.shape)
2279
self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
2282
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
2283
file_path = os.path.join(tmp_dir, "test_path.csv")
2284
bytes_written = dset.to_csv(path_or_buf=file_path)
2286
self.assertTrue(os.path.isfile(file_path))
2287
self.assertEqual(bytes_written, os.path.getsize(file_path))
2288
csv_dset = pd.read_csv(file_path)
2290
self.assertEqual(csv_dset.shape, dset.shape)
2291
self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
2293
def test_to_dict(self, in_memory):
2294
with tempfile.TemporaryDirectory() as tmp_dir:
2295
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2297
dset_to_dict = dset.to_dict()
2298
self.assertIsInstance(dset_to_dict, dict)
2299
self.assertListEqual(sorted(dset_to_dict.keys()), sorted(dset.column_names))
2301
for col_name in dset.column_names:
2302
self.assertLessEqual(len(dset_to_dict[col_name]), len(dset))
2305
with dset.select([1, 0, 3]) as dset:
2306
dset_to_dict = dset.to_dict()
2307
self.assertIsInstance(dset_to_dict, dict)
2308
self.assertEqual(len(dset_to_dict), 3)
2309
self.assertListEqual(sorted(dset_to_dict.keys()), sorted(dset.column_names))
2311
for col_name in dset.column_names:
2312
self.assertIsInstance(dset_to_dict[col_name], list)
2313
self.assertEqual(len(dset_to_dict[col_name]), len(dset))
2315
def test_to_list(self, in_memory):
2316
with tempfile.TemporaryDirectory() as tmp_dir:
2317
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2318
dset_to_list = dset.to_list()
2319
self.assertIsInstance(dset_to_list, list)
2320
for row in dset_to_list:
2321
self.assertIsInstance(row, dict)
2322
self.assertListEqual(sorted(row.keys()), sorted(dset.column_names))
2325
with dset.select([1, 0, 3]) as dset:
2326
dset_to_list = dset.to_list()
2327
self.assertIsInstance(dset_to_list, list)
2328
self.assertEqual(len(dset_to_list), 3)
2329
for row in dset_to_list:
2330
self.assertIsInstance(row, dict)
2331
self.assertListEqual(sorted(row.keys()), sorted(dset.column_names))
2333
def test_to_pandas(self, in_memory):
2334
with tempfile.TemporaryDirectory() as tmp_dir:
2336
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2337
batch_size = dset.num_rows - 1
2338
to_pandas_generator = dset.to_pandas(batched=True, batch_size=batch_size)
2340
for batch in to_pandas_generator:
2341
self.assertIsInstance(batch, pd.DataFrame)
2342
self.assertListEqual(sorted(batch.columns), sorted(dset.column_names))
2343
for col_name in dset.column_names:
2344
self.assertLessEqual(len(batch[col_name]), batch_size)
2347
dset_to_pandas = dset.to_pandas()
2348
self.assertIsInstance(dset_to_pandas, pd.DataFrame)
2349
self.assertListEqual(sorted(dset_to_pandas.columns), sorted(dset.column_names))
2350
for col_name in dset.column_names:
2351
self.assertEqual(len(dset_to_pandas[col_name]), len(dset))
2354
with dset.select([1, 0, 3]) as dset:
2355
dset_to_pandas = dset.to_pandas()
2356
self.assertIsInstance(dset_to_pandas, pd.DataFrame)
2357
self.assertEqual(len(dset_to_pandas), 3)
2358
self.assertListEqual(sorted(dset_to_pandas.columns), sorted(dset.column_names))
2360
for col_name in dset.column_names:
2361
self.assertEqual(len(dset_to_pandas[col_name]), dset.num_rows)
2363
def test_to_parquet(self, in_memory):
2364
with tempfile.TemporaryDirectory() as tmp_dir:
2366
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2367
file_path = os.path.join(tmp_dir, "test_path.parquet")
2368
dset.to_parquet(path_or_buf=file_path)
2370
self.assertTrue(os.path.isfile(file_path))
2372
parquet_dset = pd.read_parquet(file_path)
2374
self.assertEqual(parquet_dset.shape, dset.shape)
2375
self.assertListEqual(list(parquet_dset.columns), list(dset.column_names))
2378
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2379
file_path = os.path.join(tmp_dir, "test_buffer.parquet")
2380
with open(file_path, "wb+") as buffer:
2381
dset.to_parquet(path_or_buf=buffer)
2383
self.assertTrue(os.path.isfile(file_path))
2385
parquet_dset = pd.read_parquet(file_path)
2387
self.assertEqual(parquet_dset.shape, dset.shape)
2388
self.assertListEqual(list(parquet_dset.columns), list(dset.column_names))
2391
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2392
dset = dset.select(range(0, len(dset), 2)).shuffle()
2393
file_path = os.path.join(tmp_dir, "test_path.parquet")
2394
dset.to_parquet(path_or_buf=file_path)
2396
self.assertTrue(os.path.isfile(file_path))
2398
parquet_dset = pd.read_parquet(file_path)
2400
self.assertEqual(parquet_dset.shape, dset.shape)
2401
self.assertListEqual(list(parquet_dset.columns), list(dset.column_names))
2404
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
2405
file_path = os.path.join(tmp_dir, "test_path.parquet")
2406
dset.to_parquet(path_or_buf=file_path)
2408
self.assertTrue(os.path.isfile(file_path))
2410
parquet_dset = pd.read_parquet(file_path)
2412
self.assertEqual(parquet_dset.shape, dset.shape)
2413
self.assertListEqual(list(parquet_dset.columns), list(dset.column_names))
2416
def test_to_sql(self, in_memory):
2417
with tempfile.TemporaryDirectory() as tmp_dir:
2419
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2420
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2421
_ = dset.to_sql("data", "sqlite:///" + file_path)
2423
self.assertTrue(os.path.isfile(file_path))
2424
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2426
self.assertEqual(sql_dset.shape, dset.shape)
2427
self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2430
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2433
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2434
with contextlib.closing(sqlite3.connect(file_path)) as con:
2435
_ = dset.to_sql("data", con, if_exists="replace")
2437
self.assertTrue(os.path.isfile(file_path))
2438
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2440
self.assertEqual(sql_dset.shape, dset.shape)
2441
self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2444
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2445
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2446
_ = dset.to_sql("data", "sqlite:///" + file_path, batch_size=1, if_exists="replace")
2448
self.assertTrue(os.path.isfile(file_path))
2449
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2451
self.assertEqual(sql_dset.shape, dset.shape)
2452
self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2455
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2456
dset = dset.select(range(0, len(dset), 2)).shuffle()
2457
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2458
_ = dset.to_sql("data", "sqlite:///" + file_path, if_exists="replace")
2460
self.assertTrue(os.path.isfile(file_path))
2461
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2463
self.assertEqual(sql_dset.shape, dset.shape)
2464
self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2467
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
2468
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2469
_ = dset.to_sql("data", "sqlite:///" + file_path, if_exists="replace")
2471
self.assertTrue(os.path.isfile(file_path))
2472
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2474
self.assertEqual(sql_dset.shape, dset.shape)
2475
self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2477
def test_train_test_split(self, in_memory):
2478
with tempfile.TemporaryDirectory() as tmp_dir:
2479
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2480
fingerprint = dset._fingerprint
2481
dset_dict = dset.train_test_split(test_size=10, shuffle=False)
2482
self.assertListEqual(list(dset_dict.keys()), ["train", "test"])
2483
dset_train = dset_dict["train"]
2484
dset_test = dset_dict["test"]
2486
self.assertEqual(len(dset_train), 20)
2487
self.assertEqual(len(dset_test), 10)
2488
self.assertEqual(dset_train[0]["filename"], "my_name-train_0")
2489
self.assertEqual(dset_train[-1]["filename"], "my_name-train_19")
2490
self.assertEqual(dset_test[0]["filename"], "my_name-train_20")
2491
self.assertEqual(dset_test[-1]["filename"], "my_name-train_29")
2492
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2493
self.assertDictEqual(dset_train.features, Features({"filename": Value("string")}))
2494
self.assertDictEqual(dset_test.features, Features({"filename": Value("string")}))
2495
self.assertNotEqual(dset_train._fingerprint, fingerprint)
2496
self.assertNotEqual(dset_test._fingerprint, fingerprint)
2497
self.assertNotEqual(dset_train._fingerprint, dset_test._fingerprint)
2499
dset_dict = dset.train_test_split(test_size=0.5, shuffle=False)
2500
self.assertListEqual(list(dset_dict.keys()), ["train", "test"])
2501
dset_train = dset_dict["train"]
2502
dset_test = dset_dict["test"]
2504
self.assertEqual(len(dset_train), 15)
2505
self.assertEqual(len(dset_test), 15)
2506
self.assertEqual(dset_train[0]["filename"], "my_name-train_0")
2507
self.assertEqual(dset_train[-1]["filename"], "my_name-train_14")
2508
self.assertEqual(dset_test[0]["filename"], "my_name-train_15")
2509
self.assertEqual(dset_test[-1]["filename"], "my_name-train_29")
2510
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2511
self.assertDictEqual(dset_train.features, Features({"filename": Value("string")}))
2512
self.assertDictEqual(dset_test.features, Features({"filename": Value("string")}))
2514
dset_dict = dset.train_test_split(train_size=10, shuffle=False)
2515
self.assertListEqual(list(dset_dict.keys()), ["train", "test"])
2516
dset_train = dset_dict["train"]
2517
dset_test = dset_dict["test"]
2519
self.assertEqual(len(dset_train), 10)
2520
self.assertEqual(len(dset_test), 20)
2521
self.assertEqual(dset_train[0]["filename"], "my_name-train_0")
2522
self.assertEqual(dset_train[-1]["filename"], "my_name-train_9")
2523
self.assertEqual(dset_test[0]["filename"], "my_name-train_10")
2524
self.assertEqual(dset_test[-1]["filename"], "my_name-train_29")
2525
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2526
self.assertDictEqual(dset_train.features, Features({"filename": Value("string")}))
2527
self.assertDictEqual(dset_test.features, Features({"filename": Value("string")}))
2529
dset.set_format("numpy")
2530
dset_dict = dset.train_test_split(train_size=10, seed=42)
2531
self.assertListEqual(list(dset_dict.keys()), ["train", "test"])
2532
dset_train = dset_dict["train"]
2533
dset_test = dset_dict["test"]
2535
self.assertEqual(len(dset_train), 10)
2536
self.assertEqual(len(dset_test), 20)
2537
self.assertEqual(dset_train.format["type"], "numpy")
2538
self.assertEqual(dset_test.format["type"], "numpy")
2539
self.assertNotEqual(dset_train[0]["filename"].item(), "my_name-train_0")
2540
self.assertNotEqual(dset_train[-1]["filename"].item(), "my_name-train_9")
2541
self.assertNotEqual(dset_test[0]["filename"].item(), "my_name-train_10")
2542
self.assertNotEqual(dset_test[-1]["filename"].item(), "my_name-train_29")
2543
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2544
self.assertDictEqual(dset_train.features, Features({"filename": Value("string")}))
2545
self.assertDictEqual(dset_test.features, Features({"filename": Value("string")}))
2546
del dset_test, dset_train, dset_dict
2548
def test_shard(self, in_memory):
2549
with tempfile.TemporaryDirectory() as tmp_dir, self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2550
tmp_file = os.path.join(tmp_dir, "test.arrow")
2551
with dset.select(range(10), indices_cache_file_name=tmp_file) as dset:
2552
self.assertEqual(len(dset), 10)
2554
tmp_file_1 = os.path.join(tmp_dir, "test_1.arrow")
2555
fingerprint = dset._fingerprint
2556
with dset.shard(num_shards=8, index=1, indices_cache_file_name=tmp_file_1) as dset_sharded:
2557
self.assertEqual(2, len(dset_sharded))
2558
self.assertEqual(["my_name-train_1", "my_name-train_9"], dset_sharded["filename"])
2559
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2560
self.assertDictEqual(dset_sharded.features, Features({"filename": Value("string")}))
2561
self.assertNotEqual(dset_sharded._fingerprint, fingerprint)
2563
tmp_file_2 = os.path.join(tmp_dir, "test_2.arrow")
2565
num_shards=3, index=0, contiguous=True, indices_cache_file_name=tmp_file_2
2566
) as dset_sharded_contiguous:
2567
self.assertEqual([f"my_name-train_{i}" for i in (0, 1, 2, 3)], dset_sharded_contiguous["filename"])
2568
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2569
self.assertDictEqual(dset_sharded_contiguous.features, Features({"filename": Value("string")}))
2574
len(dset.shard(3, index=i, contiguous=True, indices_cache_file_name=tmp_file_2 + str(i)))
2579
dset.set_format("numpy")
2580
with dset.shard(num_shards=3, index=0) as dset_sharded_formatted:
2581
self.assertEqual(dset_sharded_formatted.format["type"], "numpy")
2583
def test_flatten_indices(self, in_memory):
2584
with tempfile.TemporaryDirectory() as tmp_dir:
2585
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2586
self.assertIsNone(dset._indices)
2588
tmp_file = os.path.join(tmp_dir, "test.arrow")
2589
with dset.select(range(0, 10, 2), indices_cache_file_name=tmp_file) as dset:
2590
self.assertEqual(len(dset), 5)
2592
self.assertIsNotNone(dset._indices)
2594
tmp_file_2 = os.path.join(tmp_dir, "test_2.arrow")
2595
fingerprint = dset._fingerprint
2596
dset.set_format("numpy")
2597
with dset.flatten_indices(cache_file_name=tmp_file_2) as dset:
2598
self.assertEqual(len(dset), 5)
2599
self.assertEqual(len(dset.data), len(dset))
2600
self.assertIsNone(dset._indices)
2601
self.assertNotEqual(dset._fingerprint, fingerprint)
2602
self.assertEqual(dset.format["type"], "numpy")
2604
dset.unique(dset.column_names[0])
2605
assert_arrow_metadata_are_synced_with_dataset_features(dset)
2608
with tempfile.TemporaryDirectory() as tmp_dir:
2609
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2610
self.assertIsNone(dset._indices, None)
2612
tmp_file = os.path.join(tmp_dir, "test.arrow")
2613
with dset.filter(lambda _: False, cache_file_name=tmp_file) as dset:
2614
self.assertEqual(len(dset), 0)
2616
self.assertIsNotNone(dset._indices, None)
2618
tmp_file_2 = os.path.join(tmp_dir, "test_2.arrow")
2619
fingerprint = dset._fingerprint
2620
dset.set_format("numpy")
2621
with dset.flatten_indices(cache_file_name=tmp_file_2) as dset:
2622
self.assertEqual(len(dset), 0)
2623
self.assertEqual(len(dset.data), len(dset))
2624
self.assertIsNone(dset._indices, None)
2625
self.assertNotEqual(dset._fingerprint, fingerprint)
2626
self.assertEqual(dset.format["type"], "numpy")
2628
dset.unique(dset.column_names[0])
2629
assert_arrow_metadata_are_synced_with_dataset_features(dset)
2633
def test_format_vectors(self, in_memory):
2635
import tensorflow as tf
2638
with tempfile.TemporaryDirectory() as tmp_dir, self._create_dummy_dataset(
2640
) as dset, dset.map(lambda ex, i: {"vec": np.ones(3) * i}, with_indices=True) as dset:
2641
columns = dset.column_names
2643
self.assertIsNotNone(dset[0])
2644
self.assertIsNotNone(dset[:2])
2646
self.assertIsInstance(dset[0][col], (str, list))
2647
self.assertIsInstance(dset[:2][col], list)
2648
self.assertDictEqual(
2649
dset.features, Features({"filename": Value("string"), "vec": Sequence(Value("float64"))})
2652
dset.set_format("tensorflow")
2653
self.assertIsNotNone(dset[0])
2654
self.assertIsNotNone(dset[:2])
2656
self.assertIsInstance(dset[0][col], (tf.Tensor, tf.RaggedTensor))
2657
self.assertIsInstance(dset[:2][col], (tf.Tensor, tf.RaggedTensor))
2658
self.assertIsInstance(dset[col], (tf.Tensor, tf.RaggedTensor))
2659
self.assertTupleEqual(tuple(dset[:2]["vec"].shape), (2, 3))
2660
self.assertTupleEqual(tuple(dset["vec"][:2].shape), (2, 3))
2662
dset.set_format("numpy")
2663
self.assertIsNotNone(dset[0])
2664
self.assertIsNotNone(dset[:2])
2665
self.assertIsInstance(dset[0]["filename"], np.str_)
2666
self.assertIsInstance(dset[:2]["filename"], np.ndarray)
2667
self.assertIsInstance(dset["filename"], np.ndarray)
2668
self.assertIsInstance(dset[0]["vec"], np.ndarray)
2669
self.assertIsInstance(dset[:2]["vec"], np.ndarray)
2670
self.assertIsInstance(dset["vec"], np.ndarray)
2671
self.assertTupleEqual(dset[:2]["vec"].shape, (2, 3))
2672
self.assertTupleEqual(dset["vec"][:2].shape, (2, 3))
2674
dset.set_format("torch", columns=["vec"])
2675
self.assertIsNotNone(dset[0])
2676
self.assertIsNotNone(dset[:2])
2678
self.assertIsInstance(dset[0]["vec"], torch.Tensor)
2679
self.assertIsInstance(dset[:2]["vec"], torch.Tensor)
2680
self.assertIsInstance(dset["vec"][:2], torch.Tensor)
2681
self.assertTupleEqual(dset[:2]["vec"].shape, (2, 3))
2682
self.assertTupleEqual(dset["vec"][:2].shape, (2, 3))
2686
def test_format_ragged_vectors(self, in_memory):
2688
import tensorflow as tf
2691
with tempfile.TemporaryDirectory() as tmp_dir, self._create_dummy_dataset(
2693
) as dset, dset.map(lambda ex, i: {"vec": np.ones(3 + i) * i}, with_indices=True) as dset:
2694
columns = dset.column_names
2696
self.assertIsNotNone(dset[0])
2697
self.assertIsNotNone(dset[:2])
2699
self.assertIsInstance(dset[0][col], (str, list))
2700
self.assertIsInstance(dset[:2][col], list)
2701
self.assertDictEqual(
2702
dset.features, Features({"filename": Value("string"), "vec": Sequence(Value("float64"))})
2705
dset.set_format("tensorflow")
2706
self.assertIsNotNone(dset[0])
2707
self.assertIsNotNone(dset[:2])
2709
self.assertIsInstance(dset[0][col], tf.Tensor)
2710
self.assertIsInstance(dset[:2][col], tf.RaggedTensor if col == "vec" else tf.Tensor)
2711
self.assertIsInstance(dset[col], tf.RaggedTensor if col == "vec" else tf.Tensor)
2713
self.assertListEqual(dset[:2]["vec"].shape.as_list(), [2, None])
2714
self.assertListEqual(dset["vec"][:2].shape.as_list(), [2, None])
2716
dset.set_format("numpy")
2717
self.assertIsNotNone(dset[0])
2718
self.assertIsNotNone(dset[:2])
2719
self.assertIsInstance(dset[0]["filename"], np.str_)
2720
self.assertIsInstance(dset[:2]["filename"], np.ndarray)
2721
self.assertIsInstance(dset["filename"], np.ndarray)
2722
self.assertIsInstance(dset[0]["vec"], np.ndarray)
2723
self.assertIsInstance(dset[:2]["vec"], np.ndarray)
2724
self.assertIsInstance(dset["vec"], np.ndarray)
2726
self.assertTupleEqual(dset[:2]["vec"].shape, (2,))
2727
self.assertTupleEqual(dset["vec"][:2].shape, (2,))
2729
dset.set_format("torch")
2730
self.assertIsNotNone(dset[0])
2731
self.assertIsNotNone(dset[:2])
2732
self.assertIsInstance(dset[0]["filename"], str)
2733
self.assertIsInstance(dset[:2]["filename"], list)
2734
self.assertIsInstance(dset["filename"], list)
2735
self.assertIsInstance(dset[0]["vec"], torch.Tensor)
2736
self.assertIsInstance(dset[:2]["vec"][0], torch.Tensor)
2737
self.assertIsInstance(dset["vec"][0], torch.Tensor)
2739
self.assertIsInstance(dset[:2]["vec"], list)
2740
self.assertIsInstance(dset[:2]["vec"][0], torch.Tensor)
2741
self.assertIsInstance(dset["vec"][:2], list)
2742
self.assertIsInstance(dset["vec"][0], torch.Tensor)
2746
def test_format_nested(self, in_memory):
2748
import tensorflow as tf
2751
with tempfile.TemporaryDirectory() as tmp_dir, self._create_dummy_dataset(
2753
) as dset, dset.map(lambda ex: {"nested": [{"foo": np.ones(3)}] * len(ex["filename"])}, batched=True) as dset:
2754
self.assertDictEqual(
2755
dset.features, Features({"filename": Value("string"), "nested": {"foo": Sequence(Value("float64"))}})
2758
dset.set_format("tensorflow")
2759
self.assertIsNotNone(dset[0])
2760
self.assertIsInstance(dset[0]["nested"]["foo"], (tf.Tensor, tf.RaggedTensor))
2761
self.assertIsNotNone(dset[:2])
2762
self.assertIsInstance(dset[:2]["nested"][0]["foo"], (tf.Tensor, tf.RaggedTensor))
2763
self.assertIsInstance(dset["nested"][0]["foo"], (tf.Tensor, tf.RaggedTensor))
2765
dset.set_format("numpy")
2766
self.assertIsNotNone(dset[0])
2767
self.assertIsInstance(dset[0]["nested"]["foo"], np.ndarray)
2768
self.assertIsNotNone(dset[:2])
2769
self.assertIsInstance(dset[:2]["nested"][0]["foo"], np.ndarray)
2770
self.assertIsInstance(dset["nested"][0]["foo"], np.ndarray)
2772
dset.set_format("torch", columns="nested")
2773
self.assertIsNotNone(dset[0])
2774
self.assertIsInstance(dset[0]["nested"]["foo"], torch.Tensor)
2775
self.assertIsNotNone(dset[:2])
2776
self.assertIsInstance(dset[:2]["nested"][0]["foo"], torch.Tensor)
2777
self.assertIsInstance(dset["nested"][0]["foo"], torch.Tensor)
2779
def test_format_pandas(self, in_memory):
2782
with tempfile.TemporaryDirectory() as tmp_dir:
2783
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2784
dset.set_format("pandas")
2785
self.assertIsInstance(dset[0], pd.DataFrame)
2786
self.assertIsInstance(dset[:2], pd.DataFrame)
2787
self.assertIsInstance(dset["col_1"], pd.Series)
2789
def test_transmit_format_single(self, in_memory):
2791
def my_single_transform(self, return_factory, *args, **kwargs):
2792
return return_factory()
2794
with tempfile.TemporaryDirectory() as tmp_dir:
2795
return_factory = partial(
2796
self._create_dummy_dataset, in_memory=in_memory, tmp_dir=tmp_dir, multiple_columns=True
2798
with return_factory() as dset:
2799
dset.set_format("numpy", columns=["col_1"])
2800
prev_format = dset.format
2801
with my_single_transform(dset, return_factory) as transformed_dset:
2802
self.assertDictEqual(transformed_dset.format, prev_format)
2804
def test_transmit_format_dict(self, in_memory):
2806
def my_split_transform(self, return_factory, *args, **kwargs):
2807
return DatasetDict({"train": return_factory()})
2809
with tempfile.TemporaryDirectory() as tmp_dir:
2810
return_factory = partial(
2811
self._create_dummy_dataset, in_memory=in_memory, tmp_dir=tmp_dir, multiple_columns=True
2813
with return_factory() as dset:
2814
dset.set_format("numpy", columns=["col_1"])
2815
prev_format = dset.format
2816
transformed_dset = my_split_transform(dset, return_factory)["train"]
2817
self.assertDictEqual(transformed_dset.format, prev_format)
2819
del transformed_dset
2821
def test_with_format(self, in_memory):
2822
with tempfile.TemporaryDirectory() as tmp_dir:
2823
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2824
with dset.with_format("numpy", columns=["col_1"]) as dset2:
2825
dset.set_format("numpy", columns=["col_1"])
2826
self.assertDictEqual(dset.format, dset2.format)
2827
self.assertEqual(dset._fingerprint, dset2._fingerprint)
2832
def test_with_transform(self, in_memory):
2833
with tempfile.TemporaryDirectory() as tmp_dir:
2834
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2835
transform = lambda x: {"foo": x["col_1"]}
2836
with dset.with_transform(transform, columns=["col_1"]) as dset2:
2837
dset.set_transform(transform, columns=["col_1"])
2838
self.assertDictEqual(dset.format, dset2.format)
2839
self.assertEqual(dset._fingerprint, dset2._fingerprint)
2841
self.assertNotEqual(dset.format, dset2.format)
2842
self.assertNotEqual(dset._fingerprint, dset2._fingerprint)
2845
def test_tf_dataset_conversion(self, in_memory):
2846
tmp_dir = tempfile.TemporaryDirectory()
2847
for num_workers in [0, 1, 2]:
2848
if num_workers > 0 and sys.platform == "win32" and not in_memory:
2850
with self._create_dummy_dataset(in_memory, tmp_dir.name, array_features=True) as dset:
2851
tf_dataset = dset.to_tf_dataset(columns="col_3", batch_size=2, num_workers=num_workers)
2852
batch = next(iter(tf_dataset))
2853
self.assertEqual(batch.shape.as_list(), [2, 4])
2854
self.assertEqual(batch.dtype.name, "int64")
2855
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2856
tf_dataset = dset.to_tf_dataset(columns="col_1", batch_size=2, num_workers=num_workers)
2857
batch = next(iter(tf_dataset))
2858
self.assertEqual(batch.shape.as_list(), [2])
2859
self.assertEqual(batch.dtype.name, "int64")
2860
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2862
tf_dataset = dset.to_tf_dataset(batch_size=2, num_workers=num_workers)
2863
batch = next(iter(tf_dataset))
2864
self.assertEqual(batch["col_1"].shape.as_list(), [2])
2865
self.assertEqual(batch["col_2"].shape.as_list(), [2])
2866
self.assertEqual(batch["col_1"].dtype.name, "int64")
2867
self.assertEqual(batch["col_2"].dtype.name, "string")
2868
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2872
transform_dset = dset.with_transform(
2873
lambda x: {"new_col": [val * 2 for val in x["col_1"]], "col_1": x["col_1"]}
2875
tf_dataset = transform_dset.to_tf_dataset(columns="new_col", batch_size=2, num_workers=num_workers)
2876
batch = next(iter(tf_dataset))
2877
self.assertEqual(batch.shape.as_list(), [2])
2878
self.assertEqual(batch.dtype.name, "int64")
2883
def test_tf_index_reshuffling(self, in_memory):
2888
data = {"col_1": list(range(20))}
2889
for num_workers in [0, 1, 2, 3]:
2890
with Dataset.from_dict(data) as dset:
2891
tf_dataset = dset.to_tf_dataset(batch_size=10, shuffle=True, num_workers=num_workers)
2893
for batch in tf_dataset:
2894
indices.append(batch["col_1"])
2895
indices = np.concatenate([arr.numpy() for arr in indices])
2897
for batch in tf_dataset:
2898
second_indices.append(batch["col_1"])
2899
second_indices = np.concatenate([arr.numpy() for arr in second_indices])
2900
self.assertFalse(np.array_equal(indices, second_indices))
2901
self.assertEqual(len(indices), len(np.unique(indices)))
2902
self.assertEqual(len(second_indices), len(np.unique(second_indices)))
2904
tf_dataset = dset.to_tf_dataset(batch_size=1, shuffle=False, num_workers=num_workers)
2905
for i, batch in enumerate(tf_dataset):
2907
self.assertEqual(i, batch["col_1"].numpy())
2910
def test_tf_label_renaming(self, in_memory):
2912
import tensorflow as tf
2914
from datasets.utils.tf_utils import minimal_tf_collate_fn_with_renaming
2916
tmp_dir = tempfile.TemporaryDirectory()
2917
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2918
with dset.rename_columns({"col_1": "features", "col_2": "label"}) as new_dset:
2919
tf_dataset = new_dset.to_tf_dataset(collate_fn=minimal_tf_collate_fn_with_renaming, batch_size=4)
2920
batch = next(iter(tf_dataset))
2921
self.assertTrue("labels" in batch and "features" in batch)
2923
tf_dataset = new_dset.to_tf_dataset(
2924
columns=["features", "labels"], collate_fn=minimal_tf_collate_fn_with_renaming, batch_size=4
2926
batch = next(iter(tf_dataset))
2927
self.assertTrue("labels" in batch and "features" in batch)
2929
tf_dataset = new_dset.to_tf_dataset(
2930
columns=["features", "label"], collate_fn=minimal_tf_collate_fn_with_renaming, batch_size=4
2932
batch = next(iter(tf_dataset))
2933
self.assertTrue("labels" in batch and "features" in batch)
2935
tf_dataset = new_dset.to_tf_dataset(
2936
columns=["features"],
2937
label_cols=["labels"],
2938
collate_fn=minimal_tf_collate_fn_with_renaming,
2941
batch = next(iter(tf_dataset))
2942
self.assertEqual(len(batch), 2)
2944
self.assertTrue(isinstance(batch[0], tf.Tensor) and isinstance(batch[1], tf.Tensor))
2946
tf_dataset = new_dset.to_tf_dataset(
2947
columns=["features"],
2948
label_cols=["label"],
2949
collate_fn=minimal_tf_collate_fn_with_renaming,
2952
batch = next(iter(tf_dataset))
2953
self.assertEqual(len(batch), 2)
2955
self.assertTrue(isinstance(batch[0], tf.Tensor) and isinstance(batch[1], tf.Tensor))
2957
tf_dataset = new_dset.to_tf_dataset(
2958
columns=["features"],
2959
collate_fn=minimal_tf_collate_fn_with_renaming,
2962
batch = next(iter(tf_dataset))
2965
self.assertTrue(isinstance(batch, tf.Tensor))
2970
def test_tf_dataset_options(self, in_memory):
2971
tmp_dir = tempfile.TemporaryDirectory()
2973
with self._create_dummy_dataset(in_memory, tmp_dir.name, array_features=True) as dset:
2974
tf_dataset = dset.to_tf_dataset(columns="col_3", batch_size=2)
2975
batch = next(iter(tf_dataset))
2976
self.assertEqual(batch.shape.as_list(), [2, 4])
2977
self.assertEqual(batch.dtype.name, "int64")
2979
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2980
tf_dataset = dset.to_tf_dataset(columns="col_3", batch_size=None)
2981
single_example = next(iter(tf_dataset))
2982
self.assertEqual(single_example.shape.as_list(), [])
2983
self.assertEqual(single_example.dtype.name, "int64")
2985
batched_dataset = tf_dataset.batch(batch_size=2)
2986
batch = next(iter(batched_dataset))
2987
self.assertEqual(batch.shape.as_list(), [2])
2988
self.assertEqual(batch.dtype.name, "int64")
2990
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2992
tf_dataset_no_batch = dset.to_tf_dataset(columns="col_3")
2993
tf_dataset_batch = dset.to_tf_dataset(columns="col_3", batch_size=batch_size)
2994
self.assertEqual(tf_dataset_no_batch.element_spec, tf_dataset_batch.unbatch().element_spec)
2995
self.assertEqual(tf_dataset_no_batch.cardinality(), tf_dataset_batch.cardinality() * batch_size)
2996
for batch_1, batch_2 in zip(tf_dataset_no_batch.batch(batch_size=batch_size), tf_dataset_batch):
2997
self.assertEqual(batch_1.shape, batch_2.shape)
2998
self.assertEqual(batch_1.dtype, batch_2.dtype)
2999
self.assertListEqual(batch_1.numpy().tolist(), batch_2.numpy().tolist())
3001
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
3002
tf_dataset = dset.to_tf_dataset(columns="col_1", label_cols=["col_2", "col_3"], batch_size=4)
3003
batch = next(iter(tf_dataset))
3004
self.assertEqual(len(batch), 2)
3005
self.assertEqual(set(batch[1].keys()), {"col_2", "col_3"})
3006
self.assertEqual(batch[0].dtype.name, "int64")
3008
self.assertEqual(batch[0].numpy().tolist(), [3, 2, 1, 0])
3009
self.assertEqual(batch[1]["col_2"].numpy().tolist(), [b"a", b"b", b"c", b"d"])
3010
self.assertEqual(batch[1]["col_3"].numpy().tolist(), [0, 1, 0, 1])
3012
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
3013
tf_dataset = dset.to_tf_dataset(columns="col_1", batch_size=3)
3014
tf_dataset_with_drop = dset.to_tf_dataset(columns="col_1", batch_size=3, drop_remainder=True)
3015
self.assertEqual(len(tf_dataset), 2)
3016
self.assertEqual(len(tf_dataset_with_drop), 1)
3018
if sys.version_info >= (3, 8):
3019
with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
3020
with self.assertRaisesRegex(
3021
NotImplementedError, "`batch_size` must be specified when using multiple workers"
3023
dset.to_tf_dataset(columns="col_1", batch_size=None, num_workers=2)
3025
del tf_dataset_with_drop
3028
class MiscellaneousDatasetTest(TestCase):
3029
def test_from_pandas(self):
3030
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
3031
df = pd.DataFrame.from_dict(data)
3032
with Dataset.from_pandas(df) as dset:
3033
self.assertListEqual(dset["col_1"], data["col_1"])
3034
self.assertListEqual(dset["col_2"], data["col_2"])
3035
self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"])
3036
self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")}))
3038
features = Features({"col_1": Value("int64"), "col_2": Value("string")})
3039
with Dataset.from_pandas(df, features=features) as dset:
3040
self.assertListEqual(dset["col_1"], data["col_1"])
3041
self.assertListEqual(dset["col_2"], data["col_2"])
3042
self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"])
3043
self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")}))
3045
features = Features({"col_1": Value("int64"), "col_2": Value("string")})
3046
with Dataset.from_pandas(df, features=features, info=DatasetInfo(features=features)) as dset:
3047
self.assertListEqual(dset["col_1"], data["col_1"])
3048
self.assertListEqual(dset["col_2"], data["col_2"])
3049
self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"])
3050
self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")}))
3052
features = Features({"col_1": Sequence(Value("string")), "col_2": Value("string")})
3053
self.assertRaises(TypeError, Dataset.from_pandas, df, features=features)
3055
def test_from_dict(self):
3056
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"], "col_3": pa.array([True, False, True, False])}
3057
with Dataset.from_dict(data) as dset:
3058
self.assertListEqual(dset["col_1"], data["col_1"])
3059
self.assertListEqual(dset["col_2"], data["col_2"])
3060
self.assertListEqual(dset["col_3"], data["col_3"].to_pylist())
3061
self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2", "col_3"])
3062
self.assertDictEqual(
3063
dset.features, Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3066
features = Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3067
with Dataset.from_dict(data, features=features) as dset:
3068
self.assertListEqual(dset["col_1"], data["col_1"])
3069
self.assertListEqual(dset["col_2"], data["col_2"])
3070
self.assertListEqual(dset["col_3"], data["col_3"].to_pylist())
3071
self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2", "col_3"])
3072
self.assertDictEqual(
3073
dset.features, Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3076
features = Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3077
with Dataset.from_dict(data, features=features, info=DatasetInfo(features=features)) as dset:
3078
self.assertListEqual(dset["col_1"], data["col_1"])
3079
self.assertListEqual(dset["col_2"], data["col_2"])
3080
self.assertListEqual(dset["col_3"], data["col_3"].to_pylist())
3081
self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2", "col_3"])
3082
self.assertDictEqual(
3083
dset.features, Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3086
features = Features({"col_1": Value("string"), "col_2": Value("string"), "col_3": Value("int32")})
3087
with Dataset.from_dict(data, features=features) as dset:
3089
self.assertListEqual(dset["col_1"], [str(x) for x in data["col_1"]])
3090
self.assertListEqual(dset["col_2"], data["col_2"])
3091
self.assertListEqual(dset["col_3"], [int(x) for x in data["col_3"].to_pylist()])
3092
self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2", "col_3"])
3093
self.assertDictEqual(
3094
dset.features, Features({"col_1": Value("string"), "col_2": Value("string"), "col_3": Value("int32")})
3097
features = Features({"col_1": Value("int64"), "col_2": Value("int64"), "col_3": Value("bool")})
3098
self.assertRaises(ValueError, Dataset.from_dict, data, features=features)
3100
def test_concatenate_mixed_memory_and_disk(self):
3101
data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
3102
info1 = DatasetInfo(description="Dataset1")
3103
info2 = DatasetInfo(description="Dataset2")
3104
with tempfile.TemporaryDirectory() as tmp_dir:
3105
with Dataset.from_dict(data1, info=info1).map(
3106
cache_file_name=os.path.join(tmp_dir, "d1.arrow")
3107
) as dset1, Dataset.from_dict(data2, info=info2).map(
3108
cache_file_name=os.path.join(tmp_dir, "d2.arrow")
3109
) as dset2, Dataset.from_dict(data3) as dset3:
3110
with concatenate_datasets([dset1, dset2, dset3]) as concatenated_dset:
3111
self.assertEqual(len(concatenated_dset), len(dset1) + len(dset2) + len(dset3))
3112
self.assertListEqual(concatenated_dset["id"], dset1["id"] + dset2["id"] + dset3["id"])
3114
@require_transformers
3115
@pytest.mark.integration
3116
def test_set_format_encode(self):
3117
from transformers import BertTokenizer
3119
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
3122
return tokenizer(batch["text"], padding="longest", return_tensors="np")
3124
with Dataset.from_dict({"text": ["hello there", "foo"]}) as dset:
3125
dset.set_transform(transform=encode)
3126
self.assertEqual(str(dset[:2]), str(encode({"text": ["hello there", "foo"]})))
3129
def test_tf_string_encoding(self):
3130
data = {"col_1": ["á", "é", "í", "ó", "ú"], "col_2": ["à", "è", "ì", "ò", "ù"]}
3131
with Dataset.from_dict(data) as dset:
3132
tf_dset_wo_batch = dset.to_tf_dataset(columns=["col_1", "col_2"])
3133
for tf_row, row in zip(tf_dset_wo_batch, dset):
3134
self.assertEqual(tf_row["col_1"].numpy().decode("utf-8"), row["col_1"])
3135
self.assertEqual(tf_row["col_2"].numpy().decode("utf-8"), row["col_2"])
3137
tf_dset_w_batch = dset.to_tf_dataset(columns=["col_1", "col_2"], batch_size=2)
3138
for tf_row, row in zip(tf_dset_w_batch.unbatch(), dset):
3139
self.assertEqual(tf_row["col_1"].numpy().decode("utf-8"), row["col_1"])
3140
self.assertEqual(tf_row["col_2"].numpy().decode("utf-8"), row["col_2"])
3142
self.assertEqual(tf_dset_w_batch.unbatch().element_spec, tf_dset_wo_batch.element_spec)
3143
self.assertEqual(tf_dset_w_batch.element_spec, tf_dset_wo_batch.batch(2).element_spec)
3146
def test_cast_with_sliced_list():
3147
old_features = Features({"foo": Sequence(Value("int64"))})
3148
new_features = Features({"foo": Sequence(Value("int32"))})
3149
dataset = Dataset.from_dict({"foo": [[i] * (i % 3) for i in range(20)]}, features=old_features)
3150
casted_dataset = dataset.cast(new_features, batch_size=2)
3151
assert dataset["foo"] == casted_dataset["foo"]
3152
assert casted_dataset.features == new_features
3155
@pytest.mark.parametrize("include_nulls", [False, True])
3156
def test_class_encode_column_with_none(include_nulls):
3157
dataset = Dataset.from_dict({"col_1": ["a", "b", "c", None, "d", None]})
3158
dataset = dataset.class_encode_column("col_1", include_nulls=include_nulls)
3159
class_names = ["a", "b", "c", "d"]
3161
class_names += ["None"]
3162
assert isinstance(dataset.features["col_1"], ClassLabel)
3163
assert set(dataset.features["col_1"].names) == set(class_names)
3164
assert (None in dataset.unique("col_1")) == (not include_nulls)
3167
@pytest.mark.parametrize("null_placement", ["first", "last"])
3168
def test_sort_with_none(null_placement):
3169
dataset = Dataset.from_dict({"col_1": ["item_2", "item_3", "item_1", None, "item_4", None]})
3170
dataset = dataset.sort("col_1", null_placement=null_placement)
3171
if null_placement == "first":
3172
assert dataset["col_1"] == [None, None, "item_1", "item_2", "item_3", "item_4"]
3174
assert dataset["col_1"] == ["item_1", "item_2", "item_3", "item_4", None, None]
3177
def test_update_metadata_with_features(dataset_dict):
3178
table1 = pa.Table.from_pydict(dataset_dict)
3179
features1 = Features.from_arrow_schema(table1.schema)
3180
features2 = features1.copy()
3181
features2["col_2"] = ClassLabel(num_classes=len(table1))
3182
assert features1 != features2
3184
table2 = update_metadata_with_features(table1, features2)
3185
metadata = json.loads(table2.schema.metadata[b"huggingface"].decode())
3186
assert features2 == Features.from_dict(metadata["info"]["features"])
3188
with Dataset(table1) as dset1, Dataset(table2) as dset2:
3189
assert dset1.features == features1
3190
assert dset2.features == features2
3193
@pytest.mark.parametrize("dataset_type", ["in_memory", "memory_mapped", "mixed"])
3194
@pytest.mark.parametrize("axis, expected_shape", [(0, (4, 3)), (1, (2, 6))])
3195
def test_concatenate_datasets(dataset_type, axis, expected_shape, dataset_dict, arrow_path):
3197
"in_memory": InMemoryTable.from_pydict(dataset_dict),
3198
"memory_mapped": MemoryMappedTable.from_file(arrow_path),
3201
table[dataset_type if dataset_type != "mixed" else "memory_mapped"].slice(0, 2),
3202
table[dataset_type if dataset_type != "mixed" else "in_memory"].slice(2, 4),
3205
tables[1] = tables[1].rename_columns([col + "_bis" for col in tables[1].column_names])
3206
datasets = [Dataset(table) for table in tables]
3207
dataset = concatenate_datasets(datasets, axis=axis)
3208
assert dataset.shape == expected_shape
3209
assert_arrow_metadata_are_synced_with_dataset_features(dataset)
3212
def test_concatenate_datasets_new_columns():
3213
dataset1 = Dataset.from_dict({"col_1": ["a", "b", "c"]})
3214
dataset2 = Dataset.from_dict({"col_1": ["d", "e", "f"], "col_2": [True, False, True]})
3215
dataset = concatenate_datasets([dataset1, dataset2])
3216
assert dataset.data.shape == (6, 2)
3217
assert dataset.features == Features({"col_1": Value("string"), "col_2": Value("bool")})
3218
assert dataset[:] == {"col_1": ["a", "b", "c", "d", "e", "f"], "col_2": [None, None, None, True, False, True]}
3219
dataset3 = Dataset.from_dict({"col_3": ["a_1"]})
3220
dataset = concatenate_datasets([dataset, dataset3])
3221
assert dataset.data.shape == (7, 3)
3222
assert dataset.features == Features({"col_1": Value("string"), "col_2": Value("bool"), "col_3": Value("string")})
3223
assert dataset[:] == {
3224
"col_1": ["a", "b", "c", "d", "e", "f", None],
3225
"col_2": [None, None, None, True, False, True, None],
3226
"col_3": [None, None, None, None, None, None, "a_1"],
3230
@pytest.mark.parametrize("axis", [0, 1])
3231
def test_concatenate_datasets_complex_features(axis):
3233
dataset1 = Dataset.from_dict(
3234
{"col_1": [0] * n, "col_2": list(range(n))},
3235
features=Features({"col_1": Value("int32"), "col_2": ClassLabel(num_classes=n)}),
3238
dataset2 = dataset1.rename_columns({col: col + "_" for col in dataset1.column_names})
3239
expected_features = Features({**dataset1.features, **dataset2.features})
3242
expected_features = dataset1.features
3243
assert concatenate_datasets([dataset1, dataset2], axis=axis).features == expected_features
3246
@pytest.mark.parametrize("other_dataset_type", ["in_memory", "memory_mapped", "concatenation"])
3247
@pytest.mark.parametrize("axis, expected_shape", [(0, (8, 3)), (1, (4, 6))])
3248
def test_concatenate_datasets_with_concatenation_tables(
3249
axis, expected_shape, other_dataset_type, dataset_dict, arrow_path
3251
def _create_concatenation_table(axis):
3253
concatenation_table = ConcatenationTable.from_blocks(
3256
InMemoryTable.from_pydict({"col_1": dataset_dict["col_1"]}),
3257
MemoryMappedTable.from_file(arrow_path).remove_column(0),
3262
concatenation_table = ConcatenationTable.from_blocks(
3264
[InMemoryTable.from_pydict(dataset_dict).slice(0, 1)],
3265
[MemoryMappedTable.from_file(arrow_path).slice(1, 4)],
3268
return concatenation_table
3270
concatenation_table = _create_concatenation_table(axis)
3271
assert concatenation_table.shape == (4, 3)
3273
if other_dataset_type == "in_memory":
3274
other_table = InMemoryTable.from_pydict(dataset_dict)
3275
elif other_dataset_type == "memory_mapped":
3276
other_table = MemoryMappedTable.from_file(arrow_path)
3277
elif other_dataset_type == "concatenation":
3278
other_table = _create_concatenation_table(axis)
3279
assert other_table.shape == (4, 3)
3281
tables = [concatenation_table, other_table]
3284
tables[1] = tables[1].rename_columns([col + "_bis" for col in tables[1].column_names])
3286
for tables in [tables, reversed(tables)]:
3287
datasets = [Dataset(table) for table in tables]
3288
dataset = concatenate_datasets(datasets, axis=axis)
3289
assert dataset.shape == expected_shape
3292
def test_concatenate_datasets_duplicate_columns(dataset):
3293
with pytest.raises(ValueError) as excinfo:
3294
concatenate_datasets([dataset, dataset], axis=1)
3295
assert "duplicated" in str(excinfo.value)
3298
def test_interleave_datasets():
3299
d1 = Dataset.from_dict({"a": [0, 1, 2]})
3300
d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
3301
d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
3302
dataset = interleave_datasets([d1, d2, d3])
3303
expected_length = 3 * min(len(d1), len(d2), len(d3))
3304
expected_values = [x["a"] for x in itertools.chain(*zip(d1, d2, d3))]
3305
assert isinstance(dataset, Dataset)
3306
assert len(dataset) == expected_length
3307
assert dataset["a"] == expected_values
3308
assert dataset._fingerprint == interleave_datasets([d1, d2, d3])._fingerprint
3311
def test_interleave_datasets_probabilities():
3313
probabilities = [0.3, 0.5, 0.2]
3314
d1 = Dataset.from_dict({"a": [0, 1, 2]})
3315
d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
3316
d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
3317
dataset = interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed)
3319
expected_values = [10, 11, 20, 12, 0, 21, 13]
3320
assert isinstance(dataset, Dataset)
3321
assert len(dataset) == expected_length
3322
assert dataset["a"] == expected_values
3324
dataset._fingerprint == interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed)._fingerprint
3328
def test_interleave_datasets_oversampling_strategy():
3329
d1 = Dataset.from_dict({"a": [0, 1, 2]})
3330
d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
3331
d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
3332
dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
3333
expected_length = 3 * max(len(d1), len(d2), len(d3))
3334
expected_values = [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 20]
3335
assert isinstance(dataset, Dataset)
3336
assert len(dataset) == expected_length
3337
assert dataset["a"] == expected_values
3338
assert dataset._fingerprint == interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")._fingerprint
3341
def test_interleave_datasets_probabilities_oversampling_strategy():
3343
probabilities = [0.3, 0.5, 0.2]
3344
d1 = Dataset.from_dict({"a": [0, 1, 2]})
3345
d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
3346
d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
3347
dataset = interleave_datasets(
3348
[d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed
3350
expected_length = 16
3351
expected_values = [10, 11, 20, 12, 0, 21, 13, 10, 1, 11, 12, 22, 13, 20, 10, 2]
3352
assert isinstance(dataset, Dataset)
3353
assert len(dataset) == expected_length
3354
assert dataset["a"] == expected_values
3356
dataset._fingerprint
3357
== interleave_datasets(
3358
[d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed
3363
@pytest.mark.parametrize("batch_size", [4, 5])
3364
@pytest.mark.parametrize("drop_last_batch", [False, True])
3365
def test_dataset_iter_batch(batch_size, drop_last_batch):
3367
dset = Dataset.from_dict({"i": list(range(n))})
3368
all_col_values = list(range(n))
3370
for i, batch in enumerate(dset.iter(batch_size, drop_last_batch=drop_last_batch)):
3371
assert batch == {"i": all_col_values[i * batch_size : (i + 1) * batch_size]}
3372
batches.append(batch)
3374
assert all(len(batch["i"]) == batch_size for batch in batches)
3376
assert all(len(batch["i"]) == batch_size for batch in batches[:-1])
3377
assert len(batches[-1]["i"]) <= batch_size
3380
@pytest.mark.parametrize(
3381
"column, expected_dtype",
3382
[(["a", "b", "c", "d"], "string"), ([1, 2, 3, 4], "int64"), ([1.0, 2.0, 3.0, 4.0], "float64")],
3384
@pytest.mark.parametrize("in_memory", [False, True])
3385
@pytest.mark.parametrize(
3389
("shuffle", (42,), {}),
3390
("with_format", ("pandas",), {}),
3391
("class_encode_column", ("col_2",), {}),
3392
("select", (range(3),), {}),
3395
def test_dataset_add_column(column, expected_dtype, in_memory, transform, dataset_dict, arrow_path):
3396
column_name = "col_4"
3397
original_dataset = (
3398
Dataset(InMemoryTable.from_pydict(dataset_dict))
3400
else Dataset(MemoryMappedTable.from_file(arrow_path))
3402
if transform is not None:
3403
transform_name, args, kwargs = transform
3404
original_dataset: Dataset = getattr(original_dataset, transform_name)(*args, **kwargs)
3405
column = column[:3] if transform is not None and transform_name == "select" else column
3406
dataset = original_dataset.add_column(column_name, column)
3407
assert dataset.data.shape == (3, 4) if transform is not None and transform_name == "select" else (4, 4)
3408
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3410
expected_features = {feature: expected_features[feature] for feature in original_dataset.features}
3412
expected_features[column_name] = expected_dtype
3413
assert dataset.data.column_names == list(expected_features.keys())
3414
for feature, expected_dtype in expected_features.items():
3415
assert dataset.features[feature].dtype == expected_dtype
3416
assert len(dataset.data.blocks) == 1 if in_memory else 2
3417
assert dataset.format["type"] == original_dataset.format["type"]
3418
assert dataset._fingerprint != original_dataset._fingerprint
3419
dataset.reset_format()
3420
original_dataset.reset_format()
3421
assert all(dataset[col] == original_dataset[col] for col in original_dataset.column_names)
3422
assert set(dataset["col_4"]) == set(column)
3423
if dataset._indices is not None:
3424
dataset_indices = dataset._indices["indices"].to_pylist()
3425
expected_dataset_indices = original_dataset._indices["indices"].to_pylist()
3426
assert dataset_indices == expected_dataset_indices
3427
assert_arrow_metadata_are_synced_with_dataset_features(dataset)
3430
@pytest.mark.parametrize(
3432
[None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})],
3434
@pytest.mark.parametrize("in_memory", [False, True])
3435
@pytest.mark.parametrize(
3438
{"col_1": "2", "col_2": 2, "col_3": 2.0},
3439
{"col_1": "2", "col_2": "2", "col_3": "2"},
3440
{"col_1": 2, "col_2": 2, "col_3": 2},
3441
{"col_1": 2.0, "col_2": 2.0, "col_3": 2.0},
3444
def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path, transform):
3446
Dataset(InMemoryTable.from_pydict(dataset_dict))
3448
else Dataset(MemoryMappedTable.from_file(arrow_path))
3450
if transform is not None:
3451
transform_name, args, kwargs = transform
3452
dataset_to_test: Dataset = getattr(dataset_to_test, transform_name)(*args, **kwargs)
3453
dataset = dataset_to_test.add_item(item)
3454
assert dataset.data.shape == (5, 3)
3455
expected_features = dataset_to_test.features
3456
assert sorted(dataset.data.column_names) == sorted(expected_features.keys())
3457
for feature, expected_dtype in expected_features.items():
3458
assert dataset.features[feature] == expected_dtype
3459
assert len(dataset.data.blocks) == 1 if in_memory else 2
3460
assert dataset.format["type"] == dataset_to_test.format["type"]
3461
assert dataset._fingerprint != dataset_to_test._fingerprint
3462
dataset.reset_format()
3463
dataset_to_test.reset_format()
3464
assert dataset[:-1] == dataset_to_test[:]
3465
assert {k: int(v) for k, v in dataset[-1].items()} == {k: int(v) for k, v in item.items()}
3466
if dataset._indices is not None:
3467
dataset_indices = dataset._indices["indices"].to_pylist()
3468
dataset_to_test_indices = dataset_to_test._indices["indices"].to_pylist()
3469
assert dataset_indices == dataset_to_test_indices + [len(dataset_to_test._data)]
3472
def test_dataset_add_item_new_columns():
3473
dataset = Dataset.from_dict({"col_1": [0, 1, 2]}, features=Features({"col_1": Value("uint8")}))
3474
dataset = dataset.add_item({"col_1": 3, "col_2": "a"})
3475
assert dataset.data.shape == (4, 2)
3476
assert dataset.features == Features({"col_1": Value("uint8"), "col_2": Value("string")})
3477
assert dataset[:] == {"col_1": [0, 1, 2, 3], "col_2": [None, None, None, "a"]}
3478
dataset = dataset.add_item({"col_3": True})
3479
assert dataset.data.shape == (5, 3)
3480
assert dataset.features == Features({"col_1": Value("uint8"), "col_2": Value("string"), "col_3": Value("bool")})
3481
assert dataset[:] == {
3482
"col_1": [0, 1, 2, 3, None],
3483
"col_2": [None, None, None, "a", None],
3484
"col_3": [None, None, None, None, True],
3488
def test_dataset_add_item_introduce_feature_type():
3489
dataset = Dataset.from_dict({"col_1": [None, None, None]})
3490
dataset = dataset.add_item({"col_1": "a"})
3491
assert dataset.data.shape == (4, 1)
3492
assert dataset.features == Features({"col_1": Value("string")})
3493
assert dataset[:] == {"col_1": [None, None, None, "a"]}
3496
def test_dataset_filter_batched_indices():
3497
ds = Dataset.from_dict({"num": [0, 1, 2, 3]})
3498
ds = ds.filter(lambda num: num % 2 == 0, input_columns="num", batch_size=2)
3499
assert all(item["num"] % 2 == 0 for item in ds)
3502
@pytest.mark.parametrize("in_memory", [False, True])
3503
def test_dataset_from_file(in_memory, dataset, arrow_file):
3504
filename = arrow_file
3505
with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase():
3506
dataset_from_file = Dataset.from_file(filename, in_memory=in_memory)
3507
assert dataset_from_file.features.type == dataset.features.type
3508
assert dataset_from_file.features == dataset.features
3509
assert dataset_from_file.cache_files == ([{"filename": filename}] if not in_memory else [])
3512
def _check_csv_dataset(dataset, expected_features):
3513
assert isinstance(dataset, Dataset)
3514
assert dataset.num_rows == 4
3515
assert dataset.num_columns == 3
3516
assert dataset.column_names == ["col_1", "col_2", "col_3"]
3517
for feature, expected_dtype in expected_features.items():
3518
assert dataset.features[feature].dtype == expected_dtype
3521
@pytest.mark.parametrize("keep_in_memory", [False, True])
3522
def test_dataset_from_csv_keep_in_memory(keep_in_memory, csv_path, tmp_path):
3523
cache_dir = tmp_path / "cache"
3524
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
3525
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3526
dataset = Dataset.from_csv(csv_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3527
_check_csv_dataset(dataset, expected_features)
3530
@pytest.mark.parametrize(
3534
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
3535
{"col_1": "string", "col_2": "string", "col_3": "string"},
3536
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3537
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3540
def test_dataset_from_csv_features(features, csv_path, tmp_path):
3541
cache_dir = tmp_path / "cache"
3543
default_expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
3544
expected_features = features.copy() if features else default_expected_features
3546
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3548
dataset = Dataset.from_csv(csv_path, features=features, cache_dir=cache_dir)
3549
_check_csv_dataset(dataset, expected_features)
3552
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
3553
def test_dataset_from_csv_split(split, csv_path, tmp_path):
3554
cache_dir = tmp_path / "cache"
3555
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
3556
dataset = Dataset.from_csv(csv_path, cache_dir=cache_dir, split=split)
3557
_check_csv_dataset(dataset, expected_features)
3558
assert dataset.split == split if split else "train"
3561
@pytest.mark.parametrize("path_type", [str, list])
3562
def test_dataset_from_csv_path_type(path_type, csv_path, tmp_path):
3563
if issubclass(path_type, str):
3565
elif issubclass(path_type, list):
3567
cache_dir = tmp_path / "cache"
3568
expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
3569
dataset = Dataset.from_csv(path, cache_dir=cache_dir)
3570
_check_csv_dataset(dataset, expected_features)
3573
def _check_json_dataset(dataset, expected_features):
3574
assert isinstance(dataset, Dataset)
3575
assert dataset.num_rows == 4
3576
assert dataset.num_columns == 3
3577
assert dataset.column_names == ["col_1", "col_2", "col_3"]
3578
for feature, expected_dtype in expected_features.items():
3579
assert dataset.features[feature].dtype == expected_dtype
3582
@pytest.mark.parametrize("keep_in_memory", [False, True])
3583
def test_dataset_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_path):
3584
cache_dir = tmp_path / "cache"
3585
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3586
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3587
dataset = Dataset.from_json(jsonl_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3588
_check_json_dataset(dataset, expected_features)
3591
@pytest.mark.parametrize(
3595
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
3596
{"col_1": "string", "col_2": "string", "col_3": "string"},
3597
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3598
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3601
def test_dataset_from_json_features(features, jsonl_path, tmp_path):
3602
cache_dir = tmp_path / "cache"
3603
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3604
expected_features = features.copy() if features else default_expected_features
3606
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3608
dataset = Dataset.from_json(jsonl_path, features=features, cache_dir=cache_dir)
3609
_check_json_dataset(dataset, expected_features)
3612
def test_dataset_from_json_with_class_label_feature(jsonl_str_path, tmp_path):
3613
features = Features(
3614
{"col_1": ClassLabel(names=["s0", "s1", "s2", "s3"]), "col_2": Value("int64"), "col_3": Value("float64")}
3616
cache_dir = tmp_path / "cache"
3617
dataset = Dataset.from_json(jsonl_str_path, features=features, cache_dir=cache_dir)
3618
assert dataset.features["col_1"].dtype == "int64"
3621
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
3622
def test_dataset_from_json_split(split, jsonl_path, tmp_path):
3623
cache_dir = tmp_path / "cache"
3624
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3625
dataset = Dataset.from_json(jsonl_path, cache_dir=cache_dir, split=split)
3626
_check_json_dataset(dataset, expected_features)
3627
assert dataset.split == split if split else "train"
3630
@pytest.mark.parametrize("path_type", [str, list])
3631
def test_dataset_from_json_path_type(path_type, jsonl_path, tmp_path):
3632
if issubclass(path_type, str):
3634
elif issubclass(path_type, list):
3636
cache_dir = tmp_path / "cache"
3637
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3638
dataset = Dataset.from_json(path, cache_dir=cache_dir)
3639
_check_json_dataset(dataset, expected_features)
3642
def _check_parquet_dataset(dataset, expected_features):
3643
assert isinstance(dataset, Dataset)
3644
assert dataset.num_rows == 4
3645
assert dataset.num_columns == 3
3646
assert dataset.column_names == ["col_1", "col_2", "col_3"]
3647
for feature, expected_dtype in expected_features.items():
3648
assert dataset.features[feature].dtype == expected_dtype
3651
@pytest.mark.parametrize("keep_in_memory", [False, True])
3652
def test_dataset_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_path):
3653
cache_dir = tmp_path / "cache"
3654
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3655
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3656
dataset = Dataset.from_parquet(parquet_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3657
_check_parquet_dataset(dataset, expected_features)
3660
@pytest.mark.parametrize(
3664
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
3665
{"col_1": "string", "col_2": "string", "col_3": "string"},
3666
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3667
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3670
def test_dataset_from_parquet_features(features, parquet_path, tmp_path):
3671
cache_dir = tmp_path / "cache"
3672
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3673
expected_features = features.copy() if features else default_expected_features
3675
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3677
dataset = Dataset.from_parquet(parquet_path, features=features, cache_dir=cache_dir)
3678
_check_parquet_dataset(dataset, expected_features)
3681
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
3682
def test_dataset_from_parquet_split(split, parquet_path, tmp_path):
3683
cache_dir = tmp_path / "cache"
3684
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3685
dataset = Dataset.from_parquet(parquet_path, cache_dir=cache_dir, split=split)
3686
_check_parquet_dataset(dataset, expected_features)
3687
assert dataset.split == split if split else "train"
3690
@pytest.mark.parametrize("path_type", [str, list])
3691
def test_dataset_from_parquet_path_type(path_type, parquet_path, tmp_path):
3692
if issubclass(path_type, str):
3694
elif issubclass(path_type, list):
3695
path = [parquet_path]
3696
cache_dir = tmp_path / "cache"
3697
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3698
dataset = Dataset.from_parquet(path, cache_dir=cache_dir)
3699
_check_parquet_dataset(dataset, expected_features)
3702
def _check_text_dataset(dataset, expected_features):
3703
assert isinstance(dataset, Dataset)
3704
assert dataset.num_rows == 4
3705
assert dataset.num_columns == 1
3706
assert dataset.column_names == ["text"]
3707
for feature, expected_dtype in expected_features.items():
3708
assert dataset.features[feature].dtype == expected_dtype
3711
@pytest.mark.parametrize("keep_in_memory", [False, True])
3712
def test_dataset_from_text_keep_in_memory(keep_in_memory, text_path, tmp_path):
3713
cache_dir = tmp_path / "cache"
3714
expected_features = {"text": "string"}
3715
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3716
dataset = Dataset.from_text(text_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3717
_check_text_dataset(dataset, expected_features)
3720
@pytest.mark.parametrize(
3726
{"text": "float32"},
3729
def test_dataset_from_text_features(features, text_path, tmp_path):
3730
cache_dir = tmp_path / "cache"
3731
default_expected_features = {"text": "string"}
3732
expected_features = features.copy() if features else default_expected_features
3734
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3736
dataset = Dataset.from_text(text_path, features=features, cache_dir=cache_dir)
3737
_check_text_dataset(dataset, expected_features)
3740
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
3741
def test_dataset_from_text_split(split, text_path, tmp_path):
3742
cache_dir = tmp_path / "cache"
3743
expected_features = {"text": "string"}
3744
dataset = Dataset.from_text(text_path, cache_dir=cache_dir, split=split)
3745
_check_text_dataset(dataset, expected_features)
3746
assert dataset.split == split if split else "train"
3749
@pytest.mark.parametrize("path_type", [str, list])
3750
def test_dataset_from_text_path_type(path_type, text_path, tmp_path):
3751
if issubclass(path_type, str):
3753
elif issubclass(path_type, list):
3755
cache_dir = tmp_path / "cache"
3756
expected_features = {"text": "string"}
3757
dataset = Dataset.from_text(path, cache_dir=cache_dir)
3758
_check_text_dataset(dataset, expected_features)
3762
def data_generator():
3765
{"col_1": "0", "col_2": 0, "col_3": 0.0},
3766
{"col_1": "1", "col_2": 1, "col_3": 1.0},
3767
{"col_1": "2", "col_2": 2, "col_3": 2.0},
3768
{"col_1": "3", "col_2": 3, "col_3": 3.0},
3776
def _check_generator_dataset(dataset, expected_features):
3777
assert isinstance(dataset, Dataset)
3778
assert dataset.num_rows == 4
3779
assert dataset.num_columns == 3
3780
assert dataset.column_names == ["col_1", "col_2", "col_3"]
3781
for feature, expected_dtype in expected_features.items():
3782
assert dataset.features[feature].dtype == expected_dtype
3785
@pytest.mark.parametrize("keep_in_memory", [False, True])
3786
def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, tmp_path):
3787
cache_dir = tmp_path / "cache"
3788
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3789
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3790
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3791
_check_generator_dataset(dataset, expected_features)
3794
@pytest.mark.parametrize(
3798
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
3799
{"col_1": "string", "col_2": "string", "col_3": "string"},
3800
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3801
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3804
def test_dataset_from_generator_features(features, data_generator, tmp_path):
3805
cache_dir = tmp_path / "cache"
3806
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3807
expected_features = features.copy() if features else default_expected_features
3809
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3811
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
3812
_check_generator_dataset(dataset, expected_features)
3816
@require_dill_gt_0_3_2
3818
def test_from_spark():
3821
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
3828
df = spark.createDataFrame(data, "col_1: string, col_2: int, col_3: float")
3829
dataset = Dataset.from_spark(df)
3830
assert isinstance(dataset, Dataset)
3831
assert dataset.num_rows == 4
3832
assert dataset.num_columns == 3
3833
assert dataset.column_names == ["col_1", "col_2", "col_3"]
3837
@require_dill_gt_0_3_2
3839
def test_from_spark_features():
3843
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
3844
data = [(0, np.arange(4 * 4 * 3).reshape(4, 4, 3).tolist())]
3845
df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>")
3846
features = Features({"idx": Value("int64"), "image": Image()})
3847
dataset = Dataset.from_spark(
3851
assert isinstance(dataset, Dataset)
3852
assert dataset.num_rows == 1
3853
assert dataset.num_columns == 2
3854
assert dataset.column_names == ["idx", "image"]
3855
assert isinstance(dataset[0]["image"], PIL.Image.Image)
3856
assert dataset.features == features
3857
assert_arrow_metadata_are_synced_with_dataset_features(dataset)
3861
@require_dill_gt_0_3_2
3863
def test_from_spark_different_cache():
3866
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
3867
df = spark.createDataFrame([("0", 0)], "col_1: string, col_2: int")
3868
dataset = Dataset.from_spark(df)
3869
assert isinstance(dataset, Dataset)
3870
different_df = spark.createDataFrame([("1", 1)], "col_1: string, col_2: int")
3871
different_dataset = Dataset.from_spark(different_df)
3872
assert isinstance(different_dataset, Dataset)
3873
assert dataset[0]["col_1"] == "0"
3875
assert different_dataset[0]["col_1"] == "1"
3878
def _check_sql_dataset(dataset, expected_features):
3879
assert isinstance(dataset, Dataset)
3880
assert dataset.num_rows == 4
3881
assert dataset.num_columns == 3
3882
assert dataset.column_names == ["col_1", "col_2", "col_3"]
3883
for feature, expected_dtype in expected_features.items():
3884
assert dataset.features[feature].dtype == expected_dtype
3888
@pytest.mark.parametrize("con_type", ["string", "engine"])
3889
def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
3890
cache_dir = tmp_path / "cache"
3891
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3892
if con_type == "string":
3893
con = "sqlite:///" + sqlite_path
3894
elif con_type == "engine":
3897
con = sqlalchemy.create_engine("sqlite:///" + sqlite_path)
3909
dataset = Dataset.from_sql(
3912
cache_dir=cache_dir,
3914
_check_sql_dataset(dataset, expected_features)
3918
@pytest.mark.parametrize(
3922
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
3923
{"col_1": "string", "col_2": "string", "col_3": "string"},
3924
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3925
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3928
def test_dataset_from_sql_features(features, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
3929
cache_dir = tmp_path / "cache"
3930
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3931
expected_features = features.copy() if features else default_expected_features
3933
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3935
dataset = Dataset.from_sql("dataset", "sqlite:///" + sqlite_path, features=features, cache_dir=cache_dir)
3936
_check_sql_dataset(dataset, expected_features)
3940
@pytest.mark.parametrize("keep_in_memory", [False, True])
3941
def test_dataset_from_sql_keep_in_memory(keep_in_memory, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
3942
cache_dir = tmp_path / "cache"
3943
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3944
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3945
dataset = Dataset.from_sql(
3946
"dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory
3948
_check_sql_dataset(dataset, expected_features)
3951
def test_dataset_to_json(dataset, tmp_path):
3952
file_path = tmp_path / "test_path.jsonl"
3953
bytes_written = dataset.to_json(path_or_buf=file_path)
3954
assert file_path.is_file()
3955
assert bytes_written == file_path.stat().st_size
3956
df = pd.read_json(file_path, orient="records", lines=True)
3957
assert df.shape == dataset.shape
3958
assert list(df.columns) == list(dataset.column_names)
3961
@pytest.mark.parametrize("in_memory", [False, True])
3962
@pytest.mark.parametrize(
3963
"method_and_params",
3965
("rename_column", (), {"original_column_name": "labels", "new_column_name": "label"}),
3966
("remove_columns", (), {"column_names": "labels"}),
3971
"features": Features(
3973
"tokens": Sequence(Value("string")),
3974
"labels": Sequence(Value("int16")),
3975
"answers": Sequence(
3977
"text": Value("string"),
3978
"answer_start": Value("int32"),
3981
"id": Value("int32"),
3986
("flatten", (), {}),
3989
def test_pickle_dataset_after_transforming_the_table(in_memory, method_and_params, arrow_file):
3990
method, args, kwargs = method_and_params
3991
with Dataset.from_file(arrow_file, in_memory=in_memory) as dataset, Dataset.from_file(
3992
arrow_file, in_memory=in_memory
3993
) as reference_dataset:
3994
out = getattr(dataset, method)(*args, **kwargs)
3995
dataset = out if out is not None else dataset
3996
pickled_dataset = pickle.dumps(dataset)
3997
reloaded_dataset = pickle.loads(pickled_dataset)
3999
assert dataset._data != reference_dataset._data
4000
assert dataset._data.table == reloaded_dataset._data.table
4003
def test_dummy_dataset_serialize_fs(dataset, mockfs):
4004
dataset_path = "mock://my_dataset"
4005
dataset.save_to_disk(dataset_path, storage_options=mockfs.storage_options)
4006
assert mockfs.isdir(dataset_path)
4007
assert mockfs.glob(dataset_path + "/*")
4008
reloaded = dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
4009
assert len(reloaded) == len(dataset)
4010
assert reloaded.features == dataset.features
4011
assert reloaded.to_dict() == dataset.to_dict()
4014
@pytest.mark.parametrize(
4019
"s3://bucket/relative/path",
4020
"hdfs://relative/path",
4021
"hdfs:///absolute/path",
4024
def test_build_local_temp_path(uri_or_path):
4025
extracted_path = strip_protocol(uri_or_path)
4026
local_temp_path = Dataset._build_local_temp_path(extracted_path).as_posix()
4027
extracted_path_without_anchor = Path(extracted_path).relative_to(Path(extracted_path).anchor).as_posix()
4029
path_relative_to_tmp_dir = Path(local_temp_path).relative_to(Path(tempfile.gettempdir())).as_posix()
4032
"hdfs://" not in path_relative_to_tmp_dir
4033
and "s3://" not in path_relative_to_tmp_dir
4034
and not local_temp_path.startswith(extracted_path_without_anchor)
4035
and local_temp_path.endswith(extracted_path_without_anchor)
4036
), f"Local temp path: {local_temp_path}"
4039
class TaskTemplatesTest(TestCase):
4040
def test_task_text_classification(self):
4041
labels = sorted(["pos", "neg"])
4042
features_before_cast = Features(
4044
"input_text": Value("string"),
4045
"input_labels": ClassLabel(names=labels),
4049
features_after_cast = Features(
4051
"text": Value("string"),
4052
"labels": ClassLabel(names=labels),
4056
task_without_labels = TextClassification(text_column="input_text", label_column="input_labels")
4057
info1 = DatasetInfo(
4058
features=features_before_cast,
4059
task_templates=task_without_labels,
4063
task_with_labels = TextClassification(text_column="input_text", label_column="input_labels")
4064
info2 = DatasetInfo(
4065
features=features_before_cast,
4066
task_templates=task_with_labels,
4068
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4070
with Dataset.from_dict(data, info=info1) as dset:
4071
self.assertSetEqual({"input_text", "input_labels"}, set(dset.column_names))
4072
self.assertDictEqual(features_before_cast, dset.features)
4073
with dset.prepare_for_task(task="text-classification") as dset:
4074
self.assertSetEqual({"labels", "text"}, set(dset.column_names))
4075
self.assertDictEqual(features_after_cast, dset.features)
4077
with Dataset.from_dict(data, info=info2) as dset:
4078
self.assertSetEqual({"input_text", "input_labels"}, set(dset.column_names))
4079
self.assertDictEqual(features_before_cast, dset.features)
4080
with dset.prepare_for_task(task="text-classification") as dset:
4081
self.assertSetEqual({"labels", "text"}, set(dset.column_names))
4082
self.assertDictEqual(features_after_cast, dset.features)
4084
info1.task_templates = None
4085
with Dataset.from_dict(data, info=info1) as dset:
4086
with dset.prepare_for_task(task=task_with_labels) as dset:
4087
self.assertSetEqual({"labels", "text"}, set(dset.column_names))
4088
self.assertDictEqual(features_after_cast, dset.features)
4090
def test_task_question_answering(self):
4091
features_before_cast = Features(
4093
"input_context": Value("string"),
4094
"input_question": Value("string"),
4095
"input_answers": Sequence(
4097
"text": Value("string"),
4098
"answer_start": Value("int32"),
4103
features_after_cast = Features(
4105
"context": Value("string"),
4106
"question": Value("string"),
4107
"answers": Sequence(
4109
"text": Value("string"),
4110
"answer_start": Value("int32"),
4115
task = QuestionAnsweringExtractive(
4116
context_column="input_context", question_column="input_question", answers_column="input_answers"
4118
info = DatasetInfo(features=features_before_cast, task_templates=task)
4120
"input_context": ["huggingface is going to the moon!"],
4121
"input_question": ["where is huggingface going?"],
4122
"input_answers": [{"text": ["to the moon!"], "answer_start": [2]}],
4125
with Dataset.from_dict(data, info=info) as dset:
4126
self.assertSetEqual(
4127
{"input_context", "input_question", "input_answers.text", "input_answers.answer_start"},
4128
set(dset.flatten().column_names),
4130
self.assertDictEqual(features_before_cast, dset.features)
4131
with dset.prepare_for_task(task="question-answering-extractive") as dset:
4132
self.assertSetEqual(
4133
{"context", "question", "answers.text", "answers.answer_start"},
4134
set(dset.flatten().column_names),
4136
self.assertDictEqual(features_after_cast, dset.features)
4138
info.task_templates = None
4139
with Dataset.from_dict(data, info=info) as dset:
4140
with dset.prepare_for_task(task=task) as dset:
4141
self.assertSetEqual(
4142
{"context", "question", "answers.text", "answers.answer_start"},
4143
set(dset.flatten().column_names),
4145
self.assertDictEqual(features_after_cast, dset.features)
4147
def test_task_summarization(self):
4149
features_before_cast = Features(
4150
{"input_text": Value("string"), "input_summary": Value("string"), "dummy": Value("string")}
4152
features_after_cast = Features({"text": Value("string"), "summary": Value("string")})
4153
task = Summarization(text_column="input_text", summary_column="input_summary")
4154
info = DatasetInfo(features=features_before_cast, task_templates=task)
4156
"input_text": ["jack and jill took a taxi to attend a super duper party in the city."],
4157
"input_summary": ["jack and jill attend party"],
4158
"dummy": ["123456"],
4161
with Dataset.from_dict(data, info=info) as dset:
4162
with dset.prepare_for_task(task="summarization") as dset:
4163
self.assertSetEqual(
4164
{"text", "summary"},
4165
set(dset.column_names),
4167
self.assertDictEqual(features_after_cast, dset.features)
4169
info.task_templates = None
4170
with Dataset.from_dict(data, info=info) as dset:
4171
with dset.prepare_for_task(task=task) as dset:
4172
self.assertSetEqual(
4173
{"text", "summary"},
4174
set(dset.column_names),
4176
self.assertDictEqual(features_after_cast, dset.features)
4178
def test_task_automatic_speech_recognition(self):
4180
features_before_cast = Features(
4182
"input_audio": Audio(sampling_rate=16_000),
4183
"input_transcription": Value("string"),
4184
"dummy": Value("string"),
4187
features_after_cast = Features({"audio": Audio(sampling_rate=16_000), "transcription": Value("string")})
4188
task = AutomaticSpeechRecognition(audio_column="input_audio", transcription_column="input_transcription")
4189
info = DatasetInfo(features=features_before_cast, task_templates=task)
4191
"input_audio": [{"bytes": None, "path": "path/to/some/audio/file.wav"}],
4192
"input_transcription": ["hello, my name is bob!"],
4193
"dummy": ["123456"],
4196
with Dataset.from_dict(data, info=info) as dset:
4197
with dset.prepare_for_task(task="automatic-speech-recognition") as dset:
4198
self.assertSetEqual(
4199
{"audio", "transcription"},
4200
set(dset.column_names),
4202
self.assertDictEqual(features_after_cast, dset.features)
4204
info.task_templates = None
4205
with Dataset.from_dict(data, info=info) as dset:
4206
with dset.prepare_for_task(task=task) as dset:
4207
self.assertSetEqual(
4208
{"audio", "transcription"},
4209
set(dset.column_names),
4211
self.assertDictEqual(features_after_cast, dset.features)
4213
def test_task_with_no_template(self):
4214
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4215
with Dataset.from_dict(data) as dset:
4216
with self.assertRaises(ValueError):
4217
dset.prepare_for_task("text-classification")
4219
def test_task_with_incompatible_templates(self):
4220
labels = sorted(["pos", "neg"])
4221
features = Features(
4223
"input_text": Value("string"),
4224
"input_labels": ClassLabel(names=labels),
4227
task = TextClassification(text_column="input_text", label_column="input_labels")
4230
task_templates=task,
4232
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4233
with Dataset.from_dict(data, info=info) as dset:
4235
self.assertRaises(ValueError, dset.prepare_for_task, "this-task-does-not-exist")
4237
self.assertRaises(ValueError, dset.prepare_for_task, 1)
4239
def test_task_with_multiple_compatible_task_templates(self):
4240
features = Features(
4242
"text1": Value("string"),
4243
"text2": Value("string"),
4246
task1 = LanguageModeling(text_column="text1")
4247
task2 = LanguageModeling(text_column="text2")
4250
task_templates=[task1, task2],
4252
data = {"text1": ["i love transformers!"], "text2": ["i love datasets!"]}
4253
with Dataset.from_dict(data, info=info) as dset:
4254
self.assertRaises(ValueError, dset.prepare_for_task, "language-modeling", id=3)
4255
with dset.prepare_for_task("language-modeling") as dset1:
4256
self.assertEqual(dset1[0]["text"], "i love transformers!")
4257
with dset.prepare_for_task("language-modeling", id=1) as dset2:
4258
self.assertEqual(dset2[0]["text"], "i love datasets!")
4260
def test_task_templates_empty_after_preparation(self):
4261
features = Features(
4263
"input_text": Value("string"),
4264
"input_labels": ClassLabel(names=["pos", "neg"]),
4267
task = TextClassification(text_column="input_text", label_column="input_labels")
4270
task_templates=task,
4272
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4273
with Dataset.from_dict(data, info=info) as dset:
4274
with dset.prepare_for_task(task="text-classification") as dset:
4275
self.assertIsNone(dset.info.task_templates)
4277
def test_align_labels_with_mapping_classification(self):
4278
features = Features(
4280
"input_text": Value("string"),
4281
"input_labels": ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]),
4284
data = {"input_text": ["a", "a", "b", "b", "c", "c"], "input_labels": [0, 0, 1, 1, 2, 2]}
4285
label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1}
4286
id2label = {v: k for k, v in label2id.items()}
4287
expected_labels = [2, 2, 1, 1, 0, 0]
4288
expected_label_names = [id2label[idx] for idx in expected_labels]
4289
with Dataset.from_dict(data, features=features) as dset:
4290
with dset.align_labels_with_mapping(label2id, "input_labels") as dset:
4291
self.assertListEqual(expected_labels, dset["input_labels"])
4292
aligned_label_names = [dset.features["input_labels"].int2str(idx) for idx in dset["input_labels"]]
4293
self.assertListEqual(expected_label_names, aligned_label_names)
4295
def test_align_labels_with_mapping_ner(self):
4296
features = Features(
4298
"input_text": Value("string"),
4299
"input_labels": Sequence(
4310
data = {"input_text": [["Optimus", "Prime", "is", "a", "Transformer"]], "input_labels": [[0, 1, 2, 2, 2]]}
4311
label2id = {"B-PER": 2, "I-PER": 1, "O": 0}
4312
id2label = {v: k for k, v in label2id.items()}
4313
expected_labels = [[2, 1, 0, 0, 0]]
4314
expected_label_names = [[id2label[idx] for idx in seq] for seq in expected_labels]
4315
with Dataset.from_dict(data, features=features) as dset:
4316
with dset.align_labels_with_mapping(label2id, "input_labels") as dset:
4317
self.assertListEqual(expected_labels, dset["input_labels"])
4318
aligned_label_names = [
4319
dset.features["input_labels"].feature.int2str(idx) for idx in dset["input_labels"]
4321
self.assertListEqual(expected_label_names, aligned_label_names)
4323
def test_concatenate_with_no_task_templates(self):
4324
info = DatasetInfo(task_templates=None)
4325
data = {"text": ["i love transformers!"], "labels": [1]}
4326
with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict(
4328
) as dset2, Dataset.from_dict(data, info=info) as dset3:
4329
with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
4330
self.assertEqual(dset_concat.info.task_templates, None)
4332
def test_concatenate_with_equal_task_templates(self):
4333
labels = ["neg", "pos"]
4334
task_template = TextClassification(text_column="text", label_column="labels")
4336
features=Features({"text": Value("string"), "labels": ClassLabel(names=labels)}),
4338
task_templates=TextClassification(text_column="text", label_column="labels"),
4340
data = {"text": ["i love transformers!"], "labels": [1]}
4341
with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict(
4343
) as dset2, Dataset.from_dict(data, info=info) as dset3:
4344
with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
4345
self.assertListEqual(dset_concat.info.task_templates, [task_template])
4347
def test_concatenate_with_mixed_task_templates_in_common(self):
4348
tc_template = TextClassification(text_column="text", label_column="labels")
4349
qa_template = QuestionAnsweringExtractive(
4350
question_column="question", context_column="context", answers_column="answers"
4352
info1 = DatasetInfo(
4353
task_templates=[qa_template],
4356
"text": Value("string"),
4357
"labels": ClassLabel(names=["pos", "neg"]),
4358
"context": Value("string"),
4359
"question": Value("string"),
4360
"answers": Sequence(
4362
"text": Value("string"),
4363
"answer_start": Value("int32"),
4369
info2 = DatasetInfo(
4370
task_templates=[qa_template, tc_template],
4373
"text": Value("string"),
4374
"labels": ClassLabel(names=["pos", "neg"]),
4375
"context": Value("string"),
4376
"question": Value("string"),
4377
"answers": Sequence(
4379
"text": Value("string"),
4380
"answer_start": Value("int32"),
4387
"text": ["i love transformers!"],
4389
"context": ["huggingface is going to the moon!"],
4390
"question": ["where is huggingface going?"],
4391
"answers": [{"text": ["to the moon!"], "answer_start": [2]}],
4393
with Dataset.from_dict(data, info=info1) as dset1, Dataset.from_dict(
4395
) as dset2, Dataset.from_dict(data, info=info2) as dset3:
4396
with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
4397
self.assertListEqual(dset_concat.info.task_templates, [qa_template])
4399
def test_concatenate_with_no_mixed_task_templates_in_common(self):
4400
tc_template1 = TextClassification(text_column="text", label_column="labels")
4401
tc_template2 = TextClassification(text_column="text", label_column="sentiment")
4402
qa_template = QuestionAnsweringExtractive(
4403
question_column="question", context_column="context", answers_column="answers"
4405
info1 = DatasetInfo(
4408
"text": Value("string"),
4409
"labels": ClassLabel(names=["pos", "neg"]),
4410
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
4411
"context": Value("string"),
4412
"question": Value("string"),
4413
"answers": Sequence(
4415
"text": Value("string"),
4416
"answer_start": Value("int32"),
4421
task_templates=[tc_template1],
4423
info2 = DatasetInfo(
4426
"text": Value("string"),
4427
"labels": ClassLabel(names=["pos", "neg"]),
4428
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
4429
"context": Value("string"),
4430
"question": Value("string"),
4431
"answers": Sequence(
4433
"text": Value("string"),
4434
"answer_start": Value("int32"),
4439
task_templates=[tc_template2],
4441
info3 = DatasetInfo(
4444
"text": Value("string"),
4445
"labels": ClassLabel(names=["pos", "neg"]),
4446
"sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
4447
"context": Value("string"),
4448
"question": Value("string"),
4449
"answers": Sequence(
4451
"text": Value("string"),
4452
"answer_start": Value("int32"),
4457
task_templates=[qa_template],
4460
"text": ["i love transformers!"],
4463
"context": ["huggingface is going to the moon!"],
4464
"question": ["where is huggingface going?"],
4465
"answers": [{"text": ["to the moon!"], "answer_start": [2]}],
4467
with Dataset.from_dict(data, info=info1) as dset1, Dataset.from_dict(
4469
) as dset2, Dataset.from_dict(data, info=info3) as dset3:
4470
with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
4471
self.assertEqual(dset_concat.info.task_templates, None)
4473
def test_task_text_classification_when_columns_removed(self):
4474
labels = sorted(["pos", "neg"])
4475
features_before_map = Features(
4477
"input_text": Value("string"),
4478
"input_labels": ClassLabel(names=labels),
4481
features_after_map = Features({"new_column": Value("int64")})
4483
task = TextClassification(text_column="input_text", label_column="input_labels")
4485
features=features_before_map,
4486
task_templates=task,
4488
data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4489
with Dataset.from_dict(data, info=info) as dset:
4490
with dset.map(lambda x: {"new_column": 0}, remove_columns=dset.column_names) as dset:
4491
self.assertDictEqual(dset.features, features_after_map)
4494
class StratifiedTest(TestCase):
4495
def test_errors_train_test_split_stratify(self):
4497
np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2]),
4498
np.array([0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
4499
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
4500
np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]),
4501
np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]),
4503
for i in range(len(ys)):
4504
features = Features({"text": Value("int64"), "label": ClassLabel(len(np.unique(ys[i])))})
4505
data = {"text": np.ones(len(ys[i])), "label": ys[i]}
4506
d1 = Dataset.from_dict(data, features=features)
4510
self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="labl")
4514
self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="label")
4518
d1 = Dataset.from_dict(data)
4519
self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="label")
4523
self.assertRaises(ValueError, d1.train_test_split, 0.30, stratify_by_column="label")
4527
self.assertRaises(ValueError, d1.train_test_split, 0.60, stratify_by_column="label")
4529
def test_train_test_split_startify(self):
4531
np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2]),
4532
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
4533
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
4534
np.array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]),
4535
np.array([0] * 800 + [1] * 50),
4538
features = Features({"text": Value("int64"), "label": ClassLabel(len(np.unique(y)))})
4539
data = {"text": np.ones(len(y)), "label": y}
4540
d1 = Dataset.from_dict(data, features=features)
4541
d1 = d1.train_test_split(test_size=0.33, stratify_by_column="label")
4542
y = np.asanyarray(y)
4543
test_size = np.ceil(0.33 * len(y))
4544
train_size = len(y) - test_size
4545
npt.assert_array_equal(np.unique(d1["train"]["label"]), np.unique(d1["test"]["label"]))
4548
p_train = np.bincount(np.unique(d1["train"]["label"], return_inverse=True)[1]) / float(
4549
len(d1["train"]["label"])
4551
p_test = np.bincount(np.unique(d1["test"]["label"], return_inverse=True)[1]) / float(
4552
len(d1["test"]["label"])
4554
npt.assert_array_almost_equal(p_train, p_test, 1)
4555
assert len(d1["train"]["text"]) + len(d1["test"]["text"]) == y.size
4556
assert len(d1["train"]["text"]) == train_size
4557
assert len(d1["test"]["text"]) == test_size
4560
def test_dataset_estimate_nbytes():
4561
ds = Dataset.from_dict({"a": ["0" * 100] * 100})
4562
assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than full dataset size"
4564
ds = Dataset.from_dict({"a": ["0" * 100] * 100}).select([0])
4565
assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than one chunk"
4567
ds = Dataset.from_dict({"a": ["0" * 100] * 100})
4568
ds = concatenate_datasets([ds] * 100)
4569
assert 0.9 * ds._estimate_nbytes() < 100 * 100 * 100, "must be smaller than full dataset size"
4570
assert 1.1 * ds._estimate_nbytes() > 100 * 100 * 100, "must be bigger than full dataset size"
4572
ds = Dataset.from_dict({"a": ["0" * 100] * 100})
4573
ds = concatenate_datasets([ds] * 100).select([0])
4574
assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than one chunk"
4577
def test_dataset_to_iterable_dataset(dataset: Dataset):
4578
iterable_dataset = dataset.to_iterable_dataset()
4579
assert isinstance(iterable_dataset, IterableDataset)
4580
assert list(iterable_dataset) == list(dataset)
4581
assert iterable_dataset.features == dataset.features
4582
iterable_dataset = dataset.to_iterable_dataset(num_shards=3)
4583
assert isinstance(iterable_dataset, IterableDataset)
4584
assert list(iterable_dataset) == list(dataset)
4585
assert iterable_dataset.features == dataset.features
4586
assert iterable_dataset.n_shards == 3
4587
with pytest.raises(ValueError):
4588
dataset.to_iterable_dataset(num_shards=len(dataset) + 1)
4589
with pytest.raises(NotImplementedError):
4590
dataset.with_format("torch").to_iterable_dataset()
4594
def test_dataset_format_with_unformatted_image():
4597
ds = Dataset.from_dict(
4598
{"a": [np.arange(4 * 4 * 3).reshape(4, 4, 3)] * 10, "b": [[0, 1]] * 10},
4599
Features({"a": Image(), "b": Sequence(Value("int64"))}),
4601
ds.set_format("np", columns=["b"], output_all_columns=True)
4602
assert isinstance(ds[0]["a"], PIL.Image.Image)
4603
assert isinstance(ds[0]["b"], np.ndarray)
4606
@pytest.mark.parametrize("batch_size", [1, 4])
4608
def test_dataset_with_torch_dataloader(dataset, batch_size):
4609
from torch.utils.data import DataLoader
4611
from datasets import config
4613
dataloader = DataLoader(dataset, batch_size=batch_size)
4614
with patch.object(dataset, "_getitem", wraps=dataset._getitem) as mock_getitem:
4615
out = list(dataloader)
4616
getitem_call_count = mock_getitem.call_count
4617
assert len(out) == len(dataset) // batch_size + int(len(dataset) % batch_size > 0)
4619
if config.TORCH_VERSION >= version.parse("1.13.0"):
4620
assert getitem_call_count == len(dataset) // batch_size + int(len(dataset) % batch_size > 0)
4623
@pytest.mark.parametrize("return_lazy_dict", [True, False, "mix"])
4624
def test_map_cases(return_lazy_dict):
4626
"""May return a mix of LazyDict and regular Dict"""
4629
return dict(x) if return_lazy_dict is False else x
4631
return x if return_lazy_dict is True else {}
4633
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4636
assert outputs == {"a": [-1, -1, 2, 3]}
4639
"""May return a mix of LazyDict and regular Dict, but sometimes with None values"""
4642
return dict(x) if return_lazy_dict is False else x
4644
return x if return_lazy_dict is True else {}
4646
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4649
assert outputs == {"a": [None, None, 2, 3]}
4652
"""Return a LazyDict, but we remove a lazy column and add a new one"""
4660
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4661
ds = ds.map(f, remove_columns=["a"])
4663
assert outputs == {"b": [-1, -1, 2, 3]}
4666
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4667
ds = ds.with_format("numpy")
4668
ds = ds.map(f, remove_columns=["a"])
4669
ds = ds.with_format(None)
4671
assert outputs == {"a": [0, 1, 2, 3], "b": [-1, -1, 2, 3]}
4674
"""May return a mix of LazyDict and regular Dict, but we replace a lazy column"""
4677
return dict(x) if return_lazy_dict is False else x
4680
return x if return_lazy_dict is True else {"a": x["a"]}
4682
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4683
ds = ds.map(f, remove_columns=["a"])
4685
assert outputs == ({"a": [-1, -1, 2, 3]} if return_lazy_dict is False else {})
4688
"""May return a mix of LazyDict and regular Dict, but we modify a nested lazy column in-place"""
4691
return dict(x) if return_lazy_dict is False else x
4693
x["a"]["c"] = x["a"]["b"]
4694
return x if return_lazy_dict is True else {}
4696
ds = Dataset.from_dict({"a": [{"b": 0}, {"b": 1}, {"b": 2}, {"b": 3}]})
4699
assert outputs == {"a": [{"b": 0, "c": -1}, {"b": 1, "c": -1}, {"b": 2, "c": 2}, {"b": 3, "c": 3}]}
4702
"""May return a mix of LazyDict and regular Dict, but using an extension type"""
4703
if x["a"][0][0] < 2:
4705
return dict(x) if return_lazy_dict is False else x
4707
return x if return_lazy_dict is True else {}
4709
features = Features({"a": Array2D(shape=(1, 1), dtype="int32")})
4710
ds = Dataset.from_dict({"a": [[[i]] for i in [0, 1, 2, 3]]}, features=features)
4713
assert outputs == {"a": [[[i]] for i in [-1, -1, 2, 3]]}
4716
"""May return a mix of LazyDict and regular Dict, but using a nested extension type"""
4717
if x["a"]["nested"][0][0] < 2:
4718
x["a"] = {"nested": [[-1]]}
4719
return dict(x) if return_lazy_dict is False else x
4721
return x if return_lazy_dict is True else {}
4723
features = Features({"a": {"nested": Array2D(shape=(1, 1), dtype="int64")}})
4724
ds = Dataset.from_dict({"a": [{"nested": [[i]]} for i in [0, 1, 2, 3]]}, features=features)
4727
assert outputs == {"a": [{"nested": [[i]]} for i in [-1, -1, 2, 3]]}
4730
def test_dataset_getitem_raises():
4731
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4732
with pytest.raises(TypeError):
4734
with pytest.raises(TypeError):