3
from typing import List, Union
9
from datasets import Sequence, Value
10
from datasets.features.features import Array2D, Array2DExtensionType, ClassLabel, Features, Image, get_nested_type
11
from datasets.table import (
17
_in_memory_arrow_table_from_buffer,
18
_in_memory_arrow_table_from_file,
19
_interpolation_search,
20
_memory_mapped_arrow_table_from_file,
21
cast_array_to_feature,
25
inject_arrow_table_documentation,
30
from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, slow
33
@pytest.fixture(scope="session")
34
def in_memory_pa_table(arrow_file) -> pa.Table:
35
return pa.ipc.open_stream(arrow_file).read_all()
38
def _to_testing_blocks(table: TableBlock) -> List[List[TableBlock]]:
42
[table.slice(2).drop([c for c in table.column_names if c != "tokens"]), table.slice(2).drop(["tokens"])],
47
@pytest.fixture(scope="session")
48
def in_memory_blocks(in_memory_pa_table):
49
table = InMemoryTable(in_memory_pa_table)
50
return _to_testing_blocks(table)
53
@pytest.fixture(scope="session")
54
def memory_mapped_blocks(arrow_file):
55
table = MemoryMappedTable.from_file(arrow_file)
56
return _to_testing_blocks(table)
59
@pytest.fixture(scope="session")
60
def mixed_in_memory_and_memory_mapped_blocks(in_memory_blocks, memory_mapped_blocks):
61
return in_memory_blocks[:1] + memory_mapped_blocks[1:]
64
def assert_deepcopy_without_bringing_data_in_memory(table: MemoryMappedTable):
65
with assert_arrow_memory_doesnt_increase():
66
copied_table = copy.deepcopy(table)
67
assert isinstance(copied_table, MemoryMappedTable)
68
assert copied_table.table == table.table
71
def assert_deepcopy_does_bring_data_in_memory(table: MemoryMappedTable):
72
with assert_arrow_memory_increases():
73
copied_table = copy.deepcopy(table)
74
assert isinstance(copied_table, MemoryMappedTable)
75
assert copied_table.table == table.table
78
def assert_pickle_without_bringing_data_in_memory(table: MemoryMappedTable):
79
with assert_arrow_memory_doesnt_increase():
80
pickled_table = pickle.dumps(table)
81
unpickled_table = pickle.loads(pickled_table)
82
assert isinstance(unpickled_table, MemoryMappedTable)
83
assert unpickled_table.table == table.table
86
def assert_pickle_does_bring_data_in_memory(table: MemoryMappedTable):
87
with assert_arrow_memory_increases():
88
pickled_table = pickle.dumps(table)
89
unpickled_table = pickle.loads(pickled_table)
90
assert isinstance(unpickled_table, MemoryMappedTable)
91
assert unpickled_table.table == table.table
94
def assert_index_attributes_equal(table: Table, other: Table):
95
assert table._batches == other._batches
96
np.testing.assert_array_equal(table._offsets, other._offsets)
97
assert table._schema == other._schema
100
def add_suffix_to_column_names(table, suffix):
101
return table.rename_columns([f"{name}{suffix}" for name in table.column_names])
104
def test_inject_arrow_table_documentation(in_memory_pa_table):
105
method = pa.Table.slice
107
def function_to_wrap(*args):
111
wrapped_method = inject_arrow_table_documentation(method)(function_to_wrap)
112
assert method(in_memory_pa_table, *args) == wrapped_method(in_memory_pa_table, *args)
113
assert "pyarrow.Table" not in wrapped_method.__doc__
114
assert "Table" in wrapped_method.__doc__
117
def test_in_memory_arrow_table_from_file(arrow_file, in_memory_pa_table):
118
with assert_arrow_memory_increases():
119
pa_table = _in_memory_arrow_table_from_file(arrow_file)
120
assert in_memory_pa_table == pa_table
123
def test_in_memory_arrow_table_from_buffer(in_memory_pa_table):
124
with assert_arrow_memory_increases():
125
buf_writer = pa.BufferOutputStream()
126
writer = pa.RecordBatchStreamWriter(buf_writer, schema=in_memory_pa_table.schema)
127
writer.write_table(in_memory_pa_table)
130
pa_table = _in_memory_arrow_table_from_buffer(buf_writer.getvalue())
131
assert in_memory_pa_table == pa_table
134
def test_memory_mapped_arrow_table_from_file(arrow_file, in_memory_pa_table):
135
with assert_arrow_memory_doesnt_increase():
136
pa_table = _memory_mapped_arrow_table_from_file(arrow_file)
137
assert in_memory_pa_table == pa_table
140
def test_table_init(in_memory_pa_table):
141
table = Table(in_memory_pa_table)
142
assert table.table == in_memory_pa_table
145
def test_table_validate(in_memory_pa_table):
146
table = Table(in_memory_pa_table)
147
assert table.validate() == in_memory_pa_table.validate()
150
def test_table_equals(in_memory_pa_table):
151
table = Table(in_memory_pa_table)
152
assert table.equals(in_memory_pa_table)
155
def test_table_to_batches(in_memory_pa_table):
156
table = Table(in_memory_pa_table)
157
assert table.to_batches() == in_memory_pa_table.to_batches()
160
def test_table_to_pydict(in_memory_pa_table):
161
table = Table(in_memory_pa_table)
162
assert table.to_pydict() == in_memory_pa_table.to_pydict()
165
def test_table_to_string(in_memory_pa_table):
166
table = Table(in_memory_pa_table)
167
assert table.to_string() == in_memory_pa_table.to_string()
170
def test_table_field(in_memory_pa_table):
171
assert "tokens" in in_memory_pa_table.column_names
172
table = Table(in_memory_pa_table)
173
assert table.field("tokens") == in_memory_pa_table.field("tokens")
176
def test_table_column(in_memory_pa_table):
177
assert "tokens" in in_memory_pa_table.column_names
178
table = Table(in_memory_pa_table)
179
assert table.column("tokens") == in_memory_pa_table.column("tokens")
182
def test_table_itercolumns(in_memory_pa_table):
183
table = Table(in_memory_pa_table)
184
assert isinstance(table.itercolumns(), type(in_memory_pa_table.itercolumns()))
185
assert list(table.itercolumns()) == list(in_memory_pa_table.itercolumns())
188
def test_table_getitem(in_memory_pa_table):
189
table = Table(in_memory_pa_table)
190
assert table[0] == in_memory_pa_table[0]
193
def test_table_len(in_memory_pa_table):
194
table = Table(in_memory_pa_table)
195
assert len(table) == len(in_memory_pa_table)
198
def test_table_str(in_memory_pa_table):
199
table = Table(in_memory_pa_table)
200
assert str(table) == str(in_memory_pa_table).replace("pyarrow.Table", "Table")
201
assert repr(table) == repr(in_memory_pa_table).replace("pyarrow.Table", "Table")
204
@pytest.mark.parametrize(
205
"attribute", ["schema", "columns", "num_columns", "num_rows", "shape", "nbytes", "column_names"]
207
def test_table_attributes(in_memory_pa_table, attribute):
208
table = Table(in_memory_pa_table)
209
assert getattr(table, attribute) == getattr(in_memory_pa_table, attribute)
212
def test_in_memory_table_from_file(arrow_file, in_memory_pa_table):
213
with assert_arrow_memory_increases():
214
table = InMemoryTable.from_file(arrow_file)
215
assert table.table == in_memory_pa_table
216
assert isinstance(table, InMemoryTable)
219
def test_in_memory_table_from_buffer(in_memory_pa_table):
220
with assert_arrow_memory_increases():
221
buf_writer = pa.BufferOutputStream()
222
writer = pa.RecordBatchStreamWriter(buf_writer, schema=in_memory_pa_table.schema)
223
writer.write_table(in_memory_pa_table)
226
table = InMemoryTable.from_buffer(buf_writer.getvalue())
227
assert table.table == in_memory_pa_table
228
assert isinstance(table, InMemoryTable)
231
def test_in_memory_table_from_pandas(in_memory_pa_table):
232
df = in_memory_pa_table.to_pandas()
233
with assert_arrow_memory_increases():
234
# with no schema it might infer another order of the fields in the schema
235
table = InMemoryTable.from_pandas(df)
236
assert isinstance(table, InMemoryTable)
237
# by specifying schema we get the same order of features, and so the exact same table
238
table = InMemoryTable.from_pandas(df, schema=in_memory_pa_table.schema)
239
assert table.table == in_memory_pa_table
240
assert isinstance(table, InMemoryTable)
243
def test_in_memory_table_from_arrays(in_memory_pa_table):
244
arrays = list(in_memory_pa_table.columns)
245
names = list(in_memory_pa_table.column_names)
246
table = InMemoryTable.from_arrays(arrays, names=names)
247
assert table.table == in_memory_pa_table
248
assert isinstance(table, InMemoryTable)
251
def test_in_memory_table_from_pydict(in_memory_pa_table):
252
pydict = in_memory_pa_table.to_pydict()
253
with assert_arrow_memory_increases():
254
table = InMemoryTable.from_pydict(pydict)
255
assert isinstance(table, InMemoryTable)
256
assert table.table == pa.Table.from_pydict(pydict)
259
def test_in_memory_table_from_pylist(in_memory_pa_table):
260
pylist = InMemoryTable(in_memory_pa_table).to_pylist()
261
table = InMemoryTable.from_pylist(pylist)
262
assert isinstance(table, InMemoryTable)
263
assert pylist == table.to_pylist()
266
def test_in_memory_table_from_batches(in_memory_pa_table):
267
batches = list(in_memory_pa_table.to_batches())
268
table = InMemoryTable.from_batches(batches)
269
assert table.table == in_memory_pa_table
270
assert isinstance(table, InMemoryTable)
273
def test_in_memory_table_deepcopy(in_memory_pa_table):
274
table = InMemoryTable(in_memory_pa_table)
275
copied_table = copy.deepcopy(table)
276
assert table.table == copied_table.table
277
assert_index_attributes_equal(table, copied_table)
278
# deepcopy must return the exact same arrow objects since they are immutable
279
assert table.table is copied_table.table
280
assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))
283
def test_in_memory_table_pickle(in_memory_pa_table):
284
table = InMemoryTable(in_memory_pa_table)
285
pickled_table = pickle.dumps(table)
286
unpickled_table = pickle.loads(pickled_table)
287
assert unpickled_table.table == table.table
288
assert_index_attributes_equal(table, unpickled_table)
292
def test_in_memory_table_pickle_big_table():
293
big_table_4GB = InMemoryTable.from_pydict({"col": [0] * ((4 * 8 << 30) // 64)})
294
length = len(big_table_4GB)
295
big_table_4GB = pickle.dumps(big_table_4GB)
296
big_table_4GB = pickle.loads(big_table_4GB)
297
assert len(big_table_4GB) == length
300
def test_in_memory_table_slice(in_memory_pa_table):
301
table = InMemoryTable(in_memory_pa_table).slice(1, 2)
302
assert table.table == in_memory_pa_table.slice(1, 2)
303
assert isinstance(table, InMemoryTable)
306
def test_in_memory_table_filter(in_memory_pa_table):
307
mask = pa.array([i % 2 == 0 for i in range(len(in_memory_pa_table))])
308
table = InMemoryTable(in_memory_pa_table).filter(mask)
309
assert table.table == in_memory_pa_table.filter(mask)
310
assert isinstance(table, InMemoryTable)
313
def test_in_memory_table_flatten(in_memory_pa_table):
314
table = InMemoryTable(in_memory_pa_table).flatten()
315
assert table.table == in_memory_pa_table.flatten()
316
assert isinstance(table, InMemoryTable)
319
def test_in_memory_table_combine_chunks(in_memory_pa_table):
320
table = InMemoryTable(in_memory_pa_table).combine_chunks()
321
assert table.table == in_memory_pa_table.combine_chunks()
322
assert isinstance(table, InMemoryTable)
325
def test_in_memory_table_cast(in_memory_pa_table):
326
assert pa.list_(pa.int64()) in in_memory_pa_table.schema.types
329
k: v if v != pa.list_(pa.int64()) else pa.list_(pa.int32())
330
for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
333
table = InMemoryTable(in_memory_pa_table).cast(schema)
334
assert table.table == in_memory_pa_table.cast(schema)
335
assert isinstance(table, InMemoryTable)
338
def test_in_memory_table_cast_reorder_struct():
339
table = InMemoryTable(
340
pa.Table.from_pydict(
351
schema = pa.schema({"top": pa.struct({"bar": pa.string(), "foo": pa.string()})})
352
assert table.cast(schema).schema == schema
355
def test_in_memory_table_cast_with_hf_features():
356
table = InMemoryTable(pa.Table.from_pydict({"labels": [0, 1]}))
357
features = Features({"labels": ClassLabel(names=["neg", "pos"])})
358
schema = features.arrow_schema
359
assert table.cast(schema).schema == schema
360
assert Features.from_arrow_schema(table.cast(schema).schema) == features
363
def test_in_memory_table_replace_schema_metadata(in_memory_pa_table):
364
metadata = {"huggingface": "{}"}
365
table = InMemoryTable(in_memory_pa_table).replace_schema_metadata(metadata)
366
assert table.table.schema.metadata == in_memory_pa_table.replace_schema_metadata(metadata).schema.metadata
367
assert isinstance(table, InMemoryTable)
370
def test_in_memory_table_add_column(in_memory_pa_table):
371
i = len(in_memory_pa_table.column_names)
373
column = pa.array(list(range(len(in_memory_pa_table))))
374
table = InMemoryTable(in_memory_pa_table).add_column(i, field_, column)
375
assert table.table == in_memory_pa_table.add_column(i, field_, column)
376
assert isinstance(table, InMemoryTable)
379
def test_in_memory_table_append_column(in_memory_pa_table):
381
column = pa.array(list(range(len(in_memory_pa_table))))
382
table = InMemoryTable(in_memory_pa_table).append_column(field_, column)
383
assert table.table == in_memory_pa_table.append_column(field_, column)
384
assert isinstance(table, InMemoryTable)
387
def test_in_memory_table_remove_column(in_memory_pa_table):
388
table = InMemoryTable(in_memory_pa_table).remove_column(0)
389
assert table.table == in_memory_pa_table.remove_column(0)
390
assert isinstance(table, InMemoryTable)
393
def test_in_memory_table_set_column(in_memory_pa_table):
394
i = len(in_memory_pa_table.column_names)
396
column = pa.array(list(range(len(in_memory_pa_table))))
397
table = InMemoryTable(in_memory_pa_table).set_column(i, field_, column)
398
assert table.table == in_memory_pa_table.set_column(i, field_, column)
399
assert isinstance(table, InMemoryTable)
402
def test_in_memory_table_rename_columns(in_memory_pa_table):
403
assert "tokens" in in_memory_pa_table.column_names
404
names = [name if name != "tokens" else "new_tokens" for name in in_memory_pa_table.column_names]
405
table = InMemoryTable(in_memory_pa_table).rename_columns(names)
406
assert table.table == in_memory_pa_table.rename_columns(names)
407
assert isinstance(table, InMemoryTable)
410
def test_in_memory_table_drop(in_memory_pa_table):
411
names = [in_memory_pa_table.column_names[0]]
412
table = InMemoryTable(in_memory_pa_table).drop(names)
413
assert table.table == in_memory_pa_table.drop(names)
414
assert isinstance(table, InMemoryTable)
417
def test_memory_mapped_table_init(arrow_file, in_memory_pa_table):
418
table = MemoryMappedTable(_memory_mapped_arrow_table_from_file(arrow_file), arrow_file)
419
assert table.table == in_memory_pa_table
420
assert isinstance(table, MemoryMappedTable)
421
assert_deepcopy_without_bringing_data_in_memory(table)
422
assert_pickle_without_bringing_data_in_memory(table)
425
def test_memory_mapped_table_from_file(arrow_file, in_memory_pa_table):
426
with assert_arrow_memory_doesnt_increase():
427
table = MemoryMappedTable.from_file(arrow_file)
428
assert table.table == in_memory_pa_table
429
assert isinstance(table, MemoryMappedTable)
430
assert_deepcopy_without_bringing_data_in_memory(table)
431
assert_pickle_without_bringing_data_in_memory(table)
434
def test_memory_mapped_table_from_file_with_replay(arrow_file, in_memory_pa_table):
435
replays = [("slice", (0, 1), {}), ("flatten", (), {})]
436
with assert_arrow_memory_doesnt_increase():
437
table = MemoryMappedTable.from_file(arrow_file, replays=replays)
438
assert len(table) == 1
439
for method, args, kwargs in replays:
440
in_memory_pa_table = getattr(in_memory_pa_table, method)(*args, **kwargs)
441
assert table.table == in_memory_pa_table
442
assert_deepcopy_without_bringing_data_in_memory(table)
443
assert_pickle_without_bringing_data_in_memory(table)
446
def test_memory_mapped_table_deepcopy(arrow_file):
447
table = MemoryMappedTable.from_file(arrow_file)
448
copied_table = copy.deepcopy(table)
449
assert table.table == copied_table.table
450
assert table.path == copied_table.path
451
assert_index_attributes_equal(table, copied_table)
452
# deepcopy must return the exact same arrow objects since they are immutable
453
assert table.table is copied_table.table
454
assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))
457
def test_memory_mapped_table_pickle(arrow_file):
458
table = MemoryMappedTable.from_file(arrow_file)
459
pickled_table = pickle.dumps(table)
460
unpickled_table = pickle.loads(pickled_table)
461
assert unpickled_table.table == table.table
462
assert unpickled_table.path == table.path
463
assert_index_attributes_equal(table, unpickled_table)
466
def test_memory_mapped_table_pickle_doesnt_fill_memory(arrow_file):
467
with assert_arrow_memory_doesnt_increase():
468
table = MemoryMappedTable.from_file(arrow_file)
469
assert_deepcopy_without_bringing_data_in_memory(table)
470
assert_pickle_without_bringing_data_in_memory(table)
473
def test_memory_mapped_table_pickle_applies_replay(arrow_file):
474
replays = [("slice", (0, 1), {}), ("flatten", (), {})]
475
with assert_arrow_memory_doesnt_increase():
476
table = MemoryMappedTable.from_file(arrow_file, replays=replays)
477
assert isinstance(table, MemoryMappedTable)
478
assert table.replays == replays
479
assert_deepcopy_without_bringing_data_in_memory(table)
480
assert_pickle_without_bringing_data_in_memory(table)
483
def test_memory_mapped_table_slice(arrow_file, in_memory_pa_table):
484
table = MemoryMappedTable.from_file(arrow_file).slice(1, 2)
485
assert table.table == in_memory_pa_table.slice(1, 2)
486
assert isinstance(table, MemoryMappedTable)
487
assert table.replays == [("slice", (1, 2), {})]
488
assert_deepcopy_without_bringing_data_in_memory(table)
489
assert_pickle_without_bringing_data_in_memory(table)
492
def test_memory_mapped_table_filter(arrow_file, in_memory_pa_table):
493
mask = pa.array([i % 2 == 0 for i in range(len(in_memory_pa_table))])
494
table = MemoryMappedTable.from_file(arrow_file).filter(mask)
495
assert table.table == in_memory_pa_table.filter(mask)
496
assert isinstance(table, MemoryMappedTable)
497
assert table.replays == [("filter", (mask,), {})]
498
assert_deepcopy_without_bringing_data_in_memory(table)
499
# filter DOES increase memory
500
# assert_pickle_without_bringing_data_in_memory(table)
501
assert_pickle_does_bring_data_in_memory(table)
504
def test_memory_mapped_table_flatten(arrow_file, in_memory_pa_table):
505
table = MemoryMappedTable.from_file(arrow_file).flatten()
506
assert table.table == in_memory_pa_table.flatten()
507
assert isinstance(table, MemoryMappedTable)
508
assert table.replays == [("flatten", (), {})]
509
assert_deepcopy_without_bringing_data_in_memory(table)
510
assert_pickle_without_bringing_data_in_memory(table)
513
def test_memory_mapped_table_combine_chunks(arrow_file, in_memory_pa_table):
514
table = MemoryMappedTable.from_file(arrow_file).combine_chunks()
515
assert table.table == in_memory_pa_table.combine_chunks()
516
assert isinstance(table, MemoryMappedTable)
517
assert table.replays == [("combine_chunks", (), {})]
518
assert_deepcopy_without_bringing_data_in_memory(table)
519
assert_pickle_without_bringing_data_in_memory(table)
522
def test_memory_mapped_table_cast(arrow_file, in_memory_pa_table):
523
assert pa.list_(pa.int64()) in in_memory_pa_table.schema.types
526
k: v if v != pa.list_(pa.int64()) else pa.list_(pa.int32())
527
for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
530
table = MemoryMappedTable.from_file(arrow_file).cast(schema)
531
assert table.table == in_memory_pa_table.cast(schema)
532
assert isinstance(table, MemoryMappedTable)
533
assert table.replays == [("cast", (schema,), {})]
534
assert_deepcopy_without_bringing_data_in_memory(table)
535
# cast DOES increase memory when converting integers precision for example
536
# assert_pickle_without_bringing_data_in_memory(table)
537
assert_pickle_does_bring_data_in_memory(table)
540
def test_memory_mapped_table_replace_schema_metadata(arrow_file, in_memory_pa_table):
541
metadata = {"huggingface": "{}"}
542
table = MemoryMappedTable.from_file(arrow_file).replace_schema_metadata(metadata)
543
assert table.table.schema.metadata == in_memory_pa_table.replace_schema_metadata(metadata).schema.metadata
544
assert isinstance(table, MemoryMappedTable)
545
assert table.replays == [("replace_schema_metadata", (metadata,), {})]
546
assert_deepcopy_without_bringing_data_in_memory(table)
547
assert_pickle_without_bringing_data_in_memory(table)
550
def test_memory_mapped_table_add_column(arrow_file, in_memory_pa_table):
551
i = len(in_memory_pa_table.column_names)
553
column = pa.array(list(range(len(in_memory_pa_table))))
554
table = MemoryMappedTable.from_file(arrow_file).add_column(i, field_, column)
555
assert table.table == in_memory_pa_table.add_column(i, field_, column)
556
assert isinstance(table, MemoryMappedTable)
557
assert table.replays == [("add_column", (i, field_, column), {})]
558
assert_deepcopy_without_bringing_data_in_memory(table)
559
assert_pickle_without_bringing_data_in_memory(table)
562
def test_memory_mapped_table_append_column(arrow_file, in_memory_pa_table):
564
column = pa.array(list(range(len(in_memory_pa_table))))
565
table = MemoryMappedTable.from_file(arrow_file).append_column(field_, column)
566
assert table.table == in_memory_pa_table.append_column(field_, column)
567
assert isinstance(table, MemoryMappedTable)
568
assert table.replays == [("append_column", (field_, column), {})]
569
assert_deepcopy_without_bringing_data_in_memory(table)
570
assert_pickle_without_bringing_data_in_memory(table)
573
def test_memory_mapped_table_remove_column(arrow_file, in_memory_pa_table):
574
table = MemoryMappedTable.from_file(arrow_file).remove_column(0)
575
assert table.table == in_memory_pa_table.remove_column(0)
576
assert isinstance(table, MemoryMappedTable)
577
assert table.replays == [("remove_column", (0,), {})]
578
assert_deepcopy_without_bringing_data_in_memory(table)
579
assert_pickle_without_bringing_data_in_memory(table)
582
def test_memory_mapped_table_set_column(arrow_file, in_memory_pa_table):
583
i = len(in_memory_pa_table.column_names)
585
column = pa.array(list(range(len(in_memory_pa_table))))
586
table = MemoryMappedTable.from_file(arrow_file).set_column(i, field_, column)
587
assert table.table == in_memory_pa_table.set_column(i, field_, column)
588
assert isinstance(table, MemoryMappedTable)
589
assert table.replays == [("set_column", (i, field_, column), {})]
590
assert_deepcopy_without_bringing_data_in_memory(table)
591
assert_pickle_without_bringing_data_in_memory(table)
594
def test_memory_mapped_table_rename_columns(arrow_file, in_memory_pa_table):
595
assert "tokens" in in_memory_pa_table.column_names
596
names = [name if name != "tokens" else "new_tokens" for name in in_memory_pa_table.column_names]
597
table = MemoryMappedTable.from_file(arrow_file).rename_columns(names)
598
assert table.table == in_memory_pa_table.rename_columns(names)
599
assert isinstance(table, MemoryMappedTable)
600
assert table.replays == [("rename_columns", (names,), {})]
601
assert_deepcopy_without_bringing_data_in_memory(table)
602
assert_pickle_without_bringing_data_in_memory(table)
605
def test_memory_mapped_table_drop(arrow_file, in_memory_pa_table):
606
names = [in_memory_pa_table.column_names[0]]
607
table = MemoryMappedTable.from_file(arrow_file).drop(names)
608
assert table.table == in_memory_pa_table.drop(names)
609
assert isinstance(table, MemoryMappedTable)
610
assert table.replays == [("drop", (names,), {})]
611
assert_deepcopy_without_bringing_data_in_memory(table)
612
assert_pickle_without_bringing_data_in_memory(table)
615
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
616
def test_concatenation_table_init(
617
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
621
if blocks_type == "in_memory"
622
else memory_mapped_blocks
623
if blocks_type == "memory_mapped"
624
else mixed_in_memory_and_memory_mapped_blocks
626
table = ConcatenationTable(in_memory_pa_table, blocks)
627
assert table.table == in_memory_pa_table
628
assert table.blocks == blocks
631
def test_concatenation_table_from_blocks(in_memory_pa_table, in_memory_blocks):
632
assert len(in_memory_pa_table) > 2
633
in_memory_table = InMemoryTable(in_memory_pa_table)
634
t1, t2 = in_memory_table.slice(0, 2), in_memory_table.slice(2)
635
table = ConcatenationTable.from_blocks(in_memory_table)
636
assert isinstance(table, ConcatenationTable)
637
assert table.table == in_memory_pa_table
638
assert table.blocks == [[in_memory_table]]
639
table = ConcatenationTable.from_blocks([t1, t2])
640
assert isinstance(table, ConcatenationTable)
641
assert table.table == in_memory_pa_table
642
assert table.blocks == [[in_memory_table]]
643
table = ConcatenationTable.from_blocks([[t1], [t2]])
644
assert isinstance(table, ConcatenationTable)
645
assert table.table == in_memory_pa_table
646
assert table.blocks == [[in_memory_table]]
647
table = ConcatenationTable.from_blocks(in_memory_blocks)
648
assert isinstance(table, ConcatenationTable)
649
assert table.table == in_memory_pa_table
650
assert table.blocks == [[in_memory_table]]
653
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
654
def test_concatenation_table_from_blocks_doesnt_increase_memory(
655
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
658
"in_memory": in_memory_blocks,
659
"memory_mapped": memory_mapped_blocks,
660
"mixed": mixed_in_memory_and_memory_mapped_blocks,
662
with assert_arrow_memory_doesnt_increase():
663
table = ConcatenationTable.from_blocks(blocks)
664
assert isinstance(table, ConcatenationTable)
665
assert table.table == in_memory_pa_table
666
if blocks_type == "in_memory":
667
assert table.blocks == [[InMemoryTable(in_memory_pa_table)]]
669
assert table.blocks == blocks
672
@pytest.mark.parametrize("axis", [0, 1])
673
def test_concatenation_table_from_tables(axis, in_memory_pa_table, arrow_file):
674
in_memory_table = InMemoryTable(in_memory_pa_table)
675
concatenation_table = ConcatenationTable.from_blocks(in_memory_table)
676
memory_mapped_table = MemoryMappedTable.from_file(arrow_file)
677
tables = [in_memory_pa_table, in_memory_table, concatenation_table, memory_mapped_table]
679
expected_table = pa.concat_tables([in_memory_pa_table] * len(tables))
681
# avoids error due to duplicate column names
682
tables[1:] = [add_suffix_to_column_names(table, i) for i, table in enumerate(tables[1:], 1)]
683
expected_table = in_memory_pa_table
684
for table in tables[1:]:
685
for name, col in zip(table.column_names, table.columns):
686
expected_table = expected_table.append_column(name, col)
688
with assert_arrow_memory_doesnt_increase():
689
table = ConcatenationTable.from_tables(tables, axis=axis)
690
assert isinstance(table, ConcatenationTable)
691
assert table.table == expected_table
692
# because of consolidation, we end up with 1 InMemoryTable and 1 MemoryMappedTable
693
assert len(table.blocks) == 1 if axis == 1 else 2
694
assert len(table.blocks[0]) == 1 if axis == 0 else 2
695
assert axis == 1 or len(table.blocks[1]) == 1
696
assert isinstance(table.blocks[0][0], InMemoryTable)
697
assert isinstance(table.blocks[1][0] if axis == 0 else table.blocks[0][1], MemoryMappedTable)
700
def test_concatenation_table_from_tables_axis1_misaligned_blocks(arrow_file):
701
table = MemoryMappedTable.from_file(arrow_file)
702
t1 = table.slice(0, 2)
703
t2 = table.slice(0, 3).rename_columns([col + "_1" for col in table.column_names])
704
concatenated = ConcatenationTable.from_tables(
706
ConcatenationTable.from_blocks([[t1], [t1], [t1]]),
707
ConcatenationTable.from_blocks([[t2], [t2]]),
711
assert len(concatenated) == 6
712
assert [len(row_blocks[0]) for row_blocks in concatenated.blocks] == [2, 1, 1, 2]
713
concatenated = ConcatenationTable.from_tables(
715
ConcatenationTable.from_blocks([[t2], [t2]]),
716
ConcatenationTable.from_blocks([[t1], [t1], [t1]]),
720
assert len(concatenated) == 6
721
assert [len(row_blocks[0]) for row_blocks in concatenated.blocks] == [2, 1, 1, 2]
724
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
725
def test_concatenation_table_deepcopy(
726
blocks_type, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
729
"in_memory": in_memory_blocks,
730
"memory_mapped": memory_mapped_blocks,
731
"mixed": mixed_in_memory_and_memory_mapped_blocks,
733
table = ConcatenationTable.from_blocks(blocks)
734
copied_table = copy.deepcopy(table)
735
assert table.table == copied_table.table
736
assert table.blocks == copied_table.blocks
737
assert_index_attributes_equal(table, copied_table)
738
# deepcopy must return the exact same arrow objects since they are immutable
739
assert table.table is copied_table.table
740
assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))
743
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
744
def test_concatenation_table_pickle(
745
blocks_type, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
748
"in_memory": in_memory_blocks,
749
"memory_mapped": memory_mapped_blocks,
750
"mixed": mixed_in_memory_and_memory_mapped_blocks,
752
table = ConcatenationTable.from_blocks(blocks)
753
pickled_table = pickle.dumps(table)
754
unpickled_table = pickle.loads(pickled_table)
755
assert unpickled_table.table == table.table
756
assert unpickled_table.blocks == table.blocks
757
assert_index_attributes_equal(table, unpickled_table)
760
def test_concat_tables_with_features_metadata(arrow_file, in_memory_pa_table):
761
input_features = Features.from_arrow_schema(in_memory_pa_table.schema)
762
input_features["id"] = Value("int64", id="my_id")
763
intput_schema = input_features.arrow_schema
764
t0 = in_memory_pa_table.replace_schema_metadata(intput_schema.metadata)
765
t1 = MemoryMappedTable.from_file(arrow_file)
767
concatenated_table = concat_tables(tables, axis=0)
768
output_schema = concatenated_table.schema
769
output_features = Features.from_arrow_schema(output_schema)
770
assert output_schema == intput_schema
771
assert output_schema.metadata == intput_schema.metadata
772
assert output_features == input_features
773
assert output_features["id"].id == "my_id"
776
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
777
def test_concatenation_table_slice(
778
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
781
"in_memory": in_memory_blocks,
782
"memory_mapped": memory_mapped_blocks,
783
"mixed": mixed_in_memory_and_memory_mapped_blocks,
785
table = ConcatenationTable.from_blocks(blocks).slice(1, 2)
786
assert table.table == in_memory_pa_table.slice(1, 2)
787
assert isinstance(table, ConcatenationTable)
790
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
791
def test_concatenation_table_filter(
792
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
795
"in_memory": in_memory_blocks,
796
"memory_mapped": memory_mapped_blocks,
797
"mixed": mixed_in_memory_and_memory_mapped_blocks,
799
mask = pa.array([i % 2 == 0 for i in range(len(in_memory_pa_table))])
800
table = ConcatenationTable.from_blocks(blocks).filter(mask)
801
assert table.table == in_memory_pa_table.filter(mask)
802
assert isinstance(table, ConcatenationTable)
805
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
806
def test_concatenation_table_flatten(
807
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
810
"in_memory": in_memory_blocks,
811
"memory_mapped": memory_mapped_blocks,
812
"mixed": mixed_in_memory_and_memory_mapped_blocks,
814
table = ConcatenationTable.from_blocks(blocks).flatten()
815
assert table.table == in_memory_pa_table.flatten()
816
assert isinstance(table, ConcatenationTable)
819
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
820
def test_concatenation_table_combine_chunks(
821
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
824
"in_memory": in_memory_blocks,
825
"memory_mapped": memory_mapped_blocks,
826
"mixed": mixed_in_memory_and_memory_mapped_blocks,
828
table = ConcatenationTable.from_blocks(blocks).combine_chunks()
829
assert table.table == in_memory_pa_table.combine_chunks()
830
assert isinstance(table, ConcatenationTable)
833
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
834
def test_concatenation_table_cast(
835
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
838
"in_memory": in_memory_blocks,
839
"memory_mapped": memory_mapped_blocks,
840
"mixed": mixed_in_memory_and_memory_mapped_blocks,
842
assert pa.list_(pa.int64()) in in_memory_pa_table.schema.types
843
assert pa.int64() in in_memory_pa_table.schema.types
846
k: v if v != pa.list_(pa.int64()) else pa.list_(pa.int32())
847
for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
850
table = ConcatenationTable.from_blocks(blocks).cast(schema)
851
assert table.table == in_memory_pa_table.cast(schema)
852
assert isinstance(table, ConcatenationTable)
855
k: v if v != pa.int64() else pa.int32()
856
for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
859
table = ConcatenationTable.from_blocks(blocks).cast(schema)
860
assert table.table == in_memory_pa_table.cast(schema)
861
assert isinstance(table, ConcatenationTable)
864
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
865
def test_concat_tables_cast_with_features_metadata(
866
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
869
"in_memory": in_memory_blocks,
870
"memory_mapped": memory_mapped_blocks,
871
"mixed": mixed_in_memory_and_memory_mapped_blocks,
873
input_features = Features.from_arrow_schema(in_memory_pa_table.schema)
874
input_features["id"] = Value("int64", id="my_id")
875
intput_schema = input_features.arrow_schema
876
concatenated_table = ConcatenationTable.from_blocks(blocks).cast(intput_schema)
877
output_schema = concatenated_table.schema
878
output_features = Features.from_arrow_schema(output_schema)
879
assert output_schema == intput_schema
880
assert output_schema.metadata == intput_schema.metadata
881
assert output_features == input_features
882
assert output_features["id"].id == "my_id"
885
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
886
def test_concatenation_table_replace_schema_metadata(
887
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
890
"in_memory": in_memory_blocks,
891
"memory_mapped": memory_mapped_blocks,
892
"mixed": mixed_in_memory_and_memory_mapped_blocks,
894
metadata = {"huggingface": "{}"}
895
table = ConcatenationTable.from_blocks(blocks).replace_schema_metadata(metadata)
896
assert table.table.schema.metadata == in_memory_pa_table.replace_schema_metadata(metadata).schema.metadata
897
assert isinstance(table, ConcatenationTable)
900
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
901
def test_concatenation_table_add_column(
902
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
905
"in_memory": in_memory_blocks,
906
"memory_mapped": memory_mapped_blocks,
907
"mixed": mixed_in_memory_and_memory_mapped_blocks,
909
i = len(in_memory_pa_table.column_names)
911
column = pa.array(list(range(len(in_memory_pa_table))))
912
with pytest.raises(NotImplementedError):
913
ConcatenationTable.from_blocks(blocks).add_column(i, field_, column)
914
# assert table.table == in_memory_pa_table.add_column(i, field_, column)
915
# unpickled_table = pickle.loads(pickle.dumps(table))
916
# assert unpickled_table.table == in_memory_pa_table.add_column(i, field_, column)
919
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
920
def test_concatenation_table_append_column(
921
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
924
"in_memory": in_memory_blocks,
925
"memory_mapped": memory_mapped_blocks,
926
"mixed": mixed_in_memory_and_memory_mapped_blocks,
929
column = pa.array(list(range(len(in_memory_pa_table))))
930
with pytest.raises(NotImplementedError):
931
ConcatenationTable.from_blocks(blocks).append_column(field_, column)
932
# assert table.table == in_memory_pa_table.append_column(field_, column)
933
# unpickled_table = pickle.loads(pickle.dumps(table))
934
# assert unpickled_table.table == in_memory_pa_table.append_column(field_, column)
937
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
938
def test_concatenation_table_remove_column(
939
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
942
"in_memory": in_memory_blocks,
943
"memory_mapped": memory_mapped_blocks,
944
"mixed": mixed_in_memory_and_memory_mapped_blocks,
946
table = ConcatenationTable.from_blocks(blocks).remove_column(0)
947
assert table.table == in_memory_pa_table.remove_column(0)
948
assert isinstance(table, ConcatenationTable)
951
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
952
def test_concatenation_table_set_column(
953
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
956
"in_memory": in_memory_blocks,
957
"memory_mapped": memory_mapped_blocks,
958
"mixed": mixed_in_memory_and_memory_mapped_blocks,
960
i = len(in_memory_pa_table.column_names)
962
column = pa.array(list(range(len(in_memory_pa_table))))
963
with pytest.raises(NotImplementedError):
964
ConcatenationTable.from_blocks(blocks).set_column(i, field_, column)
965
# assert table.table == in_memory_pa_table.set_column(i, field_, column)
966
# unpickled_table = pickle.loads(pickle.dumps(table))
967
# assert unpickled_table.table == in_memory_pa_table.set_column(i, field_, column)
970
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
971
def test_concatenation_table_rename_columns(
972
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
975
"in_memory": in_memory_blocks,
976
"memory_mapped": memory_mapped_blocks,
977
"mixed": mixed_in_memory_and_memory_mapped_blocks,
979
assert "tokens" in in_memory_pa_table.column_names
980
names = [name if name != "tokens" else "new_tokens" for name in in_memory_pa_table.column_names]
981
table = ConcatenationTable.from_blocks(blocks).rename_columns(names)
982
assert isinstance(table, ConcatenationTable)
983
assert table.table == in_memory_pa_table.rename_columns(names)
986
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
987
def test_concatenation_table_drop(
988
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
991
"in_memory": in_memory_blocks,
992
"memory_mapped": memory_mapped_blocks,
993
"mixed": mixed_in_memory_and_memory_mapped_blocks,
995
names = [in_memory_pa_table.column_names[0]]
996
table = ConcatenationTable.from_blocks(blocks).drop(names)
997
assert table.table == in_memory_pa_table.drop(names)
998
assert isinstance(table, ConcatenationTable)
1001
def test_concat_tables(arrow_file, in_memory_pa_table):
1002
t0 = in_memory_pa_table
1003
t1 = InMemoryTable(t0)
1004
t2 = MemoryMappedTable.from_file(arrow_file)
1005
t3 = ConcatenationTable.from_blocks(t1)
1006
tables = [t0, t1, t2, t3]
1007
concatenated_table = concat_tables(tables, axis=0)
1008
assert concatenated_table.table == pa.concat_tables([t0] * 4)
1009
assert concatenated_table.table.shape == (40, 4)
1010
assert isinstance(concatenated_table, ConcatenationTable)
1011
assert len(concatenated_table.blocks) == 3 # t0 and t1 are consolidated as a single InMemoryTable
1012
assert isinstance(concatenated_table.blocks[0][0], InMemoryTable)
1013
assert isinstance(concatenated_table.blocks[1][0], MemoryMappedTable)
1014
assert isinstance(concatenated_table.blocks[2][0], InMemoryTable)
1015
# add suffix to avoid error due to duplicate column names
1016
concatenated_table = concat_tables(
1017
[add_suffix_to_column_names(table, i) for i, table in enumerate(tables)], axis=1
1019
assert concatenated_table.table.shape == (10, 16)
1020
assert len(concatenated_table.blocks[0]) == 3 # t0 and t1 are consolidated as a single InMemoryTable
1021
assert isinstance(concatenated_table.blocks[0][0], InMemoryTable)
1022
assert isinstance(concatenated_table.blocks[0][1], MemoryMappedTable)
1023
assert isinstance(concatenated_table.blocks[0][2], InMemoryTable)
1026
def _interpolation_search_ground_truth(arr: List[int], x: int) -> Union[int, IndexError]:
1027
for i in range(len(arr) - 1):
1028
if arr[i] <= x < arr[i + 1]:
1033
class _ListWithGetitemCounter(list):
1034
def __init__(self, *args, **kwargs):
1035
super().__init__(*args, **kwargs)
1036
self.unique_getitem_calls = set()
1038
def __getitem__(self, i):
1039
out = super().__getitem__(i)
1040
self.unique_getitem_calls.add(i)
1044
def getitem_unique_count(self):
1045
return len(self.unique_getitem_calls)
1048
@pytest.mark.parametrize(
1050
[(np.arange(0, 14, 3), x) for x in range(-1, 22)]
1051
+ [(list(np.arange(-5, 5)), x) for x in range(-6, 6)]
1052
+ [([0, 1_000, 1_001, 1_003], x) for x in [-1, 0, 2, 100, 999, 1_000, 1_001, 1_002, 1_003, 1_004]]
1053
+ [(list(range(1_000)), x) for x in [-1, 0, 1, 10, 666, 999, 1_000, 1_0001]],
1055
def test_interpolation_search(arr, x):
1056
ground_truth = _interpolation_search_ground_truth(arr, x)
1057
if isinstance(ground_truth, int):
1058
arr = _ListWithGetitemCounter(arr)
1059
output = _interpolation_search(arr, x)
1060
assert ground_truth == output
1061
# 4 maximum unique getitem calls is expected for the cases of this test
1062
# but it can be bigger for large and messy arrays.
1063
assert arr.getitem_unique_count <= 4
1065
with pytest.raises(ground_truth):
1066
_interpolation_search(arr, x)
1069
def test_indexed_table_mixin():
1070
n_rows_per_chunk = 10
1072
pa_table = pa.Table.from_pydict({"col": [0] * n_rows_per_chunk})
1073
pa_table = pa.concat_tables([pa_table] * n_chunks)
1074
table = Table(pa_table)
1075
assert all(table._offsets.tolist() == np.cumsum([0] + [n_rows_per_chunk] * n_chunks))
1076
assert table.fast_slice(5) == pa_table.slice(5)
1077
assert table.fast_slice(2, 13) == pa_table.slice(2, 13)
1080
def test_cast_array_to_features():
1081
arr = pa.array([[0, 1]])
1082
assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
1083
with pytest.raises(TypeError):
1084
cast_array_to_feature(arr, Sequence(Value("string")), allow_number_to_str=False)
1087
def test_cast_array_to_features_nested():
1088
arr = pa.array([[{"foo": [0]}]])
1089
assert cast_array_to_feature(arr, [{"foo": Sequence(Value("string"))}]).type == pa.list_(
1090
pa.struct({"foo": pa.list_(pa.string())})
1094
def test_cast_array_to_features_to_nested_with_no_fields():
1095
arr = pa.array([{}])
1096
assert cast_array_to_feature(arr, {}).type == pa.struct({})
1097
assert cast_array_to_feature(arr, {}).to_pylist() == arr.to_pylist()
1100
def test_cast_array_to_features_nested_with_nulls():
1102
arr = pa.array([{"foo": [None, [0]]}], pa.struct({"foo": pa.list_(pa.list_(pa.int64()))}))
1103
casted_array = cast_array_to_feature(arr, {"foo": [[Value("int64")]]})
1104
assert casted_array.type == pa.struct({"foo": pa.list_(pa.list_(pa.int64()))})
1105
assert casted_array.to_pylist() == arr.to_pylist()
1107
arr = pa.array([{"foo": [None, [0]]}], pa.struct({"foo": pa.list_(pa.list_(pa.int64()))}))
1108
casted_array = cast_array_to_feature(arr, {"foo": [[Value("int32")]]})
1109
assert casted_array.type == pa.struct({"foo": pa.list_(pa.list_(pa.int32()))})
1110
assert casted_array.to_pylist() == [{"foo": [None, [0]]}]
1113
def test_cast_array_to_features_to_null_type():
1115
arr = pa.array([[None, None]])
1116
assert cast_array_to_feature(arr, Sequence(Value("null"))).type == pa.list_(pa.null())
1119
arr = pa.array([[None, 1]])
1120
with pytest.raises(TypeError):
1121
cast_array_to_feature(arr, Sequence(Value("null")))
1124
def test_cast_array_to_features_array_xd():
1126
arr = pa.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], pa.list_(pa.list_(pa.int32(), 2), 2))
1127
casted_array = cast_array_to_feature(arr, Array2D(shape=(2, 2), dtype="int32"))
1128
assert casted_array.type == Array2DExtensionType(shape=(2, 2), dtype="int32")
1129
# different storage type
1130
casted_array = cast_array_to_feature(arr, Array2D(shape=(2, 2), dtype="float32"))
1131
assert casted_array.type == Array2DExtensionType(shape=(2, 2), dtype="float32")
1134
def test_cast_array_to_features_sequence_classlabel():
1135
arr = pa.array([[], [1], [0, 1]], pa.list_(pa.int64()))
1136
assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
1138
arr = pa.array([[], ["bar"], ["foo", "bar"]], pa.list_(pa.string()))
1139
assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
1142
arr = pa.array([[], []], pa.list_(pa.int64()))
1143
assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
1145
arr = pa.array([[], []], pa.list_(pa.string()))
1146
assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
1148
# Test invalid class labels
1149
arr = pa.array([[2]], pa.list_(pa.int64()))
1150
with pytest.raises(ValueError):
1151
assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"])))
1153
arr = pa.array([["baz"]], pa.list_(pa.string()))
1154
with pytest.raises(ValueError):
1155
assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"])))
1158
@pytest.mark.parametrize(
1161
pa.array([[0, 1, 2], [3, None, 5], None, [6, 7, 8], None], pa.list_(pa.int32(), 3)),
1164
@pytest.mark.parametrize("slice", [None, slice(1, None), slice(-1), slice(1, 3), slice(2, 3), slice(1, 1)])
1165
@pytest.mark.parametrize("target_value_feature", [Value("int64")])
1166
def test_cast_fixed_size_list_array_to_features_sequence(arr, slice, target_value_feature):
1167
arr = arr if slice is None else arr[slice]
1169
casted_array = cast_array_to_feature(arr, Sequence(target_value_feature, length=arr.type.list_size))
1170
assert casted_array.type == get_nested_type(Sequence(target_value_feature, length=arr.type.list_size))
1171
assert casted_array.to_pylist() == arr.to_pylist()
1172
with pytest.raises(TypeError):
1173
cast_array_to_feature(arr, Sequence(target_value_feature, length=arr.type.list_size + 1))
1174
# Variable size list
1175
casted_array = cast_array_to_feature(arr, Sequence(target_value_feature))
1176
assert casted_array.type == get_nested_type(Sequence(target_value_feature))
1177
assert casted_array.to_pylist() == arr.to_pylist()
1178
casted_array = cast_array_to_feature(arr, [target_value_feature])
1179
assert casted_array.type == get_nested_type([target_value_feature])
1180
assert casted_array.to_pylist() == arr.to_pylist()
1183
@pytest.mark.parametrize(
1186
pa.array([[0, 1, 2], [3, None, 5], None, [6, 7, 8], None], pa.list_(pa.int32())),
1189
@pytest.mark.parametrize("slice", [None, slice(1, None), slice(-1), slice(1, 3), slice(2, 3), slice(1, 1)])
1190
@pytest.mark.parametrize("target_value_feature", [Value("int64")])
1191
def test_cast_list_array_to_features_sequence(arr, slice, target_value_feature):
1192
arr = arr if slice is None else arr[slice]
1193
# Variable size list
1194
casted_array = cast_array_to_feature(arr, Sequence(target_value_feature))
1195
assert casted_array.type == get_nested_type(Sequence(target_value_feature))
1196
assert casted_array.to_pylist() == arr.to_pylist()
1197
casted_array = cast_array_to_feature(arr, [target_value_feature])
1198
assert casted_array.type == get_nested_type([target_value_feature])
1199
assert casted_array.to_pylist() == arr.to_pylist()
1201
list_size = arr.value_lengths().drop_null()[0].as_py() if arr.value_lengths().drop_null() else 2
1202
casted_array = cast_array_to_feature(arr, Sequence(target_value_feature, length=list_size))
1203
assert casted_array.type == get_nested_type(Sequence(target_value_feature, length=list_size))
1204
assert casted_array.to_pylist() == arr.to_pylist()
1207
def test_cast_array_xd_to_features_sequence():
1208
arr = np.random.randint(0, 10, size=(8, 2, 3)).tolist()
1209
arr = Array2DExtensionType(shape=(2, 3), dtype="int64").wrap_array(pa.array(arr, pa.list_(pa.list_(pa.int64()))))
1210
arr = pa.ListArray.from_arrays([0, None, 4, 8], arr)
1211
# Variable size list
1212
casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32")))
1213
assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32")))
1214
assert casted_array.to_pylist() == arr.to_pylist()
1216
casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4))
1217
assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4))
1218
assert casted_array.to_pylist() == arr.to_pylist()
1221
def test_embed_array_storage(image_file):
1222
array = pa.array([{"bytes": None, "path": image_file}], type=Image.pa_type)
1223
embedded_images_array = embed_array_storage(array, Image())
1224
assert isinstance(embedded_images_array.to_pylist()[0]["path"], str)
1225
assert embedded_images_array.to_pylist()[0]["path"] == "test_image_rgb.jpg"
1226
assert isinstance(embedded_images_array.to_pylist()[0]["bytes"], bytes)
1229
def test_embed_array_storage_nested(image_file):
1230
array = pa.array([[{"bytes": None, "path": image_file}]], type=pa.list_(Image.pa_type))
1231
embedded_images_array = embed_array_storage(array, [Image()])
1232
assert isinstance(embedded_images_array.to_pylist()[0][0]["path"], str)
1233
assert isinstance(embedded_images_array.to_pylist()[0][0]["bytes"], bytes)
1234
array = pa.array([{"foo": {"bytes": None, "path": image_file}}], type=pa.struct({"foo": Image.pa_type}))
1235
embedded_images_array = embed_array_storage(array, {"foo": Image()})
1236
assert isinstance(embedded_images_array.to_pylist()[0]["foo"]["path"], str)
1237
assert isinstance(embedded_images_array.to_pylist()[0]["foo"]["bytes"], bytes)
1240
def test_embed_table_storage(image_file):
1241
features = Features({"image": Image()})
1242
table = table_cast(pa.table({"image": [image_file]}), features.arrow_schema)
1243
embedded_images_table = embed_table_storage(table)
1244
assert isinstance(embedded_images_table.to_pydict()["image"][0]["path"], str)
1245
assert isinstance(embedded_images_table.to_pydict()["image"][0]["bytes"], bytes)
1248
@pytest.mark.parametrize(
1251
InMemoryTable(pa.table({"foo": range(10)})),
1252
InMemoryTable(pa.concat_tables([pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})])),
1253
InMemoryTable(pa.concat_tables([pa.table({"foo": [i]}) for i in range(10)])),
1256
@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20])
1257
@pytest.mark.parametrize("drop_last_batch", [False, True])
1258
def test_table_iter(table, batch_size, drop_last_batch):
1259
num_rows = len(table) if not drop_last_batch else len(table) // batch_size * batch_size
1260
num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
1261
subtables = list(table_iter(table, batch_size=batch_size, drop_last_batch=drop_last_batch))
1262
assert len(subtables) == num_batches
1264
assert all(len(subtable) == batch_size for subtable in subtables)
1266
assert all(len(subtable) == batch_size for subtable in subtables[:-1])
1267
assert len(subtables[-1]) <= batch_size
1269
reloaded = pa.concat_tables(subtables)
1270
assert table.slice(0, num_rows).to_pydict() == reloaded.to_pydict()