datasets
60 строк · 1.6 Кб
1import json
2import os
3import tempfile
4
5import datasets
6from utils import generate_example_dataset, get_duration
7
8
9SPEED_TEST_N_EXAMPLES = 500_000
10
11RESULTS_BASEPATH, RESULTS_FILENAME = os.path.split(__file__)
12RESULTS_FILE_PATH = os.path.join(RESULTS_BASEPATH, "results", RESULTS_FILENAME.replace(".py", ".json"))
13
14
15@get_duration
16def select(dataset: datasets.Dataset):
17_ = dataset.select(range(0, len(dataset), 2))
18
19
20@get_duration
21def sort(dataset: datasets.Dataset):
22_ = dataset.sort("numbers")
23
24
25@get_duration
26def shuffle(dataset: datasets.Dataset):
27_ = dataset.shuffle()
28
29
30@get_duration
31def train_test_split(dataset: datasets.Dataset):
32_ = dataset.train_test_split(0.1)
33
34
35@get_duration
36def shard(dataset: datasets.Dataset, num_shards=10):
37for shard_id in range(num_shards):
38_ = dataset.shard(num_shards, shard_id)
39
40
41def benchmark_indices_mapping():
42times = {"num examples": SPEED_TEST_N_EXAMPLES}
43functions = (select, sort, shuffle, train_test_split, shard)
44with tempfile.TemporaryDirectory() as tmp_dir:
45print("generating dataset")
46features = datasets.Features({"text": datasets.Value("string"), "numbers": datasets.Value("float32")})
47dataset = generate_example_dataset(
48os.path.join(tmp_dir, "dataset.arrow"), features, num_examples=SPEED_TEST_N_EXAMPLES
49)
50print("Functions")
51for func in functions:
52print(func.__name__)
53times[func.__name__] = func(dataset)
54
55with open(RESULTS_FILE_PATH, "wb") as f:
56f.write(json.dumps(times).encode("utf-8"))
57
58
59if __name__ == "__main__": # useful to run the profiler
60benchmark_indices_mapping()
61