datasets
78 строк · 2.1 Кб
1import json
2import os
3from dataclasses import dataclass
4
5import numpy as np
6import pyarrow as pa
7
8import datasets
9from utils import get_duration
10
11
12SPEED_TEST_N_EXAMPLES = 100_000_000_000
13SPEED_TEST_CHUNK_SIZE = 10_000
14
15RESULTS_BASEPATH, RESULTS_FILENAME = os.path.split(__file__)
16RESULTS_FILE_PATH = os.path.join(RESULTS_BASEPATH, "results", RESULTS_FILENAME.replace(".py", ".json"))
17
18
19def generate_100B_dataset(num_examples: int, chunk_size: int) -> datasets.Dataset:
20table = pa.Table.from_pydict({"col": [0] * chunk_size})
21table = pa.concat_tables([table] * (num_examples // chunk_size))
22return datasets.Dataset(table, fingerprint="table_100B")
23
24
25@dataclass
26class RandIter:
27low: int
28high: int
29size: int
30seed: int
31
32def __post_init__(self):
33rng = np.random.default_rng(self.seed)
34self._sampled_values = rng.integers(low=self.low, high=self.high, size=self.size).tolist()
35
36def __iter__(self):
37return iter(self._sampled_values)
38
39def __len__(self):
40return self.size
41
42
43@get_duration
44def get_first_row(dataset: datasets.Dataset):
45_ = dataset[0]
46
47
48@get_duration
49def get_last_row(dataset: datasets.Dataset):
50_ = dataset[-1]
51
52
53@get_duration
54def get_batch_of_1024_rows(dataset: datasets.Dataset):
55_ = dataset[range(len(dataset) // 2, len(dataset) // 2 + 1024)]
56
57
58@get_duration
59def get_batch_of_1024_random_rows(dataset: datasets.Dataset):
60_ = dataset[RandIter(0, len(dataset), 1024, seed=42)]
61
62
63def benchmark_table_100B():
64times = {"num examples": SPEED_TEST_N_EXAMPLES}
65functions = (get_first_row, get_last_row, get_batch_of_1024_rows, get_batch_of_1024_random_rows)
66print("generating dataset")
67dataset = generate_100B_dataset(num_examples=SPEED_TEST_N_EXAMPLES, chunk_size=SPEED_TEST_CHUNK_SIZE)
68print("Functions")
69for func in functions:
70print(func.__name__)
71times[func.__name__] = func(dataset)
72
73with open(RESULTS_FILE_PATH, "wb") as f:
74f.write(json.dumps(times).encode("utf-8"))
75
76
77if __name__ == "__main__": # useful to run the profiler
78benchmark_table_100B()
79