llama-index

Форк
0
/
bench_embeddings.py 
138 строк · 4.4 Кб
1
import time
2
from functools import partial
3
from typing import Callable, List, Optional, Tuple
4

5
import pandas as pd
6
from llama_index.core import SimpleDirectoryReader
7
from llama_index.core.base.embeddings.base import (
8
    DEFAULT_EMBED_BATCH_SIZE,
9
    BaseEmbedding,
10
)
11
from llama_index.embeddings import OpenAIEmbedding, resolve_embed_model
12

13

14
def generate_strings(num_strings: int = 100, string_length: int = 10) -> List[str]:
15
    """
16
    Generate random strings sliced from the paul graham essay of the following form:
17

18
    offset 0: [0:string_length], [string_length:2*string_length], ...
19
    offset 1: [1:1+string_length], [1+string_length:1+2*string_length],...
20
    ...
21
    """  # noqa: D415
22
    content = (
23
        SimpleDirectoryReader("../../examples/paul_graham_essay/data")
24
        .load_data()[0]
25
        .get_content()
26
    )
27
    content_length = len(content)
28

29
    strings_per_loop = content_length / string_length
30
    num_loops_upper_bound = int(num_strings / strings_per_loop) + 1
31
    strings = []
32

33
    for offset in range(num_loops_upper_bound + 1):
34
        ptr = offset % string_length
35
        while ptr + string_length < content_length:
36
            strings.append(content[ptr : ptr + string_length])
37
            ptr += string_length
38
            if len(strings) == num_strings:
39
                break
40

41
    return strings
42

43

44
def create_open_ai_embedding(batch_size: int) -> Tuple[BaseEmbedding, str, int]:
45
    return (
46
        OpenAIEmbedding(embed_batch_size=batch_size),
47
        "OpenAIEmbedding",
48
        4096,
49
    )
50

51

52
def create_local_embedding(
53
    model_name: str, batch_size: int
54
) -> Tuple[BaseEmbedding, str, int]:
55
    model = resolve_embed_model(f"local:{model_name}")
56
    return (
57
        model,
58
        "hf/" + model_name,
59
        model._langchain_embedding.client.max_seq_length,  # type: ignore
60
    )
61

62

63
def bench_simple_vector_store(
64
    embed_models: List[Callable[[int], Tuple[BaseEmbedding, str, int]]],
65
    num_strings: List[int] = [100],
66
    string_lengths: List[int] = [64, 256],
67
    embed_batch_sizes: List[int] = [1, DEFAULT_EMBED_BATCH_SIZE],
68
    torch_num_threads: Optional[int] = None,
69
) -> None:
70
    """Benchmark embeddings."""
71
    print("Benchmarking Embeddings\n---------------------------")
72

73
    results = []
74

75
    if torch_num_threads is not None:
76
        import torch  # pants: no-infer-dep
77

78
        torch.set_num_threads(torch_num_threads)
79

80
    max_num_strings = max(num_strings)
81
    for string_length in string_lengths:
82
        generated_strings = generate_strings(
83
            num_strings=max_num_strings, string_length=string_length
84
        )
85

86
        for string_count in num_strings:
87
            strings = generated_strings[:string_count]
88

89
            for batch_size in embed_batch_sizes:
90
                models = []
91
                for create_model in embed_models:
92
                    models.append(create_model(batch_size=batch_size))  # type: ignore
93

94
                for model in models:
95
                    time1 = time.time()
96
                    _ = model[0].get_text_embedding_batch(strings, show_progress=True)
97

98
                    time2 = time.time()
99
                    print(
100
                        f"Embedding with model {model[1]} with "
101
                        f"batch size {batch_size} and max_seq_length {model[2]} for "
102
                        f"{string_count} strings of length {string_length} took "
103
                        f"{time2 - time1} seconds."
104
                    )
105
                    results.append((model[1], batch_size, string_length, time2 - time1))
106
                # TODO: async version
107

108
    # print final results
109
    print("\n\nFinal Results\n---------------------------")
110
    results_df = pd.DataFrame(
111
        results, columns=["model", "batch_size", "string_length", "time"]
112
    )
113
    print(results_df)
114

115

116
if __name__ == "__main__":
117
    bench_simple_vector_store(
118
        embed_models=[
119
            # create_open_ai_embedding,
120
            partial(
121
                create_local_embedding,
122
                model_name="sentence-transformers/all-MiniLM-L6-v2",
123
            ),
124
            partial(
125
                create_local_embedding,
126
                model_name="sentence-transformers/all-MiniLM-L12-v2",
127
            ),
128
            partial(
129
                create_local_embedding,
130
                model_name="BAAI/bge-small-en",
131
            ),
132
            partial(
133
                create_local_embedding,
134
                model_name="sentence-transformers/all-mpnet-base-v2",
135
            ),
136
        ],
137
        torch_num_threads=None,
138
    )
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.