datasets

Форк
0
578 строк · 17.3 Кб
1
import contextlib
2
import csv
3
import json
4
import os
5
import sqlite3
6
import tarfile
7
import textwrap
8
import zipfile
9

10
import pandas as pd
11
import pyarrow as pa
12
import pyarrow.parquet as pq
13
import pytest
14

15
import datasets
16
import datasets.config
17

18

19
# dataset + arrow_file
20

21

22
@pytest.fixture(scope="session")
23
def dataset():
24
    n = 10
25
    features = datasets.Features(
26
        {
27
            "tokens": datasets.Sequence(datasets.Value("string")),
28
            "labels": datasets.Sequence(datasets.ClassLabel(names=["negative", "positive"])),
29
            "answers": datasets.Sequence(
30
                {
31
                    "text": datasets.Value("string"),
32
                    "answer_start": datasets.Value("int32"),
33
                }
34
            ),
35
            "id": datasets.Value("int64"),
36
        }
37
    )
38
    dataset = datasets.Dataset.from_dict(
39
        {
40
            "tokens": [["foo"] * 5] * n,
41
            "labels": [[1] * 5] * n,
42
            "answers": [{"answer_start": [97], "text": ["1976"]}] * 10,
43
            "id": list(range(n)),
44
        },
45
        features=features,
46
    )
47
    return dataset
48

49

50
@pytest.fixture(scope="session")
51
def arrow_file(tmp_path_factory, dataset):
52
    filename = str(tmp_path_factory.mktemp("data") / "file.arrow")
53
    dataset.map(cache_file_name=filename)
54
    return filename
55

56

57
# FILE_CONTENT + files
58

59

60
FILE_CONTENT = """\
61
    Text data.
62
    Second line of data."""
63

64

65
@pytest.fixture(scope="session")
66
def text_file(tmp_path_factory):
67
    filename = tmp_path_factory.mktemp("data") / "file.txt"
68
    data = FILE_CONTENT
69
    with open(filename, "w") as f:
70
        f.write(data)
71
    return filename
72

73

74
@pytest.fixture(scope="session")
75
def bz2_file(tmp_path_factory):
76
    import bz2
77

78
    path = tmp_path_factory.mktemp("data") / "file.txt.bz2"
79
    data = bytes(FILE_CONTENT, "utf-8")
80
    with bz2.open(path, "wb") as f:
81
        f.write(data)
82
    return path
83

84

85
@pytest.fixture(scope="session")
86
def gz_file(tmp_path_factory):
87
    import gzip
88

89
    path = str(tmp_path_factory.mktemp("data") / "file.txt.gz")
90
    data = bytes(FILE_CONTENT, "utf-8")
91
    with gzip.open(path, "wb") as f:
92
        f.write(data)
93
    return path
94

95

96
@pytest.fixture(scope="session")
97
def lz4_file(tmp_path_factory):
98
    if datasets.config.LZ4_AVAILABLE:
99
        import lz4.frame
100

101
        path = tmp_path_factory.mktemp("data") / "file.txt.lz4"
102
        data = bytes(FILE_CONTENT, "utf-8")
103
        with lz4.frame.open(path, "wb") as f:
104
            f.write(data)
105
        return path
106

107

108
@pytest.fixture(scope="session")
109
def seven_zip_file(tmp_path_factory, text_file):
110
    if datasets.config.PY7ZR_AVAILABLE:
111
        import py7zr
112

113
        path = tmp_path_factory.mktemp("data") / "file.txt.7z"
114
        with py7zr.SevenZipFile(path, "w") as archive:
115
            archive.write(text_file, arcname=os.path.basename(text_file))
116
        return path
117

118

119
@pytest.fixture(scope="session")
120
def tar_file(tmp_path_factory, text_file):
121
    import tarfile
122

123
    path = tmp_path_factory.mktemp("data") / "file.txt.tar"
124
    with tarfile.TarFile(path, "w") as f:
125
        f.add(text_file, arcname=os.path.basename(text_file))
126
    return path
127

128

129
@pytest.fixture(scope="session")
130
def xz_file(tmp_path_factory):
131
    import lzma
132

133
    path = tmp_path_factory.mktemp("data") / "file.txt.xz"
134
    data = bytes(FILE_CONTENT, "utf-8")
135
    with lzma.open(path, "wb") as f:
136
        f.write(data)
137
    return path
138

139

140
@pytest.fixture(scope="session")
141
def zip_file(tmp_path_factory, text_file):
142
    import zipfile
143

144
    path = tmp_path_factory.mktemp("data") / "file.txt.zip"
145
    with zipfile.ZipFile(path, "w") as f:
146
        f.write(text_file, arcname=os.path.basename(text_file))
147
    return path
148

149

150
@pytest.fixture(scope="session")
151
def zstd_file(tmp_path_factory):
152
    if datasets.config.ZSTANDARD_AVAILABLE:
153
        import zstandard as zstd
154

155
        path = tmp_path_factory.mktemp("data") / "file.txt.zst"
156
        data = bytes(FILE_CONTENT, "utf-8")
157
        with zstd.open(path, "wb") as f:
158
            f.write(data)
159
        return path
160

161

162
# xml_file
163

164

165
@pytest.fixture(scope="session")
166
def xml_file(tmp_path_factory):
167
    filename = tmp_path_factory.mktemp("data") / "file.xml"
168
    data = textwrap.dedent(
169
        """\
170
    <?xml version="1.0" encoding="UTF-8" ?>
171
    <tmx version="1.4">
172
      <header segtype="sentence" srclang="ca" />
173
      <body>
174
        <tu>
175
          <tuv xml:lang="ca"><seg>Contingut 1</seg></tuv>
176
          <tuv xml:lang="en"><seg>Content 1</seg></tuv>
177
        </tu>
178
        <tu>
179
          <tuv xml:lang="ca"><seg>Contingut 2</seg></tuv>
180
          <tuv xml:lang="en"><seg>Content 2</seg></tuv>
181
        </tu>
182
        <tu>
183
          <tuv xml:lang="ca"><seg>Contingut 3</seg></tuv>
184
          <tuv xml:lang="en"><seg>Content 3</seg></tuv>
185
        </tu>
186
        <tu>
187
          <tuv xml:lang="ca"><seg>Contingut 4</seg></tuv>
188
          <tuv xml:lang="en"><seg>Content 4</seg></tuv>
189
        </tu>
190
        <tu>
191
          <tuv xml:lang="ca"><seg>Contingut 5</seg></tuv>
192
          <tuv xml:lang="en"><seg>Content 5</seg></tuv>
193
        </tu>
194
      </body>
195
    </tmx>"""
196
    )
197
    with open(filename, "w") as f:
198
        f.write(data)
199
    return filename
200

201

202
DATA = [
203
    {"col_1": "0", "col_2": 0, "col_3": 0.0},
204
    {"col_1": "1", "col_2": 1, "col_3": 1.0},
205
    {"col_1": "2", "col_2": 2, "col_3": 2.0},
206
    {"col_1": "3", "col_2": 3, "col_3": 3.0},
207
]
208
DATA2 = [
209
    {"col_1": "4", "col_2": 4, "col_3": 4.0},
210
    {"col_1": "5", "col_2": 5, "col_3": 5.0},
211
]
212
DATA_DICT_OF_LISTS = {
213
    "col_1": ["0", "1", "2", "3"],
214
    "col_2": [0, 1, 2, 3],
215
    "col_3": [0.0, 1.0, 2.0, 3.0],
216
}
217

218
DATA_312 = [
219
    {"col_3": 0.0, "col_1": "0", "col_2": 0},
220
    {"col_3": 1.0, "col_1": "1", "col_2": 1},
221
]
222

223
DATA_STR = [
224
    {"col_1": "s0", "col_2": 0, "col_3": 0.0},
225
    {"col_1": "s1", "col_2": 1, "col_3": 1.0},
226
    {"col_1": "s2", "col_2": 2, "col_3": 2.0},
227
    {"col_1": "s3", "col_2": 3, "col_3": 3.0},
228
]
229

230

231
@pytest.fixture(scope="session")
232
def dataset_dict():
233
    return DATA_DICT_OF_LISTS
234

235

236
@pytest.fixture(scope="session")
237
def arrow_path(tmp_path_factory):
238
    dataset = datasets.Dataset.from_dict(DATA_DICT_OF_LISTS)
239
    path = str(tmp_path_factory.mktemp("data") / "dataset.arrow")
240
    dataset.map(cache_file_name=path)
241
    return path
242

243

244
@pytest.fixture(scope="session")
245
def sqlite_path(tmp_path_factory):
246
    path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite")
247
    with contextlib.closing(sqlite3.connect(path)) as con:
248
        cur = con.cursor()
249
        cur.execute("CREATE TABLE dataset(col_1 text, col_2 int, col_3 real)")
250
        for item in DATA:
251
            cur.execute("INSERT INTO dataset(col_1, col_2, col_3) VALUES (?, ?, ?)", tuple(item.values()))
252
        con.commit()
253
    return path
254

255

256
@pytest.fixture(scope="session")
257
def csv_path(tmp_path_factory):
258
    path = str(tmp_path_factory.mktemp("data") / "dataset.csv")
259
    with open(path, "w", newline="") as f:
260
        writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"])
261
        writer.writeheader()
262
        for item in DATA:
263
            writer.writerow(item)
264
    return path
265

266

267
@pytest.fixture(scope="session")
268
def csv2_path(tmp_path_factory):
269
    path = str(tmp_path_factory.mktemp("data") / "dataset2.csv")
270
    with open(path, "w", newline="") as f:
271
        writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"])
272
        writer.writeheader()
273
        for item in DATA:
274
            writer.writerow(item)
275
    return path
276

277

278
@pytest.fixture(scope="session")
279
def bz2_csv_path(csv_path, tmp_path_factory):
280
    import bz2
281

282
    path = tmp_path_factory.mktemp("data") / "dataset.csv.bz2"
283
    with open(csv_path, "rb") as f:
284
        data = f.read()
285
    # data = bytes(FILE_CONTENT, "utf-8")
286
    with bz2.open(path, "wb") as f:
287
        f.write(data)
288
    return path
289

290

291
@pytest.fixture(scope="session")
292
def zip_csv_path(csv_path, csv2_path, tmp_path_factory):
293
    path = tmp_path_factory.mktemp("zip_csv_path") / "csv-dataset.zip"
294
    with zipfile.ZipFile(path, "w") as f:
295
        f.write(csv_path, arcname=os.path.basename(csv_path))
296
        f.write(csv2_path, arcname=os.path.basename(csv2_path))
297
    return path
298

299

300
@pytest.fixture(scope="session")
301
def zip_uppercase_csv_path(csv_path, csv2_path, tmp_path_factory):
302
    path = tmp_path_factory.mktemp("data") / "dataset.csv.zip"
303
    with zipfile.ZipFile(path, "w") as f:
304
        f.write(csv_path, arcname=os.path.basename(csv_path.replace(".csv", ".CSV")))
305
        f.write(csv2_path, arcname=os.path.basename(csv2_path.replace(".csv", ".CSV")))
306
    return path
307

308

309
@pytest.fixture(scope="session")
310
def zip_csv_with_dir_path(csv_path, csv2_path, tmp_path_factory):
311
    path = tmp_path_factory.mktemp("data") / "dataset_with_dir.csv.zip"
312
    with zipfile.ZipFile(path, "w") as f:
313
        f.write(csv_path, arcname=os.path.join("main_dir", os.path.basename(csv_path)))
314
        f.write(csv2_path, arcname=os.path.join("main_dir", os.path.basename(csv2_path)))
315
    return path
316

317

318
@pytest.fixture(scope="session")
319
def parquet_path(tmp_path_factory):
320
    path = str(tmp_path_factory.mktemp("data") / "dataset.parquet")
321
    schema = pa.schema(
322
        {
323
            "col_1": pa.string(),
324
            "col_2": pa.int64(),
325
            "col_3": pa.float64(),
326
        }
327
    )
328
    with open(path, "wb") as f:
329
        writer = pq.ParquetWriter(f, schema=schema)
330
        pa_table = pa.Table.from_pydict({k: [DATA[i][k] for i in range(len(DATA))] for k in DATA[0]}, schema=schema)
331
        writer.write_table(pa_table)
332
        writer.close()
333
    return path
334

335

336
@pytest.fixture(scope="session")
337
def geoparquet_path(tmp_path_factory):
338
    df = pd.read_parquet(path="https://github.com/opengeospatial/geoparquet/raw/v1.0.0/examples/example.parquet")
339
    path = str(tmp_path_factory.mktemp("data") / "dataset.geoparquet")
340
    df.to_parquet(path=path)
341
    return path
342

343

344
@pytest.fixture(scope="session")
345
def json_list_of_dicts_path(tmp_path_factory):
346
    path = str(tmp_path_factory.mktemp("data") / "dataset.json")
347
    data = {"data": DATA}
348
    with open(path, "w") as f:
349
        json.dump(data, f)
350
    return path
351

352

353
@pytest.fixture(scope="session")
354
def json_dict_of_lists_path(tmp_path_factory):
355
    path = str(tmp_path_factory.mktemp("data") / "dataset.json")
356
    data = {"data": DATA_DICT_OF_LISTS}
357
    with open(path, "w") as f:
358
        json.dump(data, f)
359
    return path
360

361

362
@pytest.fixture(scope="session")
363
def jsonl_path(tmp_path_factory):
364
    path = str(tmp_path_factory.mktemp("data") / "dataset.jsonl")
365
    with open(path, "w") as f:
366
        for item in DATA:
367
            f.write(json.dumps(item) + "\n")
368
    return path
369

370

371
@pytest.fixture(scope="session")
372
def jsonl2_path(tmp_path_factory):
373
    path = str(tmp_path_factory.mktemp("data") / "dataset2.jsonl")
374
    with open(path, "w") as f:
375
        for item in DATA:
376
            f.write(json.dumps(item) + "\n")
377
    return path
378

379

380
@pytest.fixture(scope="session")
381
def jsonl_312_path(tmp_path_factory):
382
    path = str(tmp_path_factory.mktemp("data") / "dataset_312.jsonl")
383
    with open(path, "w") as f:
384
        for item in DATA_312:
385
            f.write(json.dumps(item) + "\n")
386
    return path
387

388

389
@pytest.fixture(scope="session")
390
def jsonl_str_path(tmp_path_factory):
391
    path = str(tmp_path_factory.mktemp("data") / "dataset-str.jsonl")
392
    with open(path, "w") as f:
393
        for item in DATA_STR:
394
            f.write(json.dumps(item) + "\n")
395
    return path
396

397

398
@pytest.fixture(scope="session")
399
def text_gz_path(tmp_path_factory, text_path):
400
    import gzip
401

402
    path = str(tmp_path_factory.mktemp("data") / "dataset.txt.gz")
403
    with open(text_path, "rb") as orig_file:
404
        with gzip.open(path, "wb") as zipped_file:
405
            zipped_file.writelines(orig_file)
406
    return path
407

408

409
@pytest.fixture(scope="session")
410
def jsonl_gz_path(tmp_path_factory, jsonl_path):
411
    import gzip
412

413
    path = str(tmp_path_factory.mktemp("data") / "dataset.jsonl.gz")
414
    with open(jsonl_path, "rb") as orig_file:
415
        with gzip.open(path, "wb") as zipped_file:
416
            zipped_file.writelines(orig_file)
417
    return path
418

419

420
@pytest.fixture(scope="session")
421
def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):
422
    path = tmp_path_factory.mktemp("data") / "dataset.jsonl.zip"
423
    with zipfile.ZipFile(path, "w") as f:
424
        f.write(jsonl_path, arcname=os.path.basename(jsonl_path))
425
        f.write(jsonl2_path, arcname=os.path.basename(jsonl2_path))
426
    return path
427

428

429
@pytest.fixture(scope="session")
430
def zip_nested_jsonl_path(zip_jsonl_path, jsonl_path, jsonl2_path, tmp_path_factory):
431
    path = tmp_path_factory.mktemp("data") / "dataset_nested.jsonl.zip"
432
    with zipfile.ZipFile(path, "w") as f:
433
        f.write(zip_jsonl_path, arcname=os.path.join("nested", os.path.basename(zip_jsonl_path)))
434
    return path
435

436

437
@pytest.fixture(scope="session")
438
def zip_jsonl_with_dir_path(jsonl_path, jsonl2_path, tmp_path_factory):
439
    path = tmp_path_factory.mktemp("data") / "dataset_with_dir.jsonl.zip"
440
    with zipfile.ZipFile(path, "w") as f:
441
        f.write(jsonl_path, arcname=os.path.join("main_dir", os.path.basename(jsonl_path)))
442
        f.write(jsonl2_path, arcname=os.path.join("main_dir", os.path.basename(jsonl2_path)))
443
    return path
444

445

446
@pytest.fixture(scope="session")
447
def tar_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):
448
    path = tmp_path_factory.mktemp("data") / "dataset.jsonl.tar"
449
    with tarfile.TarFile(path, "w") as f:
450
        f.add(jsonl_path, arcname=os.path.basename(jsonl_path))
451
        f.add(jsonl2_path, arcname=os.path.basename(jsonl2_path))
452
    return path
453

454

455
@pytest.fixture(scope="session")
456
def tar_nested_jsonl_path(tar_jsonl_path, jsonl_path, jsonl2_path, tmp_path_factory):
457
    path = tmp_path_factory.mktemp("data") / "dataset_nested.jsonl.tar"
458
    with tarfile.TarFile(path, "w") as f:
459
        f.add(tar_jsonl_path, arcname=os.path.join("nested", os.path.basename(tar_jsonl_path)))
460
    return path
461

462

463
@pytest.fixture(scope="session")
464
def text_path(tmp_path_factory):
465
    data = ["0", "1", "2", "3"]
466
    path = str(tmp_path_factory.mktemp("data") / "dataset.txt")
467
    with open(path, "w") as f:
468
        for item in data:
469
            f.write(item + "\n")
470
    return path
471

472

473
@pytest.fixture(scope="session")
474
def text2_path(tmp_path_factory):
475
    data = ["0", "1", "2", "3"]
476
    path = str(tmp_path_factory.mktemp("data") / "dataset2.txt")
477
    with open(path, "w") as f:
478
        for item in data:
479
            f.write(item + "\n")
480
    return path
481

482

483
@pytest.fixture(scope="session")
484
def text_dir(tmp_path_factory):
485
    data = ["0", "1", "2", "3"]
486
    path = tmp_path_factory.mktemp("data_text_dir") / "dataset.txt"
487
    with open(path, "w") as f:
488
        for item in data:
489
            f.write(item + "\n")
490
    return path.parent
491

492

493
@pytest.fixture(scope="session")
494
def text_dir_with_unsupported_extension(tmp_path_factory):
495
    data = ["0", "1", "2", "3"]
496
    path = tmp_path_factory.mktemp("data") / "dataset.abc"
497
    with open(path, "w") as f:
498
        for item in data:
499
            f.write(item + "\n")
500
    return path
501

502

503
@pytest.fixture(scope="session")
504
def zip_text_path(text_path, text2_path, tmp_path_factory):
505
    path = tmp_path_factory.mktemp("data") / "dataset.text.zip"
506
    with zipfile.ZipFile(path, "w") as f:
507
        f.write(text_path, arcname=os.path.basename(text_path))
508
        f.write(text2_path, arcname=os.path.basename(text2_path))
509
    return path
510

511

512
@pytest.fixture(scope="session")
513
def zip_text_with_dir_path(text_path, text2_path, tmp_path_factory):
514
    path = tmp_path_factory.mktemp("data") / "dataset_with_dir.text.zip"
515
    with zipfile.ZipFile(path, "w") as f:
516
        f.write(text_path, arcname=os.path.join("main_dir", os.path.basename(text_path)))
517
        f.write(text2_path, arcname=os.path.join("main_dir", os.path.basename(text2_path)))
518
    return path
519

520

521
@pytest.fixture(scope="session")
522
def zip_unsupported_ext_path(text_path, text2_path, tmp_path_factory):
523
    path = tmp_path_factory.mktemp("data") / "dataset.ext.zip"
524
    with zipfile.ZipFile(path, "w") as f:
525
        f.write(text_path, arcname=os.path.basename("unsupported.ext"))
526
        f.write(text2_path, arcname=os.path.basename("unsupported_2.ext"))
527
    return path
528

529

530
@pytest.fixture(scope="session")
531
def text_path_with_unicode_new_lines(tmp_path_factory):
532
    text = "\n".join(["First", "Second\u2029with Unicode new line", "Third"])
533
    path = str(tmp_path_factory.mktemp("data") / "dataset_with_unicode_new_lines.txt")
534
    with open(path, "w", encoding="utf-8") as f:
535
        f.write(text)
536
    return path
537

538

539
@pytest.fixture(scope="session")
540
def image_file():
541
    return os.path.join("tests", "features", "data", "test_image_rgb.jpg")
542

543

544
@pytest.fixture(scope="session")
545
def audio_file():
546
    return os.path.join("tests", "features", "data", "test_audio_44100.wav")
547

548

549
@pytest.fixture(scope="session")
550
def zip_image_path(image_file, tmp_path_factory):
551
    path = tmp_path_factory.mktemp("data") / "dataset.img.zip"
552
    with zipfile.ZipFile(path, "w") as f:
553
        f.write(image_file, arcname=os.path.basename(image_file))
554
        f.write(image_file, arcname=os.path.basename(image_file).replace(".jpg", "2.jpg"))
555
    return path
556

557

558
@pytest.fixture(scope="session")
559
def data_dir_with_hidden_files(tmp_path_factory):
560
    data_dir = tmp_path_factory.mktemp("data_dir")
561

562
    (data_dir / "subdir").mkdir()
563
    with open(data_dir / "subdir" / "train.txt", "w") as f:
564
        f.write("foo\n" * 10)
565
    with open(data_dir / "subdir" / "test.txt", "w") as f:
566
        f.write("bar\n" * 10)
567
    # hidden file
568
    with open(data_dir / "subdir" / ".test.txt", "w") as f:
569
        f.write("bar\n" * 10)
570

571
    # hidden directory
572
    (data_dir / ".subdir").mkdir()
573
    with open(data_dir / ".subdir" / "train.txt", "w") as f:
574
        f.write("foo\n" * 10)
575
    with open(data_dir / ".subdir" / "test.txt", "w") as f:
576
        f.write("bar\n" * 10)
577

578
    return data_dir
579

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.