llama-index

Форк
0
164 строки · 5.8 Кб
1
from typing import Any, Callable, List, Protocol, Tuple, runtime_checkable
2

3
from llama_index.legacy.vector_stores.types import VectorStoreQueryResult
4

5
SparseEncoderCallable = Callable[[List[str]], Tuple[List[List[int]], List[List[float]]]]
6

7

8
@runtime_checkable
9
class HybridFusionCallable(Protocol):
10
    """Hybrid fusion callable protocol."""
11

12
    def __call__(
13
        self,
14
        dense_result: VectorStoreQueryResult,
15
        sparse_result: VectorStoreQueryResult,
16
        **kwargs: Any,
17
    ) -> VectorStoreQueryResult:
18
        """Hybrid fusion callable."""
19
        ...
20

21

22
def default_sparse_encoder(model_id: str) -> SparseEncoderCallable:
23
    try:
24
        import torch
25
        from transformers import AutoModelForMaskedLM, AutoTokenizer
26
    except ImportError:
27
        raise ImportError(
28
            "Could not import transformers library. "
29
            'Please install transformers with `pip install "transformers[torch]"`'
30
        )
31

32
    tokenizer = AutoTokenizer.from_pretrained(model_id)
33
    model = AutoModelForMaskedLM.from_pretrained(model_id)
34
    if torch.cuda.is_available():
35
        model = model.to("cuda")
36

37
    def compute_vectors(texts: List[str]) -> Tuple[List[List[int]], List[List[float]]]:
38
        """
39
        Computes vectors from logits and attention mask using ReLU, log, and max operations.
40
        """
41
        # TODO: compute sparse vectors in batches if max length is exceeded
42
        tokens = tokenizer(
43
            texts, truncation=True, padding=True, max_length=512, return_tensors="pt"
44
        )
45
        if torch.cuda.is_available():
46
            tokens = tokens.to("cuda")
47

48
        output = model(**tokens)
49
        logits, attention_mask = output.logits, tokens.attention_mask
50
        relu_log = torch.log(1 + torch.relu(logits))
51
        weighted_log = relu_log * attention_mask.unsqueeze(-1)
52
        tvecs, _ = torch.max(weighted_log, dim=1)
53

54
        # extract the vectors that are non-zero and their indices
55
        indices = []
56
        vecs = []
57
        for batch in tvecs:
58
            indices.append(batch.nonzero(as_tuple=True)[0].tolist())
59
            vecs.append(batch[indices[-1]].tolist())
60

61
        return indices, vecs
62

63
    return compute_vectors
64

65

66
def relative_score_fusion(
67
    dense_result: VectorStoreQueryResult,
68
    sparse_result: VectorStoreQueryResult,
69
    # NOTE: only for hybrid search (0 for sparse search, 1 for dense search)
70
    alpha: float = 0.5,
71
    top_k: int = 2,
72
) -> VectorStoreQueryResult:
73
    """
74
    Fuse dense and sparse results using relative score fusion.
75
    """
76
    # check if dense or sparse results is empty
77
    if (dense_result.nodes is None or len(dense_result.nodes) == 0) and (
78
        sparse_result.nodes is None or len(sparse_result.nodes) == 0
79
    ):
80
        return VectorStoreQueryResult(nodes=None, similarities=None, ids=None)
81
    elif sparse_result.nodes is None or len(sparse_result.nodes) == 0:
82
        return dense_result
83
    elif dense_result.nodes is None or len(dense_result.nodes) == 0:
84
        return sparse_result
85

86
    assert dense_result.nodes is not None
87
    assert dense_result.similarities is not None
88
    assert sparse_result.nodes is not None
89
    assert sparse_result.similarities is not None
90

91
    # deconstruct results
92
    sparse_result_tuples = list(zip(sparse_result.similarities, sparse_result.nodes))
93
    sparse_result_tuples.sort(key=lambda x: x[0], reverse=True)
94

95
    dense_result_tuples = list(zip(dense_result.similarities, dense_result.nodes))
96
    dense_result_tuples.sort(key=lambda x: x[0], reverse=True)
97

98
    # track nodes in both results
99
    all_nodes_dict = {x.node_id: x for x in dense_result.nodes}
100
    for node in sparse_result.nodes:
101
        if node.node_id not in all_nodes_dict:
102
            all_nodes_dict[node.node_id] = node
103

104
    # normalize sparse similarities from 0 to 1
105
    sparse_similarities = [x[0] for x in sparse_result_tuples]
106

107
    sparse_per_node = {}
108
    if len(sparse_similarities) > 0:
109
        max_sparse_sim = max(sparse_similarities)
110
        min_sparse_sim = min(sparse_similarities)
111

112
        # avoid division by zero
113
        if max_sparse_sim == min_sparse_sim:
114
            sparse_similarities = [max_sparse_sim] * len(sparse_similarities)
115
        else:
116
            sparse_similarities = [
117
                (x - min_sparse_sim) / (max_sparse_sim - min_sparse_sim)
118
                for x in sparse_similarities
119
            ]
120

121
        sparse_per_node = {
122
            sparse_result_tuples[i][1].node_id: x
123
            for i, x in enumerate(sparse_similarities)
124
        }
125

126
    # normalize dense similarities from 0 to 1
127
    dense_similarities = [x[0] for x in dense_result_tuples]
128

129
    dense_per_node = {}
130
    if len(dense_similarities) > 0:
131
        max_dense_sim = max(dense_similarities)
132
        min_dense_sim = min(dense_similarities)
133

134
        # avoid division by zero
135
        if max_dense_sim == min_dense_sim:
136
            dense_similarities = [max_dense_sim] * len(dense_similarities)
137
        else:
138
            dense_similarities = [
139
                (x - min_dense_sim) / (max_dense_sim - min_dense_sim)
140
                for x in dense_similarities
141
            ]
142

143
        dense_per_node = {
144
            dense_result_tuples[i][1].node_id: x
145
            for i, x in enumerate(dense_similarities)
146
        }
147

148
    # fuse the scores
149
    fused_similarities = []
150
    for node_id in all_nodes_dict:
151
        sparse_sim = sparse_per_node.get(node_id, 0)
152
        dense_sim = dense_per_node.get(node_id, 0)
153
        fused_sim = (1 - alpha) * sparse_sim + alpha * dense_sim
154
        fused_similarities.append((fused_sim, all_nodes_dict[node_id]))
155

156
    fused_similarities.sort(key=lambda x: x[0], reverse=True)
157
    fused_similarities = fused_similarities[:top_k]
158

159
    # create final response object
160
    return VectorStoreQueryResult(
161
        nodes=[x[1] for x in fused_similarities],
162
        similarities=[x[0] for x in fused_similarities],
163
        ids=[x[1].node_id for x in fused_similarities],
164
    )
165

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.