datasets

Форк
0
/
test_arrow_dataset.py 
4735 строк · 239.7 Кб
1
import contextlib
2
import copy
3
import itertools
4
import json
5
import os
6
import pickle
7
import re
8
import sys
9
import tempfile
10
from functools import partial
11
from pathlib import Path
12
from unittest import TestCase
13
from unittest.mock import MagicMock, patch
14

15
import numpy as np
16
import numpy.testing as npt
17
import pandas as pd
18
import pyarrow as pa
19
import pytest
20
from absl.testing import parameterized
21
from fsspec.core import strip_protocol
22
from packaging import version
23

24
import datasets.arrow_dataset
25
from datasets import concatenate_datasets, interleave_datasets, load_from_disk
26
from datasets.arrow_dataset import Dataset, transmit_format, update_metadata_with_features
27
from datasets.dataset_dict import DatasetDict
28
from datasets.features import (
29
    Array2D,
30
    Array3D,
31
    Audio,
32
    ClassLabel,
33
    Features,
34
    Image,
35
    Sequence,
36
    Translation,
37
    TranslationVariableLanguages,
38
    Value,
39
)
40
from datasets.info import DatasetInfo
41
from datasets.iterable_dataset import IterableDataset
42
from datasets.splits import NamedSplit
43
from datasets.table import ConcatenationTable, InMemoryTable, MemoryMappedTable
44
from datasets.tasks import (
45
    AutomaticSpeechRecognition,
46
    LanguageModeling,
47
    QuestionAnsweringExtractive,
48
    Summarization,
49
    TextClassification,
50
)
51
from datasets.utils.logging import INFO, get_logger
52
from datasets.utils.py_utils import temp_seed
53

54
from .utils import (
55
    assert_arrow_memory_doesnt_increase,
56
    assert_arrow_memory_increases,
57
    require_dill_gt_0_3_2,
58
    require_jax,
59
    require_not_windows,
60
    require_pil,
61
    require_pyspark,
62
    require_sqlalchemy,
63
    require_tf,
64
    require_torch,
65
    require_transformers,
66
    set_current_working_directory_to_temp_dir,
67
)
68

69

70
class PickableMagicMock(MagicMock):
71
    def __reduce__(self):
72
        return MagicMock, ()
73

74

75
class Unpicklable:
76
    def __getstate__(self):
77
        raise pickle.PicklingError()
78

79

80
def picklable_map_function(x):
81
    return {"id": int(x["filename"].split("_")[-1])}
82

83

84
def picklable_map_function_with_indices(x, i):
85
    return {"id": i}
86

87

88
def picklable_map_function_with_rank(x, r):
89
    return {"rank": r}
90

91

92
def picklable_map_function_with_indices_and_rank(x, i, r):
93
    return {"id": i, "rank": r}
94

95

96
def picklable_filter_function(x):
97
    return int(x["filename"].split("_")[-1]) < 10
98

99

100
def picklable_filter_function_with_rank(x, r):
101
    return r == 0
102

103

104
def assert_arrow_metadata_are_synced_with_dataset_features(dataset: Dataset):
105
    assert dataset.data.schema.metadata is not None
106
    assert b"huggingface" in dataset.data.schema.metadata
107
    metadata = json.loads(dataset.data.schema.metadata[b"huggingface"].decode())
108
    assert "info" in metadata
109
    features = DatasetInfo.from_dict(metadata["info"]).features
110
    assert features is not None
111
    assert features == dataset.features
112
    assert features == Features.from_arrow_schema(dataset.data.schema)
113
    assert list(features) == dataset.data.column_names
114
    assert list(features) == list(dataset.features)
115

116

117
IN_MEMORY_PARAMETERS = [
118
    {"testcase_name": name, "in_memory": im} for im, name in [(True, "in_memory"), (False, "on_disk")]
119
]
120

121

122
@parameterized.named_parameters(IN_MEMORY_PARAMETERS)
123
class BaseDatasetTest(TestCase):
124
    @pytest.fixture(autouse=True)
125
    def inject_fixtures(self, caplog, set_sqlalchemy_silence_uber_warning):
126
        self._caplog = caplog
127

128
    def _create_dummy_dataset(
129
        self, in_memory: bool, tmp_dir: str, multiple_columns=False, array_features=False, nested_features=False
130
    ) -> Dataset:
131
        assert int(multiple_columns) + int(array_features) + int(nested_features) < 2
132
        if multiple_columns:
133
            data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"], "col_3": [False, True, False, True]}
134
            dset = Dataset.from_dict(data)
135
        elif array_features:
136
            data = {
137
                "col_1": [[[True, False], [False, True]]] * 4,  # 2D
138
                "col_2": [[[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]] * 4,  # 3D array
139
                "col_3": [[3, 2, 1, 0]] * 4,  # Sequence
140
            }
141
            features = Features(
142
                {
143
                    "col_1": Array2D(shape=(2, 2), dtype="bool"),
144
                    "col_2": Array3D(shape=(2, 2, 2), dtype="string"),
145
                    "col_3": Sequence(feature=Value("int64")),
146
                }
147
            )
148
            dset = Dataset.from_dict(data, features=features)
149
        elif nested_features:
150
            data = {"nested": [{"a": i, "x": i * 10, "c": i * 100} for i in range(1, 11)]}
151
            features = Features({"nested": {"a": Value("int64"), "x": Value("int64"), "c": Value("int64")}})
152
            dset = Dataset.from_dict(data, features=features)
153
        else:
154
            dset = Dataset.from_dict({"filename": ["my_name-train" + "_" + str(x) for x in np.arange(30).tolist()]})
155
        if not in_memory:
156
            dset = self._to(in_memory, tmp_dir, dset)
157
        return dset
158

159
    def _to(self, in_memory, tmp_dir, *datasets):
160
        if in_memory:
161
            datasets = [dataset.map(keep_in_memory=True) for dataset in datasets]
162
        else:
163
            start = 0
164
            while os.path.isfile(os.path.join(tmp_dir, f"dataset{start}.arrow")):
165
                start += 1
166
            datasets = [
167
                dataset.map(cache_file_name=os.path.join(tmp_dir, f"dataset{start + i}.arrow"))
168
                for i, dataset in enumerate(datasets)
169
            ]
170
        return datasets if len(datasets) > 1 else datasets[0]
171

172
    def test_dummy_dataset(self, in_memory):
173
        with tempfile.TemporaryDirectory() as tmp_dir:
174
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
175
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
176
                self.assertEqual(dset[0]["filename"], "my_name-train_0")
177
                self.assertEqual(dset["filename"][0], "my_name-train_0")
178

179
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
180
                self.assertDictEqual(
181
                    dset.features,
182
                    Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")}),
183
                )
184
                self.assertEqual(dset[0]["col_1"], 3)
185
                self.assertEqual(dset["col_1"][0], 3)
186

187
        with tempfile.TemporaryDirectory() as tmp_dir:
188
            with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
189
                self.assertDictEqual(
190
                    dset.features,
191
                    Features(
192
                        {
193
                            "col_1": Array2D(shape=(2, 2), dtype="bool"),
194
                            "col_2": Array3D(shape=(2, 2, 2), dtype="string"),
195
                            "col_3": Sequence(feature=Value("int64")),
196
                        }
197
                    ),
198
                )
199
                self.assertEqual(dset[0]["col_2"], [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]])
200
                self.assertEqual(dset["col_2"][0], [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]])
201

202
    def test_dataset_getitem(self, in_memory):
203
        with tempfile.TemporaryDirectory() as tmp_dir:
204
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
205
                self.assertEqual(dset[0]["filename"], "my_name-train_0")
206
                self.assertEqual(dset["filename"][0], "my_name-train_0")
207

208
                self.assertEqual(dset[-1]["filename"], "my_name-train_29")
209
                self.assertEqual(dset["filename"][-1], "my_name-train_29")
210

211
                self.assertListEqual(dset[:2]["filename"], ["my_name-train_0", "my_name-train_1"])
212
                self.assertListEqual(dset["filename"][:2], ["my_name-train_0", "my_name-train_1"])
213

214
                self.assertEqual(dset[:-1]["filename"][-1], "my_name-train_28")
215
                self.assertEqual(dset["filename"][:-1][-1], "my_name-train_28")
216

217
                self.assertListEqual(dset[[0, -1]]["filename"], ["my_name-train_0", "my_name-train_29"])
218
                self.assertListEqual(dset[range(0, -2, -1)]["filename"], ["my_name-train_0", "my_name-train_29"])
219
                self.assertListEqual(dset[np.array([0, -1])]["filename"], ["my_name-train_0", "my_name-train_29"])
220
                self.assertListEqual(dset[pd.Series([0, -1])]["filename"], ["my_name-train_0", "my_name-train_29"])
221

222
                with dset.select(range(2)) as dset_subset:
223
                    self.assertListEqual(dset_subset[-1:]["filename"], ["my_name-train_1"])
224
                    self.assertListEqual(dset_subset["filename"][-1:], ["my_name-train_1"])
225

226
    def test_dummy_dataset_deepcopy(self, in_memory):
227
        with tempfile.TemporaryDirectory() as tmp_dir:
228
            with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
229
                with assert_arrow_memory_doesnt_increase():
230
                    dset2 = copy.deepcopy(dset)
231
                # don't copy the underlying arrow data using memory
232
                self.assertEqual(len(dset2), 10)
233
                self.assertDictEqual(dset2.features, Features({"filename": Value("string")}))
234
                self.assertEqual(dset2[0]["filename"], "my_name-train_0")
235
                self.assertEqual(dset2["filename"][0], "my_name-train_0")
236
                del dset2
237

238
    def test_dummy_dataset_pickle(self, in_memory):
239
        with tempfile.TemporaryDirectory() as tmp_dir:
240
            tmp_file = os.path.join(tmp_dir, "dset.pt")
241

242
            with self._create_dummy_dataset(in_memory, tmp_dir).select(range(0, 10, 2)) as dset:
243
                with open(tmp_file, "wb") as f:
244
                    pickle.dump(dset, f)
245

246
            with open(tmp_file, "rb") as f:
247
                with pickle.load(f) as dset:
248
                    self.assertEqual(len(dset), 5)
249
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
250
                    self.assertEqual(dset[0]["filename"], "my_name-train_0")
251
                    self.assertEqual(dset["filename"][0], "my_name-train_0")
252

253
            with self._create_dummy_dataset(in_memory, tmp_dir).select(
254
                range(0, 10, 2), indices_cache_file_name=os.path.join(tmp_dir, "ind.arrow")
255
            ) as dset:
256
                if not in_memory:
257
                    dset._data.table = Unpicklable()
258
                dset._indices.table = Unpicklable()
259
                with open(tmp_file, "wb") as f:
260
                    pickle.dump(dset, f)
261

262
            with open(tmp_file, "rb") as f:
263
                with pickle.load(f) as dset:
264
                    self.assertEqual(len(dset), 5)
265
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
266
                    self.assertEqual(dset[0]["filename"], "my_name-train_0")
267
                    self.assertEqual(dset["filename"][0], "my_name-train_0")
268

269
    def test_dummy_dataset_serialize(self, in_memory):
270
        with tempfile.TemporaryDirectory() as tmp_dir:
271
            with set_current_working_directory_to_temp_dir():
272
                with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
273
                    dataset_path = "my_dataset"  # rel path
274
                    dset.save_to_disk(dataset_path)
275

276
                with Dataset.load_from_disk(dataset_path) as dset:
277
                    self.assertEqual(len(dset), 10)
278
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
279
                    self.assertEqual(dset[0]["filename"], "my_name-train_0")
280
                    self.assertEqual(dset["filename"][0], "my_name-train_0")
281
                    expected = dset.to_dict()
282

283
            with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
284
                dataset_path = os.path.join(tmp_dir, "my_dataset")  # abs path
285
                dset.save_to_disk(dataset_path)
286

287
            with Dataset.load_from_disk(dataset_path) as dset:
288
                self.assertEqual(len(dset), 10)
289
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
290
                self.assertEqual(dset[0]["filename"], "my_name-train_0")
291
                self.assertEqual(dset["filename"][0], "my_name-train_0")
292

293
            with self._create_dummy_dataset(in_memory, tmp_dir).select(
294
                range(10), indices_cache_file_name=os.path.join(tmp_dir, "ind.arrow")
295
            ) as dset:
296
                with assert_arrow_memory_doesnt_increase():
297
                    dset.save_to_disk(dataset_path)
298

299
            with Dataset.load_from_disk(dataset_path) as dset:
300
                self.assertEqual(len(dset), 10)
301
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
302
                self.assertEqual(dset[0]["filename"], "my_name-train_0")
303
                self.assertEqual(dset["filename"][0], "my_name-train_0")
304

305
            with self._create_dummy_dataset(in_memory, tmp_dir, nested_features=True) as dset:
306
                with assert_arrow_memory_doesnt_increase():
307
                    dset.save_to_disk(dataset_path)
308

309
            with Dataset.load_from_disk(dataset_path) as dset:
310
                self.assertEqual(len(dset), 10)
311
                self.assertDictEqual(
312
                    dset.features,
313
                    Features({"nested": {"a": Value("int64"), "x": Value("int64"), "c": Value("int64")}}),
314
                )
315
                self.assertDictEqual(dset[0]["nested"], {"a": 1, "c": 100, "x": 10})
316
                self.assertDictEqual(dset["nested"][0], {"a": 1, "c": 100, "x": 10})
317

318
            with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
319
                with assert_arrow_memory_doesnt_increase():
320
                    dset.save_to_disk(dataset_path, num_shards=4)
321

322
            with Dataset.load_from_disk(dataset_path) as dset:
323
                self.assertEqual(len(dset), 10)
324
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
325
                self.assertDictEqual(dset.to_dict(), expected)
326
                self.assertEqual(len(dset.cache_files), 4)
327

328
            with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
329
                with assert_arrow_memory_doesnt_increase():
330
                    dset.save_to_disk(dataset_path, num_proc=2)
331

332
            with Dataset.load_from_disk(dataset_path) as dset:
333
                self.assertEqual(len(dset), 10)
334
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
335
                self.assertDictEqual(dset.to_dict(), expected)
336
                self.assertEqual(len(dset.cache_files), 2)
337

338
            with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
339
                with assert_arrow_memory_doesnt_increase():
340
                    dset.save_to_disk(dataset_path, num_shards=7, num_proc=2)
341

342
            with Dataset.load_from_disk(dataset_path) as dset:
343
                self.assertEqual(len(dset), 10)
344
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
345
                self.assertDictEqual(dset.to_dict(), expected)
346
                self.assertEqual(len(dset.cache_files), 7)
347

348
            with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
349
                with assert_arrow_memory_doesnt_increase():
350
                    max_shard_size = dset._estimate_nbytes() // 2 + 1
351
                    dset.save_to_disk(dataset_path, max_shard_size=max_shard_size)
352

353
            with Dataset.load_from_disk(dataset_path) as dset:
354
                self.assertEqual(len(dset), 10)
355
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
356
                self.assertDictEqual(dset.to_dict(), expected)
357
                self.assertEqual(len(dset.cache_files), 2)
358

359
    def test_dummy_dataset_load_from_disk(self, in_memory):
360
        with tempfile.TemporaryDirectory() as tmp_dir:
361
            with self._create_dummy_dataset(in_memory, tmp_dir).select(range(10)) as dset:
362
                dataset_path = os.path.join(tmp_dir, "my_dataset")
363
                dset.save_to_disk(dataset_path)
364

365
            with load_from_disk(dataset_path) as dset:
366
                self.assertEqual(len(dset), 10)
367
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
368
                self.assertEqual(dset[0]["filename"], "my_name-train_0")
369
                self.assertEqual(dset["filename"][0], "my_name-train_0")
370

371
    def test_restore_saved_format(self, in_memory):
372
        with tempfile.TemporaryDirectory() as tmp_dir:
373
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
374
                dset.set_format(type="numpy", columns=["col_1"], output_all_columns=True)
375
                dataset_path = os.path.join(tmp_dir, "my_dataset")
376
                dset.save_to_disk(dataset_path)
377

378
                with load_from_disk(dataset_path) as loaded_dset:
379
                    self.assertEqual(dset.format, loaded_dset.format)
380

381
    def test_set_format_numpy_multiple_columns(self, in_memory):
382
        with tempfile.TemporaryDirectory() as tmp_dir:
383
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
384
                fingerprint = dset._fingerprint
385
                dset.set_format(type="numpy", columns=["col_1"])
386
                self.assertEqual(len(dset[0]), 1)
387
                self.assertIsInstance(dset[0]["col_1"], np.int64)
388
                self.assertEqual(dset[0]["col_1"].item(), 3)
389
                self.assertIsInstance(dset["col_1"], np.ndarray)
390
                self.assertListEqual(list(dset["col_1"].shape), [4])
391
                np.testing.assert_array_equal(dset["col_1"], np.array([3, 2, 1, 0]))
392
                self.assertNotEqual(dset._fingerprint, fingerprint)
393

394
                dset.reset_format()
395
                with dset.formatted_as(type="numpy", columns=["col_1"]):
396
                    self.assertEqual(len(dset[0]), 1)
397
                    self.assertIsInstance(dset[0]["col_1"], np.int64)
398
                    self.assertEqual(dset[0]["col_1"].item(), 3)
399
                    self.assertIsInstance(dset["col_1"], np.ndarray)
400
                    self.assertListEqual(list(dset["col_1"].shape), [4])
401
                    np.testing.assert_array_equal(dset["col_1"], np.array([3, 2, 1, 0]))
402

403
                self.assertEqual(dset.format["type"], None)
404
                self.assertEqual(dset.format["format_kwargs"], {})
405
                self.assertEqual(dset.format["columns"], dset.column_names)
406
                self.assertEqual(dset.format["output_all_columns"], False)
407

408
                dset.set_format(type="numpy", columns=["col_1"], output_all_columns=True)
409
                self.assertEqual(len(dset[0]), 3)
410
                self.assertIsInstance(dset[0]["col_2"], str)
411
                self.assertEqual(dset[0]["col_2"], "a")
412

413
                dset.set_format(type="numpy", columns=["col_1", "col_2"])
414
                self.assertEqual(len(dset[0]), 2)
415
                self.assertIsInstance(dset[0]["col_2"], np.str_)
416
                self.assertEqual(dset[0]["col_2"].item(), "a")
417

418
    @require_torch
419
    def test_set_format_torch(self, in_memory):
420
        import torch
421

422
        with tempfile.TemporaryDirectory() as tmp_dir:
423
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
424
                dset.set_format(type="torch", columns=["col_1"])
425
                self.assertEqual(len(dset[0]), 1)
426
                self.assertIsInstance(dset[0]["col_1"], torch.Tensor)
427
                self.assertIsInstance(dset["col_1"], torch.Tensor)
428
                self.assertListEqual(list(dset[0]["col_1"].shape), [])
429
                self.assertEqual(dset[0]["col_1"].item(), 3)
430

431
                dset.set_format(type="torch", columns=["col_1"], output_all_columns=True)
432
                self.assertEqual(len(dset[0]), 3)
433
                self.assertIsInstance(dset[0]["col_2"], str)
434
                self.assertEqual(dset[0]["col_2"], "a")
435

436
                dset.set_format(type="torch")
437
                self.assertEqual(len(dset[0]), 3)
438
                self.assertIsInstance(dset[0]["col_1"], torch.Tensor)
439
                self.assertIsInstance(dset["col_1"], torch.Tensor)
440
                self.assertListEqual(list(dset[0]["col_1"].shape), [])
441
                self.assertEqual(dset[0]["col_1"].item(), 3)
442
                self.assertIsInstance(dset[0]["col_2"], str)
443
                self.assertEqual(dset[0]["col_2"], "a")
444
                self.assertIsInstance(dset[0]["col_3"], torch.Tensor)
445
                self.assertIsInstance(dset["col_3"], torch.Tensor)
446
                self.assertListEqual(list(dset[0]["col_3"].shape), [])
447

448
    @require_tf
449
    def test_set_format_tf(self, in_memory):
450
        import tensorflow as tf
451

452
        with tempfile.TemporaryDirectory() as tmp_dir:
453
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
454
                dset.set_format(type="tensorflow", columns=["col_1"])
455
                self.assertEqual(len(dset[0]), 1)
456
                self.assertIsInstance(dset[0]["col_1"], tf.Tensor)
457
                self.assertListEqual(list(dset[0]["col_1"].shape), [])
458
                self.assertEqual(dset[0]["col_1"].numpy().item(), 3)
459

460
                dset.set_format(type="tensorflow", columns=["col_1"], output_all_columns=True)
461
                self.assertEqual(len(dset[0]), 3)
462
                self.assertIsInstance(dset[0]["col_2"], str)
463
                self.assertEqual(dset[0]["col_2"], "a")
464

465
                dset.set_format(type="tensorflow", columns=["col_1", "col_2"])
466
                self.assertEqual(len(dset[0]), 2)
467
                self.assertEqual(dset[0]["col_2"].numpy().decode("utf-8"), "a")
468

469
    def test_set_format_pandas(self, in_memory):
470
        with tempfile.TemporaryDirectory() as tmp_dir:
471
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
472
                dset.set_format(type="pandas", columns=["col_1"])
473
                self.assertEqual(len(dset[0].columns), 1)
474
                self.assertIsInstance(dset[0], pd.DataFrame)
475
                self.assertListEqual(list(dset[0].shape), [1, 1])
476
                self.assertEqual(dset[0]["col_1"].item(), 3)
477

478
                dset.set_format(type="pandas", columns=["col_1", "col_2"])
479
                self.assertEqual(len(dset[0].columns), 2)
480
                self.assertEqual(dset[0]["col_2"].item(), "a")
481

482
    def test_set_transform(self, in_memory):
483
        def transform(batch):
484
            return {k: [str(i).upper() for i in v] for k, v in batch.items()}
485

486
        with tempfile.TemporaryDirectory() as tmp_dir:
487
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
488
                dset.set_transform(transform=transform, columns=["col_1"])
489
                self.assertEqual(dset.format["type"], "custom")
490
                self.assertEqual(len(dset[0].keys()), 1)
491
                self.assertEqual(dset[0]["col_1"], "3")
492
                self.assertEqual(dset[:2]["col_1"], ["3", "2"])
493
                self.assertEqual(dset["col_1"][:2], ["3", "2"])
494

495
                prev_format = dset.format
496
                dset.set_format(**dset.format)
497
                self.assertEqual(prev_format, dset.format)
498

499
                dset.set_transform(transform=transform, columns=["col_1", "col_2"])
500
                self.assertEqual(len(dset[0].keys()), 2)
501
                self.assertEqual(dset[0]["col_2"], "A")
502

503
    def test_transmit_format(self, in_memory):
504
        with tempfile.TemporaryDirectory() as tmp_dir:
505
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
506
                transform = datasets.arrow_dataset.transmit_format(lambda x: x)
507
                # make sure identity transform doesn't apply unnecessary format
508
                self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
509
                dset.set_format(**dset.format)
510
                self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
511
                # check lists comparisons
512
                dset.set_format(columns=["col_1"])
513
                self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
514
                dset.set_format(columns=["col_1", "col_2"])
515
                self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
516
                dset.set_format("numpy", columns=["col_1", "col_2"])
517
                self.assertEqual(dset._fingerprint, transform(dset)._fingerprint)
518

519
    def test_cast(self, in_memory):
520
        with tempfile.TemporaryDirectory() as tmp_dir:
521
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
522
                features = dset.features
523
                features["col_1"] = Value("float64")
524
                features = Features({k: features[k] for k in list(features)[::-1]})
525
                fingerprint = dset._fingerprint
526
                # TODO: with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase():
527
                with dset.cast(features) as casted_dset:
528
                    self.assertEqual(casted_dset.num_columns, 3)
529
                    self.assertEqual(casted_dset.features["col_1"], Value("float64"))
530
                    self.assertIsInstance(casted_dset[0]["col_1"], float)
531
                    self.assertNotEqual(casted_dset._fingerprint, fingerprint)
532
                    self.assertNotEqual(casted_dset, dset)
533
                    assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
534

535
    def test_class_encode_column(self, in_memory):
536
        with tempfile.TemporaryDirectory() as tmp_dir:
537
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
538
                with self.assertRaises(ValueError):
539
                    dset.class_encode_column(column="does not exist")
540

541
                with dset.class_encode_column("col_1") as casted_dset:
542
                    self.assertIsInstance(casted_dset.features["col_1"], ClassLabel)
543
                    self.assertListEqual(casted_dset.features["col_1"].names, ["0", "1", "2", "3"])
544
                    self.assertListEqual(casted_dset["col_1"], [3, 2, 1, 0])
545
                    self.assertNotEqual(casted_dset._fingerprint, dset._fingerprint)
546
                    self.assertNotEqual(casted_dset, dset)
547
                    assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
548

549
                with dset.class_encode_column("col_2") as casted_dset:
550
                    self.assertIsInstance(casted_dset.features["col_2"], ClassLabel)
551
                    self.assertListEqual(casted_dset.features["col_2"].names, ["a", "b", "c", "d"])
552
                    self.assertListEqual(casted_dset["col_2"], [0, 1, 2, 3])
553
                    self.assertNotEqual(casted_dset._fingerprint, dset._fingerprint)
554
                    self.assertNotEqual(casted_dset, dset)
555
                    assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
556

557
                with dset.class_encode_column("col_3") as casted_dset:
558
                    self.assertIsInstance(casted_dset.features["col_3"], ClassLabel)
559
                    self.assertListEqual(casted_dset.features["col_3"].names, ["False", "True"])
560
                    self.assertListEqual(casted_dset["col_3"], [0, 1, 0, 1])
561
                    self.assertNotEqual(casted_dset._fingerprint, dset._fingerprint)
562
                    self.assertNotEqual(casted_dset, dset)
563
                    assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
564

565
            # Test raises if feature is an array / sequence
566
            with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
567
                for column in dset.column_names:
568
                    with self.assertRaises(ValueError):
569
                        dset.class_encode_column(column)
570

571
    def test_remove_columns(self, in_memory):
572
        with tempfile.TemporaryDirectory() as tmp_dir:
573
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
574
                fingerprint = dset._fingerprint
575
                with dset.remove_columns(column_names="col_1") as new_dset:
576
                    self.assertEqual(new_dset.num_columns, 2)
577
                    self.assertListEqual(list(new_dset.column_names), ["col_2", "col_3"])
578
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
579
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
580

581
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
582
                with dset.remove_columns(column_names=["col_1", "col_2", "col_3"]) as new_dset:
583
                    self.assertEqual(new_dset.num_columns, 0)
584
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
585
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
586

587
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
588
                dset._format_columns = ["col_1", "col_2", "col_3"]
589
                with dset.remove_columns(column_names=["col_1"]) as new_dset:
590
                    self.assertListEqual(new_dset._format_columns, ["col_2", "col_3"])
591
                    self.assertEqual(new_dset.num_columns, 2)
592
                    self.assertListEqual(list(new_dset.column_names), ["col_2", "col_3"])
593
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
594
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
595

596
    def test_rename_column(self, in_memory):
597
        with tempfile.TemporaryDirectory() as tmp_dir:
598
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
599
                fingerprint = dset._fingerprint
600
                with dset.rename_column(original_column_name="col_1", new_column_name="new_name") as new_dset:
601
                    self.assertEqual(new_dset.num_columns, 3)
602
                    self.assertListEqual(list(new_dset.column_names), ["new_name", "col_2", "col_3"])
603
                    self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
604
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
605
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
606

607
    def test_rename_columns(self, in_memory):
608
        with tempfile.TemporaryDirectory() as tmp_dir:
609
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
610
                fingerprint = dset._fingerprint
611
                with dset.rename_columns({"col_1": "new_name"}) as new_dset:
612
                    self.assertEqual(new_dset.num_columns, 3)
613
                    self.assertListEqual(list(new_dset.column_names), ["new_name", "col_2", "col_3"])
614
                    self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
615
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
616

617
                with dset.rename_columns({"col_1": "new_name", "col_2": "new_name2"}) as new_dset:
618
                    self.assertEqual(new_dset.num_columns, 3)
619
                    self.assertListEqual(list(new_dset.column_names), ["new_name", "new_name2", "col_3"])
620
                    self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
621
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
622

623
                # Original column not in dataset
624
                with self.assertRaises(ValueError):
625
                    dset.rename_columns({"not_there": "new_name"})
626

627
                # Empty new name
628
                with self.assertRaises(ValueError):
629
                    dset.rename_columns({"col_1": ""})
630

631
                # Duplicates
632
                with self.assertRaises(ValueError):
633
                    dset.rename_columns({"col_1": "new_name", "col_2": "new_name"})
634

635
    def test_select_columns(self, in_memory):
636
        with tempfile.TemporaryDirectory() as tmp_dir:
637
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
638
                fingerprint = dset._fingerprint
639
                with dset.select_columns(column_names=[]) as new_dset:
640
                    self.assertEqual(new_dset.num_columns, 0)
641
                    self.assertListEqual(list(new_dset.column_names), [])
642
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
643
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
644

645
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
646
                fingerprint = dset._fingerprint
647
                with dset.select_columns(column_names="col_1") as new_dset:
648
                    self.assertEqual(new_dset.num_columns, 1)
649
                    self.assertListEqual(list(new_dset.column_names), ["col_1"])
650
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
651
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
652

653
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
654
                with dset.select_columns(column_names=["col_1", "col_2", "col_3"]) as new_dset:
655
                    self.assertEqual(new_dset.num_columns, 3)
656
                    self.assertListEqual(list(new_dset.column_names), ["col_1", "col_2", "col_3"])
657
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
658
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
659

660
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
661
                with dset.select_columns(column_names=["col_3", "col_2", "col_1"]) as new_dset:
662
                    self.assertEqual(new_dset.num_columns, 3)
663
                    self.assertListEqual(list(new_dset.column_names), ["col_3", "col_2", "col_1"])
664
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
665
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
666

667
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
668
                dset._format_columns = ["col_1", "col_2", "col_3"]
669
                with dset.select_columns(column_names=["col_1"]) as new_dset:
670
                    self.assertListEqual(new_dset._format_columns, ["col_1"])
671
                    self.assertEqual(new_dset.num_columns, 1)
672
                    self.assertListEqual(list(new_dset.column_names), ["col_1"])
673
                    self.assertNotEqual(new_dset._fingerprint, fingerprint)
674
                    assert_arrow_metadata_are_synced_with_dataset_features(new_dset)
675

676
    def test_concatenate(self, in_memory):
677
        data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
678
        info1 = DatasetInfo(description="Dataset1")
679
        info2 = DatasetInfo(description="Dataset2")
680
        with tempfile.TemporaryDirectory() as tmp_dir:
681
            dset1, dset2, dset3 = (
682
                Dataset.from_dict(data1, info=info1),
683
                Dataset.from_dict(data2, info=info2),
684
                Dataset.from_dict(data3),
685
            )
686
            dset1, dset2, dset3 = self._to(in_memory, tmp_dir, dset1, dset2, dset3)
687

688
            with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
689
                self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 2))
690
                self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
691
                self.assertListEqual(dset_concat["id"], [0, 1, 2, 3, 4, 5, 6, 7])
692
                self.assertEqual(len(dset_concat.cache_files), 0 if in_memory else 3)
693
                self.assertEqual(dset_concat.info.description, "Dataset1\n\nDataset2")
694
            del dset1, dset2, dset3
695

696
    def test_concatenate_formatted(self, in_memory):
697
        data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
698
        info1 = DatasetInfo(description="Dataset1")
699
        info2 = DatasetInfo(description="Dataset2")
700
        with tempfile.TemporaryDirectory() as tmp_dir:
701
            dset1, dset2, dset3 = (
702
                Dataset.from_dict(data1, info=info1),
703
                Dataset.from_dict(data2, info=info2),
704
                Dataset.from_dict(data3),
705
            )
706
            dset1, dset2, dset3 = self._to(in_memory, tmp_dir, dset1, dset2, dset3)
707

708
            dset1.set_format("numpy")
709
            with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
710
                self.assertEqual(dset_concat.format["type"], None)
711
            dset2.set_format("numpy")
712
            dset3.set_format("numpy")
713
            with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
714
                self.assertEqual(dset_concat.format["type"], "numpy")
715
            del dset1, dset2, dset3
716

717
    def test_concatenate_with_indices(self, in_memory):
718
        data1, data2, data3 = {"id": [0, 1, 2] * 2}, {"id": [3, 4, 5] * 2}, {"id": [6, 7, 8]}
719
        info1 = DatasetInfo(description="Dataset1")
720
        info2 = DatasetInfo(description="Dataset2")
721
        with tempfile.TemporaryDirectory() as tmp_dir:
722
            dset1, dset2, dset3 = (
723
                Dataset.from_dict(data1, info=info1),
724
                Dataset.from_dict(data2, info=info2),
725
                Dataset.from_dict(data3),
726
            )
727
            dset1, dset2, dset3 = self._to(in_memory, tmp_dir, dset1, dset2, dset3)
728
            dset1, dset2, dset3 = dset1.select([2, 1, 0]), dset2.select([2, 1, 0]), dset3
729

730
            with concatenate_datasets([dset3, dset2, dset1]) as dset_concat:
731
                self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 3))
732
                self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
733
                self.assertListEqual(dset_concat["id"], [6, 7, 8, 5, 4, 3, 2, 1, 0])
734
                # in_memory = False:
735
                # 3 cache files for the dset_concat._data table
736
                # no cache file for the indices because it's in memory
737
                # in_memory = True:
738
                # no cache files since both dset_concat._data and dset_concat._indices are in memory
739
                self.assertEqual(len(dset_concat.cache_files), 0 if in_memory else 3)
740
                self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1")
741

742
            dset1 = dset1.rename_columns({"id": "id1"})
743
            dset2 = dset2.rename_columns({"id": "id2"})
744
            dset3 = dset3.rename_columns({"id": "id3"})
745
            with concatenate_datasets([dset1, dset2, dset3], axis=1) as dset_concat:
746
                self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 3))
747
                self.assertEqual(len(dset_concat), len(dset1))
748
                self.assertListEqual(dset_concat["id1"], [2, 1, 0])
749
                self.assertListEqual(dset_concat["id2"], [5, 4, 3])
750
                self.assertListEqual(dset_concat["id3"], [6, 7, 8])
751
                # in_memory = False:
752
                # 3 cache files for the dset_concat._data table
753
                # no cache file for the indices because it's None
754
                # in_memory = True:
755
                # no cache files since dset_concat._data is in memory and dset_concat._indices is None
756
                self.assertEqual(len(dset_concat.cache_files), 0 if in_memory else 3)
757
                self.assertIsNone(dset_concat._indices)
758
                self.assertEqual(dset_concat.info.description, "Dataset1\n\nDataset2")
759

760
            with concatenate_datasets([dset1], axis=1) as dset_concat:
761
                self.assertEqual(len(dset_concat), len(dset1))
762
                self.assertListEqual(dset_concat["id1"], [2, 1, 0])
763
                # in_memory = False:
764
                # 1 cache file for the dset_concat._data table
765
                # no cache file for the indices because it's in memory
766
                # in_memory = True:
767
                # no cache files since both dset_concat._data and dset_concat._indices are in memory
768
                self.assertEqual(len(dset_concat.cache_files), 0 if in_memory else 1)
769
                self.assertTrue(dset_concat._indices == dset1._indices)
770
                self.assertEqual(dset_concat.info.description, "Dataset1")
771
            del dset1, dset2, dset3
772

773
    def test_concatenate_with_indices_from_disk(self, in_memory):
774
        data1, data2, data3 = {"id": [0, 1, 2] * 2}, {"id": [3, 4, 5] * 2}, {"id": [6, 7]}
775
        info1 = DatasetInfo(description="Dataset1")
776
        info2 = DatasetInfo(description="Dataset2")
777
        with tempfile.TemporaryDirectory() as tmp_dir:
778
            dset1, dset2, dset3 = (
779
                Dataset.from_dict(data1, info=info1),
780
                Dataset.from_dict(data2, info=info2),
781
                Dataset.from_dict(data3),
782
            )
783
            dset1, dset2, dset3 = self._to(in_memory, tmp_dir, dset1, dset2, dset3)
784
            dset1, dset2, dset3 = (
785
                dset1.select([2, 1, 0], indices_cache_file_name=os.path.join(tmp_dir, "i1.arrow")),
786
                dset2.select([2, 1, 0], indices_cache_file_name=os.path.join(tmp_dir, "i2.arrow")),
787
                dset3.select([1, 0], indices_cache_file_name=os.path.join(tmp_dir, "i3.arrow")),
788
            )
789

790
            with concatenate_datasets([dset3, dset2, dset1]) as dset_concat:
791
                self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 2))
792
                self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
793
                self.assertListEqual(dset_concat["id"], [7, 6, 5, 4, 3, 2, 1, 0])
794
                # in_memory = False:
795
                # 3 cache files for the dset_concat._data table, and 1 for the dset_concat._indices_table
796
                # There is only 1 for the indices tables (i1.arrow)
797
                # Indeed, the others are brought to memory since an offset is applied to them.
798
                # in_memory = True:
799
                # 1 cache file for i1.arrow since both dset_concat._data and dset_concat._indices are in memory
800
                self.assertEqual(len(dset_concat.cache_files), 1 if in_memory else 3 + 1)
801
                self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1")
802
            del dset1, dset2, dset3
803

804
    def test_concatenate_pickle(self, in_memory):
805
        data1, data2, data3 = {"id": [0, 1, 2] * 2}, {"id": [3, 4, 5] * 2}, {"id": [6, 7], "foo": ["bar", "bar"]}
806
        info1 = DatasetInfo(description="Dataset1")
807
        info2 = DatasetInfo(description="Dataset2")
808
        with tempfile.TemporaryDirectory() as tmp_dir:
809
            dset1, dset2, dset3 = (
810
                Dataset.from_dict(data1, info=info1),
811
                Dataset.from_dict(data2, info=info2),
812
                Dataset.from_dict(data3),
813
            )
814
            # mix from in-memory and on-disk datasets
815
            dset1, dset2 = self._to(in_memory, tmp_dir, dset1, dset2)
816
            dset3 = self._to(not in_memory, tmp_dir, dset3)
817
            dset1, dset2, dset3 = (
818
                dset1.select(
819
                    [2, 1, 0],
820
                    keep_in_memory=in_memory,
821
                    indices_cache_file_name=os.path.join(tmp_dir, "i1.arrow") if not in_memory else None,
822
                ),
823
                dset2.select(
824
                    [2, 1, 0],
825
                    keep_in_memory=in_memory,
826
                    indices_cache_file_name=os.path.join(tmp_dir, "i2.arrow") if not in_memory else None,
827
                ),
828
                dset3.select(
829
                    [1, 0],
830
                    keep_in_memory=in_memory,
831
                    indices_cache_file_name=os.path.join(tmp_dir, "i3.arrow") if not in_memory else None,
832
                ),
833
            )
834

835
            dset3 = dset3.rename_column("foo", "new_foo")
836
            dset3 = dset3.remove_columns("new_foo")
837
            if in_memory:
838
                dset3._data.table = Unpicklable()
839
            else:
840
                dset1._data.table, dset2._data.table = Unpicklable(), Unpicklable()
841
            dset1, dset2, dset3 = (pickle.loads(pickle.dumps(d)) for d in (dset1, dset2, dset3))
842
            with concatenate_datasets([dset3, dset2, dset1]) as dset_concat:
843
                if not in_memory:
844
                    dset_concat._data.table = Unpicklable()
845
                with pickle.loads(pickle.dumps(dset_concat)) as dset_concat:
846
                    self.assertTupleEqual((len(dset1), len(dset2), len(dset3)), (3, 3, 2))
847
                    self.assertEqual(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
848
                    self.assertListEqual(dset_concat["id"], [7, 6, 5, 4, 3, 2, 1, 0])
849
                    # in_memory = True: 1 cache file for dset3
850
                    # in_memory = False: 2 caches files for dset1 and dset2, and 1 cache file for i1.arrow
851
                    self.assertEqual(len(dset_concat.cache_files), 1 if in_memory else 2 + 1)
852
                    self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1")
853
            del dset1, dset2, dset3
854

855
    def test_flatten(self, in_memory):
856
        with tempfile.TemporaryDirectory() as tmp_dir:
857
            with Dataset.from_dict(
858
                {"a": [{"b": {"c": ["text"]}}] * 10, "foo": [1] * 10},
859
                features=Features({"a": {"b": Sequence({"c": Value("string")})}, "foo": Value("int64")}),
860
            ) as dset:
861
                with self._to(in_memory, tmp_dir, dset) as dset:
862
                    fingerprint = dset._fingerprint
863
                    with dset.flatten() as dset:
864
                        self.assertListEqual(sorted(dset.column_names), ["a.b.c", "foo"])
865
                        self.assertListEqual(sorted(dset.features.keys()), ["a.b.c", "foo"])
866
                        self.assertDictEqual(
867
                            dset.features, Features({"a.b.c": Sequence(Value("string")), "foo": Value("int64")})
868
                        )
869
                        self.assertNotEqual(dset._fingerprint, fingerprint)
870
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
871

872
        with tempfile.TemporaryDirectory() as tmp_dir:
873
            with Dataset.from_dict(
874
                {"a": [{"en": "Thank you", "fr": "Merci"}] * 10, "foo": [1] * 10},
875
                features=Features({"a": Translation(languages=["en", "fr"]), "foo": Value("int64")}),
876
            ) as dset:
877
                with self._to(in_memory, tmp_dir, dset) as dset:
878
                    fingerprint = dset._fingerprint
879
                    with dset.flatten() as dset:
880
                        self.assertListEqual(sorted(dset.column_names), ["a.en", "a.fr", "foo"])
881
                        self.assertListEqual(sorted(dset.features.keys()), ["a.en", "a.fr", "foo"])
882
                        self.assertDictEqual(
883
                            dset.features,
884
                            Features({"a.en": Value("string"), "a.fr": Value("string"), "foo": Value("int64")}),
885
                        )
886
                        self.assertNotEqual(dset._fingerprint, fingerprint)
887
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
888

889
        with tempfile.TemporaryDirectory() as tmp_dir:
890
            with Dataset.from_dict(
891
                {"a": [{"en": "the cat", "fr": ["le chat", "la chatte"], "de": "die katze"}] * 10, "foo": [1] * 10},
892
                features=Features(
893
                    {"a": TranslationVariableLanguages(languages=["en", "fr", "de"]), "foo": Value("int64")}
894
                ),
895
            ) as dset:
896
                with self._to(in_memory, tmp_dir, dset) as dset:
897
                    fingerprint = dset._fingerprint
898
                    with dset.flatten() as dset:
899
                        self.assertListEqual(sorted(dset.column_names), ["a.language", "a.translation", "foo"])
900
                        self.assertListEqual(sorted(dset.features.keys()), ["a.language", "a.translation", "foo"])
901
                        self.assertDictEqual(
902
                            dset.features,
903
                            Features(
904
                                {
905
                                    "a.language": Sequence(Value("string")),
906
                                    "a.translation": Sequence(Value("string")),
907
                                    "foo": Value("int64"),
908
                                }
909
                            ),
910
                        )
911
                        self.assertNotEqual(dset._fingerprint, fingerprint)
912
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
913

914
    @require_pil
915
    def test_flatten_complex_image(self, in_memory):
916
        # decoding turned on
917
        with tempfile.TemporaryDirectory() as tmp_dir:
918
            with Dataset.from_dict(
919
                {"a": [np.arange(4 * 4 * 3, dtype=np.uint8).reshape(4, 4, 3)] * 10, "foo": [1] * 10},
920
                features=Features({"a": Image(), "foo": Value("int64")}),
921
            ) as dset:
922
                with self._to(in_memory, tmp_dir, dset) as dset:
923
                    fingerprint = dset._fingerprint
924
                    with dset.flatten() as dset:
925
                        self.assertListEqual(sorted(dset.column_names), ["a", "foo"])
926
                        self.assertListEqual(sorted(dset.features.keys()), ["a", "foo"])
927
                        self.assertDictEqual(dset.features, Features({"a": Image(), "foo": Value("int64")}))
928
                        self.assertNotEqual(dset._fingerprint, fingerprint)
929
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
930

931
        # decoding turned on + nesting
932
        with tempfile.TemporaryDirectory() as tmp_dir:
933
            with Dataset.from_dict(
934
                {"a": [{"b": np.arange(4 * 4 * 3, dtype=np.uint8).reshape(4, 4, 3)}] * 10, "foo": [1] * 10},
935
                features=Features({"a": {"b": Image()}, "foo": Value("int64")}),
936
            ) as dset:
937
                with self._to(in_memory, tmp_dir, dset) as dset:
938
                    fingerprint = dset._fingerprint
939
                    with dset.flatten() as dset:
940
                        self.assertListEqual(sorted(dset.column_names), ["a.b", "foo"])
941
                        self.assertListEqual(sorted(dset.features.keys()), ["a.b", "foo"])
942
                        self.assertDictEqual(dset.features, Features({"a.b": Image(), "foo": Value("int64")}))
943
                        self.assertNotEqual(dset._fingerprint, fingerprint)
944
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
945

946
        # decoding turned off
947
        with tempfile.TemporaryDirectory() as tmp_dir:
948
            with Dataset.from_dict(
949
                {"a": [np.arange(4 * 4 * 3, dtype=np.uint8).reshape(4, 4, 3)] * 10, "foo": [1] * 10},
950
                features=Features({"a": Image(decode=False), "foo": Value("int64")}),
951
            ) as dset:
952
                with self._to(in_memory, tmp_dir, dset) as dset:
953
                    fingerprint = dset._fingerprint
954
                    with dset.flatten() as dset:
955
                        self.assertListEqual(sorted(dset.column_names), ["a.bytes", "a.path", "foo"])
956
                        self.assertListEqual(sorted(dset.features.keys()), ["a.bytes", "a.path", "foo"])
957
                        self.assertDictEqual(
958
                            dset.features,
959
                            Features({"a.bytes": Value("binary"), "a.path": Value("string"), "foo": Value("int64")}),
960
                        )
961
                        self.assertNotEqual(dset._fingerprint, fingerprint)
962
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
963

964
        # decoding turned off + nesting
965
        with tempfile.TemporaryDirectory() as tmp_dir:
966
            with Dataset.from_dict(
967
                {"a": [{"b": np.arange(4 * 4 * 3, dtype=np.uint8).reshape(4, 4, 3)}] * 10, "foo": [1] * 10},
968
                features=Features({"a": {"b": Image(decode=False)}, "foo": Value("int64")}),
969
            ) as dset:
970
                with self._to(in_memory, tmp_dir, dset) as dset:
971
                    fingerprint = dset._fingerprint
972
                    with dset.flatten() as dset:
973
                        self.assertListEqual(sorted(dset.column_names), ["a.b.bytes", "a.b.path", "foo"])
974
                        self.assertListEqual(sorted(dset.features.keys()), ["a.b.bytes", "a.b.path", "foo"])
975
                        self.assertDictEqual(
976
                            dset.features,
977
                            Features(
978
                                {"a.b.bytes": Value("binary"), "a.b.path": Value("string"), "foo": Value("int64")}
979
                            ),
980
                        )
981
                        self.assertNotEqual(dset._fingerprint, fingerprint)
982
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
983

984
    def test_map(self, in_memory):
985
        # standard
986
        with tempfile.TemporaryDirectory() as tmp_dir:
987
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
988
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
989
                fingerprint = dset._fingerprint
990
                with dset.map(
991
                    lambda x: {"name": x["filename"][:-2], "id": int(x["filename"].split("_")[-1])}
992
                ) as dset_test:
993
                    self.assertEqual(len(dset_test), 30)
994
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
995
                    self.assertDictEqual(
996
                        dset_test.features,
997
                        Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
998
                    )
999
                    self.assertListEqual(dset_test["id"], list(range(30)))
1000
                    self.assertNotEqual(dset_test._fingerprint, fingerprint)
1001
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1002

1003
        # no transform
1004
        with tempfile.TemporaryDirectory() as tmp_dir:
1005
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1006
                fingerprint = dset._fingerprint
1007
                with dset.map(lambda x: None) as dset_test:
1008
                    self.assertEqual(len(dset_test), 30)
1009
                    self.assertEqual(dset_test._fingerprint, fingerprint)
1010
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1011

1012
        # with indices
1013
        with tempfile.TemporaryDirectory() as tmp_dir:
1014
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1015
                with dset.map(
1016
                    lambda x, i: {"name": x["filename"][:-2], "id": i}, with_indices=True
1017
                ) as dset_test_with_indices:
1018
                    self.assertEqual(len(dset_test_with_indices), 30)
1019
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1020
                    self.assertDictEqual(
1021
                        dset_test_with_indices.features,
1022
                        Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
1023
                    )
1024
                    self.assertListEqual(dset_test_with_indices["id"], list(range(30)))
1025
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices)
1026

1027
        # interrupted
1028
        with tempfile.TemporaryDirectory() as tmp_dir:
1029
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1030

1031
                def func(x, i):
1032
                    if i == 4:
1033
                        raise KeyboardInterrupt()
1034
                    return {"name": x["filename"][:-2], "id": i}
1035

1036
                tmp_file = os.path.join(tmp_dir, "test.arrow")
1037
                self.assertRaises(
1038
                    KeyboardInterrupt,
1039
                    dset.map,
1040
                    function=func,
1041
                    with_indices=True,
1042
                    cache_file_name=tmp_file,
1043
                    writer_batch_size=2,
1044
                )
1045
                self.assertFalse(os.path.exists(tmp_file))
1046
                with dset.map(
1047
                    lambda x, i: {"name": x["filename"][:-2], "id": i},
1048
                    with_indices=True,
1049
                    cache_file_name=tmp_file,
1050
                    writer_batch_size=2,
1051
                ) as dset_test_with_indices:
1052
                    self.assertTrue(os.path.exists(tmp_file))
1053
                    self.assertEqual(len(dset_test_with_indices), 30)
1054
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1055
                    self.assertDictEqual(
1056
                        dset_test_with_indices.features,
1057
                        Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
1058
                    )
1059
                    self.assertListEqual(dset_test_with_indices["id"], list(range(30)))
1060
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices)
1061

1062
        # formatted
1063
        with tempfile.TemporaryDirectory() as tmp_dir:
1064
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
1065
                dset.set_format("numpy", columns=["col_1"])
1066
                with dset.map(lambda x: {"col_1_plus_one": x["col_1"] + 1}) as dset_test:
1067
                    self.assertEqual(len(dset_test), 4)
1068
                    self.assertEqual(dset_test.format["type"], "numpy")
1069
                    self.assertIsInstance(dset_test["col_1"], np.ndarray)
1070
                    self.assertIsInstance(dset_test["col_1_plus_one"], np.ndarray)
1071
                    self.assertListEqual(sorted(dset_test[0].keys()), ["col_1", "col_1_plus_one"])
1072
                    self.assertListEqual(sorted(dset_test.column_names), ["col_1", "col_1_plus_one", "col_2", "col_3"])
1073
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1074

1075
    def test_map_multiprocessing(self, in_memory):
1076
        with tempfile.TemporaryDirectory() as tmp_dir:  # standard
1077
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1078
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1079
                fingerprint = dset._fingerprint
1080
                with dset.map(picklable_map_function, num_proc=2) as dset_test:
1081
                    self.assertEqual(len(dset_test), 30)
1082
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1083
                    self.assertDictEqual(
1084
                        dset_test.features,
1085
                        Features({"filename": Value("string"), "id": Value("int64")}),
1086
                    )
1087
                    self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 2)
1088
                    if not in_memory:
1089
                        self.assertIn("_of_00002.arrow", dset_test.cache_files[0]["filename"])
1090
                    self.assertListEqual(dset_test["id"], list(range(30)))
1091
                    self.assertNotEqual(dset_test._fingerprint, fingerprint)
1092
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1093

1094
        with tempfile.TemporaryDirectory() as tmp_dir:  # num_proc > num rows
1095
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1096
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1097
                fingerprint = dset._fingerprint
1098
                with dset.select([0, 1], keep_in_memory=True).map(picklable_map_function, num_proc=10) as dset_test:
1099
                    self.assertEqual(len(dset_test), 2)
1100
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1101
                    self.assertDictEqual(
1102
                        dset_test.features,
1103
                        Features({"filename": Value("string"), "id": Value("int64")}),
1104
                    )
1105
                    self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 2)
1106
                    self.assertListEqual(dset_test["id"], list(range(2)))
1107
                    self.assertNotEqual(dset_test._fingerprint, fingerprint)
1108
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1109

1110
        with tempfile.TemporaryDirectory() as tmp_dir:  # with_indices
1111
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1112
                fingerprint = dset._fingerprint
1113
                with dset.map(picklable_map_function_with_indices, num_proc=3, with_indices=True) as dset_test:
1114
                    self.assertEqual(len(dset_test), 30)
1115
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1116
                    self.assertDictEqual(
1117
                        dset_test.features,
1118
                        Features({"filename": Value("string"), "id": Value("int64")}),
1119
                    )
1120
                    self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 3)
1121
                    self.assertListEqual(dset_test["id"], list(range(30)))
1122
                    self.assertNotEqual(dset_test._fingerprint, fingerprint)
1123
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1124

1125
        with tempfile.TemporaryDirectory() as tmp_dir:  # with_rank
1126
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1127
                fingerprint = dset._fingerprint
1128
                with dset.map(picklable_map_function_with_rank, num_proc=3, with_rank=True) as dset_test:
1129
                    self.assertEqual(len(dset_test), 30)
1130
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1131
                    self.assertDictEqual(
1132
                        dset_test.features,
1133
                        Features({"filename": Value("string"), "rank": Value("int64")}),
1134
                    )
1135
                    self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 3)
1136
                    self.assertListEqual(dset_test["rank"], [0] * 10 + [1] * 10 + [2] * 10)
1137
                    self.assertNotEqual(dset_test._fingerprint, fingerprint)
1138
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1139

1140
        with tempfile.TemporaryDirectory() as tmp_dir:  # with_indices AND with_rank
1141
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1142
                fingerprint = dset._fingerprint
1143
                with dset.map(
1144
                    picklable_map_function_with_indices_and_rank, num_proc=3, with_indices=True, with_rank=True
1145
                ) as dset_test:
1146
                    self.assertEqual(len(dset_test), 30)
1147
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1148
                    self.assertDictEqual(
1149
                        dset_test.features,
1150
                        Features({"filename": Value("string"), "id": Value("int64"), "rank": Value("int64")}),
1151
                    )
1152
                    self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 3)
1153
                    self.assertListEqual(dset_test["id"], list(range(30)))
1154
                    self.assertListEqual(dset_test["rank"], [0] * 10 + [1] * 10 + [2] * 10)
1155
                    self.assertNotEqual(dset_test._fingerprint, fingerprint)
1156
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1157

1158
        with tempfile.TemporaryDirectory() as tmp_dir:  # new_fingerprint
1159
            new_fingerprint = "foobar"
1160
            invalid_new_fingerprint = "foobar/hey"
1161
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1162
                fingerprint = dset._fingerprint
1163
                self.assertRaises(
1164
                    ValueError, dset.map, picklable_map_function, num_proc=2, new_fingerprint=invalid_new_fingerprint
1165
                )
1166
                with dset.map(picklable_map_function, num_proc=2, new_fingerprint=new_fingerprint) as dset_test:
1167
                    self.assertEqual(len(dset_test), 30)
1168
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1169
                    self.assertDictEqual(
1170
                        dset_test.features,
1171
                        Features({"filename": Value("string"), "id": Value("int64")}),
1172
                    )
1173
                    self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 2)
1174
                    self.assertListEqual(dset_test["id"], list(range(30)))
1175
                    self.assertNotEqual(dset_test._fingerprint, fingerprint)
1176
                    self.assertEqual(dset_test._fingerprint, new_fingerprint)
1177
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1178
                    file_names = sorted(Path(cache_file["filename"]).name for cache_file in dset_test.cache_files)
1179
                    for i, file_name in enumerate(file_names):
1180
                        self.assertIn(new_fingerprint + f"_{i:05d}", file_name)
1181

1182
        with tempfile.TemporaryDirectory() as tmp_dir:  # lambda (requires multiprocess from pathos)
1183
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1184
                fingerprint = dset._fingerprint
1185
                with dset.map(lambda x: {"id": int(x["filename"].split("_")[-1])}, num_proc=2) as dset_test:
1186
                    self.assertEqual(len(dset_test), 30)
1187
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1188
                    self.assertDictEqual(
1189
                        dset_test.features,
1190
                        Features({"filename": Value("string"), "id": Value("int64")}),
1191
                    )
1192
                    self.assertEqual(len(dset_test.cache_files), 0 if in_memory else 2)
1193
                    self.assertListEqual(dset_test["id"], list(range(30)))
1194
                    self.assertNotEqual(dset_test._fingerprint, fingerprint)
1195
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1196

1197
    def test_map_new_features(self, in_memory):
1198
        with tempfile.TemporaryDirectory() as tmp_dir:
1199
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1200
                features = Features({"filename": Value("string"), "label": ClassLabel(names=["positive", "negative"])})
1201
                with dset.map(
1202
                    lambda x, i: {"label": i % 2}, with_indices=True, features=features
1203
                ) as dset_test_with_indices:
1204
                    self.assertEqual(len(dset_test_with_indices), 30)
1205
                    self.assertDictEqual(
1206
                        dset_test_with_indices.features,
1207
                        features,
1208
                    )
1209
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices)
1210

1211
    def test_map_batched(self, in_memory):
1212
        def map_batched(example):
1213
            return {"filename_new": [x + "_extension" for x in example["filename"]]}
1214

1215
        with tempfile.TemporaryDirectory() as tmp_dir:
1216
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1217
                with dset.map(map_batched, batched=True) as dset_test_batched:
1218
                    self.assertEqual(len(dset_test_batched), 30)
1219
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1220
                    self.assertDictEqual(
1221
                        dset_test_batched.features,
1222
                        Features({"filename": Value("string"), "filename_new": Value("string")}),
1223
                    )
1224
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test_batched)
1225

1226
        # change batch size and drop the last batch
1227
        with tempfile.TemporaryDirectory() as tmp_dir:
1228
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1229
                batch_size = 4
1230
                with dset.map(
1231
                    map_batched, batched=True, batch_size=batch_size, drop_last_batch=True
1232
                ) as dset_test_batched:
1233
                    self.assertEqual(len(dset_test_batched), 30 // batch_size * batch_size)
1234
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1235
                    self.assertDictEqual(
1236
                        dset_test_batched.features,
1237
                        Features({"filename": Value("string"), "filename_new": Value("string")}),
1238
                    )
1239
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test_batched)
1240

1241
        with tempfile.TemporaryDirectory() as tmp_dir:
1242
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1243
                with dset.formatted_as("numpy", columns=["filename"]):
1244
                    with dset.map(map_batched, batched=True) as dset_test_batched:
1245
                        self.assertEqual(len(dset_test_batched), 30)
1246
                        self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1247
                        self.assertDictEqual(
1248
                            dset_test_batched.features,
1249
                            Features({"filename": Value("string"), "filename_new": Value("string")}),
1250
                        )
1251
                        assert_arrow_metadata_are_synced_with_dataset_features(dset_test_batched)
1252

1253
        def map_batched_with_indices(example, idx):
1254
            return {"filename_new": [x + "_extension_" + str(idx) for x in example["filename"]]}
1255

1256
        with tempfile.TemporaryDirectory() as tmp_dir:
1257
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1258
                with dset.map(
1259
                    map_batched_with_indices, batched=True, with_indices=True
1260
                ) as dset_test_with_indices_batched:
1261
                    self.assertEqual(len(dset_test_with_indices_batched), 30)
1262
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1263
                    self.assertDictEqual(
1264
                        dset_test_with_indices_batched.features,
1265
                        Features({"filename": Value("string"), "filename_new": Value("string")}),
1266
                    )
1267
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices_batched)
1268

1269
        # check remove columns for even if the function modifies input in-place
1270
        def map_batched_modifying_inputs_inplace(example):
1271
            result = {"filename_new": [x + "_extension" for x in example["filename"]]}
1272
            del example["filename"]
1273
            return result
1274

1275
        with tempfile.TemporaryDirectory() as tmp_dir:
1276
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1277
                with dset.map(
1278
                    map_batched_modifying_inputs_inplace, batched=True, remove_columns="filename"
1279
                ) as dset_test_modifying_inputs_inplace:
1280
                    self.assertEqual(len(dset_test_modifying_inputs_inplace), 30)
1281
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1282
                    self.assertDictEqual(
1283
                        dset_test_modifying_inputs_inplace.features,
1284
                        Features({"filename_new": Value("string")}),
1285
                    )
1286
                    assert_arrow_metadata_are_synced_with_dataset_features(dset_test_modifying_inputs_inplace)
1287

1288
    def test_map_nested(self, in_memory):
1289
        with tempfile.TemporaryDirectory() as tmp_dir:
1290
            with Dataset.from_dict({"field": ["a", "b"]}) as dset:
1291
                with self._to(in_memory, tmp_dir, dset) as dset:
1292
                    with dset.map(lambda example: {"otherfield": {"capital": example["field"].capitalize()}}) as dset:
1293
                        with dset.map(lambda example: {"otherfield": {"append_x": example["field"] + "x"}}) as dset:
1294
                            self.assertEqual(dset[0], {"field": "a", "otherfield": {"append_x": "ax"}})
1295

1296
    def test_map_return_example_as_dict_value(self, in_memory):
1297
        with tempfile.TemporaryDirectory() as tmp_dir:
1298
            with Dataset.from_dict({"en": ["aa", "bb"], "fr": ["cc", "dd"]}) as dset:
1299
                with self._to(in_memory, tmp_dir, dset) as dset:
1300
                    with dset.map(lambda example: {"translation": example}) as dset:
1301
                        self.assertEqual(dset[0], {"en": "aa", "fr": "cc", "translation": {"en": "aa", "fr": "cc"}})
1302

1303
    def test_map_fn_kwargs(self, in_memory):
1304
        with tempfile.TemporaryDirectory() as tmp_dir:
1305
            with Dataset.from_dict({"id": range(10)}) as dset:
1306
                with self._to(in_memory, tmp_dir, dset) as dset:
1307
                    fn_kwargs = {"offset": 3}
1308
                    with dset.map(
1309
                        lambda example, offset: {"id+offset": example["id"] + offset}, fn_kwargs=fn_kwargs
1310
                    ) as mapped_dset:
1311
                        assert mapped_dset["id+offset"] == list(range(3, 13))
1312
                    with dset.map(
1313
                        lambda id, offset: {"id+offset": id + offset}, fn_kwargs=fn_kwargs, input_columns="id"
1314
                    ) as mapped_dset:
1315
                        assert mapped_dset["id+offset"] == list(range(3, 13))
1316
                    with dset.map(
1317
                        lambda id, i, offset: {"id+offset": i + offset},
1318
                        fn_kwargs=fn_kwargs,
1319
                        input_columns="id",
1320
                        with_indices=True,
1321
                    ) as mapped_dset:
1322
                        assert mapped_dset["id+offset"] == list(range(3, 13))
1323

1324
    def test_map_caching(self, in_memory):
1325
        with tempfile.TemporaryDirectory() as tmp_dir:
1326
            self._caplog.clear()
1327
            with self._caplog.at_level(INFO, logger=get_logger().name):
1328
                with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1329
                    with patch(
1330
                        "datasets.arrow_dataset.Dataset._map_single",
1331
                        autospec=Dataset._map_single,
1332
                        side_effect=Dataset._map_single,
1333
                    ) as mock_map_single:
1334
                        with dset.map(lambda x: {"foo": "bar"}) as dset_test1:
1335
                            dset_test1_data_files = list(dset_test1.cache_files)
1336
                        self.assertEqual(mock_map_single.call_count, 1)
1337
                        with dset.map(lambda x: {"foo": "bar"}) as dset_test2:
1338
                            self.assertEqual(dset_test1_data_files, dset_test2.cache_files)
1339
                            self.assertEqual(len(dset_test2.cache_files), 1 - int(in_memory))
1340
                            self.assertTrue(("Loading cached processed dataset" in self._caplog.text) ^ in_memory)
1341
                        self.assertEqual(mock_map_single.call_count, 2 if in_memory else 1)
1342

1343
        with tempfile.TemporaryDirectory() as tmp_dir:
1344
            self._caplog.clear()
1345
            with self._caplog.at_level(INFO, logger=get_logger().name):
1346
                with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1347
                    with dset.map(lambda x: {"foo": "bar"}) as dset_test1:
1348
                        dset_test1_data_files = list(dset_test1.cache_files)
1349
                    with dset.map(lambda x: {"foo": "bar"}, load_from_cache_file=False) as dset_test2:
1350
                        self.assertEqual(dset_test1_data_files, dset_test2.cache_files)
1351
                        self.assertEqual(len(dset_test2.cache_files), 1 - int(in_memory))
1352
                        self.assertNotIn("Loading cached processed dataset", self._caplog.text)
1353

1354
        with tempfile.TemporaryDirectory() as tmp_dir:
1355
            self._caplog.clear()
1356
            with self._caplog.at_level(INFO, logger=get_logger().name):
1357
                with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1358
                    with patch(
1359
                        "datasets.arrow_dataset.Pool",
1360
                        new_callable=PickableMagicMock,
1361
                        side_effect=datasets.arrow_dataset.Pool,
1362
                    ) as mock_pool:
1363
                        with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1:
1364
                            dset_test1_data_files = list(dset_test1.cache_files)
1365
                        self.assertEqual(mock_pool.call_count, 1)
1366
                        with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test2:
1367
                            self.assertEqual(dset_test1_data_files, dset_test2.cache_files)
1368
                            self.assertTrue(
1369
                                (len(re.findall("Loading cached processed dataset", self._caplog.text)) == 1)
1370
                                ^ in_memory
1371
                            )
1372
                        self.assertEqual(mock_pool.call_count, 2 if in_memory else 1)
1373

1374
        with tempfile.TemporaryDirectory() as tmp_dir:
1375
            self._caplog.clear()
1376
            with self._caplog.at_level(INFO, logger=get_logger().name):
1377
                with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1378
                    with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1:
1379
                        dset_test1_data_files = list(dset_test1.cache_files)
1380
                    with dset.map(lambda x: {"foo": "bar"}, num_proc=2, load_from_cache_file=False) as dset_test2:
1381
                        self.assertEqual(dset_test1_data_files, dset_test2.cache_files)
1382
                        self.assertEqual(len(dset_test2.cache_files), (1 - int(in_memory)) * 2)
1383
                        self.assertNotIn("Loading cached processed dataset", self._caplog.text)
1384

1385
        if not in_memory:
1386
            try:
1387
                self._caplog.clear()
1388
                with tempfile.TemporaryDirectory() as tmp_dir:
1389
                    with self._caplog.at_level(INFO, logger=get_logger().name):
1390
                        with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1391
                            datasets.disable_caching()
1392
                            with dset.map(lambda x: {"foo": "bar"}) as dset_test1:
1393
                                with dset.map(lambda x: {"foo": "bar"}) as dset_test2:
1394
                                    self.assertNotEqual(dset_test1.cache_files, dset_test2.cache_files)
1395
                                    self.assertEqual(len(dset_test1.cache_files), 1)
1396
                                    self.assertEqual(len(dset_test2.cache_files), 1)
1397
                                    self.assertNotIn("Loading cached processed dataset", self._caplog.text)
1398
                                    # make sure the arrow files are going to be removed
1399
                                    self.assertIn(
1400
                                        Path(tempfile.gettempdir()),
1401
                                        Path(dset_test1.cache_files[0]["filename"]).parents,
1402
                                    )
1403
                                    self.assertIn(
1404
                                        Path(tempfile.gettempdir()),
1405
                                        Path(dset_test2.cache_files[0]["filename"]).parents,
1406
                                    )
1407
            finally:
1408
                datasets.enable_caching()
1409

1410
    def test_map_return_pa_table(self, in_memory):
1411
        def func_return_single_row_pa_table(x):
1412
            return pa.table({"id": [0], "text": ["a"]})
1413

1414
        with tempfile.TemporaryDirectory() as tmp_dir:
1415
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1416
                with dset.map(func_return_single_row_pa_table) as dset_test:
1417
                    self.assertEqual(len(dset_test), 30)
1418
                    self.assertDictEqual(
1419
                        dset_test.features,
1420
                        Features({"id": Value("int64"), "text": Value("string")}),
1421
                    )
1422
                    self.assertEqual(dset_test[0]["id"], 0)
1423
                    self.assertEqual(dset_test[0]["text"], "a")
1424

1425
        # Batched
1426
        def func_return_single_row_pa_table_batched(x):
1427
            batch_size = len(x[next(iter(x))])
1428
            return pa.table({"id": [0] * batch_size, "text": ["a"] * batch_size})
1429

1430
        with tempfile.TemporaryDirectory() as tmp_dir:
1431
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1432
                with dset.map(func_return_single_row_pa_table_batched, batched=True) as dset_test:
1433
                    self.assertEqual(len(dset_test), 30)
1434
                    self.assertDictEqual(
1435
                        dset_test.features,
1436
                        Features({"id": Value("int64"), "text": Value("string")}),
1437
                    )
1438
                    self.assertEqual(dset_test[0]["id"], 0)
1439
                    self.assertEqual(dset_test[0]["text"], "a")
1440

1441
        # Error when returning a table with more than one row in the non-batched mode
1442
        def func_return_multi_row_pa_table(x):
1443
            return pa.table({"id": [0, 1], "text": ["a", "b"]})
1444

1445
        with tempfile.TemporaryDirectory() as tmp_dir:
1446
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1447
                self.assertRaises(ValueError, dset.map, func_return_multi_row_pa_table)
1448

1449
        # arrow formatted dataset
1450
        def func_return_table_from_expression(t):
1451
            import pyarrow.dataset as pds
1452

1453
            return pds.dataset(t).to_table(
1454
                columns={"new_column": pds.field("")._call("ascii_capitalize", [pds.field("filename")])}
1455
            )
1456

1457
        with tempfile.TemporaryDirectory() as tmp_dir:
1458
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1459
                with dset.with_format("arrow").map(func_return_table_from_expression, batched=True) as dset_test:
1460
                    self.assertEqual(len(dset_test), 30)
1461
                    self.assertDictEqual(
1462
                        dset_test.features,
1463
                        Features({"new_column": Value("string")}),
1464
                    )
1465
                    self.assertEqual(dset_test.with_format(None)[0]["new_column"], dset[0]["filename"].capitalize())
1466

1467
    def test_map_return_pd_dataframe(self, in_memory):
1468
        def func_return_single_row_pd_dataframe(x):
1469
            return pd.DataFrame({"id": [0], "text": ["a"]})
1470

1471
        with tempfile.TemporaryDirectory() as tmp_dir:
1472
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1473
                with dset.map(func_return_single_row_pd_dataframe) as dset_test:
1474
                    self.assertEqual(len(dset_test), 30)
1475
                    self.assertDictEqual(
1476
                        dset_test.features,
1477
                        Features({"id": Value("int64"), "text": Value("string")}),
1478
                    )
1479
                    self.assertEqual(dset_test[0]["id"], 0)
1480
                    self.assertEqual(dset_test[0]["text"], "a")
1481

1482
        # Batched
1483
        def func_return_single_row_pd_dataframe_batched(x):
1484
            batch_size = len(x[next(iter(x))])
1485
            return pd.DataFrame({"id": [0] * batch_size, "text": ["a"] * batch_size})
1486

1487
        with tempfile.TemporaryDirectory() as tmp_dir:
1488
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1489
                with dset.map(func_return_single_row_pd_dataframe_batched, batched=True) as dset_test:
1490
                    self.assertEqual(len(dset_test), 30)
1491
                    self.assertDictEqual(
1492
                        dset_test.features,
1493
                        Features({"id": Value("int64"), "text": Value("string")}),
1494
                    )
1495
                    self.assertEqual(dset_test[0]["id"], 0)
1496
                    self.assertEqual(dset_test[0]["text"], "a")
1497

1498
        # Error when returning a table with more than one row in the non-batched mode
1499
        def func_return_multi_row_pd_dataframe(x):
1500
            return pd.DataFrame({"id": [0, 1], "text": ["a", "b"]})
1501

1502
        with tempfile.TemporaryDirectory() as tmp_dir:
1503
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1504
                self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)
1505

1506
    @require_torch
1507
    def test_map_torch(self, in_memory):
1508
        import torch
1509

1510
        def func(example):
1511
            return {"tensor": torch.tensor([1.0, 2, 3])}
1512

1513
        with tempfile.TemporaryDirectory() as tmp_dir:
1514
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1515
                with dset.map(func) as dset_test:
1516
                    self.assertEqual(len(dset_test), 30)
1517
                    self.assertDictEqual(
1518
                        dset_test.features,
1519
                        Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
1520
                    )
1521
                    self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1522

1523
    @require_tf
1524
    def test_map_tf(self, in_memory):
1525
        import tensorflow as tf
1526

1527
        def func(example):
1528
            return {"tensor": tf.constant([1.0, 2, 3])}
1529

1530
        with tempfile.TemporaryDirectory() as tmp_dir:
1531
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1532
                with dset.map(func) as dset_test:
1533
                    self.assertEqual(len(dset_test), 30)
1534
                    self.assertDictEqual(
1535
                        dset_test.features,
1536
                        Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
1537
                    )
1538
                    self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1539

1540
    @require_jax
1541
    def test_map_jax(self, in_memory):
1542
        import jax.numpy as jnp
1543

1544
        def func(example):
1545
            return {"tensor": jnp.asarray([1.0, 2, 3])}
1546

1547
        with tempfile.TemporaryDirectory() as tmp_dir:
1548
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1549
                with dset.map(func) as dset_test:
1550
                    self.assertEqual(len(dset_test), 30)
1551
                    self.assertDictEqual(
1552
                        dset_test.features,
1553
                        Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
1554
                    )
1555
                    self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1556

1557
    def test_map_numpy(self, in_memory):
1558
        def func(example):
1559
            return {"tensor": np.array([1.0, 2, 3])}
1560

1561
        with tempfile.TemporaryDirectory() as tmp_dir:
1562
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1563
                with dset.map(func) as dset_test:
1564
                    self.assertEqual(len(dset_test), 30)
1565
                    self.assertDictEqual(
1566
                        dset_test.features,
1567
                        Features({"filename": Value("string"), "tensor": Sequence(Value("float64"))}),
1568
                    )
1569
                    self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1570

1571
    @require_torch
1572
    def test_map_tensor_batched(self, in_memory):
1573
        import torch
1574

1575
        def func(batch):
1576
            return {"tensor": torch.tensor([[1.0, 2, 3]] * len(batch["filename"]))}
1577

1578
        with tempfile.TemporaryDirectory() as tmp_dir:
1579
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1580
                with dset.map(func, batched=True) as dset_test:
1581
                    self.assertEqual(len(dset_test), 30)
1582
                    self.assertDictEqual(
1583
                        dset_test.features,
1584
                        Features({"filename": Value("string"), "tensor": Sequence(Value("float32"))}),
1585
                    )
1586
                    self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
1587

1588
    def test_map_input_columns(self, in_memory):
1589
        with tempfile.TemporaryDirectory() as tmp_dir:
1590
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
1591
                with dset.map(lambda col_1: {"label": col_1 % 2}, input_columns="col_1") as mapped_dset:
1592
                    self.assertEqual(mapped_dset[0].keys(), {"col_1", "col_2", "col_3", "label"})
1593
                    self.assertEqual(
1594
                        mapped_dset.features,
1595
                        Features(
1596
                            {
1597
                                "col_1": Value("int64"),
1598
                                "col_2": Value("string"),
1599
                                "col_3": Value("bool"),
1600
                                "label": Value("int64"),
1601
                            }
1602
                        ),
1603
                    )
1604

1605
    def test_map_remove_columns(self, in_memory):
1606
        with tempfile.TemporaryDirectory() as tmp_dir:
1607
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1608
                with dset.map(lambda x, i: {"name": x["filename"][:-2], "id": i}, with_indices=True) as dset:
1609
                    self.assertTrue("id" in dset[0])
1610
                    self.assertDictEqual(
1611
                        dset.features,
1612
                        Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}),
1613
                    )
1614
                    assert_arrow_metadata_are_synced_with_dataset_features(dset)
1615
                    with dset.map(lambda x: x, remove_columns=["id"]) as mapped_dset:
1616
                        self.assertTrue("id" not in mapped_dset[0])
1617
                        self.assertDictEqual(
1618
                            mapped_dset.features, Features({"filename": Value("string"), "name": Value("string")})
1619
                        )
1620
                        assert_arrow_metadata_are_synced_with_dataset_features(mapped_dset)
1621
                        with mapped_dset.with_format("numpy", columns=mapped_dset.column_names) as mapped_dset:
1622
                            with mapped_dset.map(
1623
                                lambda x: {"name": 1}, remove_columns=mapped_dset.column_names
1624
                            ) as mapped_dset:
1625
                                self.assertTrue("filename" not in mapped_dset[0])
1626
                                self.assertTrue("name" in mapped_dset[0])
1627
                                self.assertDictEqual(mapped_dset.features, Features({"name": Value(dtype="int64")}))
1628
                                assert_arrow_metadata_are_synced_with_dataset_features(mapped_dset)
1629
                    # empty dataset
1630
                    columns_names = dset.column_names
1631
                    with dset.select([]) as empty_dset:
1632
                        self.assertEqual(len(empty_dset), 0)
1633
                        with empty_dset.map(lambda x: {}, remove_columns=columns_names[0]) as mapped_dset:
1634
                            self.assertListEqual(columns_names[1:], mapped_dset.column_names)
1635
                            assert_arrow_metadata_are_synced_with_dataset_features(mapped_dset)
1636

1637
    def test_map_stateful_callable(self, in_memory):
1638
        # be sure that the state of the map callable is unaffected
1639
        # before processing the dataset examples
1640

1641
        class ExampleCounter:
1642
            def __init__(self, batched=False):
1643
                self.batched = batched
1644
                # state
1645
                self.cnt = 0
1646

1647
            def __call__(self, example):
1648
                if self.batched:
1649
                    self.cnt += len(example)
1650
                else:
1651
                    self.cnt += 1
1652

1653
        with tempfile.TemporaryDirectory() as tmp_dir:
1654
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1655
                ex_cnt = ExampleCounter()
1656
                dset.map(ex_cnt)
1657
                self.assertEqual(ex_cnt.cnt, len(dset))
1658

1659
                ex_cnt = ExampleCounter(batched=True)
1660
                dset.map(ex_cnt)
1661
                self.assertEqual(ex_cnt.cnt, len(dset))
1662

1663
    @require_not_windows
1664
    def test_map_crash_subprocess(self, in_memory):
1665
        # be sure that a crash in one of the subprocess will not
1666
        # hang dataset.map() call forever
1667

1668
        def do_crash(row):
1669
            import os
1670

1671
            os.kill(os.getpid(), 9)
1672
            return row
1673

1674
        with tempfile.TemporaryDirectory() as tmp_dir:
1675
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1676
                with pytest.raises(RuntimeError) as excinfo:
1677
                    dset.map(do_crash, num_proc=2)
1678
                assert str(excinfo.value) == (
1679
                    "One of the subprocesses has abruptly died during map operation."
1680
                    "To debug the error, disable multiprocessing."
1681
                )
1682

1683
    def test_filter(self, in_memory):
1684
        # keep only first five examples
1685

1686
        with tempfile.TemporaryDirectory() as tmp_dir:
1687
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1688
                fingerprint = dset._fingerprint
1689
                with dset.filter(lambda x, i: i < 5, with_indices=True) as dset_filter_first_five:
1690
                    self.assertEqual(len(dset_filter_first_five), 5)
1691
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1692
                    self.assertDictEqual(dset_filter_first_five.features, Features({"filename": Value("string")}))
1693
                    self.assertNotEqual(dset_filter_first_five._fingerprint, fingerprint)
1694

1695
        # filter filenames with even id at the end + formatted
1696
        with tempfile.TemporaryDirectory() as tmp_dir:
1697
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1698
                dset.set_format("numpy")
1699
                fingerprint = dset._fingerprint
1700
                with dset.filter(lambda x: (int(x["filename"][-1]) % 2 == 0)) as dset_filter_even_num:
1701
                    self.assertEqual(len(dset_filter_even_num), 15)
1702
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1703
                    self.assertDictEqual(dset_filter_even_num.features, Features({"filename": Value("string")}))
1704
                    self.assertNotEqual(dset_filter_even_num._fingerprint, fingerprint)
1705
                    self.assertEqual(dset_filter_even_num.format["type"], "numpy")
1706

1707
    def test_filter_with_indices_mapping(self, in_memory):
1708
        with tempfile.TemporaryDirectory() as tmp_dir:
1709
            dset = Dataset.from_dict({"col": [0, 1, 2]})
1710
            with self._to(in_memory, tmp_dir, dset) as dset:
1711
                with dset.filter(lambda x: x["col"] > 0) as dset:
1712
                    self.assertListEqual(dset["col"], [1, 2])
1713
                    with dset.filter(lambda x: x["col"] < 2) as dset:
1714
                        self.assertListEqual(dset["col"], [1])
1715

1716
    def test_filter_empty(self, in_memory):
1717
        with tempfile.TemporaryDirectory() as tmp_dir:
1718
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1719
                self.assertIsNone(dset._indices, None)
1720

1721
                tmp_file = os.path.join(tmp_dir, "test.arrow")
1722
                with dset.filter(lambda _: False, cache_file_name=tmp_file) as dset:
1723
                    self.assertEqual(len(dset), 0)
1724
                    self.assertIsNotNone(dset._indices, None)
1725

1726
                    tmp_file_2 = os.path.join(tmp_dir, "test_2.arrow")
1727
                    with dset.filter(lambda _: False, cache_file_name=tmp_file_2) as dset2:
1728
                        self.assertEqual(len(dset2), 0)
1729
                        self.assertEqual(dset._indices, dset2._indices)
1730

1731
    def test_filter_batched(self, in_memory):
1732
        with tempfile.TemporaryDirectory() as tmp_dir:
1733
            dset = Dataset.from_dict({"col": [0, 1, 2]})
1734
            with self._to(in_memory, tmp_dir, dset) as dset:
1735
                with dset.filter(lambda x: [i > 0 for i in x["col"]], batched=True) as dset:
1736
                    self.assertListEqual(dset["col"], [1, 2])
1737
                    with dset.filter(lambda x: [i < 2 for i in x["col"]], batched=True) as dset:
1738
                        self.assertListEqual(dset["col"], [1])
1739

1740
    def test_filter_input_columns(self, in_memory):
1741
        with tempfile.TemporaryDirectory() as tmp_dir:
1742
            dset = Dataset.from_dict({"col_1": [0, 1, 2], "col_2": ["a", "b", "c"]})
1743
            with self._to(in_memory, tmp_dir, dset) as dset:
1744
                with dset.filter(lambda x: x > 0, input_columns=["col_1"]) as filtered_dset:
1745
                    self.assertListEqual(filtered_dset.column_names, dset.column_names)
1746
                    self.assertListEqual(filtered_dset["col_1"], [1, 2])
1747
                    self.assertListEqual(filtered_dset["col_2"], ["b", "c"])
1748

1749
    def test_filter_fn_kwargs(self, in_memory):
1750
        with tempfile.TemporaryDirectory() as tmp_dir:
1751
            with Dataset.from_dict({"id": range(10)}) as dset:
1752
                with self._to(in_memory, tmp_dir, dset) as dset:
1753
                    fn_kwargs = {"max_offset": 3}
1754
                    with dset.filter(
1755
                        lambda example, max_offset: example["id"] < max_offset, fn_kwargs=fn_kwargs
1756
                    ) as filtered_dset:
1757
                        assert len(filtered_dset) == 3
1758
                    with dset.filter(
1759
                        lambda id, max_offset: id < max_offset, fn_kwargs=fn_kwargs, input_columns="id"
1760
                    ) as filtered_dset:
1761
                        assert len(filtered_dset) == 3
1762
                    with dset.filter(
1763
                        lambda id, i, max_offset: i < max_offset,
1764
                        fn_kwargs=fn_kwargs,
1765
                        input_columns="id",
1766
                        with_indices=True,
1767
                    ) as filtered_dset:
1768
                        assert len(filtered_dset) == 3
1769

1770
    def test_filter_multiprocessing(self, in_memory):
1771
        with tempfile.TemporaryDirectory() as tmp_dir:
1772
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1773
                fingerprint = dset._fingerprint
1774
                with dset.filter(picklable_filter_function, num_proc=2) as dset_filter_first_ten:
1775
                    self.assertEqual(len(dset_filter_first_ten), 10)
1776
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1777
                    self.assertDictEqual(dset_filter_first_ten.features, Features({"filename": Value("string")}))
1778
                    self.assertEqual(len(dset_filter_first_ten.cache_files), 0 if in_memory else 2)
1779
                    self.assertNotEqual(dset_filter_first_ten._fingerprint, fingerprint)
1780

1781
        with tempfile.TemporaryDirectory() as tmp_dir:  # with_rank
1782
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1783
                fingerprint = dset._fingerprint
1784
                with dset.filter(
1785
                    picklable_filter_function_with_rank, num_proc=2, with_rank=True
1786
                ) as dset_filter_first_rank:
1787
                    self.assertEqual(len(dset_filter_first_rank), min(len(dset) // 2, len(dset)))
1788
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1789
                    self.assertDictEqual(dset_filter_first_rank.features, Features({"filename": Value("string")}))
1790
                    self.assertEqual(len(dset_filter_first_rank.cache_files), 0 if in_memory else 2)
1791
                    self.assertNotEqual(dset_filter_first_rank._fingerprint, fingerprint)
1792

1793
    def test_filter_caching(self, in_memory):
1794
        with tempfile.TemporaryDirectory() as tmp_dir:
1795
            self._caplog.clear()
1796
            with self._caplog.at_level(INFO, logger=get_logger().name):
1797
                with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1798
                    with dset.filter(lambda x, i: i < 5, with_indices=True) as dset_filter_first_five1:
1799
                        dset_test1_data_files = list(dset_filter_first_five1.cache_files)
1800
                    with dset.filter(lambda x, i: i < 5, with_indices=True) as dset_filter_first_five2:
1801
                        self.assertEqual(dset_test1_data_files, dset_filter_first_five2.cache_files)
1802
                        self.assertEqual(len(dset_filter_first_five2.cache_files), 0 if in_memory else 2)
1803
                        self.assertTrue(("Loading cached processed dataset" in self._caplog.text) ^ in_memory)
1804

1805
    def test_keep_features_after_transform_specified(self, in_memory):
1806
        features = Features(
1807
            {"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1808
        )
1809

1810
        def invert_labels(x):
1811
            return {"labels": [(1 - label) for label in x["labels"]]}
1812

1813
        with tempfile.TemporaryDirectory() as tmp_dir:
1814
            with Dataset.from_dict(
1815
                {"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1816
            ) as dset:
1817
                with self._to(in_memory, tmp_dir, dset) as dset:
1818
                    with dset.map(invert_labels, features=features) as inverted_dset:
1819
                        self.assertEqual(inverted_dset.features.type, features.type)
1820
                        self.assertDictEqual(inverted_dset.features, features)
1821
                        assert_arrow_metadata_are_synced_with_dataset_features(inverted_dset)
1822

1823
    def test_keep_features_after_transform_unspecified(self, in_memory):
1824
        features = Features(
1825
            {"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1826
        )
1827

1828
        def invert_labels(x):
1829
            return {"labels": [(1 - label) for label in x["labels"]]}
1830

1831
        with tempfile.TemporaryDirectory() as tmp_dir:
1832
            with Dataset.from_dict(
1833
                {"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1834
            ) as dset:
1835
                with self._to(in_memory, tmp_dir, dset) as dset:
1836
                    with dset.map(invert_labels) as inverted_dset:
1837
                        self.assertEqual(inverted_dset.features.type, features.type)
1838
                        self.assertDictEqual(inverted_dset.features, features)
1839
                        assert_arrow_metadata_are_synced_with_dataset_features(inverted_dset)
1840

1841
    def test_keep_features_after_transform_to_file(self, in_memory):
1842
        features = Features(
1843
            {"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1844
        )
1845

1846
        def invert_labels(x):
1847
            return {"labels": [(1 - label) for label in x["labels"]]}
1848

1849
        with tempfile.TemporaryDirectory() as tmp_dir:
1850
            with Dataset.from_dict(
1851
                {"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1852
            ) as dset:
1853
                with self._to(in_memory, tmp_dir, dset) as dset:
1854
                    tmp_file = os.path.join(tmp_dir, "test.arrow")
1855
                    dset.map(invert_labels, cache_file_name=tmp_file)
1856
                    with Dataset.from_file(tmp_file) as inverted_dset:
1857
                        self.assertEqual(inverted_dset.features.type, features.type)
1858
                        self.assertDictEqual(inverted_dset.features, features)
1859

1860
    def test_keep_features_after_transform_to_memory(self, in_memory):
1861
        features = Features(
1862
            {"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1863
        )
1864

1865
        def invert_labels(x):
1866
            return {"labels": [(1 - label) for label in x["labels"]]}
1867

1868
        with tempfile.TemporaryDirectory() as tmp_dir:
1869
            with Dataset.from_dict(
1870
                {"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1871
            ) as dset:
1872
                with self._to(in_memory, tmp_dir, dset) as dset:
1873
                    with dset.map(invert_labels, keep_in_memory=True) as inverted_dset:
1874
                        self.assertEqual(inverted_dset.features.type, features.type)
1875
                        self.assertDictEqual(inverted_dset.features, features)
1876

1877
    def test_keep_features_after_loading_from_cache(self, in_memory):
1878
        features = Features(
1879
            {"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1880
        )
1881

1882
        def invert_labels(x):
1883
            return {"labels": [(1 - label) for label in x["labels"]]}
1884

1885
        with tempfile.TemporaryDirectory() as tmp_dir:
1886
            with Dataset.from_dict(
1887
                {"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1888
            ) as dset:
1889
                with self._to(in_memory, tmp_dir, dset) as dset:
1890
                    tmp_file1 = os.path.join(tmp_dir, "test1.arrow")
1891
                    tmp_file2 = os.path.join(tmp_dir, "test2.arrow")
1892
                    # TODO: Why mapped twice?
1893
                    inverted_dset = dset.map(invert_labels, cache_file_name=tmp_file1)
1894
                    inverted_dset = dset.map(invert_labels, cache_file_name=tmp_file2)
1895
                    self.assertGreater(len(inverted_dset.cache_files), 0)
1896
                    self.assertEqual(inverted_dset.features.type, features.type)
1897
                    self.assertDictEqual(inverted_dset.features, features)
1898
                    del inverted_dset
1899

1900
    def test_keep_features_with_new_features(self, in_memory):
1901
        features = Features(
1902
            {"tokens": Sequence(Value("string")), "labels": Sequence(ClassLabel(names=["negative", "positive"]))}
1903
        )
1904

1905
        def invert_labels(x):
1906
            return {"labels": [(1 - label) for label in x["labels"]], "labels2": x["labels"]}
1907

1908
        expected_features = Features(
1909
            {
1910
                "tokens": Sequence(Value("string")),
1911
                "labels": Sequence(ClassLabel(names=["negative", "positive"])),
1912
                "labels2": Sequence(Value("int64")),
1913
            }
1914
        )
1915

1916
        with tempfile.TemporaryDirectory() as tmp_dir:
1917
            with Dataset.from_dict(
1918
                {"tokens": [["foo"] * 5] * 10, "labels": [[1] * 5] * 10}, features=features
1919
            ) as dset:
1920
                with self._to(in_memory, tmp_dir, dset) as dset:
1921
                    with dset.map(invert_labels) as inverted_dset:
1922
                        self.assertEqual(inverted_dset.features.type, expected_features.type)
1923
                        self.assertDictEqual(inverted_dset.features, expected_features)
1924
                        assert_arrow_metadata_are_synced_with_dataset_features(inverted_dset)
1925

1926
    def test_select(self, in_memory):
1927
        with tempfile.TemporaryDirectory() as tmp_dir:
1928
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1929
                # select every two example
1930
                indices = list(range(0, len(dset), 2))
1931
                tmp_file = os.path.join(tmp_dir, "test.arrow")
1932
                fingerprint = dset._fingerprint
1933
                with dset.select(indices, indices_cache_file_name=tmp_file) as dset_select_even:
1934
                    self.assertIsNotNone(dset_select_even._indices)  # an indices mapping is created
1935
                    self.assertTrue(os.path.exists(tmp_file))
1936
                    self.assertEqual(len(dset_select_even), 15)
1937
                    for row in dset_select_even:
1938
                        self.assertEqual(int(row["filename"][-1]) % 2, 0)
1939
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1940
                    self.assertDictEqual(dset_select_even.features, Features({"filename": Value("string")}))
1941
                    self.assertNotEqual(dset_select_even._fingerprint, fingerprint)
1942

1943
        with tempfile.TemporaryDirectory() as tmp_dir:
1944
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1945
                indices = list(range(0, len(dset)))
1946
                with dset.select(indices) as dset_select_all:
1947
                    # no indices mapping, since the indices are contiguous
1948
                    # (in this case the arrow table is simply sliced, which is more efficient)
1949
                    self.assertIsNone(dset_select_all._indices)
1950
                    self.assertEqual(len(dset_select_all), len(dset))
1951
                    self.assertListEqual(list(dset_select_all), list(dset))
1952
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1953
                    self.assertDictEqual(dset_select_all.features, Features({"filename": Value("string")}))
1954
                    self.assertNotEqual(dset_select_all._fingerprint, fingerprint)
1955
                indices = range(0, len(dset))
1956
                with dset.select(indices) as dset_select_all:
1957
                    # same but with range
1958
                    self.assertIsNone(dset_select_all._indices)
1959
                    self.assertEqual(len(dset_select_all), len(dset))
1960
                    self.assertListEqual(list(dset_select_all), list(dset))
1961
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1962
                    self.assertDictEqual(dset_select_all.features, Features({"filename": Value("string")}))
1963
                    self.assertNotEqual(dset_select_all._fingerprint, fingerprint)
1964

1965
        with tempfile.TemporaryDirectory() as tmp_dir:
1966
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1967
                bad_indices = list(range(5))
1968
                bad_indices[-1] = len(dset) + 10  # out of bounds
1969
                tmp_file = os.path.join(tmp_dir, "test.arrow")
1970
                self.assertRaises(
1971
                    Exception,
1972
                    dset.select,
1973
                    indices=bad_indices,
1974
                    indices_cache_file_name=tmp_file,
1975
                    writer_batch_size=2,
1976
                )
1977
                self.assertFalse(os.path.exists(tmp_file))
1978

1979
        with tempfile.TemporaryDirectory() as tmp_dir:
1980
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1981
                indices = iter(range(len(dset)))  # iterator of contiguous indices
1982
                with dset.select(indices) as dset_select_all:
1983
                    # no indices mapping, since the indices are contiguous
1984
                    self.assertIsNone(dset_select_all._indices)
1985
                    self.assertEqual(len(dset_select_all), len(dset))
1986
                indices = reversed(range(len(dset)))  # iterator of not contiguous indices
1987
                tmp_file = os.path.join(tmp_dir, "test.arrow")
1988
                with dset.select(indices, indices_cache_file_name=tmp_file) as dset_select_all:
1989
                    # new indices mapping, since the indices are not contiguous
1990
                    self.assertIsNotNone(dset_select_all._indices)
1991
                    self.assertEqual(len(dset_select_all), len(dset))
1992

1993
        with tempfile.TemporaryDirectory() as tmp_dir:
1994
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1995
                bad_indices = list(range(5))
1996
                bad_indices[3] = "foo"  # wrong type
1997
                tmp_file = os.path.join(tmp_dir, "test.arrow")
1998
                self.assertRaises(
1999
                    Exception,
2000
                    dset.select,
2001
                    indices=bad_indices,
2002
                    indices_cache_file_name=tmp_file,
2003
                    writer_batch_size=2,
2004
                )
2005
                self.assertFalse(os.path.exists(tmp_file))
2006
                dset.set_format("numpy")
2007
                with dset.select(
2008
                    range(5),
2009
                    indices_cache_file_name=tmp_file,
2010
                    writer_batch_size=2,
2011
                ) as dset_select_five:
2012
                    self.assertIsNone(dset_select_five._indices)
2013
                    self.assertEqual(len(dset_select_five), 5)
2014
                    self.assertEqual(dset_select_five.format["type"], "numpy")
2015
                    for i, row in enumerate(dset_select_five):
2016
                        self.assertEqual(int(row["filename"][-1]), i)
2017
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2018
                    self.assertDictEqual(dset_select_five.features, Features({"filename": Value("string")}))
2019

2020
    def test_select_then_map(self, in_memory):
2021
        with tempfile.TemporaryDirectory() as tmp_dir:
2022
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2023
                with dset.select([0]) as d1:
2024
                    with d1.map(lambda x: {"id": int(x["filename"].split("_")[-1])}) as d1:
2025
                        self.assertEqual(d1[0]["id"], 0)
2026
                with dset.select([1]) as d2:
2027
                    with d2.map(lambda x: {"id": int(x["filename"].split("_")[-1])}) as d2:
2028
                        self.assertEqual(d2[0]["id"], 1)
2029

2030
        with tempfile.TemporaryDirectory() as tmp_dir:
2031
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2032
                with dset.select([0], indices_cache_file_name=os.path.join(tmp_dir, "i1.arrow")) as d1:
2033
                    with d1.map(lambda x: {"id": int(x["filename"].split("_")[-1])}) as d1:
2034
                        self.assertEqual(d1[0]["id"], 0)
2035
                with dset.select([1], indices_cache_file_name=os.path.join(tmp_dir, "i2.arrow")) as d2:
2036
                    with d2.map(lambda x: {"id": int(x["filename"].split("_")[-1])}) as d2:
2037
                        self.assertEqual(d2[0]["id"], 1)
2038

2039
    def test_pickle_after_many_transforms_on_disk(self, in_memory):
2040
        with tempfile.TemporaryDirectory() as tmp_dir:
2041
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2042
                self.assertEqual(len(dset.cache_files), 0 if in_memory else 1)
2043
                with dset.rename_column("filename", "file") as dset:
2044
                    self.assertListEqual(dset.column_names, ["file"])
2045
                    with dset.select(range(5)) as dset:
2046
                        self.assertEqual(len(dset), 5)
2047
                        with dset.map(lambda x: {"id": int(x["file"][-1])}) as dset:
2048
                            self.assertListEqual(sorted(dset.column_names), ["file", "id"])
2049
                            with dset.rename_column("id", "number") as dset:
2050
                                self.assertListEqual(sorted(dset.column_names), ["file", "number"])
2051
                                with dset.select([1, 0]) as dset:
2052
                                    self.assertEqual(dset[0]["file"], "my_name-train_1")
2053
                                    self.assertEqual(dset[0]["number"], 1)
2054

2055
                                    self.assertEqual(dset._indices["indices"].to_pylist(), [1, 0])
2056
                                    if not in_memory:
2057
                                        self.assertIn(
2058
                                            ("rename_columns", (["file", "number"],), {}),
2059
                                            dset._data.replays,
2060
                                        )
2061
                                    if not in_memory:
2062
                                        dset._data.table = Unpicklable()  # check that we don't pickle the entire table
2063

2064
                                    pickled = pickle.dumps(dset)
2065
                                    with pickle.loads(pickled) as loaded:
2066
                                        self.assertEqual(loaded[0]["file"], "my_name-train_1")
2067
                                        self.assertEqual(loaded[0]["number"], 1)
2068

2069
    def test_shuffle(self, in_memory):
2070
        with tempfile.TemporaryDirectory() as tmp_dir:
2071
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2072
                tmp_file = os.path.join(tmp_dir, "test.arrow")
2073
                fingerprint = dset._fingerprint
2074

2075
                with dset.shuffle(seed=1234, keep_in_memory=True) as dset_shuffled:
2076
                    self.assertEqual(len(dset_shuffled), 30)
2077
                    self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28")
2078
                    self.assertEqual(dset_shuffled[2]["filename"], "my_name-train_10")
2079
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2080
                    self.assertDictEqual(dset_shuffled.features, Features({"filename": Value("string")}))
2081
                    self.assertNotEqual(dset_shuffled._fingerprint, fingerprint)
2082

2083
                with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled:
2084
                    self.assertEqual(len(dset_shuffled), 30)
2085
                    self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28")
2086
                    self.assertEqual(dset_shuffled[2]["filename"], "my_name-train_10")
2087
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2088
                    self.assertDictEqual(dset_shuffled.features, Features({"filename": Value("string")}))
2089
                    self.assertNotEqual(dset_shuffled._fingerprint, fingerprint)
2090

2091
                    # Reproducibility
2092
                    tmp_file = os.path.join(tmp_dir, "test_2.arrow")
2093
                    with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled_2:
2094
                        self.assertListEqual(dset_shuffled["filename"], dset_shuffled_2["filename"])
2095

2096
                # Compatible with temp_seed
2097
                with temp_seed(42), dset.shuffle() as d1:
2098
                    with temp_seed(42), dset.shuffle() as d2, dset.shuffle() as d3:
2099
                        self.assertListEqual(d1["filename"], d2["filename"])
2100
                        self.assertEqual(d1._fingerprint, d2._fingerprint)
2101
                        self.assertNotEqual(d3["filename"], d2["filename"])
2102
                        self.assertNotEqual(d3._fingerprint, d2._fingerprint)
2103

2104
    def test_sort(self, in_memory):
2105
        with tempfile.TemporaryDirectory() as tmp_dir:
2106
            # Sort on a single key
2107
            with self._create_dummy_dataset(in_memory=in_memory, tmp_dir=tmp_dir) as dset:
2108
                # Keep only 10 examples
2109
                tmp_file = os.path.join(tmp_dir, "test.arrow")
2110
                with dset.select(range(10), indices_cache_file_name=tmp_file) as dset:
2111
                    tmp_file = os.path.join(tmp_dir, "test_2.arrow")
2112
                    with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset:
2113
                        self.assertEqual(len(dset), 10)
2114
                        self.assertEqual(dset[0]["filename"], "my_name-train_8")
2115
                        self.assertEqual(dset[1]["filename"], "my_name-train_9")
2116
                        # Sort
2117
                        tmp_file = os.path.join(tmp_dir, "test_3.arrow")
2118
                        fingerprint = dset._fingerprint
2119
                        with dset.sort("filename", indices_cache_file_name=tmp_file) as dset_sorted:
2120
                            for i, row in enumerate(dset_sorted):
2121
                                self.assertEqual(int(row["filename"][-1]), i)
2122
                            self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2123
                            self.assertDictEqual(dset_sorted.features, Features({"filename": Value("string")}))
2124
                            self.assertNotEqual(dset_sorted._fingerprint, fingerprint)
2125
                            # Sort reversed
2126
                            tmp_file = os.path.join(tmp_dir, "test_4.arrow")
2127
                            fingerprint = dset._fingerprint
2128
                            with dset.sort("filename", indices_cache_file_name=tmp_file, reverse=True) as dset_sorted:
2129
                                for i, row in enumerate(dset_sorted):
2130
                                    self.assertEqual(int(row["filename"][-1]), len(dset_sorted) - 1 - i)
2131
                                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2132
                                self.assertDictEqual(dset_sorted.features, Features({"filename": Value("string")}))
2133
                                self.assertNotEqual(dset_sorted._fingerprint, fingerprint)
2134
                            # formatted
2135
                            dset.set_format("numpy")
2136
                            with dset.sort("filename") as dset_sorted_formatted:
2137
                                self.assertEqual(dset_sorted_formatted.format["type"], "numpy")
2138
            # Sort on multiple keys
2139
            with self._create_dummy_dataset(in_memory=in_memory, tmp_dir=tmp_dir, multiple_columns=True) as dset:
2140
                tmp_file = os.path.join(tmp_dir, "test_5.arrow")
2141
                fingerprint = dset._fingerprint
2142
                # Throw error when reverse is a list of bools that does not match the length of column_names
2143
                with pytest.raises(ValueError):
2144
                    dset.sort(["col_1", "col_2", "col_3"], reverse=[False])
2145
                with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset:
2146
                    # Sort
2147
                    with dset.sort(["col_1", "col_2", "col_3"], reverse=[False, True, False]) as dset_sorted:
2148
                        for i, row in enumerate(dset_sorted):
2149
                            self.assertEqual(row["col_1"], i)
2150
                        self.assertDictEqual(
2151
                            dset.features,
2152
                            Features(
2153
                                {
2154
                                    "col_1": Value("int64"),
2155
                                    "col_2": Value("string"),
2156
                                    "col_3": Value("bool"),
2157
                                }
2158
                            ),
2159
                        )
2160
                        self.assertDictEqual(
2161
                            dset_sorted.features,
2162
                            Features(
2163
                                {
2164
                                    "col_1": Value("int64"),
2165
                                    "col_2": Value("string"),
2166
                                    "col_3": Value("bool"),
2167
                                }
2168
                            ),
2169
                        )
2170
                        self.assertNotEqual(dset_sorted._fingerprint, fingerprint)
2171
                        # Sort reversed
2172
                        with dset.sort(["col_1", "col_2", "col_3"], reverse=[True, False, True]) as dset_sorted:
2173
                            for i, row in enumerate(dset_sorted):
2174
                                self.assertEqual(row["col_1"], len(dset_sorted) - 1 - i)
2175
                            self.assertDictEqual(
2176
                                dset.features,
2177
                                Features(
2178
                                    {
2179
                                        "col_1": Value("int64"),
2180
                                        "col_2": Value("string"),
2181
                                        "col_3": Value("bool"),
2182
                                    }
2183
                                ),
2184
                            )
2185
                            self.assertDictEqual(
2186
                                dset_sorted.features,
2187
                                Features(
2188
                                    {
2189
                                        "col_1": Value("int64"),
2190
                                        "col_2": Value("string"),
2191
                                        "col_3": Value("bool"),
2192
                                    }
2193
                                ),
2194
                            )
2195
                            self.assertNotEqual(dset_sorted._fingerprint, fingerprint)
2196
                            # formatted
2197
                            dset.set_format("numpy")
2198
                            with dset.sort(
2199
                                ["col_1", "col_2", "col_3"], reverse=[False, True, False]
2200
                            ) as dset_sorted_formatted:
2201
                                self.assertEqual(dset_sorted_formatted.format["type"], "numpy")
2202

2203
    @require_tf
2204
    def test_export(self, in_memory):
2205
        with tempfile.TemporaryDirectory() as tmp_dir:
2206
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2207
                # Export the data
2208
                tfrecord_path = os.path.join(tmp_dir, "test.tfrecord")
2209
                with dset.map(
2210
                    lambda ex, i: {
2211
                        "id": i,
2212
                        "question": f"Question {i}",
2213
                        "answers": {"text": [f"Answer {i}-0", f"Answer {i}-1"], "answer_start": [0, 1]},
2214
                    },
2215
                    with_indices=True,
2216
                    remove_columns=["filename"],
2217
                ) as formatted_dset:
2218
                    with formatted_dset.flatten() as formatted_dset:
2219
                        formatted_dset.set_format("numpy")
2220
                        formatted_dset.export(filename=tfrecord_path, format="tfrecord")
2221

2222
                        # Import the data
2223
                        import tensorflow as tf
2224

2225
                        tf_dset = tf.data.TFRecordDataset([tfrecord_path])
2226
                        feature_description = {
2227
                            "id": tf.io.FixedLenFeature([], tf.int64),
2228
                            "question": tf.io.FixedLenFeature([], tf.string),
2229
                            "answers.text": tf.io.VarLenFeature(tf.string),
2230
                            "answers.answer_start": tf.io.VarLenFeature(tf.int64),
2231
                        }
2232
                        tf_parsed_dset = tf_dset.map(
2233
                            lambda example_proto: tf.io.parse_single_example(example_proto, feature_description)
2234
                        )
2235
                        # Test that keys match original dataset
2236
                        for i, ex in enumerate(tf_parsed_dset):
2237
                            self.assertEqual(ex.keys(), formatted_dset[i].keys())
2238
                        # Test for equal number of elements
2239
                        self.assertEqual(i, len(formatted_dset) - 1)
2240

2241
    def test_to_csv(self, in_memory):
2242
        with tempfile.TemporaryDirectory() as tmp_dir:
2243
            # File path argument
2244
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2245
                file_path = os.path.join(tmp_dir, "test_path.csv")
2246
                bytes_written = dset.to_csv(path_or_buf=file_path)
2247

2248
                self.assertTrue(os.path.isfile(file_path))
2249
                self.assertEqual(bytes_written, os.path.getsize(file_path))
2250
                csv_dset = pd.read_csv(file_path)
2251

2252
                self.assertEqual(csv_dset.shape, dset.shape)
2253
                self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
2254

2255
            # File buffer argument
2256
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2257
                file_path = os.path.join(tmp_dir, "test_buffer.csv")
2258
                with open(file_path, "wb+") as buffer:
2259
                    bytes_written = dset.to_csv(path_or_buf=buffer)
2260

2261
                self.assertTrue(os.path.isfile(file_path))
2262
                self.assertEqual(bytes_written, os.path.getsize(file_path))
2263
                csv_dset = pd.read_csv(file_path)
2264

2265
                self.assertEqual(csv_dset.shape, dset.shape)
2266
                self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
2267

2268
            # After a select/shuffle transform
2269
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2270
                dset = dset.select(range(0, len(dset), 2)).shuffle()
2271
                file_path = os.path.join(tmp_dir, "test_path.csv")
2272
                bytes_written = dset.to_csv(path_or_buf=file_path)
2273

2274
                self.assertTrue(os.path.isfile(file_path))
2275
                self.assertEqual(bytes_written, os.path.getsize(file_path))
2276
                csv_dset = pd.read_csv(file_path)
2277

2278
                self.assertEqual(csv_dset.shape, dset.shape)
2279
                self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
2280

2281
            # With array features
2282
            with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
2283
                file_path = os.path.join(tmp_dir, "test_path.csv")
2284
                bytes_written = dset.to_csv(path_or_buf=file_path)
2285

2286
                self.assertTrue(os.path.isfile(file_path))
2287
                self.assertEqual(bytes_written, os.path.getsize(file_path))
2288
                csv_dset = pd.read_csv(file_path)
2289

2290
                self.assertEqual(csv_dset.shape, dset.shape)
2291
                self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
2292

2293
    def test_to_dict(self, in_memory):
2294
        with tempfile.TemporaryDirectory() as tmp_dir:
2295
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2296
                # Full
2297
                dset_to_dict = dset.to_dict()
2298
                self.assertIsInstance(dset_to_dict, dict)
2299
                self.assertListEqual(sorted(dset_to_dict.keys()), sorted(dset.column_names))
2300

2301
                for col_name in dset.column_names:
2302
                    self.assertLessEqual(len(dset_to_dict[col_name]), len(dset))
2303

2304
                # With index mapping
2305
                with dset.select([1, 0, 3]) as dset:
2306
                    dset_to_dict = dset.to_dict()
2307
                    self.assertIsInstance(dset_to_dict, dict)
2308
                    self.assertEqual(len(dset_to_dict), 3)
2309
                    self.assertListEqual(sorted(dset_to_dict.keys()), sorted(dset.column_names))
2310

2311
                    for col_name in dset.column_names:
2312
                        self.assertIsInstance(dset_to_dict[col_name], list)
2313
                        self.assertEqual(len(dset_to_dict[col_name]), len(dset))
2314

2315
    def test_to_list(self, in_memory):
2316
        with tempfile.TemporaryDirectory() as tmp_dir:
2317
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2318
                dset_to_list = dset.to_list()
2319
                self.assertIsInstance(dset_to_list, list)
2320
                for row in dset_to_list:
2321
                    self.assertIsInstance(row, dict)
2322
                    self.assertListEqual(sorted(row.keys()), sorted(dset.column_names))
2323

2324
                # With index mapping
2325
                with dset.select([1, 0, 3]) as dset:
2326
                    dset_to_list = dset.to_list()
2327
                    self.assertIsInstance(dset_to_list, list)
2328
                    self.assertEqual(len(dset_to_list), 3)
2329
                    for row in dset_to_list:
2330
                        self.assertIsInstance(row, dict)
2331
                        self.assertListEqual(sorted(row.keys()), sorted(dset.column_names))
2332

2333
    def test_to_pandas(self, in_memory):
2334
        with tempfile.TemporaryDirectory() as tmp_dir:
2335
            # Batched
2336
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2337
                batch_size = dset.num_rows - 1
2338
                to_pandas_generator = dset.to_pandas(batched=True, batch_size=batch_size)
2339

2340
                for batch in to_pandas_generator:
2341
                    self.assertIsInstance(batch, pd.DataFrame)
2342
                    self.assertListEqual(sorted(batch.columns), sorted(dset.column_names))
2343
                    for col_name in dset.column_names:
2344
                        self.assertLessEqual(len(batch[col_name]), batch_size)
2345

2346
                # Full
2347
                dset_to_pandas = dset.to_pandas()
2348
                self.assertIsInstance(dset_to_pandas, pd.DataFrame)
2349
                self.assertListEqual(sorted(dset_to_pandas.columns), sorted(dset.column_names))
2350
                for col_name in dset.column_names:
2351
                    self.assertEqual(len(dset_to_pandas[col_name]), len(dset))
2352

2353
                # With index mapping
2354
                with dset.select([1, 0, 3]) as dset:
2355
                    dset_to_pandas = dset.to_pandas()
2356
                    self.assertIsInstance(dset_to_pandas, pd.DataFrame)
2357
                    self.assertEqual(len(dset_to_pandas), 3)
2358
                    self.assertListEqual(sorted(dset_to_pandas.columns), sorted(dset.column_names))
2359

2360
                    for col_name in dset.column_names:
2361
                        self.assertEqual(len(dset_to_pandas[col_name]), dset.num_rows)
2362

2363
    def test_to_parquet(self, in_memory):
2364
        with tempfile.TemporaryDirectory() as tmp_dir:
2365
            # File path argument
2366
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2367
                file_path = os.path.join(tmp_dir, "test_path.parquet")
2368
                dset.to_parquet(path_or_buf=file_path)
2369

2370
                self.assertTrue(os.path.isfile(file_path))
2371
                # self.assertEqual(bytes_written, os.path.getsize(file_path))  # because of compression, the number of bytes doesn't match
2372
                parquet_dset = pd.read_parquet(file_path)
2373

2374
                self.assertEqual(parquet_dset.shape, dset.shape)
2375
                self.assertListEqual(list(parquet_dset.columns), list(dset.column_names))
2376

2377
            # File buffer argument
2378
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2379
                file_path = os.path.join(tmp_dir, "test_buffer.parquet")
2380
                with open(file_path, "wb+") as buffer:
2381
                    dset.to_parquet(path_or_buf=buffer)
2382

2383
                self.assertTrue(os.path.isfile(file_path))
2384
                # self.assertEqual(bytes_written, os.path.getsize(file_path))  # because of compression, the number of bytes doesn't match
2385
                parquet_dset = pd.read_parquet(file_path)
2386

2387
                self.assertEqual(parquet_dset.shape, dset.shape)
2388
                self.assertListEqual(list(parquet_dset.columns), list(dset.column_names))
2389

2390
            # After a select/shuffle transform
2391
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2392
                dset = dset.select(range(0, len(dset), 2)).shuffle()
2393
                file_path = os.path.join(tmp_dir, "test_path.parquet")
2394
                dset.to_parquet(path_or_buf=file_path)
2395

2396
                self.assertTrue(os.path.isfile(file_path))
2397
                # self.assertEqual(bytes_written, os.path.getsize(file_path))  # because of compression, the number of bytes doesn't match
2398
                parquet_dset = pd.read_parquet(file_path)
2399

2400
                self.assertEqual(parquet_dset.shape, dset.shape)
2401
                self.assertListEqual(list(parquet_dset.columns), list(dset.column_names))
2402

2403
            # With array features
2404
            with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
2405
                file_path = os.path.join(tmp_dir, "test_path.parquet")
2406
                dset.to_parquet(path_or_buf=file_path)
2407

2408
                self.assertTrue(os.path.isfile(file_path))
2409
                # self.assertEqual(bytes_written, os.path.getsize(file_path))  # because of compression, the number of bytes doesn't match
2410
                parquet_dset = pd.read_parquet(file_path)
2411

2412
                self.assertEqual(parquet_dset.shape, dset.shape)
2413
                self.assertListEqual(list(parquet_dset.columns), list(dset.column_names))
2414

2415
    @require_sqlalchemy
2416
    def test_to_sql(self, in_memory):
2417
        with tempfile.TemporaryDirectory() as tmp_dir:
2418
            # Destionation specified as database URI string
2419
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2420
                file_path = os.path.join(tmp_dir, "test_path.sqlite")
2421
                _ = dset.to_sql("data", "sqlite:///" + file_path)
2422

2423
                self.assertTrue(os.path.isfile(file_path))
2424
                sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2425

2426
                self.assertEqual(sql_dset.shape, dset.shape)
2427
                self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2428

2429
            # Destionation specified as sqlite3 connection
2430
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2431
                import sqlite3
2432

2433
                file_path = os.path.join(tmp_dir, "test_path.sqlite")
2434
                with contextlib.closing(sqlite3.connect(file_path)) as con:
2435
                    _ = dset.to_sql("data", con, if_exists="replace")
2436

2437
                self.assertTrue(os.path.isfile(file_path))
2438
                sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2439

2440
                self.assertEqual(sql_dset.shape, dset.shape)
2441
                self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2442

2443
            # Test writing to a database in chunks
2444
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2445
                file_path = os.path.join(tmp_dir, "test_path.sqlite")
2446
                _ = dset.to_sql("data", "sqlite:///" + file_path, batch_size=1, if_exists="replace")
2447

2448
                self.assertTrue(os.path.isfile(file_path))
2449
                sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2450

2451
                self.assertEqual(sql_dset.shape, dset.shape)
2452
                self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2453

2454
            # After a select/shuffle transform
2455
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2456
                dset = dset.select(range(0, len(dset), 2)).shuffle()
2457
                file_path = os.path.join(tmp_dir, "test_path.sqlite")
2458
                _ = dset.to_sql("data", "sqlite:///" + file_path, if_exists="replace")
2459

2460
                self.assertTrue(os.path.isfile(file_path))
2461
                sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2462

2463
                self.assertEqual(sql_dset.shape, dset.shape)
2464
                self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2465

2466
            # With array features
2467
            with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
2468
                file_path = os.path.join(tmp_dir, "test_path.sqlite")
2469
                _ = dset.to_sql("data", "sqlite:///" + file_path, if_exists="replace")
2470

2471
                self.assertTrue(os.path.isfile(file_path))
2472
                sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
2473

2474
                self.assertEqual(sql_dset.shape, dset.shape)
2475
                self.assertListEqual(list(sql_dset.columns), list(dset.column_names))
2476

2477
    def test_train_test_split(self, in_memory):
2478
        with tempfile.TemporaryDirectory() as tmp_dir:
2479
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2480
                fingerprint = dset._fingerprint
2481
                dset_dict = dset.train_test_split(test_size=10, shuffle=False)
2482
                self.assertListEqual(list(dset_dict.keys()), ["train", "test"])
2483
                dset_train = dset_dict["train"]
2484
                dset_test = dset_dict["test"]
2485

2486
                self.assertEqual(len(dset_train), 20)
2487
                self.assertEqual(len(dset_test), 10)
2488
                self.assertEqual(dset_train[0]["filename"], "my_name-train_0")
2489
                self.assertEqual(dset_train[-1]["filename"], "my_name-train_19")
2490
                self.assertEqual(dset_test[0]["filename"], "my_name-train_20")
2491
                self.assertEqual(dset_test[-1]["filename"], "my_name-train_29")
2492
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2493
                self.assertDictEqual(dset_train.features, Features({"filename": Value("string")}))
2494
                self.assertDictEqual(dset_test.features, Features({"filename": Value("string")}))
2495
                self.assertNotEqual(dset_train._fingerprint, fingerprint)
2496
                self.assertNotEqual(dset_test._fingerprint, fingerprint)
2497
                self.assertNotEqual(dset_train._fingerprint, dset_test._fingerprint)
2498

2499
                dset_dict = dset.train_test_split(test_size=0.5, shuffle=False)
2500
                self.assertListEqual(list(dset_dict.keys()), ["train", "test"])
2501
                dset_train = dset_dict["train"]
2502
                dset_test = dset_dict["test"]
2503

2504
                self.assertEqual(len(dset_train), 15)
2505
                self.assertEqual(len(dset_test), 15)
2506
                self.assertEqual(dset_train[0]["filename"], "my_name-train_0")
2507
                self.assertEqual(dset_train[-1]["filename"], "my_name-train_14")
2508
                self.assertEqual(dset_test[0]["filename"], "my_name-train_15")
2509
                self.assertEqual(dset_test[-1]["filename"], "my_name-train_29")
2510
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2511
                self.assertDictEqual(dset_train.features, Features({"filename": Value("string")}))
2512
                self.assertDictEqual(dset_test.features, Features({"filename": Value("string")}))
2513

2514
                dset_dict = dset.train_test_split(train_size=10, shuffle=False)
2515
                self.assertListEqual(list(dset_dict.keys()), ["train", "test"])
2516
                dset_train = dset_dict["train"]
2517
                dset_test = dset_dict["test"]
2518

2519
                self.assertEqual(len(dset_train), 10)
2520
                self.assertEqual(len(dset_test), 20)
2521
                self.assertEqual(dset_train[0]["filename"], "my_name-train_0")
2522
                self.assertEqual(dset_train[-1]["filename"], "my_name-train_9")
2523
                self.assertEqual(dset_test[0]["filename"], "my_name-train_10")
2524
                self.assertEqual(dset_test[-1]["filename"], "my_name-train_29")
2525
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2526
                self.assertDictEqual(dset_train.features, Features({"filename": Value("string")}))
2527
                self.assertDictEqual(dset_test.features, Features({"filename": Value("string")}))
2528

2529
                dset.set_format("numpy")
2530
                dset_dict = dset.train_test_split(train_size=10, seed=42)
2531
                self.assertListEqual(list(dset_dict.keys()), ["train", "test"])
2532
                dset_train = dset_dict["train"]
2533
                dset_test = dset_dict["test"]
2534

2535
                self.assertEqual(len(dset_train), 10)
2536
                self.assertEqual(len(dset_test), 20)
2537
                self.assertEqual(dset_train.format["type"], "numpy")
2538
                self.assertEqual(dset_test.format["type"], "numpy")
2539
                self.assertNotEqual(dset_train[0]["filename"].item(), "my_name-train_0")
2540
                self.assertNotEqual(dset_train[-1]["filename"].item(), "my_name-train_9")
2541
                self.assertNotEqual(dset_test[0]["filename"].item(), "my_name-train_10")
2542
                self.assertNotEqual(dset_test[-1]["filename"].item(), "my_name-train_29")
2543
                self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2544
                self.assertDictEqual(dset_train.features, Features({"filename": Value("string")}))
2545
                self.assertDictEqual(dset_test.features, Features({"filename": Value("string")}))
2546
                del dset_test, dset_train, dset_dict  # DatasetDict
2547

2548
    def test_shard(self, in_memory):
2549
        with tempfile.TemporaryDirectory() as tmp_dir, self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2550
            tmp_file = os.path.join(tmp_dir, "test.arrow")
2551
            with dset.select(range(10), indices_cache_file_name=tmp_file) as dset:
2552
                self.assertEqual(len(dset), 10)
2553
                # Shard
2554
                tmp_file_1 = os.path.join(tmp_dir, "test_1.arrow")
2555
                fingerprint = dset._fingerprint
2556
                with dset.shard(num_shards=8, index=1, indices_cache_file_name=tmp_file_1) as dset_sharded:
2557
                    self.assertEqual(2, len(dset_sharded))
2558
                    self.assertEqual(["my_name-train_1", "my_name-train_9"], dset_sharded["filename"])
2559
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2560
                    self.assertDictEqual(dset_sharded.features, Features({"filename": Value("string")}))
2561
                    self.assertNotEqual(dset_sharded._fingerprint, fingerprint)
2562
                # Shard contiguous
2563
                tmp_file_2 = os.path.join(tmp_dir, "test_2.arrow")
2564
                with dset.shard(
2565
                    num_shards=3, index=0, contiguous=True, indices_cache_file_name=tmp_file_2
2566
                ) as dset_sharded_contiguous:
2567
                    self.assertEqual([f"my_name-train_{i}" for i in (0, 1, 2, 3)], dset_sharded_contiguous["filename"])
2568
                    self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
2569
                    self.assertDictEqual(dset_sharded_contiguous.features, Features({"filename": Value("string")}))
2570
                    # Test lengths of sharded contiguous
2571
                    self.assertEqual(
2572
                        [4, 3, 3],
2573
                        [
2574
                            len(dset.shard(3, index=i, contiguous=True, indices_cache_file_name=tmp_file_2 + str(i)))
2575
                            for i in range(3)
2576
                        ],
2577
                    )
2578
                # formatted
2579
                dset.set_format("numpy")
2580
                with dset.shard(num_shards=3, index=0) as dset_sharded_formatted:
2581
                    self.assertEqual(dset_sharded_formatted.format["type"], "numpy")
2582

2583
    def test_flatten_indices(self, in_memory):
2584
        with tempfile.TemporaryDirectory() as tmp_dir:
2585
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2586
                self.assertIsNone(dset._indices)
2587

2588
                tmp_file = os.path.join(tmp_dir, "test.arrow")
2589
                with dset.select(range(0, 10, 2), indices_cache_file_name=tmp_file) as dset:
2590
                    self.assertEqual(len(dset), 5)
2591

2592
                    self.assertIsNotNone(dset._indices)
2593

2594
                    tmp_file_2 = os.path.join(tmp_dir, "test_2.arrow")
2595
                    fingerprint = dset._fingerprint
2596
                    dset.set_format("numpy")
2597
                    with dset.flatten_indices(cache_file_name=tmp_file_2) as dset:
2598
                        self.assertEqual(len(dset), 5)
2599
                        self.assertEqual(len(dset.data), len(dset))
2600
                        self.assertIsNone(dset._indices)
2601
                        self.assertNotEqual(dset._fingerprint, fingerprint)
2602
                        self.assertEqual(dset.format["type"], "numpy")
2603
                        # Test unique works
2604
                        dset.unique(dset.column_names[0])
2605
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
2606

2607
        # Empty indices mapping
2608
        with tempfile.TemporaryDirectory() as tmp_dir:
2609
            with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
2610
                self.assertIsNone(dset._indices, None)
2611

2612
                tmp_file = os.path.join(tmp_dir, "test.arrow")
2613
                with dset.filter(lambda _: False, cache_file_name=tmp_file) as dset:
2614
                    self.assertEqual(len(dset), 0)
2615

2616
                    self.assertIsNotNone(dset._indices, None)
2617

2618
                    tmp_file_2 = os.path.join(tmp_dir, "test_2.arrow")
2619
                    fingerprint = dset._fingerprint
2620
                    dset.set_format("numpy")
2621
                    with dset.flatten_indices(cache_file_name=tmp_file_2) as dset:
2622
                        self.assertEqual(len(dset), 0)
2623
                        self.assertEqual(len(dset.data), len(dset))
2624
                        self.assertIsNone(dset._indices, None)
2625
                        self.assertNotEqual(dset._fingerprint, fingerprint)
2626
                        self.assertEqual(dset.format["type"], "numpy")
2627
                        # Test unique works
2628
                        dset.unique(dset.column_names[0])
2629
                        assert_arrow_metadata_are_synced_with_dataset_features(dset)
2630

2631
    @require_tf
2632
    @require_torch
2633
    def test_format_vectors(self, in_memory):
2634
        import numpy as np
2635
        import tensorflow as tf
2636
        import torch
2637

2638
        with tempfile.TemporaryDirectory() as tmp_dir, self._create_dummy_dataset(
2639
            in_memory, tmp_dir
2640
        ) as dset, dset.map(lambda ex, i: {"vec": np.ones(3) * i}, with_indices=True) as dset:
2641
            columns = dset.column_names
2642

2643
            self.assertIsNotNone(dset[0])
2644
            self.assertIsNotNone(dset[:2])
2645
            for col in columns:
2646
                self.assertIsInstance(dset[0][col], (str, list))
2647
                self.assertIsInstance(dset[:2][col], list)
2648
            self.assertDictEqual(
2649
                dset.features, Features({"filename": Value("string"), "vec": Sequence(Value("float64"))})
2650
            )
2651

2652
            dset.set_format("tensorflow")
2653
            self.assertIsNotNone(dset[0])
2654
            self.assertIsNotNone(dset[:2])
2655
            for col in columns:
2656
                self.assertIsInstance(dset[0][col], (tf.Tensor, tf.RaggedTensor))
2657
                self.assertIsInstance(dset[:2][col], (tf.Tensor, tf.RaggedTensor))
2658
                self.assertIsInstance(dset[col], (tf.Tensor, tf.RaggedTensor))
2659
            self.assertTupleEqual(tuple(dset[:2]["vec"].shape), (2, 3))
2660
            self.assertTupleEqual(tuple(dset["vec"][:2].shape), (2, 3))
2661

2662
            dset.set_format("numpy")
2663
            self.assertIsNotNone(dset[0])
2664
            self.assertIsNotNone(dset[:2])
2665
            self.assertIsInstance(dset[0]["filename"], np.str_)
2666
            self.assertIsInstance(dset[:2]["filename"], np.ndarray)
2667
            self.assertIsInstance(dset["filename"], np.ndarray)
2668
            self.assertIsInstance(dset[0]["vec"], np.ndarray)
2669
            self.assertIsInstance(dset[:2]["vec"], np.ndarray)
2670
            self.assertIsInstance(dset["vec"], np.ndarray)
2671
            self.assertTupleEqual(dset[:2]["vec"].shape, (2, 3))
2672
            self.assertTupleEqual(dset["vec"][:2].shape, (2, 3))
2673

2674
            dset.set_format("torch", columns=["vec"])
2675
            self.assertIsNotNone(dset[0])
2676
            self.assertIsNotNone(dset[:2])
2677
            # torch.Tensor is only for numerical columns
2678
            self.assertIsInstance(dset[0]["vec"], torch.Tensor)
2679
            self.assertIsInstance(dset[:2]["vec"], torch.Tensor)
2680
            self.assertIsInstance(dset["vec"][:2], torch.Tensor)
2681
            self.assertTupleEqual(dset[:2]["vec"].shape, (2, 3))
2682
            self.assertTupleEqual(dset["vec"][:2].shape, (2, 3))
2683

2684
    @require_tf
2685
    @require_torch
2686
    def test_format_ragged_vectors(self, in_memory):
2687
        import numpy as np
2688
        import tensorflow as tf
2689
        import torch
2690

2691
        with tempfile.TemporaryDirectory() as tmp_dir, self._create_dummy_dataset(
2692
            in_memory, tmp_dir
2693
        ) as dset, dset.map(lambda ex, i: {"vec": np.ones(3 + i) * i}, with_indices=True) as dset:
2694
            columns = dset.column_names
2695

2696
            self.assertIsNotNone(dset[0])
2697
            self.assertIsNotNone(dset[:2])
2698
            for col in columns:
2699
                self.assertIsInstance(dset[0][col], (str, list))
2700
                self.assertIsInstance(dset[:2][col], list)
2701
            self.assertDictEqual(
2702
                dset.features, Features({"filename": Value("string"), "vec": Sequence(Value("float64"))})
2703
            )
2704

2705
            dset.set_format("tensorflow")
2706
            self.assertIsNotNone(dset[0])
2707
            self.assertIsNotNone(dset[:2])
2708
            for col in columns:
2709
                self.assertIsInstance(dset[0][col], tf.Tensor)
2710
                self.assertIsInstance(dset[:2][col], tf.RaggedTensor if col == "vec" else tf.Tensor)
2711
                self.assertIsInstance(dset[col], tf.RaggedTensor if col == "vec" else tf.Tensor)
2712
            # dim is None for ragged vectors in tensorflow
2713
            self.assertListEqual(dset[:2]["vec"].shape.as_list(), [2, None])
2714
            self.assertListEqual(dset["vec"][:2].shape.as_list(), [2, None])
2715

2716
            dset.set_format("numpy")
2717
            self.assertIsNotNone(dset[0])
2718
            self.assertIsNotNone(dset[:2])
2719
            self.assertIsInstance(dset[0]["filename"], np.str_)
2720
            self.assertIsInstance(dset[:2]["filename"], np.ndarray)
2721
            self.assertIsInstance(dset["filename"], np.ndarray)
2722
            self.assertIsInstance(dset[0]["vec"], np.ndarray)
2723
            self.assertIsInstance(dset[:2]["vec"], np.ndarray)
2724
            self.assertIsInstance(dset["vec"], np.ndarray)
2725
            # array is flat for ragged vectors in numpy
2726
            self.assertTupleEqual(dset[:2]["vec"].shape, (2,))
2727
            self.assertTupleEqual(dset["vec"][:2].shape, (2,))
2728

2729
            dset.set_format("torch")
2730
            self.assertIsNotNone(dset[0])
2731
            self.assertIsNotNone(dset[:2])
2732
            self.assertIsInstance(dset[0]["filename"], str)
2733
            self.assertIsInstance(dset[:2]["filename"], list)
2734
            self.assertIsInstance(dset["filename"], list)
2735
            self.assertIsInstance(dset[0]["vec"], torch.Tensor)
2736
            self.assertIsInstance(dset[:2]["vec"][0], torch.Tensor)
2737
            self.assertIsInstance(dset["vec"][0], torch.Tensor)
2738
            # pytorch doesn't support ragged tensors, so we should have lists
2739
            self.assertIsInstance(dset[:2]["vec"], list)
2740
            self.assertIsInstance(dset[:2]["vec"][0], torch.Tensor)
2741
            self.assertIsInstance(dset["vec"][:2], list)
2742
            self.assertIsInstance(dset["vec"][0], torch.Tensor)
2743

2744
    @require_tf
2745
    @require_torch
2746
    def test_format_nested(self, in_memory):
2747
        import numpy as np
2748
        import tensorflow as tf
2749
        import torch
2750

2751
        with tempfile.TemporaryDirectory() as tmp_dir, self._create_dummy_dataset(
2752
            in_memory, tmp_dir
2753
        ) as dset, dset.map(lambda ex: {"nested": [{"foo": np.ones(3)}] * len(ex["filename"])}, batched=True) as dset:
2754
            self.assertDictEqual(
2755
                dset.features, Features({"filename": Value("string"), "nested": {"foo": Sequence(Value("float64"))}})
2756
            )
2757

2758
            dset.set_format("tensorflow")
2759
            self.assertIsNotNone(dset[0])
2760
            self.assertIsInstance(dset[0]["nested"]["foo"], (tf.Tensor, tf.RaggedTensor))
2761
            self.assertIsNotNone(dset[:2])
2762
            self.assertIsInstance(dset[:2]["nested"][0]["foo"], (tf.Tensor, tf.RaggedTensor))
2763
            self.assertIsInstance(dset["nested"][0]["foo"], (tf.Tensor, tf.RaggedTensor))
2764

2765
            dset.set_format("numpy")
2766
            self.assertIsNotNone(dset[0])
2767
            self.assertIsInstance(dset[0]["nested"]["foo"], np.ndarray)
2768
            self.assertIsNotNone(dset[:2])
2769
            self.assertIsInstance(dset[:2]["nested"][0]["foo"], np.ndarray)
2770
            self.assertIsInstance(dset["nested"][0]["foo"], np.ndarray)
2771

2772
            dset.set_format("torch", columns="nested")
2773
            self.assertIsNotNone(dset[0])
2774
            self.assertIsInstance(dset[0]["nested"]["foo"], torch.Tensor)
2775
            self.assertIsNotNone(dset[:2])
2776
            self.assertIsInstance(dset[:2]["nested"][0]["foo"], torch.Tensor)
2777
            self.assertIsInstance(dset["nested"][0]["foo"], torch.Tensor)
2778

2779
    def test_format_pandas(self, in_memory):
2780
        import pandas as pd
2781

2782
        with tempfile.TemporaryDirectory() as tmp_dir:
2783
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2784
                dset.set_format("pandas")
2785
                self.assertIsInstance(dset[0], pd.DataFrame)
2786
                self.assertIsInstance(dset[:2], pd.DataFrame)
2787
                self.assertIsInstance(dset["col_1"], pd.Series)
2788

2789
    def test_transmit_format_single(self, in_memory):
2790
        @transmit_format
2791
        def my_single_transform(self, return_factory, *args, **kwargs):
2792
            return return_factory()
2793

2794
        with tempfile.TemporaryDirectory() as tmp_dir:
2795
            return_factory = partial(
2796
                self._create_dummy_dataset, in_memory=in_memory, tmp_dir=tmp_dir, multiple_columns=True
2797
            )
2798
            with return_factory() as dset:
2799
                dset.set_format("numpy", columns=["col_1"])
2800
                prev_format = dset.format
2801
                with my_single_transform(dset, return_factory) as transformed_dset:
2802
                    self.assertDictEqual(transformed_dset.format, prev_format)
2803

2804
    def test_transmit_format_dict(self, in_memory):
2805
        @transmit_format
2806
        def my_split_transform(self, return_factory, *args, **kwargs):
2807
            return DatasetDict({"train": return_factory()})
2808

2809
        with tempfile.TemporaryDirectory() as tmp_dir:
2810
            return_factory = partial(
2811
                self._create_dummy_dataset, in_memory=in_memory, tmp_dir=tmp_dir, multiple_columns=True
2812
            )
2813
            with return_factory() as dset:
2814
                dset.set_format("numpy", columns=["col_1"])
2815
                prev_format = dset.format
2816
                transformed_dset = my_split_transform(dset, return_factory)["train"]
2817
                self.assertDictEqual(transformed_dset.format, prev_format)
2818

2819
                del transformed_dset  # DatasetDict
2820

2821
    def test_with_format(self, in_memory):
2822
        with tempfile.TemporaryDirectory() as tmp_dir:
2823
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2824
                with dset.with_format("numpy", columns=["col_1"]) as dset2:
2825
                    dset.set_format("numpy", columns=["col_1"])
2826
                    self.assertDictEqual(dset.format, dset2.format)
2827
                    self.assertEqual(dset._fingerprint, dset2._fingerprint)
2828
                    # dset.reset_format()
2829
                    # self.assertNotEqual(dset.format, dset2.format)
2830
                    # self.assertNotEqual(dset._fingerprint, dset2._fingerprint)
2831

2832
    def test_with_transform(self, in_memory):
2833
        with tempfile.TemporaryDirectory() as tmp_dir:
2834
            with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
2835
                transform = lambda x: {"foo": x["col_1"]}  # noqa: E731
2836
                with dset.with_transform(transform, columns=["col_1"]) as dset2:
2837
                    dset.set_transform(transform, columns=["col_1"])
2838
                    self.assertDictEqual(dset.format, dset2.format)
2839
                    self.assertEqual(dset._fingerprint, dset2._fingerprint)
2840
                    dset.reset_format()
2841
                    self.assertNotEqual(dset.format, dset2.format)
2842
                    self.assertNotEqual(dset._fingerprint, dset2._fingerprint)
2843

2844
    @require_tf
2845
    def test_tf_dataset_conversion(self, in_memory):
2846
        tmp_dir = tempfile.TemporaryDirectory()
2847
        for num_workers in [0, 1, 2]:
2848
            if num_workers > 0 and sys.platform == "win32" and not in_memory:
2849
                continue  # This test hangs on the Py3.10 test worker, but it runs fine locally on my Windows machine
2850
            with self._create_dummy_dataset(in_memory, tmp_dir.name, array_features=True) as dset:
2851
                tf_dataset = dset.to_tf_dataset(columns="col_3", batch_size=2, num_workers=num_workers)
2852
                batch = next(iter(tf_dataset))
2853
                self.assertEqual(batch.shape.as_list(), [2, 4])
2854
                self.assertEqual(batch.dtype.name, "int64")
2855
            with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2856
                tf_dataset = dset.to_tf_dataset(columns="col_1", batch_size=2, num_workers=num_workers)
2857
                batch = next(iter(tf_dataset))
2858
                self.assertEqual(batch.shape.as_list(), [2])
2859
                self.assertEqual(batch.dtype.name, "int64")
2860
            with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2861
                # Check that it works with all default options (except batch_size because the dummy dataset only has 4)
2862
                tf_dataset = dset.to_tf_dataset(batch_size=2, num_workers=num_workers)
2863
                batch = next(iter(tf_dataset))
2864
                self.assertEqual(batch["col_1"].shape.as_list(), [2])
2865
                self.assertEqual(batch["col_2"].shape.as_list(), [2])
2866
                self.assertEqual(batch["col_1"].dtype.name, "int64")
2867
                self.assertEqual(batch["col_2"].dtype.name, "string")  # Assert that we're converting strings properly
2868
            with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2869
                # Check that when we use a transform that creates a new column from existing column values
2870
                # but don't load the old columns that the new column depends on in the final dataset,
2871
                # that they're still kept around long enough to be used in the transform
2872
                transform_dset = dset.with_transform(
2873
                    lambda x: {"new_col": [val * 2 for val in x["col_1"]], "col_1": x["col_1"]}
2874
                )
2875
                tf_dataset = transform_dset.to_tf_dataset(columns="new_col", batch_size=2, num_workers=num_workers)
2876
                batch = next(iter(tf_dataset))
2877
                self.assertEqual(batch.shape.as_list(), [2])
2878
                self.assertEqual(batch.dtype.name, "int64")
2879
                del transform_dset
2880
        del tf_dataset  # For correct cleanup
2881

2882
    @require_tf
2883
    def test_tf_index_reshuffling(self, in_memory):
2884
        # This test checks that when we do two epochs over a tf.data.Dataset from to_tf_dataset
2885
        # that we get a different shuffle order each time
2886
        # It also checks that when we aren't shuffling, that the dataset order is fully preserved
2887
        # even when loading is split across multiple workers
2888
        data = {"col_1": list(range(20))}
2889
        for num_workers in [0, 1, 2, 3]:
2890
            with Dataset.from_dict(data) as dset:
2891
                tf_dataset = dset.to_tf_dataset(batch_size=10, shuffle=True, num_workers=num_workers)
2892
                indices = []
2893
                for batch in tf_dataset:
2894
                    indices.append(batch["col_1"])
2895
                indices = np.concatenate([arr.numpy() for arr in indices])
2896
                second_indices = []
2897
                for batch in tf_dataset:
2898
                    second_indices.append(batch["col_1"])
2899
                second_indices = np.concatenate([arr.numpy() for arr in second_indices])
2900
                self.assertFalse(np.array_equal(indices, second_indices))
2901
                self.assertEqual(len(indices), len(np.unique(indices)))
2902
                self.assertEqual(len(second_indices), len(np.unique(second_indices)))
2903

2904
                tf_dataset = dset.to_tf_dataset(batch_size=1, shuffle=False, num_workers=num_workers)
2905
                for i, batch in enumerate(tf_dataset):
2906
                    # Assert that the unshuffled order is fully preserved even when multiprocessing
2907
                    self.assertEqual(i, batch["col_1"].numpy())
2908

2909
    @require_tf
2910
    def test_tf_label_renaming(self, in_memory):
2911
        # Protect TF-specific imports in here
2912
        import tensorflow as tf
2913

2914
        from datasets.utils.tf_utils import minimal_tf_collate_fn_with_renaming
2915

2916
        tmp_dir = tempfile.TemporaryDirectory()
2917
        with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2918
            with dset.rename_columns({"col_1": "features", "col_2": "label"}) as new_dset:
2919
                tf_dataset = new_dset.to_tf_dataset(collate_fn=minimal_tf_collate_fn_with_renaming, batch_size=4)
2920
                batch = next(iter(tf_dataset))
2921
                self.assertTrue("labels" in batch and "features" in batch)
2922

2923
                tf_dataset = new_dset.to_tf_dataset(
2924
                    columns=["features", "labels"], collate_fn=minimal_tf_collate_fn_with_renaming, batch_size=4
2925
                )
2926
                batch = next(iter(tf_dataset))
2927
                self.assertTrue("labels" in batch and "features" in batch)
2928

2929
                tf_dataset = new_dset.to_tf_dataset(
2930
                    columns=["features", "label"], collate_fn=minimal_tf_collate_fn_with_renaming, batch_size=4
2931
                )
2932
                batch = next(iter(tf_dataset))
2933
                self.assertTrue("labels" in batch and "features" in batch)  # Assert renaming was handled correctly
2934

2935
                tf_dataset = new_dset.to_tf_dataset(
2936
                    columns=["features"],
2937
                    label_cols=["labels"],
2938
                    collate_fn=minimal_tf_collate_fn_with_renaming,
2939
                    batch_size=4,
2940
                )
2941
                batch = next(iter(tf_dataset))
2942
                self.assertEqual(len(batch), 2)
2943
                # Assert that we don't have any empty entries here
2944
                self.assertTrue(isinstance(batch[0], tf.Tensor) and isinstance(batch[1], tf.Tensor))
2945

2946
                tf_dataset = new_dset.to_tf_dataset(
2947
                    columns=["features"],
2948
                    label_cols=["label"],
2949
                    collate_fn=minimal_tf_collate_fn_with_renaming,
2950
                    batch_size=4,
2951
                )
2952
                batch = next(iter(tf_dataset))
2953
                self.assertEqual(len(batch), 2)
2954
                # Assert that we don't have any empty entries here
2955
                self.assertTrue(isinstance(batch[0], tf.Tensor) and isinstance(batch[1], tf.Tensor))
2956

2957
                tf_dataset = new_dset.to_tf_dataset(
2958
                    columns=["features"],
2959
                    collate_fn=minimal_tf_collate_fn_with_renaming,
2960
                    batch_size=4,
2961
                )
2962
                batch = next(iter(tf_dataset))
2963
                # Assert that labels didn't creep in when we don't ask for them
2964
                # just because the collate_fn added them
2965
                self.assertTrue(isinstance(batch, tf.Tensor))
2966

2967
        del tf_dataset  # For correct cleanup
2968

2969
    @require_tf
2970
    def test_tf_dataset_options(self, in_memory):
2971
        tmp_dir = tempfile.TemporaryDirectory()
2972
        # Test that batch_size option works as expected
2973
        with self._create_dummy_dataset(in_memory, tmp_dir.name, array_features=True) as dset:
2974
            tf_dataset = dset.to_tf_dataset(columns="col_3", batch_size=2)
2975
            batch = next(iter(tf_dataset))
2976
            self.assertEqual(batch.shape.as_list(), [2, 4])
2977
            self.assertEqual(batch.dtype.name, "int64")
2978
        # Test that batch_size=None (optional) works as expected
2979
        with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2980
            tf_dataset = dset.to_tf_dataset(columns="col_3", batch_size=None)
2981
            single_example = next(iter(tf_dataset))
2982
            self.assertEqual(single_example.shape.as_list(), [])
2983
            self.assertEqual(single_example.dtype.name, "int64")
2984
            # Assert that we can batch it with `tf.data.Dataset.batch` method
2985
            batched_dataset = tf_dataset.batch(batch_size=2)
2986
            batch = next(iter(batched_dataset))
2987
            self.assertEqual(batch.shape.as_list(), [2])
2988
            self.assertEqual(batch.dtype.name, "int64")
2989
        # Test that batching a batch_size=None dataset produces the same results as using batch_size arg
2990
        with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
2991
            batch_size = 2
2992
            tf_dataset_no_batch = dset.to_tf_dataset(columns="col_3")
2993
            tf_dataset_batch = dset.to_tf_dataset(columns="col_3", batch_size=batch_size)
2994
            self.assertEqual(tf_dataset_no_batch.element_spec, tf_dataset_batch.unbatch().element_spec)
2995
            self.assertEqual(tf_dataset_no_batch.cardinality(), tf_dataset_batch.cardinality() * batch_size)
2996
            for batch_1, batch_2 in zip(tf_dataset_no_batch.batch(batch_size=batch_size), tf_dataset_batch):
2997
                self.assertEqual(batch_1.shape, batch_2.shape)
2998
                self.assertEqual(batch_1.dtype, batch_2.dtype)
2999
                self.assertListEqual(batch_1.numpy().tolist(), batch_2.numpy().tolist())
3000
        # Test that requesting label_cols works as expected
3001
        with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
3002
            tf_dataset = dset.to_tf_dataset(columns="col_1", label_cols=["col_2", "col_3"], batch_size=4)
3003
            batch = next(iter(tf_dataset))
3004
            self.assertEqual(len(batch), 2)
3005
            self.assertEqual(set(batch[1].keys()), {"col_2", "col_3"})
3006
            self.assertEqual(batch[0].dtype.name, "int64")
3007
            # Assert data comes out as expected and isn't shuffled
3008
            self.assertEqual(batch[0].numpy().tolist(), [3, 2, 1, 0])
3009
            self.assertEqual(batch[1]["col_2"].numpy().tolist(), [b"a", b"b", b"c", b"d"])
3010
            self.assertEqual(batch[1]["col_3"].numpy().tolist(), [0, 1, 0, 1])
3011
        # Check that incomplete batches are dropped if requested
3012
        with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
3013
            tf_dataset = dset.to_tf_dataset(columns="col_1", batch_size=3)
3014
            tf_dataset_with_drop = dset.to_tf_dataset(columns="col_1", batch_size=3, drop_remainder=True)
3015
            self.assertEqual(len(tf_dataset), 2)  # One batch of 3 and one batch of 1
3016
            self.assertEqual(len(tf_dataset_with_drop), 1)  # Incomplete batch of 1 is dropped
3017
        # Test that `NotImplementedError` is raised `batch_size` is None and `num_workers` is > 0
3018
        if sys.version_info >= (3, 8):
3019
            with self._create_dummy_dataset(in_memory, tmp_dir.name, multiple_columns=True) as dset:
3020
                with self.assertRaisesRegex(
3021
                    NotImplementedError, "`batch_size` must be specified when using multiple workers"
3022
                ):
3023
                    dset.to_tf_dataset(columns="col_1", batch_size=None, num_workers=2)
3024
        del tf_dataset  # For correct cleanup
3025
        del tf_dataset_with_drop
3026

3027

3028
class MiscellaneousDatasetTest(TestCase):
3029
    def test_from_pandas(self):
3030
        data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
3031
        df = pd.DataFrame.from_dict(data)
3032
        with Dataset.from_pandas(df) as dset:
3033
            self.assertListEqual(dset["col_1"], data["col_1"])
3034
            self.assertListEqual(dset["col_2"], data["col_2"])
3035
            self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"])
3036
            self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")}))
3037

3038
        features = Features({"col_1": Value("int64"), "col_2": Value("string")})
3039
        with Dataset.from_pandas(df, features=features) as dset:
3040
            self.assertListEqual(dset["col_1"], data["col_1"])
3041
            self.assertListEqual(dset["col_2"], data["col_2"])
3042
            self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"])
3043
            self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")}))
3044

3045
        features = Features({"col_1": Value("int64"), "col_2": Value("string")})
3046
        with Dataset.from_pandas(df, features=features, info=DatasetInfo(features=features)) as dset:
3047
            self.assertListEqual(dset["col_1"], data["col_1"])
3048
            self.assertListEqual(dset["col_2"], data["col_2"])
3049
            self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"])
3050
            self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")}))
3051

3052
        features = Features({"col_1": Sequence(Value("string")), "col_2": Value("string")})
3053
        self.assertRaises(TypeError, Dataset.from_pandas, df, features=features)
3054

3055
    def test_from_dict(self):
3056
        data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"], "col_3": pa.array([True, False, True, False])}
3057
        with Dataset.from_dict(data) as dset:
3058
            self.assertListEqual(dset["col_1"], data["col_1"])
3059
            self.assertListEqual(dset["col_2"], data["col_2"])
3060
            self.assertListEqual(dset["col_3"], data["col_3"].to_pylist())
3061
            self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2", "col_3"])
3062
            self.assertDictEqual(
3063
                dset.features, Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3064
            )
3065

3066
        features = Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3067
        with Dataset.from_dict(data, features=features) as dset:
3068
            self.assertListEqual(dset["col_1"], data["col_1"])
3069
            self.assertListEqual(dset["col_2"], data["col_2"])
3070
            self.assertListEqual(dset["col_3"], data["col_3"].to_pylist())
3071
            self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2", "col_3"])
3072
            self.assertDictEqual(
3073
                dset.features, Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3074
            )
3075

3076
        features = Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3077
        with Dataset.from_dict(data, features=features, info=DatasetInfo(features=features)) as dset:
3078
            self.assertListEqual(dset["col_1"], data["col_1"])
3079
            self.assertListEqual(dset["col_2"], data["col_2"])
3080
            self.assertListEqual(dset["col_3"], data["col_3"].to_pylist())
3081
            self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2", "col_3"])
3082
            self.assertDictEqual(
3083
                dset.features, Features({"col_1": Value("int64"), "col_2": Value("string"), "col_3": Value("bool")})
3084
            )
3085

3086
        features = Features({"col_1": Value("string"), "col_2": Value("string"), "col_3": Value("int32")})
3087
        with Dataset.from_dict(data, features=features) as dset:
3088
            # the integers are converted to strings
3089
            self.assertListEqual(dset["col_1"], [str(x) for x in data["col_1"]])
3090
            self.assertListEqual(dset["col_2"], data["col_2"])
3091
            self.assertListEqual(dset["col_3"], [int(x) for x in data["col_3"].to_pylist()])
3092
            self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2", "col_3"])
3093
            self.assertDictEqual(
3094
                dset.features, Features({"col_1": Value("string"), "col_2": Value("string"), "col_3": Value("int32")})
3095
            )
3096

3097
        features = Features({"col_1": Value("int64"), "col_2": Value("int64"), "col_3": Value("bool")})
3098
        self.assertRaises(ValueError, Dataset.from_dict, data, features=features)
3099

3100
    def test_concatenate_mixed_memory_and_disk(self):
3101
        data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
3102
        info1 = DatasetInfo(description="Dataset1")
3103
        info2 = DatasetInfo(description="Dataset2")
3104
        with tempfile.TemporaryDirectory() as tmp_dir:
3105
            with Dataset.from_dict(data1, info=info1).map(
3106
                cache_file_name=os.path.join(tmp_dir, "d1.arrow")
3107
            ) as dset1, Dataset.from_dict(data2, info=info2).map(
3108
                cache_file_name=os.path.join(tmp_dir, "d2.arrow")
3109
            ) as dset2, Dataset.from_dict(data3) as dset3:
3110
                with concatenate_datasets([dset1, dset2, dset3]) as concatenated_dset:
3111
                    self.assertEqual(len(concatenated_dset), len(dset1) + len(dset2) + len(dset3))
3112
                    self.assertListEqual(concatenated_dset["id"], dset1["id"] + dset2["id"] + dset3["id"])
3113

3114
    @require_transformers
3115
    @pytest.mark.integration
3116
    def test_set_format_encode(self):
3117
        from transformers import BertTokenizer
3118

3119
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
3120

3121
        def encode(batch):
3122
            return tokenizer(batch["text"], padding="longest", return_tensors="np")
3123

3124
        with Dataset.from_dict({"text": ["hello there", "foo"]}) as dset:
3125
            dset.set_transform(transform=encode)
3126
            self.assertEqual(str(dset[:2]), str(encode({"text": ["hello there", "foo"]})))
3127

3128
    @require_tf
3129
    def test_tf_string_encoding(self):
3130
        data = {"col_1": ["á", "é", "í", "ó", "ú"], "col_2": ["à", "è", "ì", "ò", "ù"]}
3131
        with Dataset.from_dict(data) as dset:
3132
            tf_dset_wo_batch = dset.to_tf_dataset(columns=["col_1", "col_2"])
3133
            for tf_row, row in zip(tf_dset_wo_batch, dset):
3134
                self.assertEqual(tf_row["col_1"].numpy().decode("utf-8"), row["col_1"])
3135
                self.assertEqual(tf_row["col_2"].numpy().decode("utf-8"), row["col_2"])
3136

3137
            tf_dset_w_batch = dset.to_tf_dataset(columns=["col_1", "col_2"], batch_size=2)
3138
            for tf_row, row in zip(tf_dset_w_batch.unbatch(), dset):
3139
                self.assertEqual(tf_row["col_1"].numpy().decode("utf-8"), row["col_1"])
3140
                self.assertEqual(tf_row["col_2"].numpy().decode("utf-8"), row["col_2"])
3141

3142
            self.assertEqual(tf_dset_w_batch.unbatch().element_spec, tf_dset_wo_batch.element_spec)
3143
            self.assertEqual(tf_dset_w_batch.element_spec, tf_dset_wo_batch.batch(2).element_spec)
3144

3145

3146
def test_cast_with_sliced_list():
3147
    old_features = Features({"foo": Sequence(Value("int64"))})
3148
    new_features = Features({"foo": Sequence(Value("int32"))})
3149
    dataset = Dataset.from_dict({"foo": [[i] * (i % 3) for i in range(20)]}, features=old_features)
3150
    casted_dataset = dataset.cast(new_features, batch_size=2)  # small batch size to slice the ListArray
3151
    assert dataset["foo"] == casted_dataset["foo"]
3152
    assert casted_dataset.features == new_features
3153

3154

3155
@pytest.mark.parametrize("include_nulls", [False, True])
3156
def test_class_encode_column_with_none(include_nulls):
3157
    dataset = Dataset.from_dict({"col_1": ["a", "b", "c", None, "d", None]})
3158
    dataset = dataset.class_encode_column("col_1", include_nulls=include_nulls)
3159
    class_names = ["a", "b", "c", "d"]
3160
    if include_nulls:
3161
        class_names += ["None"]
3162
    assert isinstance(dataset.features["col_1"], ClassLabel)
3163
    assert set(dataset.features["col_1"].names) == set(class_names)
3164
    assert (None in dataset.unique("col_1")) == (not include_nulls)
3165

3166

3167
@pytest.mark.parametrize("null_placement", ["first", "last"])
3168
def test_sort_with_none(null_placement):
3169
    dataset = Dataset.from_dict({"col_1": ["item_2", "item_3", "item_1", None, "item_4", None]})
3170
    dataset = dataset.sort("col_1", null_placement=null_placement)
3171
    if null_placement == "first":
3172
        assert dataset["col_1"] == [None, None, "item_1", "item_2", "item_3", "item_4"]
3173
    else:
3174
        assert dataset["col_1"] == ["item_1", "item_2", "item_3", "item_4", None, None]
3175

3176

3177
def test_update_metadata_with_features(dataset_dict):
3178
    table1 = pa.Table.from_pydict(dataset_dict)
3179
    features1 = Features.from_arrow_schema(table1.schema)
3180
    features2 = features1.copy()
3181
    features2["col_2"] = ClassLabel(num_classes=len(table1))
3182
    assert features1 != features2
3183

3184
    table2 = update_metadata_with_features(table1, features2)
3185
    metadata = json.loads(table2.schema.metadata[b"huggingface"].decode())
3186
    assert features2 == Features.from_dict(metadata["info"]["features"])
3187

3188
    with Dataset(table1) as dset1, Dataset(table2) as dset2:
3189
        assert dset1.features == features1
3190
        assert dset2.features == features2
3191

3192

3193
@pytest.mark.parametrize("dataset_type", ["in_memory", "memory_mapped", "mixed"])
3194
@pytest.mark.parametrize("axis, expected_shape", [(0, (4, 3)), (1, (2, 6))])
3195
def test_concatenate_datasets(dataset_type, axis, expected_shape, dataset_dict, arrow_path):
3196
    table = {
3197
        "in_memory": InMemoryTable.from_pydict(dataset_dict),
3198
        "memory_mapped": MemoryMappedTable.from_file(arrow_path),
3199
    }
3200
    tables = [
3201
        table[dataset_type if dataset_type != "mixed" else "memory_mapped"].slice(0, 2),  # shape = (2, 3)
3202
        table[dataset_type if dataset_type != "mixed" else "in_memory"].slice(2, 4),  # shape = (2, 3)
3203
    ]
3204
    if axis == 1:  # don't duplicate columns
3205
        tables[1] = tables[1].rename_columns([col + "_bis" for col in tables[1].column_names])
3206
    datasets = [Dataset(table) for table in tables]
3207
    dataset = concatenate_datasets(datasets, axis=axis)
3208
    assert dataset.shape == expected_shape
3209
    assert_arrow_metadata_are_synced_with_dataset_features(dataset)
3210

3211

3212
def test_concatenate_datasets_new_columns():
3213
    dataset1 = Dataset.from_dict({"col_1": ["a", "b", "c"]})
3214
    dataset2 = Dataset.from_dict({"col_1": ["d", "e", "f"], "col_2": [True, False, True]})
3215
    dataset = concatenate_datasets([dataset1, dataset2])
3216
    assert dataset.data.shape == (6, 2)
3217
    assert dataset.features == Features({"col_1": Value("string"), "col_2": Value("bool")})
3218
    assert dataset[:] == {"col_1": ["a", "b", "c", "d", "e", "f"], "col_2": [None, None, None, True, False, True]}
3219
    dataset3 = Dataset.from_dict({"col_3": ["a_1"]})
3220
    dataset = concatenate_datasets([dataset, dataset3])
3221
    assert dataset.data.shape == (7, 3)
3222
    assert dataset.features == Features({"col_1": Value("string"), "col_2": Value("bool"), "col_3": Value("string")})
3223
    assert dataset[:] == {
3224
        "col_1": ["a", "b", "c", "d", "e", "f", None],
3225
        "col_2": [None, None, None, True, False, True, None],
3226
        "col_3": [None, None, None, None, None, None, "a_1"],
3227
    }
3228

3229

3230
@pytest.mark.parametrize("axis", [0, 1])
3231
def test_concatenate_datasets_complex_features(axis):
3232
    n = 5
3233
    dataset1 = Dataset.from_dict(
3234
        {"col_1": [0] * n, "col_2": list(range(n))},
3235
        features=Features({"col_1": Value("int32"), "col_2": ClassLabel(num_classes=n)}),
3236
    )
3237
    if axis == 1:
3238
        dataset2 = dataset1.rename_columns({col: col + "_" for col in dataset1.column_names})
3239
        expected_features = Features({**dataset1.features, **dataset2.features})
3240
    else:
3241
        dataset2 = dataset1
3242
        expected_features = dataset1.features
3243
    assert concatenate_datasets([dataset1, dataset2], axis=axis).features == expected_features
3244

3245

3246
@pytest.mark.parametrize("other_dataset_type", ["in_memory", "memory_mapped", "concatenation"])
3247
@pytest.mark.parametrize("axis, expected_shape", [(0, (8, 3)), (1, (4, 6))])
3248
def test_concatenate_datasets_with_concatenation_tables(
3249
    axis, expected_shape, other_dataset_type, dataset_dict, arrow_path
3250
):
3251
    def _create_concatenation_table(axis):
3252
        if axis == 0:  # shape: (4, 3) = (4, 1) + (4, 2)
3253
            concatenation_table = ConcatenationTable.from_blocks(
3254
                [
3255
                    [
3256
                        InMemoryTable.from_pydict({"col_1": dataset_dict["col_1"]}),
3257
                        MemoryMappedTable.from_file(arrow_path).remove_column(0),
3258
                    ]
3259
                ]
3260
            )
3261
        elif axis == 1:  # shape: (4, 3) = (1, 3) + (3, 3)
3262
            concatenation_table = ConcatenationTable.from_blocks(
3263
                [
3264
                    [InMemoryTable.from_pydict(dataset_dict).slice(0, 1)],
3265
                    [MemoryMappedTable.from_file(arrow_path).slice(1, 4)],
3266
                ]
3267
            )
3268
        return concatenation_table
3269

3270
    concatenation_table = _create_concatenation_table(axis)
3271
    assert concatenation_table.shape == (4, 3)
3272

3273
    if other_dataset_type == "in_memory":
3274
        other_table = InMemoryTable.from_pydict(dataset_dict)
3275
    elif other_dataset_type == "memory_mapped":
3276
        other_table = MemoryMappedTable.from_file(arrow_path)
3277
    elif other_dataset_type == "concatenation":
3278
        other_table = _create_concatenation_table(axis)
3279
    assert other_table.shape == (4, 3)
3280

3281
    tables = [concatenation_table, other_table]
3282

3283
    if axis == 1:  # don't duplicate columns
3284
        tables[1] = tables[1].rename_columns([col + "_bis" for col in tables[1].column_names])
3285

3286
    for tables in [tables, reversed(tables)]:
3287
        datasets = [Dataset(table) for table in tables]
3288
        dataset = concatenate_datasets(datasets, axis=axis)
3289
        assert dataset.shape == expected_shape
3290

3291

3292
def test_concatenate_datasets_duplicate_columns(dataset):
3293
    with pytest.raises(ValueError) as excinfo:
3294
        concatenate_datasets([dataset, dataset], axis=1)
3295
    assert "duplicated" in str(excinfo.value)
3296

3297

3298
def test_interleave_datasets():
3299
    d1 = Dataset.from_dict({"a": [0, 1, 2]})
3300
    d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
3301
    d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
3302
    dataset = interleave_datasets([d1, d2, d3])
3303
    expected_length = 3 * min(len(d1), len(d2), len(d3))
3304
    expected_values = [x["a"] for x in itertools.chain(*zip(d1, d2, d3))]
3305
    assert isinstance(dataset, Dataset)
3306
    assert len(dataset) == expected_length
3307
    assert dataset["a"] == expected_values
3308
    assert dataset._fingerprint == interleave_datasets([d1, d2, d3])._fingerprint
3309

3310

3311
def test_interleave_datasets_probabilities():
3312
    seed = 42
3313
    probabilities = [0.3, 0.5, 0.2]
3314
    d1 = Dataset.from_dict({"a": [0, 1, 2]})
3315
    d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
3316
    d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
3317
    dataset = interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed)
3318
    expected_length = 7  # hardcoded
3319
    expected_values = [10, 11, 20, 12, 0, 21, 13]  # hardcoded
3320
    assert isinstance(dataset, Dataset)
3321
    assert len(dataset) == expected_length
3322
    assert dataset["a"] == expected_values
3323
    assert (
3324
        dataset._fingerprint == interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed)._fingerprint
3325
    )
3326

3327

3328
def test_interleave_datasets_oversampling_strategy():
3329
    d1 = Dataset.from_dict({"a": [0, 1, 2]})
3330
    d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
3331
    d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
3332
    dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
3333
    expected_length = 3 * max(len(d1), len(d2), len(d3))
3334
    expected_values = [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 20]  # hardcoded
3335
    assert isinstance(dataset, Dataset)
3336
    assert len(dataset) == expected_length
3337
    assert dataset["a"] == expected_values
3338
    assert dataset._fingerprint == interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")._fingerprint
3339

3340

3341
def test_interleave_datasets_probabilities_oversampling_strategy():
3342
    seed = 42
3343
    probabilities = [0.3, 0.5, 0.2]
3344
    d1 = Dataset.from_dict({"a": [0, 1, 2]})
3345
    d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
3346
    d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
3347
    dataset = interleave_datasets(
3348
        [d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed
3349
    )
3350
    expected_length = 16  # hardcoded
3351
    expected_values = [10, 11, 20, 12, 0, 21, 13, 10, 1, 11, 12, 22, 13, 20, 10, 2]  # hardcoded
3352
    assert isinstance(dataset, Dataset)
3353
    assert len(dataset) == expected_length
3354
    assert dataset["a"] == expected_values
3355
    assert (
3356
        dataset._fingerprint
3357
        == interleave_datasets(
3358
            [d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed
3359
        )._fingerprint
3360
    )
3361

3362

3363
@pytest.mark.parametrize("batch_size", [4, 5])
3364
@pytest.mark.parametrize("drop_last_batch", [False, True])
3365
def test_dataset_iter_batch(batch_size, drop_last_batch):
3366
    n = 25
3367
    dset = Dataset.from_dict({"i": list(range(n))})
3368
    all_col_values = list(range(n))
3369
    batches = []
3370
    for i, batch in enumerate(dset.iter(batch_size, drop_last_batch=drop_last_batch)):
3371
        assert batch == {"i": all_col_values[i * batch_size : (i + 1) * batch_size]}
3372
        batches.append(batch)
3373
    if drop_last_batch:
3374
        assert all(len(batch["i"]) == batch_size for batch in batches)
3375
    else:
3376
        assert all(len(batch["i"]) == batch_size for batch in batches[:-1])
3377
        assert len(batches[-1]["i"]) <= batch_size
3378

3379

3380
@pytest.mark.parametrize(
3381
    "column, expected_dtype",
3382
    [(["a", "b", "c", "d"], "string"), ([1, 2, 3, 4], "int64"), ([1.0, 2.0, 3.0, 4.0], "float64")],
3383
)
3384
@pytest.mark.parametrize("in_memory", [False, True])
3385
@pytest.mark.parametrize(
3386
    "transform",
3387
    [
3388
        None,
3389
        ("shuffle", (42,), {}),
3390
        ("with_format", ("pandas",), {}),
3391
        ("class_encode_column", ("col_2",), {}),
3392
        ("select", (range(3),), {}),
3393
    ],
3394
)
3395
def test_dataset_add_column(column, expected_dtype, in_memory, transform, dataset_dict, arrow_path):
3396
    column_name = "col_4"
3397
    original_dataset = (
3398
        Dataset(InMemoryTable.from_pydict(dataset_dict))
3399
        if in_memory
3400
        else Dataset(MemoryMappedTable.from_file(arrow_path))
3401
    )
3402
    if transform is not None:
3403
        transform_name, args, kwargs = transform
3404
        original_dataset: Dataset = getattr(original_dataset, transform_name)(*args, **kwargs)
3405
    column = column[:3] if transform is not None and transform_name == "select" else column
3406
    dataset = original_dataset.add_column(column_name, column)
3407
    assert dataset.data.shape == (3, 4) if transform is not None and transform_name == "select" else (4, 4)
3408
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3409
    # Sort expected features as in the original dataset
3410
    expected_features = {feature: expected_features[feature] for feature in original_dataset.features}
3411
    # Add new column feature
3412
    expected_features[column_name] = expected_dtype
3413
    assert dataset.data.column_names == list(expected_features.keys())
3414
    for feature, expected_dtype in expected_features.items():
3415
        assert dataset.features[feature].dtype == expected_dtype
3416
    assert len(dataset.data.blocks) == 1 if in_memory else 2  # multiple InMemoryTables are consolidated as one
3417
    assert dataset.format["type"] == original_dataset.format["type"]
3418
    assert dataset._fingerprint != original_dataset._fingerprint
3419
    dataset.reset_format()
3420
    original_dataset.reset_format()
3421
    assert all(dataset[col] == original_dataset[col] for col in original_dataset.column_names)
3422
    assert set(dataset["col_4"]) == set(column)
3423
    if dataset._indices is not None:
3424
        dataset_indices = dataset._indices["indices"].to_pylist()
3425
        expected_dataset_indices = original_dataset._indices["indices"].to_pylist()
3426
        assert dataset_indices == expected_dataset_indices
3427
    assert_arrow_metadata_are_synced_with_dataset_features(dataset)
3428

3429

3430
@pytest.mark.parametrize(
3431
    "transform",
3432
    [None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})],
3433
)
3434
@pytest.mark.parametrize("in_memory", [False, True])
3435
@pytest.mark.parametrize(
3436
    "item",
3437
    [
3438
        {"col_1": "2", "col_2": 2, "col_3": 2.0},
3439
        {"col_1": "2", "col_2": "2", "col_3": "2"},
3440
        {"col_1": 2, "col_2": 2, "col_3": 2},
3441
        {"col_1": 2.0, "col_2": 2.0, "col_3": 2.0},
3442
    ],
3443
)
3444
def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path, transform):
3445
    dataset_to_test = (
3446
        Dataset(InMemoryTable.from_pydict(dataset_dict))
3447
        if in_memory
3448
        else Dataset(MemoryMappedTable.from_file(arrow_path))
3449
    )
3450
    if transform is not None:
3451
        transform_name, args, kwargs = transform
3452
        dataset_to_test: Dataset = getattr(dataset_to_test, transform_name)(*args, **kwargs)
3453
    dataset = dataset_to_test.add_item(item)
3454
    assert dataset.data.shape == (5, 3)
3455
    expected_features = dataset_to_test.features
3456
    assert sorted(dataset.data.column_names) == sorted(expected_features.keys())
3457
    for feature, expected_dtype in expected_features.items():
3458
        assert dataset.features[feature] == expected_dtype
3459
    assert len(dataset.data.blocks) == 1 if in_memory else 2  # multiple InMemoryTables are consolidated as one
3460
    assert dataset.format["type"] == dataset_to_test.format["type"]
3461
    assert dataset._fingerprint != dataset_to_test._fingerprint
3462
    dataset.reset_format()
3463
    dataset_to_test.reset_format()
3464
    assert dataset[:-1] == dataset_to_test[:]
3465
    assert {k: int(v) for k, v in dataset[-1].items()} == {k: int(v) for k, v in item.items()}
3466
    if dataset._indices is not None:
3467
        dataset_indices = dataset._indices["indices"].to_pylist()
3468
        dataset_to_test_indices = dataset_to_test._indices["indices"].to_pylist()
3469
        assert dataset_indices == dataset_to_test_indices + [len(dataset_to_test._data)]
3470

3471

3472
def test_dataset_add_item_new_columns():
3473
    dataset = Dataset.from_dict({"col_1": [0, 1, 2]}, features=Features({"col_1": Value("uint8")}))
3474
    dataset = dataset.add_item({"col_1": 3, "col_2": "a"})
3475
    assert dataset.data.shape == (4, 2)
3476
    assert dataset.features == Features({"col_1": Value("uint8"), "col_2": Value("string")})
3477
    assert dataset[:] == {"col_1": [0, 1, 2, 3], "col_2": [None, None, None, "a"]}
3478
    dataset = dataset.add_item({"col_3": True})
3479
    assert dataset.data.shape == (5, 3)
3480
    assert dataset.features == Features({"col_1": Value("uint8"), "col_2": Value("string"), "col_3": Value("bool")})
3481
    assert dataset[:] == {
3482
        "col_1": [0, 1, 2, 3, None],
3483
        "col_2": [None, None, None, "a", None],
3484
        "col_3": [None, None, None, None, True],
3485
    }
3486

3487

3488
def test_dataset_add_item_introduce_feature_type():
3489
    dataset = Dataset.from_dict({"col_1": [None, None, None]})
3490
    dataset = dataset.add_item({"col_1": "a"})
3491
    assert dataset.data.shape == (4, 1)
3492
    assert dataset.features == Features({"col_1": Value("string")})
3493
    assert dataset[:] == {"col_1": [None, None, None, "a"]}
3494

3495

3496
def test_dataset_filter_batched_indices():
3497
    ds = Dataset.from_dict({"num": [0, 1, 2, 3]})
3498
    ds = ds.filter(lambda num: num % 2 == 0, input_columns="num", batch_size=2)
3499
    assert all(item["num"] % 2 == 0 for item in ds)
3500

3501

3502
@pytest.mark.parametrize("in_memory", [False, True])
3503
def test_dataset_from_file(in_memory, dataset, arrow_file):
3504
    filename = arrow_file
3505
    with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase():
3506
        dataset_from_file = Dataset.from_file(filename, in_memory=in_memory)
3507
    assert dataset_from_file.features.type == dataset.features.type
3508
    assert dataset_from_file.features == dataset.features
3509
    assert dataset_from_file.cache_files == ([{"filename": filename}] if not in_memory else [])
3510

3511

3512
def _check_csv_dataset(dataset, expected_features):
3513
    assert isinstance(dataset, Dataset)
3514
    assert dataset.num_rows == 4
3515
    assert dataset.num_columns == 3
3516
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
3517
    for feature, expected_dtype in expected_features.items():
3518
        assert dataset.features[feature].dtype == expected_dtype
3519

3520

3521
@pytest.mark.parametrize("keep_in_memory", [False, True])
3522
def test_dataset_from_csv_keep_in_memory(keep_in_memory, csv_path, tmp_path):
3523
    cache_dir = tmp_path / "cache"
3524
    expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
3525
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3526
        dataset = Dataset.from_csv(csv_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3527
    _check_csv_dataset(dataset, expected_features)
3528

3529

3530
@pytest.mark.parametrize(
3531
    "features",
3532
    [
3533
        None,
3534
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
3535
        {"col_1": "string", "col_2": "string", "col_3": "string"},
3536
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3537
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3538
    ],
3539
)
3540
def test_dataset_from_csv_features(features, csv_path, tmp_path):
3541
    cache_dir = tmp_path / "cache"
3542
    # CSV file loses col_1 string dtype information: default now is "int64" instead of "string"
3543
    default_expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
3544
    expected_features = features.copy() if features else default_expected_features
3545
    features = (
3546
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3547
    )
3548
    dataset = Dataset.from_csv(csv_path, features=features, cache_dir=cache_dir)
3549
    _check_csv_dataset(dataset, expected_features)
3550

3551

3552
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
3553
def test_dataset_from_csv_split(split, csv_path, tmp_path):
3554
    cache_dir = tmp_path / "cache"
3555
    expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
3556
    dataset = Dataset.from_csv(csv_path, cache_dir=cache_dir, split=split)
3557
    _check_csv_dataset(dataset, expected_features)
3558
    assert dataset.split == split if split else "train"
3559

3560

3561
@pytest.mark.parametrize("path_type", [str, list])
3562
def test_dataset_from_csv_path_type(path_type, csv_path, tmp_path):
3563
    if issubclass(path_type, str):
3564
        path = csv_path
3565
    elif issubclass(path_type, list):
3566
        path = [csv_path]
3567
    cache_dir = tmp_path / "cache"
3568
    expected_features = {"col_1": "int64", "col_2": "int64", "col_3": "float64"}
3569
    dataset = Dataset.from_csv(path, cache_dir=cache_dir)
3570
    _check_csv_dataset(dataset, expected_features)
3571

3572

3573
def _check_json_dataset(dataset, expected_features):
3574
    assert isinstance(dataset, Dataset)
3575
    assert dataset.num_rows == 4
3576
    assert dataset.num_columns == 3
3577
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
3578
    for feature, expected_dtype in expected_features.items():
3579
        assert dataset.features[feature].dtype == expected_dtype
3580

3581

3582
@pytest.mark.parametrize("keep_in_memory", [False, True])
3583
def test_dataset_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_path):
3584
    cache_dir = tmp_path / "cache"
3585
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3586
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3587
        dataset = Dataset.from_json(jsonl_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3588
    _check_json_dataset(dataset, expected_features)
3589

3590

3591
@pytest.mark.parametrize(
3592
    "features",
3593
    [
3594
        None,
3595
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
3596
        {"col_1": "string", "col_2": "string", "col_3": "string"},
3597
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3598
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3599
    ],
3600
)
3601
def test_dataset_from_json_features(features, jsonl_path, tmp_path):
3602
    cache_dir = tmp_path / "cache"
3603
    default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3604
    expected_features = features.copy() if features else default_expected_features
3605
    features = (
3606
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3607
    )
3608
    dataset = Dataset.from_json(jsonl_path, features=features, cache_dir=cache_dir)
3609
    _check_json_dataset(dataset, expected_features)
3610

3611

3612
def test_dataset_from_json_with_class_label_feature(jsonl_str_path, tmp_path):
3613
    features = Features(
3614
        {"col_1": ClassLabel(names=["s0", "s1", "s2", "s3"]), "col_2": Value("int64"), "col_3": Value("float64")}
3615
    )
3616
    cache_dir = tmp_path / "cache"
3617
    dataset = Dataset.from_json(jsonl_str_path, features=features, cache_dir=cache_dir)
3618
    assert dataset.features["col_1"].dtype == "int64"
3619

3620

3621
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
3622
def test_dataset_from_json_split(split, jsonl_path, tmp_path):
3623
    cache_dir = tmp_path / "cache"
3624
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3625
    dataset = Dataset.from_json(jsonl_path, cache_dir=cache_dir, split=split)
3626
    _check_json_dataset(dataset, expected_features)
3627
    assert dataset.split == split if split else "train"
3628

3629

3630
@pytest.mark.parametrize("path_type", [str, list])
3631
def test_dataset_from_json_path_type(path_type, jsonl_path, tmp_path):
3632
    if issubclass(path_type, str):
3633
        path = jsonl_path
3634
    elif issubclass(path_type, list):
3635
        path = [jsonl_path]
3636
    cache_dir = tmp_path / "cache"
3637
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3638
    dataset = Dataset.from_json(path, cache_dir=cache_dir)
3639
    _check_json_dataset(dataset, expected_features)
3640

3641

3642
def _check_parquet_dataset(dataset, expected_features):
3643
    assert isinstance(dataset, Dataset)
3644
    assert dataset.num_rows == 4
3645
    assert dataset.num_columns == 3
3646
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
3647
    for feature, expected_dtype in expected_features.items():
3648
        assert dataset.features[feature].dtype == expected_dtype
3649

3650

3651
@pytest.mark.parametrize("keep_in_memory", [False, True])
3652
def test_dataset_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_path):
3653
    cache_dir = tmp_path / "cache"
3654
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3655
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3656
        dataset = Dataset.from_parquet(parquet_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3657
    _check_parquet_dataset(dataset, expected_features)
3658

3659

3660
@pytest.mark.parametrize(
3661
    "features",
3662
    [
3663
        None,
3664
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
3665
        {"col_1": "string", "col_2": "string", "col_3": "string"},
3666
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3667
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3668
    ],
3669
)
3670
def test_dataset_from_parquet_features(features, parquet_path, tmp_path):
3671
    cache_dir = tmp_path / "cache"
3672
    default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3673
    expected_features = features.copy() if features else default_expected_features
3674
    features = (
3675
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3676
    )
3677
    dataset = Dataset.from_parquet(parquet_path, features=features, cache_dir=cache_dir)
3678
    _check_parquet_dataset(dataset, expected_features)
3679

3680

3681
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
3682
def test_dataset_from_parquet_split(split, parquet_path, tmp_path):
3683
    cache_dir = tmp_path / "cache"
3684
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3685
    dataset = Dataset.from_parquet(parquet_path, cache_dir=cache_dir, split=split)
3686
    _check_parquet_dataset(dataset, expected_features)
3687
    assert dataset.split == split if split else "train"
3688

3689

3690
@pytest.mark.parametrize("path_type", [str, list])
3691
def test_dataset_from_parquet_path_type(path_type, parquet_path, tmp_path):
3692
    if issubclass(path_type, str):
3693
        path = parquet_path
3694
    elif issubclass(path_type, list):
3695
        path = [parquet_path]
3696
    cache_dir = tmp_path / "cache"
3697
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3698
    dataset = Dataset.from_parquet(path, cache_dir=cache_dir)
3699
    _check_parquet_dataset(dataset, expected_features)
3700

3701

3702
def _check_text_dataset(dataset, expected_features):
3703
    assert isinstance(dataset, Dataset)
3704
    assert dataset.num_rows == 4
3705
    assert dataset.num_columns == 1
3706
    assert dataset.column_names == ["text"]
3707
    for feature, expected_dtype in expected_features.items():
3708
        assert dataset.features[feature].dtype == expected_dtype
3709

3710

3711
@pytest.mark.parametrize("keep_in_memory", [False, True])
3712
def test_dataset_from_text_keep_in_memory(keep_in_memory, text_path, tmp_path):
3713
    cache_dir = tmp_path / "cache"
3714
    expected_features = {"text": "string"}
3715
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3716
        dataset = Dataset.from_text(text_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3717
    _check_text_dataset(dataset, expected_features)
3718

3719

3720
@pytest.mark.parametrize(
3721
    "features",
3722
    [
3723
        None,
3724
        {"text": "string"},
3725
        {"text": "int32"},
3726
        {"text": "float32"},
3727
    ],
3728
)
3729
def test_dataset_from_text_features(features, text_path, tmp_path):
3730
    cache_dir = tmp_path / "cache"
3731
    default_expected_features = {"text": "string"}
3732
    expected_features = features.copy() if features else default_expected_features
3733
    features = (
3734
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3735
    )
3736
    dataset = Dataset.from_text(text_path, features=features, cache_dir=cache_dir)
3737
    _check_text_dataset(dataset, expected_features)
3738

3739

3740
@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
3741
def test_dataset_from_text_split(split, text_path, tmp_path):
3742
    cache_dir = tmp_path / "cache"
3743
    expected_features = {"text": "string"}
3744
    dataset = Dataset.from_text(text_path, cache_dir=cache_dir, split=split)
3745
    _check_text_dataset(dataset, expected_features)
3746
    assert dataset.split == split if split else "train"
3747

3748

3749
@pytest.mark.parametrize("path_type", [str, list])
3750
def test_dataset_from_text_path_type(path_type, text_path, tmp_path):
3751
    if issubclass(path_type, str):
3752
        path = text_path
3753
    elif issubclass(path_type, list):
3754
        path = [text_path]
3755
    cache_dir = tmp_path / "cache"
3756
    expected_features = {"text": "string"}
3757
    dataset = Dataset.from_text(path, cache_dir=cache_dir)
3758
    _check_text_dataset(dataset, expected_features)
3759

3760

3761
@pytest.fixture
3762
def data_generator():
3763
    def _gen():
3764
        data = [
3765
            {"col_1": "0", "col_2": 0, "col_3": 0.0},
3766
            {"col_1": "1", "col_2": 1, "col_3": 1.0},
3767
            {"col_1": "2", "col_2": 2, "col_3": 2.0},
3768
            {"col_1": "3", "col_2": 3, "col_3": 3.0},
3769
        ]
3770
        for item in data:
3771
            yield item
3772

3773
    return _gen
3774

3775

3776
def _check_generator_dataset(dataset, expected_features):
3777
    assert isinstance(dataset, Dataset)
3778
    assert dataset.num_rows == 4
3779
    assert dataset.num_columns == 3
3780
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
3781
    for feature, expected_dtype in expected_features.items():
3782
        assert dataset.features[feature].dtype == expected_dtype
3783

3784

3785
@pytest.mark.parametrize("keep_in_memory", [False, True])
3786
def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, tmp_path):
3787
    cache_dir = tmp_path / "cache"
3788
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3789
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3790
        dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3791
    _check_generator_dataset(dataset, expected_features)
3792

3793

3794
@pytest.mark.parametrize(
3795
    "features",
3796
    [
3797
        None,
3798
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
3799
        {"col_1": "string", "col_2": "string", "col_3": "string"},
3800
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3801
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3802
    ],
3803
)
3804
def test_dataset_from_generator_features(features, data_generator, tmp_path):
3805
    cache_dir = tmp_path / "cache"
3806
    default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3807
    expected_features = features.copy() if features else default_expected_features
3808
    features = (
3809
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3810
    )
3811
    dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
3812
    _check_generator_dataset(dataset, expected_features)
3813

3814

3815
@require_not_windows
3816
@require_dill_gt_0_3_2
3817
@require_pyspark
3818
def test_from_spark():
3819
    import pyspark
3820

3821
    spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
3822
    data = [
3823
        ("0", 0, 0.0),
3824
        ("1", 1, 1.0),
3825
        ("2", 2, 2.0),
3826
        ("3", 3, 3.0),
3827
    ]
3828
    df = spark.createDataFrame(data, "col_1: string, col_2: int, col_3: float")
3829
    dataset = Dataset.from_spark(df)
3830
    assert isinstance(dataset, Dataset)
3831
    assert dataset.num_rows == 4
3832
    assert dataset.num_columns == 3
3833
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
3834

3835

3836
@require_not_windows
3837
@require_dill_gt_0_3_2
3838
@require_pyspark
3839
def test_from_spark_features():
3840
    import PIL.Image
3841
    import pyspark
3842

3843
    spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
3844
    data = [(0, np.arange(4 * 4 * 3).reshape(4, 4, 3).tolist())]
3845
    df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>")
3846
    features = Features({"idx": Value("int64"), "image": Image()})
3847
    dataset = Dataset.from_spark(
3848
        df,
3849
        features=features,
3850
    )
3851
    assert isinstance(dataset, Dataset)
3852
    assert dataset.num_rows == 1
3853
    assert dataset.num_columns == 2
3854
    assert dataset.column_names == ["idx", "image"]
3855
    assert isinstance(dataset[0]["image"], PIL.Image.Image)
3856
    assert dataset.features == features
3857
    assert_arrow_metadata_are_synced_with_dataset_features(dataset)
3858

3859

3860
@require_not_windows
3861
@require_dill_gt_0_3_2
3862
@require_pyspark
3863
def test_from_spark_different_cache():
3864
    import pyspark
3865

3866
    spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
3867
    df = spark.createDataFrame([("0", 0)], "col_1: string, col_2: int")
3868
    dataset = Dataset.from_spark(df)
3869
    assert isinstance(dataset, Dataset)
3870
    different_df = spark.createDataFrame([("1", 1)], "col_1: string, col_2: int")
3871
    different_dataset = Dataset.from_spark(different_df)
3872
    assert isinstance(different_dataset, Dataset)
3873
    assert dataset[0]["col_1"] == "0"
3874
    # Check to make sure that the second dataset wasn't read from the cache.
3875
    assert different_dataset[0]["col_1"] == "1"
3876

3877

3878
def _check_sql_dataset(dataset, expected_features):
3879
    assert isinstance(dataset, Dataset)
3880
    assert dataset.num_rows == 4
3881
    assert dataset.num_columns == 3
3882
    assert dataset.column_names == ["col_1", "col_2", "col_3"]
3883
    for feature, expected_dtype in expected_features.items():
3884
        assert dataset.features[feature].dtype == expected_dtype
3885

3886

3887
@require_sqlalchemy
3888
@pytest.mark.parametrize("con_type", ["string", "engine"])
3889
def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
3890
    cache_dir = tmp_path / "cache"
3891
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3892
    if con_type == "string":
3893
        con = "sqlite:///" + sqlite_path
3894
    elif con_type == "engine":
3895
        import sqlalchemy
3896

3897
        con = sqlalchemy.create_engine("sqlite:///" + sqlite_path)
3898
    # # https://github.com/huggingface/datasets/issues/2832 needs to be fixed first for this to work
3899
    # with caplog.at_level(INFO):
3900
    #     dataset = Dataset.from_sql(
3901
    #         "dataset",
3902
    #         con,
3903
    #         cache_dir=cache_dir,
3904
    #     )
3905
    # if con_type == "string":
3906
    #     assert "couldn't be hashed properly" not in caplog.text
3907
    # elif con_type == "engine":
3908
    #     assert "couldn't be hashed properly" in caplog.text
3909
    dataset = Dataset.from_sql(
3910
        "dataset",
3911
        con,
3912
        cache_dir=cache_dir,
3913
    )
3914
    _check_sql_dataset(dataset, expected_features)
3915

3916

3917
@require_sqlalchemy
3918
@pytest.mark.parametrize(
3919
    "features",
3920
    [
3921
        None,
3922
        {"col_1": "string", "col_2": "int64", "col_3": "float64"},
3923
        {"col_1": "string", "col_2": "string", "col_3": "string"},
3924
        {"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3925
        {"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3926
    ],
3927
)
3928
def test_dataset_from_sql_features(features, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
3929
    cache_dir = tmp_path / "cache"
3930
    default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3931
    expected_features = features.copy() if features else default_expected_features
3932
    features = (
3933
        Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3934
    )
3935
    dataset = Dataset.from_sql("dataset", "sqlite:///" + sqlite_path, features=features, cache_dir=cache_dir)
3936
    _check_sql_dataset(dataset, expected_features)
3937

3938

3939
@require_sqlalchemy
3940
@pytest.mark.parametrize("keep_in_memory", [False, True])
3941
def test_dataset_from_sql_keep_in_memory(keep_in_memory, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
3942
    cache_dir = tmp_path / "cache"
3943
    expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3944
    with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3945
        dataset = Dataset.from_sql(
3946
            "dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory
3947
        )
3948
    _check_sql_dataset(dataset, expected_features)
3949

3950

3951
def test_dataset_to_json(dataset, tmp_path):
3952
    file_path = tmp_path / "test_path.jsonl"
3953
    bytes_written = dataset.to_json(path_or_buf=file_path)
3954
    assert file_path.is_file()
3955
    assert bytes_written == file_path.stat().st_size
3956
    df = pd.read_json(file_path, orient="records", lines=True)
3957
    assert df.shape == dataset.shape
3958
    assert list(df.columns) == list(dataset.column_names)
3959

3960

3961
@pytest.mark.parametrize("in_memory", [False, True])
3962
@pytest.mark.parametrize(
3963
    "method_and_params",
3964
    [
3965
        ("rename_column", (), {"original_column_name": "labels", "new_column_name": "label"}),
3966
        ("remove_columns", (), {"column_names": "labels"}),
3967
        (
3968
            "cast",
3969
            (),
3970
            {
3971
                "features": Features(
3972
                    {
3973
                        "tokens": Sequence(Value("string")),
3974
                        "labels": Sequence(Value("int16")),
3975
                        "answers": Sequence(
3976
                            {
3977
                                "text": Value("string"),
3978
                                "answer_start": Value("int32"),
3979
                            }
3980
                        ),
3981
                        "id": Value("int32"),
3982
                    }
3983
                )
3984
            },
3985
        ),
3986
        ("flatten", (), {}),
3987
    ],
3988
)
3989
def test_pickle_dataset_after_transforming_the_table(in_memory, method_and_params, arrow_file):
3990
    method, args, kwargs = method_and_params
3991
    with Dataset.from_file(arrow_file, in_memory=in_memory) as dataset, Dataset.from_file(
3992
        arrow_file, in_memory=in_memory
3993
    ) as reference_dataset:
3994
        out = getattr(dataset, method)(*args, **kwargs)
3995
        dataset = out if out is not None else dataset
3996
        pickled_dataset = pickle.dumps(dataset)
3997
        reloaded_dataset = pickle.loads(pickled_dataset)
3998

3999
        assert dataset._data != reference_dataset._data
4000
        assert dataset._data.table == reloaded_dataset._data.table
4001

4002

4003
def test_dummy_dataset_serialize_fs(dataset, mockfs):
4004
    dataset_path = "mock://my_dataset"
4005
    dataset.save_to_disk(dataset_path, storage_options=mockfs.storage_options)
4006
    assert mockfs.isdir(dataset_path)
4007
    assert mockfs.glob(dataset_path + "/*")
4008
    reloaded = dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
4009
    assert len(reloaded) == len(dataset)
4010
    assert reloaded.features == dataset.features
4011
    assert reloaded.to_dict() == dataset.to_dict()
4012

4013

4014
@pytest.mark.parametrize(
4015
    "uri_or_path",
4016
    [
4017
        "relative/path",
4018
        "/absolute/path",
4019
        "s3://bucket/relative/path",
4020
        "hdfs://relative/path",
4021
        "hdfs:///absolute/path",
4022
    ],
4023
)
4024
def test_build_local_temp_path(uri_or_path):
4025
    extracted_path = strip_protocol(uri_or_path)
4026
    local_temp_path = Dataset._build_local_temp_path(extracted_path).as_posix()
4027
    extracted_path_without_anchor = Path(extracted_path).relative_to(Path(extracted_path).anchor).as_posix()
4028
    # Check that the local temp path is relative to the system temp dir
4029
    path_relative_to_tmp_dir = Path(local_temp_path).relative_to(Path(tempfile.gettempdir())).as_posix()
4030

4031
    assert (
4032
        "hdfs://" not in path_relative_to_tmp_dir
4033
        and "s3://" not in path_relative_to_tmp_dir
4034
        and not local_temp_path.startswith(extracted_path_without_anchor)
4035
        and local_temp_path.endswith(extracted_path_without_anchor)
4036
    ), f"Local temp path: {local_temp_path}"
4037

4038

4039
class TaskTemplatesTest(TestCase):
4040
    def test_task_text_classification(self):
4041
        labels = sorted(["pos", "neg"])
4042
        features_before_cast = Features(
4043
            {
4044
                "input_text": Value("string"),
4045
                "input_labels": ClassLabel(names=labels),
4046
            }
4047
        )
4048
        # Labels are cast to tuple during `TextClassification.__post_init_`, so we do the same here
4049
        features_after_cast = Features(
4050
            {
4051
                "text": Value("string"),
4052
                "labels": ClassLabel(names=labels),
4053
            }
4054
        )
4055
        # Label names are added in `DatasetInfo.__post_init__` so not needed here
4056
        task_without_labels = TextClassification(text_column="input_text", label_column="input_labels")
4057
        info1 = DatasetInfo(
4058
            features=features_before_cast,
4059
            task_templates=task_without_labels,
4060
        )
4061
        # Label names are required when passing a TextClassification template directly to `Dataset.prepare_for_task`
4062
        # However they also can be used to define `DatasetInfo` so we include a test for this too
4063
        task_with_labels = TextClassification(text_column="input_text", label_column="input_labels")
4064
        info2 = DatasetInfo(
4065
            features=features_before_cast,
4066
            task_templates=task_with_labels,
4067
        )
4068
        data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4069
        # Test we can load from task name when label names not included in template (default behaviour)
4070
        with Dataset.from_dict(data, info=info1) as dset:
4071
            self.assertSetEqual({"input_text", "input_labels"}, set(dset.column_names))
4072
            self.assertDictEqual(features_before_cast, dset.features)
4073
            with dset.prepare_for_task(task="text-classification") as dset:
4074
                self.assertSetEqual({"labels", "text"}, set(dset.column_names))
4075
                self.assertDictEqual(features_after_cast, dset.features)
4076
        # Test we can load from task name when label names included in template
4077
        with Dataset.from_dict(data, info=info2) as dset:
4078
            self.assertSetEqual({"input_text", "input_labels"}, set(dset.column_names))
4079
            self.assertDictEqual(features_before_cast, dset.features)
4080
            with dset.prepare_for_task(task="text-classification") as dset:
4081
                self.assertSetEqual({"labels", "text"}, set(dset.column_names))
4082
                self.assertDictEqual(features_after_cast, dset.features)
4083
        # Test we can load from TextClassification template
4084
        info1.task_templates = None
4085
        with Dataset.from_dict(data, info=info1) as dset:
4086
            with dset.prepare_for_task(task=task_with_labels) as dset:
4087
                self.assertSetEqual({"labels", "text"}, set(dset.column_names))
4088
                self.assertDictEqual(features_after_cast, dset.features)
4089

4090
    def test_task_question_answering(self):
4091
        features_before_cast = Features(
4092
            {
4093
                "input_context": Value("string"),
4094
                "input_question": Value("string"),
4095
                "input_answers": Sequence(
4096
                    {
4097
                        "text": Value("string"),
4098
                        "answer_start": Value("int32"),
4099
                    }
4100
                ),
4101
            }
4102
        )
4103
        features_after_cast = Features(
4104
            {
4105
                "context": Value("string"),
4106
                "question": Value("string"),
4107
                "answers": Sequence(
4108
                    {
4109
                        "text": Value("string"),
4110
                        "answer_start": Value("int32"),
4111
                    }
4112
                ),
4113
            }
4114
        )
4115
        task = QuestionAnsweringExtractive(
4116
            context_column="input_context", question_column="input_question", answers_column="input_answers"
4117
        )
4118
        info = DatasetInfo(features=features_before_cast, task_templates=task)
4119
        data = {
4120
            "input_context": ["huggingface is going to the moon!"],
4121
            "input_question": ["where is huggingface going?"],
4122
            "input_answers": [{"text": ["to the moon!"], "answer_start": [2]}],
4123
        }
4124
        # Test we can load from task name
4125
        with Dataset.from_dict(data, info=info) as dset:
4126
            self.assertSetEqual(
4127
                {"input_context", "input_question", "input_answers.text", "input_answers.answer_start"},
4128
                set(dset.flatten().column_names),
4129
            )
4130
            self.assertDictEqual(features_before_cast, dset.features)
4131
            with dset.prepare_for_task(task="question-answering-extractive") as dset:
4132
                self.assertSetEqual(
4133
                    {"context", "question", "answers.text", "answers.answer_start"},
4134
                    set(dset.flatten().column_names),
4135
                )
4136
                self.assertDictEqual(features_after_cast, dset.features)
4137
        # Test we can load from QuestionAnsweringExtractive template
4138
        info.task_templates = None
4139
        with Dataset.from_dict(data, info=info) as dset:
4140
            with dset.prepare_for_task(task=task) as dset:
4141
                self.assertSetEqual(
4142
                    {"context", "question", "answers.text", "answers.answer_start"},
4143
                    set(dset.flatten().column_names),
4144
                )
4145
                self.assertDictEqual(features_after_cast, dset.features)
4146

4147
    def test_task_summarization(self):
4148
        # Include a dummy extra column `dummy` to test we drop it correctly
4149
        features_before_cast = Features(
4150
            {"input_text": Value("string"), "input_summary": Value("string"), "dummy": Value("string")}
4151
        )
4152
        features_after_cast = Features({"text": Value("string"), "summary": Value("string")})
4153
        task = Summarization(text_column="input_text", summary_column="input_summary")
4154
        info = DatasetInfo(features=features_before_cast, task_templates=task)
4155
        data = {
4156
            "input_text": ["jack and jill took a taxi to attend a super duper party in the city."],
4157
            "input_summary": ["jack and jill attend party"],
4158
            "dummy": ["123456"],
4159
        }
4160
        # Test we can load from task name
4161
        with Dataset.from_dict(data, info=info) as dset:
4162
            with dset.prepare_for_task(task="summarization") as dset:
4163
                self.assertSetEqual(
4164
                    {"text", "summary"},
4165
                    set(dset.column_names),
4166
                )
4167
                self.assertDictEqual(features_after_cast, dset.features)
4168
        # Test we can load from Summarization template
4169
        info.task_templates = None
4170
        with Dataset.from_dict(data, info=info) as dset:
4171
            with dset.prepare_for_task(task=task) as dset:
4172
                self.assertSetEqual(
4173
                    {"text", "summary"},
4174
                    set(dset.column_names),
4175
                )
4176
                self.assertDictEqual(features_after_cast, dset.features)
4177

4178
    def test_task_automatic_speech_recognition(self):
4179
        # Include a dummy extra column `dummy` to test we drop it correctly
4180
        features_before_cast = Features(
4181
            {
4182
                "input_audio": Audio(sampling_rate=16_000),
4183
                "input_transcription": Value("string"),
4184
                "dummy": Value("string"),
4185
            }
4186
        )
4187
        features_after_cast = Features({"audio": Audio(sampling_rate=16_000), "transcription": Value("string")})
4188
        task = AutomaticSpeechRecognition(audio_column="input_audio", transcription_column="input_transcription")
4189
        info = DatasetInfo(features=features_before_cast, task_templates=task)
4190
        data = {
4191
            "input_audio": [{"bytes": None, "path": "path/to/some/audio/file.wav"}],
4192
            "input_transcription": ["hello, my name is bob!"],
4193
            "dummy": ["123456"],
4194
        }
4195
        # Test we can load from task name
4196
        with Dataset.from_dict(data, info=info) as dset:
4197
            with dset.prepare_for_task(task="automatic-speech-recognition") as dset:
4198
                self.assertSetEqual(
4199
                    {"audio", "transcription"},
4200
                    set(dset.column_names),
4201
                )
4202
                self.assertDictEqual(features_after_cast, dset.features)
4203
        # Test we can load from Summarization template
4204
        info.task_templates = None
4205
        with Dataset.from_dict(data, info=info) as dset:
4206
            with dset.prepare_for_task(task=task) as dset:
4207
                self.assertSetEqual(
4208
                    {"audio", "transcription"},
4209
                    set(dset.column_names),
4210
                )
4211
                self.assertDictEqual(features_after_cast, dset.features)
4212

4213
    def test_task_with_no_template(self):
4214
        data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4215
        with Dataset.from_dict(data) as dset:
4216
            with self.assertRaises(ValueError):
4217
                dset.prepare_for_task("text-classification")
4218

4219
    def test_task_with_incompatible_templates(self):
4220
        labels = sorted(["pos", "neg"])
4221
        features = Features(
4222
            {
4223
                "input_text": Value("string"),
4224
                "input_labels": ClassLabel(names=labels),
4225
            }
4226
        )
4227
        task = TextClassification(text_column="input_text", label_column="input_labels")
4228
        info = DatasetInfo(
4229
            features=features,
4230
            task_templates=task,
4231
        )
4232
        data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4233
        with Dataset.from_dict(data, info=info) as dset:
4234
            # Invalid task name
4235
            self.assertRaises(ValueError, dset.prepare_for_task, "this-task-does-not-exist")
4236
            # Invalid task type
4237
            self.assertRaises(ValueError, dset.prepare_for_task, 1)
4238

4239
    def test_task_with_multiple_compatible_task_templates(self):
4240
        features = Features(
4241
            {
4242
                "text1": Value("string"),
4243
                "text2": Value("string"),
4244
            }
4245
        )
4246
        task1 = LanguageModeling(text_column="text1")
4247
        task2 = LanguageModeling(text_column="text2")
4248
        info = DatasetInfo(
4249
            features=features,
4250
            task_templates=[task1, task2],
4251
        )
4252
        data = {"text1": ["i love transformers!"], "text2": ["i love datasets!"]}
4253
        with Dataset.from_dict(data, info=info) as dset:
4254
            self.assertRaises(ValueError, dset.prepare_for_task, "language-modeling", id=3)
4255
            with dset.prepare_for_task("language-modeling") as dset1:
4256
                self.assertEqual(dset1[0]["text"], "i love transformers!")
4257
            with dset.prepare_for_task("language-modeling", id=1) as dset2:
4258
                self.assertEqual(dset2[0]["text"], "i love datasets!")
4259

4260
    def test_task_templates_empty_after_preparation(self):
4261
        features = Features(
4262
            {
4263
                "input_text": Value("string"),
4264
                "input_labels": ClassLabel(names=["pos", "neg"]),
4265
            }
4266
        )
4267
        task = TextClassification(text_column="input_text", label_column="input_labels")
4268
        info = DatasetInfo(
4269
            features=features,
4270
            task_templates=task,
4271
        )
4272
        data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4273
        with Dataset.from_dict(data, info=info) as dset:
4274
            with dset.prepare_for_task(task="text-classification") as dset:
4275
                self.assertIsNone(dset.info.task_templates)
4276

4277
    def test_align_labels_with_mapping_classification(self):
4278
        features = Features(
4279
            {
4280
                "input_text": Value("string"),
4281
                "input_labels": ClassLabel(num_classes=3, names=["entailment", "neutral", "contradiction"]),
4282
            }
4283
        )
4284
        data = {"input_text": ["a", "a", "b", "b", "c", "c"], "input_labels": [0, 0, 1, 1, 2, 2]}
4285
        label2id = {"CONTRADICTION": 0, "ENTAILMENT": 2, "NEUTRAL": 1}
4286
        id2label = {v: k for k, v in label2id.items()}
4287
        expected_labels = [2, 2, 1, 1, 0, 0]
4288
        expected_label_names = [id2label[idx] for idx in expected_labels]
4289
        with Dataset.from_dict(data, features=features) as dset:
4290
            with dset.align_labels_with_mapping(label2id, "input_labels") as dset:
4291
                self.assertListEqual(expected_labels, dset["input_labels"])
4292
                aligned_label_names = [dset.features["input_labels"].int2str(idx) for idx in dset["input_labels"]]
4293
                self.assertListEqual(expected_label_names, aligned_label_names)
4294

4295
    def test_align_labels_with_mapping_ner(self):
4296
        features = Features(
4297
            {
4298
                "input_text": Value("string"),
4299
                "input_labels": Sequence(
4300
                    ClassLabel(
4301
                        names=[
4302
                            "b-per",
4303
                            "i-per",
4304
                            "o",
4305
                        ]
4306
                    )
4307
                ),
4308
            }
4309
        )
4310
        data = {"input_text": [["Optimus", "Prime", "is", "a", "Transformer"]], "input_labels": [[0, 1, 2, 2, 2]]}
4311
        label2id = {"B-PER": 2, "I-PER": 1, "O": 0}
4312
        id2label = {v: k for k, v in label2id.items()}
4313
        expected_labels = [[2, 1, 0, 0, 0]]
4314
        expected_label_names = [[id2label[idx] for idx in seq] for seq in expected_labels]
4315
        with Dataset.from_dict(data, features=features) as dset:
4316
            with dset.align_labels_with_mapping(label2id, "input_labels") as dset:
4317
                self.assertListEqual(expected_labels, dset["input_labels"])
4318
                aligned_label_names = [
4319
                    dset.features["input_labels"].feature.int2str(idx) for idx in dset["input_labels"]
4320
                ]
4321
                self.assertListEqual(expected_label_names, aligned_label_names)
4322

4323
    def test_concatenate_with_no_task_templates(self):
4324
        info = DatasetInfo(task_templates=None)
4325
        data = {"text": ["i love transformers!"], "labels": [1]}
4326
        with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict(
4327
            data, info=info
4328
        ) as dset2, Dataset.from_dict(data, info=info) as dset3:
4329
            with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
4330
                self.assertEqual(dset_concat.info.task_templates, None)
4331

4332
    def test_concatenate_with_equal_task_templates(self):
4333
        labels = ["neg", "pos"]
4334
        task_template = TextClassification(text_column="text", label_column="labels")
4335
        info = DatasetInfo(
4336
            features=Features({"text": Value("string"), "labels": ClassLabel(names=labels)}),
4337
            # Label names are added in `DatasetInfo.__post_init__` so not included here
4338
            task_templates=TextClassification(text_column="text", label_column="labels"),
4339
        )
4340
        data = {"text": ["i love transformers!"], "labels": [1]}
4341
        with Dataset.from_dict(data, info=info) as dset1, Dataset.from_dict(
4342
            data, info=info
4343
        ) as dset2, Dataset.from_dict(data, info=info) as dset3:
4344
            with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
4345
                self.assertListEqual(dset_concat.info.task_templates, [task_template])
4346

4347
    def test_concatenate_with_mixed_task_templates_in_common(self):
4348
        tc_template = TextClassification(text_column="text", label_column="labels")
4349
        qa_template = QuestionAnsweringExtractive(
4350
            question_column="question", context_column="context", answers_column="answers"
4351
        )
4352
        info1 = DatasetInfo(
4353
            task_templates=[qa_template],
4354
            features=Features(
4355
                {
4356
                    "text": Value("string"),
4357
                    "labels": ClassLabel(names=["pos", "neg"]),
4358
                    "context": Value("string"),
4359
                    "question": Value("string"),
4360
                    "answers": Sequence(
4361
                        {
4362
                            "text": Value("string"),
4363
                            "answer_start": Value("int32"),
4364
                        }
4365
                    ),
4366
                }
4367
            ),
4368
        )
4369
        info2 = DatasetInfo(
4370
            task_templates=[qa_template, tc_template],
4371
            features=Features(
4372
                {
4373
                    "text": Value("string"),
4374
                    "labels": ClassLabel(names=["pos", "neg"]),
4375
                    "context": Value("string"),
4376
                    "question": Value("string"),
4377
                    "answers": Sequence(
4378
                        {
4379
                            "text": Value("string"),
4380
                            "answer_start": Value("int32"),
4381
                        }
4382
                    ),
4383
                }
4384
            ),
4385
        )
4386
        data = {
4387
            "text": ["i love transformers!"],
4388
            "labels": [1],
4389
            "context": ["huggingface is going to the moon!"],
4390
            "question": ["where is huggingface going?"],
4391
            "answers": [{"text": ["to the moon!"], "answer_start": [2]}],
4392
        }
4393
        with Dataset.from_dict(data, info=info1) as dset1, Dataset.from_dict(
4394
            data, info=info2
4395
        ) as dset2, Dataset.from_dict(data, info=info2) as dset3:
4396
            with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
4397
                self.assertListEqual(dset_concat.info.task_templates, [qa_template])
4398

4399
    def test_concatenate_with_no_mixed_task_templates_in_common(self):
4400
        tc_template1 = TextClassification(text_column="text", label_column="labels")
4401
        tc_template2 = TextClassification(text_column="text", label_column="sentiment")
4402
        qa_template = QuestionAnsweringExtractive(
4403
            question_column="question", context_column="context", answers_column="answers"
4404
        )
4405
        info1 = DatasetInfo(
4406
            features=Features(
4407
                {
4408
                    "text": Value("string"),
4409
                    "labels": ClassLabel(names=["pos", "neg"]),
4410
                    "sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
4411
                    "context": Value("string"),
4412
                    "question": Value("string"),
4413
                    "answers": Sequence(
4414
                        {
4415
                            "text": Value("string"),
4416
                            "answer_start": Value("int32"),
4417
                        }
4418
                    ),
4419
                }
4420
            ),
4421
            task_templates=[tc_template1],
4422
        )
4423
        info2 = DatasetInfo(
4424
            features=Features(
4425
                {
4426
                    "text": Value("string"),
4427
                    "labels": ClassLabel(names=["pos", "neg"]),
4428
                    "sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
4429
                    "context": Value("string"),
4430
                    "question": Value("string"),
4431
                    "answers": Sequence(
4432
                        {
4433
                            "text": Value("string"),
4434
                            "answer_start": Value("int32"),
4435
                        }
4436
                    ),
4437
                }
4438
            ),
4439
            task_templates=[tc_template2],
4440
        )
4441
        info3 = DatasetInfo(
4442
            features=Features(
4443
                {
4444
                    "text": Value("string"),
4445
                    "labels": ClassLabel(names=["pos", "neg"]),
4446
                    "sentiment": ClassLabel(names=["pos", "neg", "neutral"]),
4447
                    "context": Value("string"),
4448
                    "question": Value("string"),
4449
                    "answers": Sequence(
4450
                        {
4451
                            "text": Value("string"),
4452
                            "answer_start": Value("int32"),
4453
                        }
4454
                    ),
4455
                }
4456
            ),
4457
            task_templates=[qa_template],
4458
        )
4459
        data = {
4460
            "text": ["i love transformers!"],
4461
            "labels": [1],
4462
            "sentiment": [0],
4463
            "context": ["huggingface is going to the moon!"],
4464
            "question": ["where is huggingface going?"],
4465
            "answers": [{"text": ["to the moon!"], "answer_start": [2]}],
4466
        }
4467
        with Dataset.from_dict(data, info=info1) as dset1, Dataset.from_dict(
4468
            data, info=info2
4469
        ) as dset2, Dataset.from_dict(data, info=info3) as dset3:
4470
            with concatenate_datasets([dset1, dset2, dset3]) as dset_concat:
4471
                self.assertEqual(dset_concat.info.task_templates, None)
4472

4473
    def test_task_text_classification_when_columns_removed(self):
4474
        labels = sorted(["pos", "neg"])
4475
        features_before_map = Features(
4476
            {
4477
                "input_text": Value("string"),
4478
                "input_labels": ClassLabel(names=labels),
4479
            }
4480
        )
4481
        features_after_map = Features({"new_column": Value("int64")})
4482
        # Label names are added in `DatasetInfo.__post_init__` so not needed here
4483
        task = TextClassification(text_column="input_text", label_column="input_labels")
4484
        info = DatasetInfo(
4485
            features=features_before_map,
4486
            task_templates=task,
4487
        )
4488
        data = {"input_text": ["i love transformers!"], "input_labels": [1]}
4489
        with Dataset.from_dict(data, info=info) as dset:
4490
            with dset.map(lambda x: {"new_column": 0}, remove_columns=dset.column_names) as dset:
4491
                self.assertDictEqual(dset.features, features_after_map)
4492

4493

4494
class StratifiedTest(TestCase):
4495
    def test_errors_train_test_split_stratify(self):
4496
        ys = [
4497
            np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2]),
4498
            np.array([0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
4499
            np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
4500
            np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]),
4501
            np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]),
4502
        ]
4503
        for i in range(len(ys)):
4504
            features = Features({"text": Value("int64"), "label": ClassLabel(len(np.unique(ys[i])))})
4505
            data = {"text": np.ones(len(ys[i])), "label": ys[i]}
4506
            d1 = Dataset.from_dict(data, features=features)
4507

4508
            # For checking stratify_by_column exist as key in self.features.keys()
4509
            if i == 0:
4510
                self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="labl")
4511

4512
            # For checking minimum class count error
4513
            elif i == 1:
4514
                self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="label")
4515

4516
            # For check typeof label as ClassLabel type
4517
            elif i == 2:
4518
                d1 = Dataset.from_dict(data)
4519
                self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="label")
4520

4521
            # For checking test_size should be greater than or equal to number of classes
4522
            elif i == 3:
4523
                self.assertRaises(ValueError, d1.train_test_split, 0.30, stratify_by_column="label")
4524

4525
            # For checking train_size should be greater than or equal to number of classes
4526
            elif i == 4:
4527
                self.assertRaises(ValueError, d1.train_test_split, 0.60, stratify_by_column="label")
4528

4529
    def test_train_test_split_startify(self):
4530
        ys = [
4531
            np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2]),
4532
            np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
4533
            np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
4534
            np.array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]),
4535
            np.array([0] * 800 + [1] * 50),
4536
        ]
4537
        for y in ys:
4538
            features = Features({"text": Value("int64"), "label": ClassLabel(len(np.unique(y)))})
4539
            data = {"text": np.ones(len(y)), "label": y}
4540
            d1 = Dataset.from_dict(data, features=features)
4541
            d1 = d1.train_test_split(test_size=0.33, stratify_by_column="label")
4542
            y = np.asanyarray(y)  # To make it indexable for y[train]
4543
            test_size = np.ceil(0.33 * len(y))
4544
            train_size = len(y) - test_size
4545
            npt.assert_array_equal(np.unique(d1["train"]["label"]), np.unique(d1["test"]["label"]))
4546

4547
            # checking classes proportion
4548
            p_train = np.bincount(np.unique(d1["train"]["label"], return_inverse=True)[1]) / float(
4549
                len(d1["train"]["label"])
4550
            )
4551
            p_test = np.bincount(np.unique(d1["test"]["label"], return_inverse=True)[1]) / float(
4552
                len(d1["test"]["label"])
4553
            )
4554
            npt.assert_array_almost_equal(p_train, p_test, 1)
4555
            assert len(d1["train"]["text"]) + len(d1["test"]["text"]) == y.size
4556
            assert len(d1["train"]["text"]) == train_size
4557
            assert len(d1["test"]["text"]) == test_size
4558

4559

4560
def test_dataset_estimate_nbytes():
4561
    ds = Dataset.from_dict({"a": ["0" * 100] * 100})
4562
    assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than full dataset size"
4563

4564
    ds = Dataset.from_dict({"a": ["0" * 100] * 100}).select([0])
4565
    assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than one chunk"
4566

4567
    ds = Dataset.from_dict({"a": ["0" * 100] * 100})
4568
    ds = concatenate_datasets([ds] * 100)
4569
    assert 0.9 * ds._estimate_nbytes() < 100 * 100 * 100, "must be smaller than full dataset size"
4570
    assert 1.1 * ds._estimate_nbytes() > 100 * 100 * 100, "must be bigger than full dataset size"
4571

4572
    ds = Dataset.from_dict({"a": ["0" * 100] * 100})
4573
    ds = concatenate_datasets([ds] * 100).select([0])
4574
    assert 0.9 * ds._estimate_nbytes() < 100 * 100, "must be smaller than one chunk"
4575

4576

4577
def test_dataset_to_iterable_dataset(dataset: Dataset):
4578
    iterable_dataset = dataset.to_iterable_dataset()
4579
    assert isinstance(iterable_dataset, IterableDataset)
4580
    assert list(iterable_dataset) == list(dataset)
4581
    assert iterable_dataset.features == dataset.features
4582
    iterable_dataset = dataset.to_iterable_dataset(num_shards=3)
4583
    assert isinstance(iterable_dataset, IterableDataset)
4584
    assert list(iterable_dataset) == list(dataset)
4585
    assert iterable_dataset.features == dataset.features
4586
    assert iterable_dataset.n_shards == 3
4587
    with pytest.raises(ValueError):
4588
        dataset.to_iterable_dataset(num_shards=len(dataset) + 1)
4589
    with pytest.raises(NotImplementedError):
4590
        dataset.with_format("torch").to_iterable_dataset()
4591

4592

4593
@require_pil
4594
def test_dataset_format_with_unformatted_image():
4595
    import PIL
4596

4597
    ds = Dataset.from_dict(
4598
        {"a": [np.arange(4 * 4 * 3).reshape(4, 4, 3)] * 10, "b": [[0, 1]] * 10},
4599
        Features({"a": Image(), "b": Sequence(Value("int64"))}),
4600
    )
4601
    ds.set_format("np", columns=["b"], output_all_columns=True)
4602
    assert isinstance(ds[0]["a"], PIL.Image.Image)
4603
    assert isinstance(ds[0]["b"], np.ndarray)
4604

4605

4606
@pytest.mark.parametrize("batch_size", [1, 4])
4607
@require_torch
4608
def test_dataset_with_torch_dataloader(dataset, batch_size):
4609
    from torch.utils.data import DataLoader
4610

4611
    from datasets import config
4612

4613
    dataloader = DataLoader(dataset, batch_size=batch_size)
4614
    with patch.object(dataset, "_getitem", wraps=dataset._getitem) as mock_getitem:
4615
        out = list(dataloader)
4616
        getitem_call_count = mock_getitem.call_count
4617
    assert len(out) == len(dataset) // batch_size + int(len(dataset) % batch_size > 0)
4618
    # calling dataset[list_of_indices] is much more efficient than [dataset[idx] for idx in list of indices]
4619
    if config.TORCH_VERSION >= version.parse("1.13.0"):
4620
        assert getitem_call_count == len(dataset) // batch_size + int(len(dataset) % batch_size > 0)
4621

4622

4623
@pytest.mark.parametrize("return_lazy_dict", [True, False, "mix"])
4624
def test_map_cases(return_lazy_dict):
4625
    def f(x):
4626
        """May return a mix of LazyDict and regular Dict"""
4627
        if x["a"] < 2:
4628
            x["a"] = -1
4629
            return dict(x) if return_lazy_dict is False else x
4630
        else:
4631
            return x if return_lazy_dict is True else {}
4632

4633
    ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4634
    ds = ds.map(f)
4635
    outputs = ds[:]
4636
    assert outputs == {"a": [-1, -1, 2, 3]}
4637

4638
    def f(x):
4639
        """May return a mix of LazyDict and regular Dict, but sometimes with None values"""
4640
        if x["a"] < 2:
4641
            x["a"] = None
4642
            return dict(x) if return_lazy_dict is False else x
4643
        else:
4644
            return x if return_lazy_dict is True else {}
4645

4646
    ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4647
    ds = ds.map(f)
4648
    outputs = ds[:]
4649
    assert outputs == {"a": [None, None, 2, 3]}
4650

4651
    def f(x):
4652
        """Return a LazyDict, but we remove a lazy column and add a new one"""
4653
        if x["a"] < 2:
4654
            x["b"] = -1
4655
            return x
4656
        else:
4657
            x["b"] = x["a"]
4658
            return x
4659

4660
    ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4661
    ds = ds.map(f, remove_columns=["a"])
4662
    outputs = ds[:]
4663
    assert outputs == {"b": [-1, -1, 2, 3]}
4664

4665
    # The formatted dataset version removes the lazy column from a different dictionary, hence it should be preserved in the output
4666
    ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4667
    ds = ds.with_format("numpy")
4668
    ds = ds.map(f, remove_columns=["a"])
4669
    ds = ds.with_format(None)
4670
    outputs = ds[:]
4671
    assert outputs == {"a": [0, 1, 2, 3], "b": [-1, -1, 2, 3]}
4672

4673
    def f(x):
4674
        """May return a mix of LazyDict and regular Dict, but we replace a lazy column"""
4675
        if x["a"] < 2:
4676
            x["a"] = -1
4677
            return dict(x) if return_lazy_dict is False else x
4678
        else:
4679
            x["a"] = x["a"]
4680
            return x if return_lazy_dict is True else {"a": x["a"]}
4681

4682
    ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4683
    ds = ds.map(f, remove_columns=["a"])
4684
    outputs = ds[:]
4685
    assert outputs == ({"a": [-1, -1, 2, 3]} if return_lazy_dict is False else {})
4686

4687
    def f(x):
4688
        """May return a mix of LazyDict and regular Dict, but we modify a nested lazy column in-place"""
4689
        if x["a"]["b"] < 2:
4690
            x["a"]["c"] = -1
4691
            return dict(x) if return_lazy_dict is False else x
4692
        else:
4693
            x["a"]["c"] = x["a"]["b"]
4694
            return x if return_lazy_dict is True else {}
4695

4696
    ds = Dataset.from_dict({"a": [{"b": 0}, {"b": 1}, {"b": 2}, {"b": 3}]})
4697
    ds = ds.map(f)
4698
    outputs = ds[:]
4699
    assert outputs == {"a": [{"b": 0, "c": -1}, {"b": 1, "c": -1}, {"b": 2, "c": 2}, {"b": 3, "c": 3}]}
4700

4701
    def f(x):
4702
        """May return a mix of LazyDict and regular Dict, but using an extension type"""
4703
        if x["a"][0][0] < 2:
4704
            x["a"] = [[-1]]
4705
            return dict(x) if return_lazy_dict is False else x
4706
        else:
4707
            return x if return_lazy_dict is True else {}
4708

4709
    features = Features({"a": Array2D(shape=(1, 1), dtype="int32")})
4710
    ds = Dataset.from_dict({"a": [[[i]] for i in [0, 1, 2, 3]]}, features=features)
4711
    ds = ds.map(f)
4712
    outputs = ds[:]
4713
    assert outputs == {"a": [[[i]] for i in [-1, -1, 2, 3]]}
4714

4715
    def f(x):
4716
        """May return a mix of LazyDict and regular Dict, but using a nested extension type"""
4717
        if x["a"]["nested"][0][0] < 2:
4718
            x["a"] = {"nested": [[-1]]}
4719
            return dict(x) if return_lazy_dict is False else x
4720
        else:
4721
            return x if return_lazy_dict is True else {}
4722

4723
    features = Features({"a": {"nested": Array2D(shape=(1, 1), dtype="int64")}})
4724
    ds = Dataset.from_dict({"a": [{"nested": [[i]]} for i in [0, 1, 2, 3]]}, features=features)
4725
    ds = ds.map(f)
4726
    outputs = ds[:]
4727
    assert outputs == {"a": [{"nested": [[i]]} for i in [-1, -1, 2, 3]]}
4728

4729

4730
def test_dataset_getitem_raises():
4731
    ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4732
    with pytest.raises(TypeError):
4733
        ds[False]
4734
    with pytest.raises(TypeError):
4735
        ds._getitem(True)
4736

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.