datasets

Форк
0
/
benchmark_getitem_100B.py 
78 строк · 2.1 Кб
1
import json
2
import os
3
from dataclasses import dataclass
4

5
import numpy as np
6
import pyarrow as pa
7

8
import datasets
9
from utils import get_duration
10

11

12
SPEED_TEST_N_EXAMPLES = 100_000_000_000
13
SPEED_TEST_CHUNK_SIZE = 10_000
14

15
RESULTS_BASEPATH, RESULTS_FILENAME = os.path.split(__file__)
16
RESULTS_FILE_PATH = os.path.join(RESULTS_BASEPATH, "results", RESULTS_FILENAME.replace(".py", ".json"))
17

18

19
def generate_100B_dataset(num_examples: int, chunk_size: int) -> datasets.Dataset:
20
    table = pa.Table.from_pydict({"col": [0] * chunk_size})
21
    table = pa.concat_tables([table] * (num_examples // chunk_size))
22
    return datasets.Dataset(table, fingerprint="table_100B")
23

24

25
@dataclass
26
class RandIter:
27
    low: int
28
    high: int
29
    size: int
30
    seed: int
31

32
    def __post_init__(self):
33
        rng = np.random.default_rng(self.seed)
34
        self._sampled_values = rng.integers(low=self.low, high=self.high, size=self.size).tolist()
35

36
    def __iter__(self):
37
        return iter(self._sampled_values)
38

39
    def __len__(self):
40
        return self.size
41

42

43
@get_duration
44
def get_first_row(dataset: datasets.Dataset):
45
    _ = dataset[0]
46

47

48
@get_duration
49
def get_last_row(dataset: datasets.Dataset):
50
    _ = dataset[-1]
51

52

53
@get_duration
54
def get_batch_of_1024_rows(dataset: datasets.Dataset):
55
    _ = dataset[range(len(dataset) // 2, len(dataset) // 2 + 1024)]
56

57

58
@get_duration
59
def get_batch_of_1024_random_rows(dataset: datasets.Dataset):
60
    _ = dataset[RandIter(0, len(dataset), 1024, seed=42)]
61

62

63
def benchmark_table_100B():
64
    times = {"num examples": SPEED_TEST_N_EXAMPLES}
65
    functions = (get_first_row, get_last_row, get_batch_of_1024_rows, get_batch_of_1024_random_rows)
66
    print("generating dataset")
67
    dataset = generate_100B_dataset(num_examples=SPEED_TEST_N_EXAMPLES, chunk_size=SPEED_TEST_CHUNK_SIZE)
68
    print("Functions")
69
    for func in functions:
70
        print(func.__name__)
71
        times[func.__name__] = func(dataset)
72

73
    with open(RESULTS_FILE_PATH, "wb") as f:
74
        f.write(json.dumps(times).encode("utf-8"))
75

76

77
if __name__ == "__main__":  # useful to run the profiler
78
    benchmark_table_100B()
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.