datasets

Форк
0
/
test_table.py 
1270 строк · 53.3 Кб
1
import copy
2
import pickle
3
from typing import List, Union
4

5
import numpy as np
6
import pyarrow as pa
7
import pytest
8

9
from datasets import Sequence, Value
10
from datasets.features.features import Array2D, Array2DExtensionType, ClassLabel, Features, Image, get_nested_type
11
from datasets.table import (
12
    ConcatenationTable,
13
    InMemoryTable,
14
    MemoryMappedTable,
15
    Table,
16
    TableBlock,
17
    _in_memory_arrow_table_from_buffer,
18
    _in_memory_arrow_table_from_file,
19
    _interpolation_search,
20
    _memory_mapped_arrow_table_from_file,
21
    cast_array_to_feature,
22
    concat_tables,
23
    embed_array_storage,
24
    embed_table_storage,
25
    inject_arrow_table_documentation,
26
    table_cast,
27
    table_iter,
28
)
29

30
from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, slow
31

32

33
@pytest.fixture(scope="session")
34
def in_memory_pa_table(arrow_file) -> pa.Table:
35
    return pa.ipc.open_stream(arrow_file).read_all()
36

37

38
def _to_testing_blocks(table: TableBlock) -> List[List[TableBlock]]:
39
    assert len(table) > 2
40
    blocks = [
41
        [table.slice(0, 2)],
42
        [table.slice(2).drop([c for c in table.column_names if c != "tokens"]), table.slice(2).drop(["tokens"])],
43
    ]
44
    return blocks
45

46

47
@pytest.fixture(scope="session")
48
def in_memory_blocks(in_memory_pa_table):
49
    table = InMemoryTable(in_memory_pa_table)
50
    return _to_testing_blocks(table)
51

52

53
@pytest.fixture(scope="session")
54
def memory_mapped_blocks(arrow_file):
55
    table = MemoryMappedTable.from_file(arrow_file)
56
    return _to_testing_blocks(table)
57

58

59
@pytest.fixture(scope="session")
60
def mixed_in_memory_and_memory_mapped_blocks(in_memory_blocks, memory_mapped_blocks):
61
    return in_memory_blocks[:1] + memory_mapped_blocks[1:]
62

63

64
def assert_deepcopy_without_bringing_data_in_memory(table: MemoryMappedTable):
65
    with assert_arrow_memory_doesnt_increase():
66
        copied_table = copy.deepcopy(table)
67
    assert isinstance(copied_table, MemoryMappedTable)
68
    assert copied_table.table == table.table
69

70

71
def assert_deepcopy_does_bring_data_in_memory(table: MemoryMappedTable):
72
    with assert_arrow_memory_increases():
73
        copied_table = copy.deepcopy(table)
74
    assert isinstance(copied_table, MemoryMappedTable)
75
    assert copied_table.table == table.table
76

77

78
def assert_pickle_without_bringing_data_in_memory(table: MemoryMappedTable):
79
    with assert_arrow_memory_doesnt_increase():
80
        pickled_table = pickle.dumps(table)
81
        unpickled_table = pickle.loads(pickled_table)
82
    assert isinstance(unpickled_table, MemoryMappedTable)
83
    assert unpickled_table.table == table.table
84

85

86
def assert_pickle_does_bring_data_in_memory(table: MemoryMappedTable):
87
    with assert_arrow_memory_increases():
88
        pickled_table = pickle.dumps(table)
89
        unpickled_table = pickle.loads(pickled_table)
90
    assert isinstance(unpickled_table, MemoryMappedTable)
91
    assert unpickled_table.table == table.table
92

93

94
def assert_index_attributes_equal(table: Table, other: Table):
95
    assert table._batches == other._batches
96
    np.testing.assert_array_equal(table._offsets, other._offsets)
97
    assert table._schema == other._schema
98

99

100
def add_suffix_to_column_names(table, suffix):
101
    return table.rename_columns([f"{name}{suffix}" for name in table.column_names])
102

103

104
def test_inject_arrow_table_documentation(in_memory_pa_table):
105
    method = pa.Table.slice
106

107
    def function_to_wrap(*args):
108
        return method(*args)
109

110
    args = (0, 1)
111
    wrapped_method = inject_arrow_table_documentation(method)(function_to_wrap)
112
    assert method(in_memory_pa_table, *args) == wrapped_method(in_memory_pa_table, *args)
113
    assert "pyarrow.Table" not in wrapped_method.__doc__
114
    assert "Table" in wrapped_method.__doc__
115

116

117
def test_in_memory_arrow_table_from_file(arrow_file, in_memory_pa_table):
118
    with assert_arrow_memory_increases():
119
        pa_table = _in_memory_arrow_table_from_file(arrow_file)
120
        assert in_memory_pa_table == pa_table
121

122

123
def test_in_memory_arrow_table_from_buffer(in_memory_pa_table):
124
    with assert_arrow_memory_increases():
125
        buf_writer = pa.BufferOutputStream()
126
        writer = pa.RecordBatchStreamWriter(buf_writer, schema=in_memory_pa_table.schema)
127
        writer.write_table(in_memory_pa_table)
128
        writer.close()
129
        buf_writer.close()
130
        pa_table = _in_memory_arrow_table_from_buffer(buf_writer.getvalue())
131
        assert in_memory_pa_table == pa_table
132

133

134
def test_memory_mapped_arrow_table_from_file(arrow_file, in_memory_pa_table):
135
    with assert_arrow_memory_doesnt_increase():
136
        pa_table = _memory_mapped_arrow_table_from_file(arrow_file)
137
        assert in_memory_pa_table == pa_table
138

139

140
def test_table_init(in_memory_pa_table):
141
    table = Table(in_memory_pa_table)
142
    assert table.table == in_memory_pa_table
143

144

145
def test_table_validate(in_memory_pa_table):
146
    table = Table(in_memory_pa_table)
147
    assert table.validate() == in_memory_pa_table.validate()
148

149

150
def test_table_equals(in_memory_pa_table):
151
    table = Table(in_memory_pa_table)
152
    assert table.equals(in_memory_pa_table)
153

154

155
def test_table_to_batches(in_memory_pa_table):
156
    table = Table(in_memory_pa_table)
157
    assert table.to_batches() == in_memory_pa_table.to_batches()
158

159

160
def test_table_to_pydict(in_memory_pa_table):
161
    table = Table(in_memory_pa_table)
162
    assert table.to_pydict() == in_memory_pa_table.to_pydict()
163

164

165
def test_table_to_string(in_memory_pa_table):
166
    table = Table(in_memory_pa_table)
167
    assert table.to_string() == in_memory_pa_table.to_string()
168

169

170
def test_table_field(in_memory_pa_table):
171
    assert "tokens" in in_memory_pa_table.column_names
172
    table = Table(in_memory_pa_table)
173
    assert table.field("tokens") == in_memory_pa_table.field("tokens")
174

175

176
def test_table_column(in_memory_pa_table):
177
    assert "tokens" in in_memory_pa_table.column_names
178
    table = Table(in_memory_pa_table)
179
    assert table.column("tokens") == in_memory_pa_table.column("tokens")
180

181

182
def test_table_itercolumns(in_memory_pa_table):
183
    table = Table(in_memory_pa_table)
184
    assert isinstance(table.itercolumns(), type(in_memory_pa_table.itercolumns()))
185
    assert list(table.itercolumns()) == list(in_memory_pa_table.itercolumns())
186

187

188
def test_table_getitem(in_memory_pa_table):
189
    table = Table(in_memory_pa_table)
190
    assert table[0] == in_memory_pa_table[0]
191

192

193
def test_table_len(in_memory_pa_table):
194
    table = Table(in_memory_pa_table)
195
    assert len(table) == len(in_memory_pa_table)
196

197

198
def test_table_str(in_memory_pa_table):
199
    table = Table(in_memory_pa_table)
200
    assert str(table) == str(in_memory_pa_table).replace("pyarrow.Table", "Table")
201
    assert repr(table) == repr(in_memory_pa_table).replace("pyarrow.Table", "Table")
202

203

204
@pytest.mark.parametrize(
205
    "attribute", ["schema", "columns", "num_columns", "num_rows", "shape", "nbytes", "column_names"]
206
)
207
def test_table_attributes(in_memory_pa_table, attribute):
208
    table = Table(in_memory_pa_table)
209
    assert getattr(table, attribute) == getattr(in_memory_pa_table, attribute)
210

211

212
def test_in_memory_table_from_file(arrow_file, in_memory_pa_table):
213
    with assert_arrow_memory_increases():
214
        table = InMemoryTable.from_file(arrow_file)
215
        assert table.table == in_memory_pa_table
216
        assert isinstance(table, InMemoryTable)
217

218

219
def test_in_memory_table_from_buffer(in_memory_pa_table):
220
    with assert_arrow_memory_increases():
221
        buf_writer = pa.BufferOutputStream()
222
        writer = pa.RecordBatchStreamWriter(buf_writer, schema=in_memory_pa_table.schema)
223
        writer.write_table(in_memory_pa_table)
224
        writer.close()
225
        buf_writer.close()
226
        table = InMemoryTable.from_buffer(buf_writer.getvalue())
227
        assert table.table == in_memory_pa_table
228
        assert isinstance(table, InMemoryTable)
229

230

231
def test_in_memory_table_from_pandas(in_memory_pa_table):
232
    df = in_memory_pa_table.to_pandas()
233
    with assert_arrow_memory_increases():
234
        # with no schema it might infer another order of the fields in the schema
235
        table = InMemoryTable.from_pandas(df)
236
        assert isinstance(table, InMemoryTable)
237
    # by specifying schema we get the same order of features, and so the exact same table
238
    table = InMemoryTable.from_pandas(df, schema=in_memory_pa_table.schema)
239
    assert table.table == in_memory_pa_table
240
    assert isinstance(table, InMemoryTable)
241

242

243
def test_in_memory_table_from_arrays(in_memory_pa_table):
244
    arrays = list(in_memory_pa_table.columns)
245
    names = list(in_memory_pa_table.column_names)
246
    table = InMemoryTable.from_arrays(arrays, names=names)
247
    assert table.table == in_memory_pa_table
248
    assert isinstance(table, InMemoryTable)
249

250

251
def test_in_memory_table_from_pydict(in_memory_pa_table):
252
    pydict = in_memory_pa_table.to_pydict()
253
    with assert_arrow_memory_increases():
254
        table = InMemoryTable.from_pydict(pydict)
255
        assert isinstance(table, InMemoryTable)
256
        assert table.table == pa.Table.from_pydict(pydict)
257

258

259
def test_in_memory_table_from_pylist(in_memory_pa_table):
260
    pylist = InMemoryTable(in_memory_pa_table).to_pylist()
261
    table = InMemoryTable.from_pylist(pylist)
262
    assert isinstance(table, InMemoryTable)
263
    assert pylist == table.to_pylist()
264

265

266
def test_in_memory_table_from_batches(in_memory_pa_table):
267
    batches = list(in_memory_pa_table.to_batches())
268
    table = InMemoryTable.from_batches(batches)
269
    assert table.table == in_memory_pa_table
270
    assert isinstance(table, InMemoryTable)
271

272

273
def test_in_memory_table_deepcopy(in_memory_pa_table):
274
    table = InMemoryTable(in_memory_pa_table)
275
    copied_table = copy.deepcopy(table)
276
    assert table.table == copied_table.table
277
    assert_index_attributes_equal(table, copied_table)
278
    # deepcopy must return the exact same arrow objects since they are immutable
279
    assert table.table is copied_table.table
280
    assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))
281

282

283
def test_in_memory_table_pickle(in_memory_pa_table):
284
    table = InMemoryTable(in_memory_pa_table)
285
    pickled_table = pickle.dumps(table)
286
    unpickled_table = pickle.loads(pickled_table)
287
    assert unpickled_table.table == table.table
288
    assert_index_attributes_equal(table, unpickled_table)
289

290

291
@slow
292
def test_in_memory_table_pickle_big_table():
293
    big_table_4GB = InMemoryTable.from_pydict({"col": [0] * ((4 * 8 << 30) // 64)})
294
    length = len(big_table_4GB)
295
    big_table_4GB = pickle.dumps(big_table_4GB)
296
    big_table_4GB = pickle.loads(big_table_4GB)
297
    assert len(big_table_4GB) == length
298

299

300
def test_in_memory_table_slice(in_memory_pa_table):
301
    table = InMemoryTable(in_memory_pa_table).slice(1, 2)
302
    assert table.table == in_memory_pa_table.slice(1, 2)
303
    assert isinstance(table, InMemoryTable)
304

305

306
def test_in_memory_table_filter(in_memory_pa_table):
307
    mask = pa.array([i % 2 == 0 for i in range(len(in_memory_pa_table))])
308
    table = InMemoryTable(in_memory_pa_table).filter(mask)
309
    assert table.table == in_memory_pa_table.filter(mask)
310
    assert isinstance(table, InMemoryTable)
311

312

313
def test_in_memory_table_flatten(in_memory_pa_table):
314
    table = InMemoryTable(in_memory_pa_table).flatten()
315
    assert table.table == in_memory_pa_table.flatten()
316
    assert isinstance(table, InMemoryTable)
317

318

319
def test_in_memory_table_combine_chunks(in_memory_pa_table):
320
    table = InMemoryTable(in_memory_pa_table).combine_chunks()
321
    assert table.table == in_memory_pa_table.combine_chunks()
322
    assert isinstance(table, InMemoryTable)
323

324

325
def test_in_memory_table_cast(in_memory_pa_table):
326
    assert pa.list_(pa.int64()) in in_memory_pa_table.schema.types
327
    schema = pa.schema(
328
        {
329
            k: v if v != pa.list_(pa.int64()) else pa.list_(pa.int32())
330
            for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
331
        }
332
    )
333
    table = InMemoryTable(in_memory_pa_table).cast(schema)
334
    assert table.table == in_memory_pa_table.cast(schema)
335
    assert isinstance(table, InMemoryTable)
336

337

338
def test_in_memory_table_cast_reorder_struct():
339
    table = InMemoryTable(
340
        pa.Table.from_pydict(
341
            {
342
                "top": [
343
                    {
344
                        "foo": "a",
345
                        "bar": "b",
346
                    }
347
                ]
348
            }
349
        )
350
    )
351
    schema = pa.schema({"top": pa.struct({"bar": pa.string(), "foo": pa.string()})})
352
    assert table.cast(schema).schema == schema
353

354

355
def test_in_memory_table_cast_with_hf_features():
356
    table = InMemoryTable(pa.Table.from_pydict({"labels": [0, 1]}))
357
    features = Features({"labels": ClassLabel(names=["neg", "pos"])})
358
    schema = features.arrow_schema
359
    assert table.cast(schema).schema == schema
360
    assert Features.from_arrow_schema(table.cast(schema).schema) == features
361

362

363
def test_in_memory_table_replace_schema_metadata(in_memory_pa_table):
364
    metadata = {"huggingface": "{}"}
365
    table = InMemoryTable(in_memory_pa_table).replace_schema_metadata(metadata)
366
    assert table.table.schema.metadata == in_memory_pa_table.replace_schema_metadata(metadata).schema.metadata
367
    assert isinstance(table, InMemoryTable)
368

369

370
def test_in_memory_table_add_column(in_memory_pa_table):
371
    i = len(in_memory_pa_table.column_names)
372
    field_ = "new_field"
373
    column = pa.array(list(range(len(in_memory_pa_table))))
374
    table = InMemoryTable(in_memory_pa_table).add_column(i, field_, column)
375
    assert table.table == in_memory_pa_table.add_column(i, field_, column)
376
    assert isinstance(table, InMemoryTable)
377

378

379
def test_in_memory_table_append_column(in_memory_pa_table):
380
    field_ = "new_field"
381
    column = pa.array(list(range(len(in_memory_pa_table))))
382
    table = InMemoryTable(in_memory_pa_table).append_column(field_, column)
383
    assert table.table == in_memory_pa_table.append_column(field_, column)
384
    assert isinstance(table, InMemoryTable)
385

386

387
def test_in_memory_table_remove_column(in_memory_pa_table):
388
    table = InMemoryTable(in_memory_pa_table).remove_column(0)
389
    assert table.table == in_memory_pa_table.remove_column(0)
390
    assert isinstance(table, InMemoryTable)
391

392

393
def test_in_memory_table_set_column(in_memory_pa_table):
394
    i = len(in_memory_pa_table.column_names)
395
    field_ = "new_field"
396
    column = pa.array(list(range(len(in_memory_pa_table))))
397
    table = InMemoryTable(in_memory_pa_table).set_column(i, field_, column)
398
    assert table.table == in_memory_pa_table.set_column(i, field_, column)
399
    assert isinstance(table, InMemoryTable)
400

401

402
def test_in_memory_table_rename_columns(in_memory_pa_table):
403
    assert "tokens" in in_memory_pa_table.column_names
404
    names = [name if name != "tokens" else "new_tokens" for name in in_memory_pa_table.column_names]
405
    table = InMemoryTable(in_memory_pa_table).rename_columns(names)
406
    assert table.table == in_memory_pa_table.rename_columns(names)
407
    assert isinstance(table, InMemoryTable)
408

409

410
def test_in_memory_table_drop(in_memory_pa_table):
411
    names = [in_memory_pa_table.column_names[0]]
412
    table = InMemoryTable(in_memory_pa_table).drop(names)
413
    assert table.table == in_memory_pa_table.drop(names)
414
    assert isinstance(table, InMemoryTable)
415

416

417
def test_memory_mapped_table_init(arrow_file, in_memory_pa_table):
418
    table = MemoryMappedTable(_memory_mapped_arrow_table_from_file(arrow_file), arrow_file)
419
    assert table.table == in_memory_pa_table
420
    assert isinstance(table, MemoryMappedTable)
421
    assert_deepcopy_without_bringing_data_in_memory(table)
422
    assert_pickle_without_bringing_data_in_memory(table)
423

424

425
def test_memory_mapped_table_from_file(arrow_file, in_memory_pa_table):
426
    with assert_arrow_memory_doesnt_increase():
427
        table = MemoryMappedTable.from_file(arrow_file)
428
    assert table.table == in_memory_pa_table
429
    assert isinstance(table, MemoryMappedTable)
430
    assert_deepcopy_without_bringing_data_in_memory(table)
431
    assert_pickle_without_bringing_data_in_memory(table)
432

433

434
def test_memory_mapped_table_from_file_with_replay(arrow_file, in_memory_pa_table):
435
    replays = [("slice", (0, 1), {}), ("flatten", (), {})]
436
    with assert_arrow_memory_doesnt_increase():
437
        table = MemoryMappedTable.from_file(arrow_file, replays=replays)
438
    assert len(table) == 1
439
    for method, args, kwargs in replays:
440
        in_memory_pa_table = getattr(in_memory_pa_table, method)(*args, **kwargs)
441
    assert table.table == in_memory_pa_table
442
    assert_deepcopy_without_bringing_data_in_memory(table)
443
    assert_pickle_without_bringing_data_in_memory(table)
444

445

446
def test_memory_mapped_table_deepcopy(arrow_file):
447
    table = MemoryMappedTable.from_file(arrow_file)
448
    copied_table = copy.deepcopy(table)
449
    assert table.table == copied_table.table
450
    assert table.path == copied_table.path
451
    assert_index_attributes_equal(table, copied_table)
452
    # deepcopy must return the exact same arrow objects since they are immutable
453
    assert table.table is copied_table.table
454
    assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))
455

456

457
def test_memory_mapped_table_pickle(arrow_file):
458
    table = MemoryMappedTable.from_file(arrow_file)
459
    pickled_table = pickle.dumps(table)
460
    unpickled_table = pickle.loads(pickled_table)
461
    assert unpickled_table.table == table.table
462
    assert unpickled_table.path == table.path
463
    assert_index_attributes_equal(table, unpickled_table)
464

465

466
def test_memory_mapped_table_pickle_doesnt_fill_memory(arrow_file):
467
    with assert_arrow_memory_doesnt_increase():
468
        table = MemoryMappedTable.from_file(arrow_file)
469
    assert_deepcopy_without_bringing_data_in_memory(table)
470
    assert_pickle_without_bringing_data_in_memory(table)
471

472

473
def test_memory_mapped_table_pickle_applies_replay(arrow_file):
474
    replays = [("slice", (0, 1), {}), ("flatten", (), {})]
475
    with assert_arrow_memory_doesnt_increase():
476
        table = MemoryMappedTable.from_file(arrow_file, replays=replays)
477
    assert isinstance(table, MemoryMappedTable)
478
    assert table.replays == replays
479
    assert_deepcopy_without_bringing_data_in_memory(table)
480
    assert_pickle_without_bringing_data_in_memory(table)
481

482

483
def test_memory_mapped_table_slice(arrow_file, in_memory_pa_table):
484
    table = MemoryMappedTable.from_file(arrow_file).slice(1, 2)
485
    assert table.table == in_memory_pa_table.slice(1, 2)
486
    assert isinstance(table, MemoryMappedTable)
487
    assert table.replays == [("slice", (1, 2), {})]
488
    assert_deepcopy_without_bringing_data_in_memory(table)
489
    assert_pickle_without_bringing_data_in_memory(table)
490

491

492
def test_memory_mapped_table_filter(arrow_file, in_memory_pa_table):
493
    mask = pa.array([i % 2 == 0 for i in range(len(in_memory_pa_table))])
494
    table = MemoryMappedTable.from_file(arrow_file).filter(mask)
495
    assert table.table == in_memory_pa_table.filter(mask)
496
    assert isinstance(table, MemoryMappedTable)
497
    assert table.replays == [("filter", (mask,), {})]
498
    assert_deepcopy_without_bringing_data_in_memory(table)
499
    # filter DOES increase memory
500
    # assert_pickle_without_bringing_data_in_memory(table)
501
    assert_pickle_does_bring_data_in_memory(table)
502

503

504
def test_memory_mapped_table_flatten(arrow_file, in_memory_pa_table):
505
    table = MemoryMappedTable.from_file(arrow_file).flatten()
506
    assert table.table == in_memory_pa_table.flatten()
507
    assert isinstance(table, MemoryMappedTable)
508
    assert table.replays == [("flatten", (), {})]
509
    assert_deepcopy_without_bringing_data_in_memory(table)
510
    assert_pickle_without_bringing_data_in_memory(table)
511

512

513
def test_memory_mapped_table_combine_chunks(arrow_file, in_memory_pa_table):
514
    table = MemoryMappedTable.from_file(arrow_file).combine_chunks()
515
    assert table.table == in_memory_pa_table.combine_chunks()
516
    assert isinstance(table, MemoryMappedTable)
517
    assert table.replays == [("combine_chunks", (), {})]
518
    assert_deepcopy_without_bringing_data_in_memory(table)
519
    assert_pickle_without_bringing_data_in_memory(table)
520

521

522
def test_memory_mapped_table_cast(arrow_file, in_memory_pa_table):
523
    assert pa.list_(pa.int64()) in in_memory_pa_table.schema.types
524
    schema = pa.schema(
525
        {
526
            k: v if v != pa.list_(pa.int64()) else pa.list_(pa.int32())
527
            for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
528
        }
529
    )
530
    table = MemoryMappedTable.from_file(arrow_file).cast(schema)
531
    assert table.table == in_memory_pa_table.cast(schema)
532
    assert isinstance(table, MemoryMappedTable)
533
    assert table.replays == [("cast", (schema,), {})]
534
    assert_deepcopy_without_bringing_data_in_memory(table)
535
    # cast DOES increase memory when converting integers precision for example
536
    # assert_pickle_without_bringing_data_in_memory(table)
537
    assert_pickle_does_bring_data_in_memory(table)
538

539

540
def test_memory_mapped_table_replace_schema_metadata(arrow_file, in_memory_pa_table):
541
    metadata = {"huggingface": "{}"}
542
    table = MemoryMappedTable.from_file(arrow_file).replace_schema_metadata(metadata)
543
    assert table.table.schema.metadata == in_memory_pa_table.replace_schema_metadata(metadata).schema.metadata
544
    assert isinstance(table, MemoryMappedTable)
545
    assert table.replays == [("replace_schema_metadata", (metadata,), {})]
546
    assert_deepcopy_without_bringing_data_in_memory(table)
547
    assert_pickle_without_bringing_data_in_memory(table)
548

549

550
def test_memory_mapped_table_add_column(arrow_file, in_memory_pa_table):
551
    i = len(in_memory_pa_table.column_names)
552
    field_ = "new_field"
553
    column = pa.array(list(range(len(in_memory_pa_table))))
554
    table = MemoryMappedTable.from_file(arrow_file).add_column(i, field_, column)
555
    assert table.table == in_memory_pa_table.add_column(i, field_, column)
556
    assert isinstance(table, MemoryMappedTable)
557
    assert table.replays == [("add_column", (i, field_, column), {})]
558
    assert_deepcopy_without_bringing_data_in_memory(table)
559
    assert_pickle_without_bringing_data_in_memory(table)
560

561

562
def test_memory_mapped_table_append_column(arrow_file, in_memory_pa_table):
563
    field_ = "new_field"
564
    column = pa.array(list(range(len(in_memory_pa_table))))
565
    table = MemoryMappedTable.from_file(arrow_file).append_column(field_, column)
566
    assert table.table == in_memory_pa_table.append_column(field_, column)
567
    assert isinstance(table, MemoryMappedTable)
568
    assert table.replays == [("append_column", (field_, column), {})]
569
    assert_deepcopy_without_bringing_data_in_memory(table)
570
    assert_pickle_without_bringing_data_in_memory(table)
571

572

573
def test_memory_mapped_table_remove_column(arrow_file, in_memory_pa_table):
574
    table = MemoryMappedTable.from_file(arrow_file).remove_column(0)
575
    assert table.table == in_memory_pa_table.remove_column(0)
576
    assert isinstance(table, MemoryMappedTable)
577
    assert table.replays == [("remove_column", (0,), {})]
578
    assert_deepcopy_without_bringing_data_in_memory(table)
579
    assert_pickle_without_bringing_data_in_memory(table)
580

581

582
def test_memory_mapped_table_set_column(arrow_file, in_memory_pa_table):
583
    i = len(in_memory_pa_table.column_names)
584
    field_ = "new_field"
585
    column = pa.array(list(range(len(in_memory_pa_table))))
586
    table = MemoryMappedTable.from_file(arrow_file).set_column(i, field_, column)
587
    assert table.table == in_memory_pa_table.set_column(i, field_, column)
588
    assert isinstance(table, MemoryMappedTable)
589
    assert table.replays == [("set_column", (i, field_, column), {})]
590
    assert_deepcopy_without_bringing_data_in_memory(table)
591
    assert_pickle_without_bringing_data_in_memory(table)
592

593

594
def test_memory_mapped_table_rename_columns(arrow_file, in_memory_pa_table):
595
    assert "tokens" in in_memory_pa_table.column_names
596
    names = [name if name != "tokens" else "new_tokens" for name in in_memory_pa_table.column_names]
597
    table = MemoryMappedTable.from_file(arrow_file).rename_columns(names)
598
    assert table.table == in_memory_pa_table.rename_columns(names)
599
    assert isinstance(table, MemoryMappedTable)
600
    assert table.replays == [("rename_columns", (names,), {})]
601
    assert_deepcopy_without_bringing_data_in_memory(table)
602
    assert_pickle_without_bringing_data_in_memory(table)
603

604

605
def test_memory_mapped_table_drop(arrow_file, in_memory_pa_table):
606
    names = [in_memory_pa_table.column_names[0]]
607
    table = MemoryMappedTable.from_file(arrow_file).drop(names)
608
    assert table.table == in_memory_pa_table.drop(names)
609
    assert isinstance(table, MemoryMappedTable)
610
    assert table.replays == [("drop", (names,), {})]
611
    assert_deepcopy_without_bringing_data_in_memory(table)
612
    assert_pickle_without_bringing_data_in_memory(table)
613

614

615
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
616
def test_concatenation_table_init(
617
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
618
):
619
    blocks = (
620
        in_memory_blocks
621
        if blocks_type == "in_memory"
622
        else memory_mapped_blocks
623
        if blocks_type == "memory_mapped"
624
        else mixed_in_memory_and_memory_mapped_blocks
625
    )
626
    table = ConcatenationTable(in_memory_pa_table, blocks)
627
    assert table.table == in_memory_pa_table
628
    assert table.blocks == blocks
629

630

631
def test_concatenation_table_from_blocks(in_memory_pa_table, in_memory_blocks):
632
    assert len(in_memory_pa_table) > 2
633
    in_memory_table = InMemoryTable(in_memory_pa_table)
634
    t1, t2 = in_memory_table.slice(0, 2), in_memory_table.slice(2)
635
    table = ConcatenationTable.from_blocks(in_memory_table)
636
    assert isinstance(table, ConcatenationTable)
637
    assert table.table == in_memory_pa_table
638
    assert table.blocks == [[in_memory_table]]
639
    table = ConcatenationTable.from_blocks([t1, t2])
640
    assert isinstance(table, ConcatenationTable)
641
    assert table.table == in_memory_pa_table
642
    assert table.blocks == [[in_memory_table]]
643
    table = ConcatenationTable.from_blocks([[t1], [t2]])
644
    assert isinstance(table, ConcatenationTable)
645
    assert table.table == in_memory_pa_table
646
    assert table.blocks == [[in_memory_table]]
647
    table = ConcatenationTable.from_blocks(in_memory_blocks)
648
    assert isinstance(table, ConcatenationTable)
649
    assert table.table == in_memory_pa_table
650
    assert table.blocks == [[in_memory_table]]
651

652

653
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
654
def test_concatenation_table_from_blocks_doesnt_increase_memory(
655
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
656
):
657
    blocks = {
658
        "in_memory": in_memory_blocks,
659
        "memory_mapped": memory_mapped_blocks,
660
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
661
    }[blocks_type]
662
    with assert_arrow_memory_doesnt_increase():
663
        table = ConcatenationTable.from_blocks(blocks)
664
        assert isinstance(table, ConcatenationTable)
665
        assert table.table == in_memory_pa_table
666
        if blocks_type == "in_memory":
667
            assert table.blocks == [[InMemoryTable(in_memory_pa_table)]]
668
        else:
669
            assert table.blocks == blocks
670

671

672
@pytest.mark.parametrize("axis", [0, 1])
673
def test_concatenation_table_from_tables(axis, in_memory_pa_table, arrow_file):
674
    in_memory_table = InMemoryTable(in_memory_pa_table)
675
    concatenation_table = ConcatenationTable.from_blocks(in_memory_table)
676
    memory_mapped_table = MemoryMappedTable.from_file(arrow_file)
677
    tables = [in_memory_pa_table, in_memory_table, concatenation_table, memory_mapped_table]
678
    if axis == 0:
679
        expected_table = pa.concat_tables([in_memory_pa_table] * len(tables))
680
    else:
681
        # avoids error due to duplicate column names
682
        tables[1:] = [add_suffix_to_column_names(table, i) for i, table in enumerate(tables[1:], 1)]
683
        expected_table = in_memory_pa_table
684
        for table in tables[1:]:
685
            for name, col in zip(table.column_names, table.columns):
686
                expected_table = expected_table.append_column(name, col)
687

688
    with assert_arrow_memory_doesnt_increase():
689
        table = ConcatenationTable.from_tables(tables, axis=axis)
690
    assert isinstance(table, ConcatenationTable)
691
    assert table.table == expected_table
692
    # because of consolidation, we end up with 1 InMemoryTable and 1 MemoryMappedTable
693
    assert len(table.blocks) == 1 if axis == 1 else 2
694
    assert len(table.blocks[0]) == 1 if axis == 0 else 2
695
    assert axis == 1 or len(table.blocks[1]) == 1
696
    assert isinstance(table.blocks[0][0], InMemoryTable)
697
    assert isinstance(table.blocks[1][0] if axis == 0 else table.blocks[0][1], MemoryMappedTable)
698

699

700
def test_concatenation_table_from_tables_axis1_misaligned_blocks(arrow_file):
701
    table = MemoryMappedTable.from_file(arrow_file)
702
    t1 = table.slice(0, 2)
703
    t2 = table.slice(0, 3).rename_columns([col + "_1" for col in table.column_names])
704
    concatenated = ConcatenationTable.from_tables(
705
        [
706
            ConcatenationTable.from_blocks([[t1], [t1], [t1]]),
707
            ConcatenationTable.from_blocks([[t2], [t2]]),
708
        ],
709
        axis=1,
710
    )
711
    assert len(concatenated) == 6
712
    assert [len(row_blocks[0]) for row_blocks in concatenated.blocks] == [2, 1, 1, 2]
713
    concatenated = ConcatenationTable.from_tables(
714
        [
715
            ConcatenationTable.from_blocks([[t2], [t2]]),
716
            ConcatenationTable.from_blocks([[t1], [t1], [t1]]),
717
        ],
718
        axis=1,
719
    )
720
    assert len(concatenated) == 6
721
    assert [len(row_blocks[0]) for row_blocks in concatenated.blocks] == [2, 1, 1, 2]
722

723

724
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
725
def test_concatenation_table_deepcopy(
726
    blocks_type, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
727
):
728
    blocks = {
729
        "in_memory": in_memory_blocks,
730
        "memory_mapped": memory_mapped_blocks,
731
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
732
    }[blocks_type]
733
    table = ConcatenationTable.from_blocks(blocks)
734
    copied_table = copy.deepcopy(table)
735
    assert table.table == copied_table.table
736
    assert table.blocks == copied_table.blocks
737
    assert_index_attributes_equal(table, copied_table)
738
    # deepcopy must return the exact same arrow objects since they are immutable
739
    assert table.table is copied_table.table
740
    assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))
741

742

743
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
744
def test_concatenation_table_pickle(
745
    blocks_type, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
746
):
747
    blocks = {
748
        "in_memory": in_memory_blocks,
749
        "memory_mapped": memory_mapped_blocks,
750
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
751
    }[blocks_type]
752
    table = ConcatenationTable.from_blocks(blocks)
753
    pickled_table = pickle.dumps(table)
754
    unpickled_table = pickle.loads(pickled_table)
755
    assert unpickled_table.table == table.table
756
    assert unpickled_table.blocks == table.blocks
757
    assert_index_attributes_equal(table, unpickled_table)
758

759

760
def test_concat_tables_with_features_metadata(arrow_file, in_memory_pa_table):
761
    input_features = Features.from_arrow_schema(in_memory_pa_table.schema)
762
    input_features["id"] = Value("int64", id="my_id")
763
    intput_schema = input_features.arrow_schema
764
    t0 = in_memory_pa_table.replace_schema_metadata(intput_schema.metadata)
765
    t1 = MemoryMappedTable.from_file(arrow_file)
766
    tables = [t0, t1]
767
    concatenated_table = concat_tables(tables, axis=0)
768
    output_schema = concatenated_table.schema
769
    output_features = Features.from_arrow_schema(output_schema)
770
    assert output_schema == intput_schema
771
    assert output_schema.metadata == intput_schema.metadata
772
    assert output_features == input_features
773
    assert output_features["id"].id == "my_id"
774

775

776
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
777
def test_concatenation_table_slice(
778
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
779
):
780
    blocks = {
781
        "in_memory": in_memory_blocks,
782
        "memory_mapped": memory_mapped_blocks,
783
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
784
    }[blocks_type]
785
    table = ConcatenationTable.from_blocks(blocks).slice(1, 2)
786
    assert table.table == in_memory_pa_table.slice(1, 2)
787
    assert isinstance(table, ConcatenationTable)
788

789

790
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
791
def test_concatenation_table_filter(
792
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
793
):
794
    blocks = {
795
        "in_memory": in_memory_blocks,
796
        "memory_mapped": memory_mapped_blocks,
797
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
798
    }[blocks_type]
799
    mask = pa.array([i % 2 == 0 for i in range(len(in_memory_pa_table))])
800
    table = ConcatenationTable.from_blocks(blocks).filter(mask)
801
    assert table.table == in_memory_pa_table.filter(mask)
802
    assert isinstance(table, ConcatenationTable)
803

804

805
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
806
def test_concatenation_table_flatten(
807
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
808
):
809
    blocks = {
810
        "in_memory": in_memory_blocks,
811
        "memory_mapped": memory_mapped_blocks,
812
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
813
    }[blocks_type]
814
    table = ConcatenationTable.from_blocks(blocks).flatten()
815
    assert table.table == in_memory_pa_table.flatten()
816
    assert isinstance(table, ConcatenationTable)
817

818

819
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
820
def test_concatenation_table_combine_chunks(
821
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
822
):
823
    blocks = {
824
        "in_memory": in_memory_blocks,
825
        "memory_mapped": memory_mapped_blocks,
826
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
827
    }[blocks_type]
828
    table = ConcatenationTable.from_blocks(blocks).combine_chunks()
829
    assert table.table == in_memory_pa_table.combine_chunks()
830
    assert isinstance(table, ConcatenationTable)
831

832

833
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
834
def test_concatenation_table_cast(
835
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
836
):
837
    blocks = {
838
        "in_memory": in_memory_blocks,
839
        "memory_mapped": memory_mapped_blocks,
840
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
841
    }[blocks_type]
842
    assert pa.list_(pa.int64()) in in_memory_pa_table.schema.types
843
    assert pa.int64() in in_memory_pa_table.schema.types
844
    schema = pa.schema(
845
        {
846
            k: v if v != pa.list_(pa.int64()) else pa.list_(pa.int32())
847
            for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
848
        }
849
    )
850
    table = ConcatenationTable.from_blocks(blocks).cast(schema)
851
    assert table.table == in_memory_pa_table.cast(schema)
852
    assert isinstance(table, ConcatenationTable)
853
    schema = pa.schema(
854
        {
855
            k: v if v != pa.int64() else pa.int32()
856
            for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
857
        }
858
    )
859
    table = ConcatenationTable.from_blocks(blocks).cast(schema)
860
    assert table.table == in_memory_pa_table.cast(schema)
861
    assert isinstance(table, ConcatenationTable)
862

863

864
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
865
def test_concat_tables_cast_with_features_metadata(
866
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
867
):
868
    blocks = {
869
        "in_memory": in_memory_blocks,
870
        "memory_mapped": memory_mapped_blocks,
871
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
872
    }[blocks_type]
873
    input_features = Features.from_arrow_schema(in_memory_pa_table.schema)
874
    input_features["id"] = Value("int64", id="my_id")
875
    intput_schema = input_features.arrow_schema
876
    concatenated_table = ConcatenationTable.from_blocks(blocks).cast(intput_schema)
877
    output_schema = concatenated_table.schema
878
    output_features = Features.from_arrow_schema(output_schema)
879
    assert output_schema == intput_schema
880
    assert output_schema.metadata == intput_schema.metadata
881
    assert output_features == input_features
882
    assert output_features["id"].id == "my_id"
883

884

885
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
886
def test_concatenation_table_replace_schema_metadata(
887
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
888
):
889
    blocks = {
890
        "in_memory": in_memory_blocks,
891
        "memory_mapped": memory_mapped_blocks,
892
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
893
    }[blocks_type]
894
    metadata = {"huggingface": "{}"}
895
    table = ConcatenationTable.from_blocks(blocks).replace_schema_metadata(metadata)
896
    assert table.table.schema.metadata == in_memory_pa_table.replace_schema_metadata(metadata).schema.metadata
897
    assert isinstance(table, ConcatenationTable)
898

899

900
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
901
def test_concatenation_table_add_column(
902
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
903
):
904
    blocks = {
905
        "in_memory": in_memory_blocks,
906
        "memory_mapped": memory_mapped_blocks,
907
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
908
    }[blocks_type]
909
    i = len(in_memory_pa_table.column_names)
910
    field_ = "new_field"
911
    column = pa.array(list(range(len(in_memory_pa_table))))
912
    with pytest.raises(NotImplementedError):
913
        ConcatenationTable.from_blocks(blocks).add_column(i, field_, column)
914
        # assert table.table == in_memory_pa_table.add_column(i, field_, column)
915
        # unpickled_table = pickle.loads(pickle.dumps(table))
916
        # assert unpickled_table.table == in_memory_pa_table.add_column(i, field_, column)
917

918

919
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
920
def test_concatenation_table_append_column(
921
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
922
):
923
    blocks = {
924
        "in_memory": in_memory_blocks,
925
        "memory_mapped": memory_mapped_blocks,
926
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
927
    }[blocks_type]
928
    field_ = "new_field"
929
    column = pa.array(list(range(len(in_memory_pa_table))))
930
    with pytest.raises(NotImplementedError):
931
        ConcatenationTable.from_blocks(blocks).append_column(field_, column)
932
        # assert table.table == in_memory_pa_table.append_column(field_, column)
933
        # unpickled_table = pickle.loads(pickle.dumps(table))
934
        # assert unpickled_table.table == in_memory_pa_table.append_column(field_, column)
935

936

937
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
938
def test_concatenation_table_remove_column(
939
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
940
):
941
    blocks = {
942
        "in_memory": in_memory_blocks,
943
        "memory_mapped": memory_mapped_blocks,
944
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
945
    }[blocks_type]
946
    table = ConcatenationTable.from_blocks(blocks).remove_column(0)
947
    assert table.table == in_memory_pa_table.remove_column(0)
948
    assert isinstance(table, ConcatenationTable)
949

950

951
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
952
def test_concatenation_table_set_column(
953
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
954
):
955
    blocks = {
956
        "in_memory": in_memory_blocks,
957
        "memory_mapped": memory_mapped_blocks,
958
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
959
    }[blocks_type]
960
    i = len(in_memory_pa_table.column_names)
961
    field_ = "new_field"
962
    column = pa.array(list(range(len(in_memory_pa_table))))
963
    with pytest.raises(NotImplementedError):
964
        ConcatenationTable.from_blocks(blocks).set_column(i, field_, column)
965
        # assert table.table == in_memory_pa_table.set_column(i, field_, column)
966
        # unpickled_table = pickle.loads(pickle.dumps(table))
967
        # assert unpickled_table.table == in_memory_pa_table.set_column(i, field_, column)
968

969

970
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
971
def test_concatenation_table_rename_columns(
972
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
973
):
974
    blocks = {
975
        "in_memory": in_memory_blocks,
976
        "memory_mapped": memory_mapped_blocks,
977
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
978
    }[blocks_type]
979
    assert "tokens" in in_memory_pa_table.column_names
980
    names = [name if name != "tokens" else "new_tokens" for name in in_memory_pa_table.column_names]
981
    table = ConcatenationTable.from_blocks(blocks).rename_columns(names)
982
    assert isinstance(table, ConcatenationTable)
983
    assert table.table == in_memory_pa_table.rename_columns(names)
984

985

986
@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
987
def test_concatenation_table_drop(
988
    blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
989
):
990
    blocks = {
991
        "in_memory": in_memory_blocks,
992
        "memory_mapped": memory_mapped_blocks,
993
        "mixed": mixed_in_memory_and_memory_mapped_blocks,
994
    }[blocks_type]
995
    names = [in_memory_pa_table.column_names[0]]
996
    table = ConcatenationTable.from_blocks(blocks).drop(names)
997
    assert table.table == in_memory_pa_table.drop(names)
998
    assert isinstance(table, ConcatenationTable)
999

1000

1001
def test_concat_tables(arrow_file, in_memory_pa_table):
1002
    t0 = in_memory_pa_table
1003
    t1 = InMemoryTable(t0)
1004
    t2 = MemoryMappedTable.from_file(arrow_file)
1005
    t3 = ConcatenationTable.from_blocks(t1)
1006
    tables = [t0, t1, t2, t3]
1007
    concatenated_table = concat_tables(tables, axis=0)
1008
    assert concatenated_table.table == pa.concat_tables([t0] * 4)
1009
    assert concatenated_table.table.shape == (40, 4)
1010
    assert isinstance(concatenated_table, ConcatenationTable)
1011
    assert len(concatenated_table.blocks) == 3  # t0 and t1 are consolidated as a single InMemoryTable
1012
    assert isinstance(concatenated_table.blocks[0][0], InMemoryTable)
1013
    assert isinstance(concatenated_table.blocks[1][0], MemoryMappedTable)
1014
    assert isinstance(concatenated_table.blocks[2][0], InMemoryTable)
1015
    # add suffix to avoid error due to duplicate column names
1016
    concatenated_table = concat_tables(
1017
        [add_suffix_to_column_names(table, i) for i, table in enumerate(tables)], axis=1
1018
    )
1019
    assert concatenated_table.table.shape == (10, 16)
1020
    assert len(concatenated_table.blocks[0]) == 3  # t0 and t1 are consolidated as a single InMemoryTable
1021
    assert isinstance(concatenated_table.blocks[0][0], InMemoryTable)
1022
    assert isinstance(concatenated_table.blocks[0][1], MemoryMappedTable)
1023
    assert isinstance(concatenated_table.blocks[0][2], InMemoryTable)
1024

1025

1026
def _interpolation_search_ground_truth(arr: List[int], x: int) -> Union[int, IndexError]:
1027
    for i in range(len(arr) - 1):
1028
        if arr[i] <= x < arr[i + 1]:
1029
            return i
1030
    return IndexError
1031

1032

1033
class _ListWithGetitemCounter(list):
1034
    def __init__(self, *args, **kwargs):
1035
        super().__init__(*args, **kwargs)
1036
        self.unique_getitem_calls = set()
1037

1038
    def __getitem__(self, i):
1039
        out = super().__getitem__(i)
1040
        self.unique_getitem_calls.add(i)
1041
        return out
1042

1043
    @property
1044
    def getitem_unique_count(self):
1045
        return len(self.unique_getitem_calls)
1046

1047

1048
@pytest.mark.parametrize(
1049
    "arr, x",
1050
    [(np.arange(0, 14, 3), x) for x in range(-1, 22)]
1051
    + [(list(np.arange(-5, 5)), x) for x in range(-6, 6)]
1052
    + [([0, 1_000, 1_001, 1_003], x) for x in [-1, 0, 2, 100, 999, 1_000, 1_001, 1_002, 1_003, 1_004]]
1053
    + [(list(range(1_000)), x) for x in [-1, 0, 1, 10, 666, 999, 1_000, 1_0001]],
1054
)
1055
def test_interpolation_search(arr, x):
1056
    ground_truth = _interpolation_search_ground_truth(arr, x)
1057
    if isinstance(ground_truth, int):
1058
        arr = _ListWithGetitemCounter(arr)
1059
        output = _interpolation_search(arr, x)
1060
        assert ground_truth == output
1061
        # 4 maximum unique getitem calls is expected for the cases of this test
1062
        # but it can be bigger for large and messy arrays.
1063
        assert arr.getitem_unique_count <= 4
1064
    else:
1065
        with pytest.raises(ground_truth):
1066
            _interpolation_search(arr, x)
1067

1068

1069
def test_indexed_table_mixin():
1070
    n_rows_per_chunk = 10
1071
    n_chunks = 4
1072
    pa_table = pa.Table.from_pydict({"col": [0] * n_rows_per_chunk})
1073
    pa_table = pa.concat_tables([pa_table] * n_chunks)
1074
    table = Table(pa_table)
1075
    assert all(table._offsets.tolist() == np.cumsum([0] + [n_rows_per_chunk] * n_chunks))
1076
    assert table.fast_slice(5) == pa_table.slice(5)
1077
    assert table.fast_slice(2, 13) == pa_table.slice(2, 13)
1078

1079

1080
def test_cast_array_to_features():
1081
    arr = pa.array([[0, 1]])
1082
    assert cast_array_to_feature(arr, Sequence(Value("string"))).type == pa.list_(pa.string())
1083
    with pytest.raises(TypeError):
1084
        cast_array_to_feature(arr, Sequence(Value("string")), allow_number_to_str=False)
1085

1086

1087
def test_cast_array_to_features_nested():
1088
    arr = pa.array([[{"foo": [0]}]])
1089
    assert cast_array_to_feature(arr, [{"foo": Sequence(Value("string"))}]).type == pa.list_(
1090
        pa.struct({"foo": pa.list_(pa.string())})
1091
    )
1092

1093

1094
def test_cast_array_to_features_to_nested_with_no_fields():
1095
    arr = pa.array([{}])
1096
    assert cast_array_to_feature(arr, {}).type == pa.struct({})
1097
    assert cast_array_to_feature(arr, {}).to_pylist() == arr.to_pylist()
1098

1099

1100
def test_cast_array_to_features_nested_with_nulls():
1101
    # same type
1102
    arr = pa.array([{"foo": [None, [0]]}], pa.struct({"foo": pa.list_(pa.list_(pa.int64()))}))
1103
    casted_array = cast_array_to_feature(arr, {"foo": [[Value("int64")]]})
1104
    assert casted_array.type == pa.struct({"foo": pa.list_(pa.list_(pa.int64()))})
1105
    assert casted_array.to_pylist() == arr.to_pylist()
1106
    # different type
1107
    arr = pa.array([{"foo": [None, [0]]}], pa.struct({"foo": pa.list_(pa.list_(pa.int64()))}))
1108
    casted_array = cast_array_to_feature(arr, {"foo": [[Value("int32")]]})
1109
    assert casted_array.type == pa.struct({"foo": pa.list_(pa.list_(pa.int32()))})
1110
    assert casted_array.to_pylist() == [{"foo": [None, [0]]}]
1111

1112

1113
def test_cast_array_to_features_to_null_type():
1114
    # same type
1115
    arr = pa.array([[None, None]])
1116
    assert cast_array_to_feature(arr, Sequence(Value("null"))).type == pa.list_(pa.null())
1117

1118
    # different type
1119
    arr = pa.array([[None, 1]])
1120
    with pytest.raises(TypeError):
1121
        cast_array_to_feature(arr, Sequence(Value("null")))
1122

1123

1124
def test_cast_array_to_features_array_xd():
1125
    # same storage type
1126
    arr = pa.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], pa.list_(pa.list_(pa.int32(), 2), 2))
1127
    casted_array = cast_array_to_feature(arr, Array2D(shape=(2, 2), dtype="int32"))
1128
    assert casted_array.type == Array2DExtensionType(shape=(2, 2), dtype="int32")
1129
    # different storage type
1130
    casted_array = cast_array_to_feature(arr, Array2D(shape=(2, 2), dtype="float32"))
1131
    assert casted_array.type == Array2DExtensionType(shape=(2, 2), dtype="float32")
1132

1133

1134
def test_cast_array_to_features_sequence_classlabel():
1135
    arr = pa.array([[], [1], [0, 1]], pa.list_(pa.int64()))
1136
    assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
1137

1138
    arr = pa.array([[], ["bar"], ["foo", "bar"]], pa.list_(pa.string()))
1139
    assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
1140

1141
    # Test empty arrays
1142
    arr = pa.array([[], []], pa.list_(pa.int64()))
1143
    assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
1144

1145
    arr = pa.array([[], []], pa.list_(pa.string()))
1146
    assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"]))).type == pa.list_(pa.int64())
1147

1148
    # Test invalid class labels
1149
    arr = pa.array([[2]], pa.list_(pa.int64()))
1150
    with pytest.raises(ValueError):
1151
        assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"])))
1152

1153
    arr = pa.array([["baz"]], pa.list_(pa.string()))
1154
    with pytest.raises(ValueError):
1155
        assert cast_array_to_feature(arr, Sequence(ClassLabel(names=["foo", "bar"])))
1156

1157

1158
@pytest.mark.parametrize(
1159
    "arr",
1160
    [
1161
        pa.array([[0, 1, 2], [3, None, 5], None, [6, 7, 8], None], pa.list_(pa.int32(), 3)),
1162
    ],
1163
)
1164
@pytest.mark.parametrize("slice", [None, slice(1, None), slice(-1), slice(1, 3), slice(2, 3), slice(1, 1)])
1165
@pytest.mark.parametrize("target_value_feature", [Value("int64")])
1166
def test_cast_fixed_size_list_array_to_features_sequence(arr, slice, target_value_feature):
1167
    arr = arr if slice is None else arr[slice]
1168
    # Fixed size list
1169
    casted_array = cast_array_to_feature(arr, Sequence(target_value_feature, length=arr.type.list_size))
1170
    assert casted_array.type == get_nested_type(Sequence(target_value_feature, length=arr.type.list_size))
1171
    assert casted_array.to_pylist() == arr.to_pylist()
1172
    with pytest.raises(TypeError):
1173
        cast_array_to_feature(arr, Sequence(target_value_feature, length=arr.type.list_size + 1))
1174
    # Variable size list
1175
    casted_array = cast_array_to_feature(arr, Sequence(target_value_feature))
1176
    assert casted_array.type == get_nested_type(Sequence(target_value_feature))
1177
    assert casted_array.to_pylist() == arr.to_pylist()
1178
    casted_array = cast_array_to_feature(arr, [target_value_feature])
1179
    assert casted_array.type == get_nested_type([target_value_feature])
1180
    assert casted_array.to_pylist() == arr.to_pylist()
1181

1182

1183
@pytest.mark.parametrize(
1184
    "arr",
1185
    [
1186
        pa.array([[0, 1, 2], [3, None, 5], None, [6, 7, 8], None], pa.list_(pa.int32())),
1187
    ],
1188
)
1189
@pytest.mark.parametrize("slice", [None, slice(1, None), slice(-1), slice(1, 3), slice(2, 3), slice(1, 1)])
1190
@pytest.mark.parametrize("target_value_feature", [Value("int64")])
1191
def test_cast_list_array_to_features_sequence(arr, slice, target_value_feature):
1192
    arr = arr if slice is None else arr[slice]
1193
    # Variable size list
1194
    casted_array = cast_array_to_feature(arr, Sequence(target_value_feature))
1195
    assert casted_array.type == get_nested_type(Sequence(target_value_feature))
1196
    assert casted_array.to_pylist() == arr.to_pylist()
1197
    casted_array = cast_array_to_feature(arr, [target_value_feature])
1198
    assert casted_array.type == get_nested_type([target_value_feature])
1199
    assert casted_array.to_pylist() == arr.to_pylist()
1200
    # Fixed size list
1201
    list_size = arr.value_lengths().drop_null()[0].as_py() if arr.value_lengths().drop_null() else 2
1202
    casted_array = cast_array_to_feature(arr, Sequence(target_value_feature, length=list_size))
1203
    assert casted_array.type == get_nested_type(Sequence(target_value_feature, length=list_size))
1204
    assert casted_array.to_pylist() == arr.to_pylist()
1205

1206

1207
def test_cast_array_xd_to_features_sequence():
1208
    arr = np.random.randint(0, 10, size=(8, 2, 3)).tolist()
1209
    arr = Array2DExtensionType(shape=(2, 3), dtype="int64").wrap_array(pa.array(arr, pa.list_(pa.list_(pa.int64()))))
1210
    arr = pa.ListArray.from_arrays([0, None, 4, 8], arr)
1211
    # Variable size list
1212
    casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32")))
1213
    assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32")))
1214
    assert casted_array.to_pylist() == arr.to_pylist()
1215
    # Fixed size list
1216
    casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4))
1217
    assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4))
1218
    assert casted_array.to_pylist() == arr.to_pylist()
1219

1220

1221
def test_embed_array_storage(image_file):
1222
    array = pa.array([{"bytes": None, "path": image_file}], type=Image.pa_type)
1223
    embedded_images_array = embed_array_storage(array, Image())
1224
    assert isinstance(embedded_images_array.to_pylist()[0]["path"], str)
1225
    assert embedded_images_array.to_pylist()[0]["path"] == "test_image_rgb.jpg"
1226
    assert isinstance(embedded_images_array.to_pylist()[0]["bytes"], bytes)
1227

1228

1229
def test_embed_array_storage_nested(image_file):
1230
    array = pa.array([[{"bytes": None, "path": image_file}]], type=pa.list_(Image.pa_type))
1231
    embedded_images_array = embed_array_storage(array, [Image()])
1232
    assert isinstance(embedded_images_array.to_pylist()[0][0]["path"], str)
1233
    assert isinstance(embedded_images_array.to_pylist()[0][0]["bytes"], bytes)
1234
    array = pa.array([{"foo": {"bytes": None, "path": image_file}}], type=pa.struct({"foo": Image.pa_type}))
1235
    embedded_images_array = embed_array_storage(array, {"foo": Image()})
1236
    assert isinstance(embedded_images_array.to_pylist()[0]["foo"]["path"], str)
1237
    assert isinstance(embedded_images_array.to_pylist()[0]["foo"]["bytes"], bytes)
1238

1239

1240
def test_embed_table_storage(image_file):
1241
    features = Features({"image": Image()})
1242
    table = table_cast(pa.table({"image": [image_file]}), features.arrow_schema)
1243
    embedded_images_table = embed_table_storage(table)
1244
    assert isinstance(embedded_images_table.to_pydict()["image"][0]["path"], str)
1245
    assert isinstance(embedded_images_table.to_pydict()["image"][0]["bytes"], bytes)
1246

1247

1248
@pytest.mark.parametrize(
1249
    "table",
1250
    [
1251
        InMemoryTable(pa.table({"foo": range(10)})),
1252
        InMemoryTable(pa.concat_tables([pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})])),
1253
        InMemoryTable(pa.concat_tables([pa.table({"foo": [i]}) for i in range(10)])),
1254
    ],
1255
)
1256
@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20])
1257
@pytest.mark.parametrize("drop_last_batch", [False, True])
1258
def test_table_iter(table, batch_size, drop_last_batch):
1259
    num_rows = len(table) if not drop_last_batch else len(table) // batch_size * batch_size
1260
    num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
1261
    subtables = list(table_iter(table, batch_size=batch_size, drop_last_batch=drop_last_batch))
1262
    assert len(subtables) == num_batches
1263
    if drop_last_batch:
1264
        assert all(len(subtable) == batch_size for subtable in subtables)
1265
    else:
1266
        assert all(len(subtable) == batch_size for subtable in subtables[:-1])
1267
        assert len(subtables[-1]) <= batch_size
1268
    if num_rows > 0:
1269
        reloaded = pa.concat_tables(subtables)
1270
        assert table.slice(0, num_rows).to_pydict() == reloaded.to_pydict()
1271

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.