llama-index
156 строк · 5.7 Кб
1"""Optimization related classes and functions."""
2
3import logging4from typing import Callable, List, Optional5
6from llama_index.legacy.bridge.pydantic import Field, PrivateAttr7from llama_index.legacy.embeddings.base import BaseEmbedding8from llama_index.legacy.embeddings.openai import OpenAIEmbedding9from llama_index.legacy.indices.query.embedding_utils import get_top_k_embeddings10from llama_index.legacy.postprocessor.types import BaseNodePostprocessor11from llama_index.legacy.schema import MetadataMode, NodeWithScore, QueryBundle12
13logger = logging.getLogger(__name__)14
15
16class SentenceEmbeddingOptimizer(BaseNodePostprocessor):17"""Optimization of a text chunk given the query by shortening the input text."""18
19percentile_cutoff: Optional[float] = Field(20description="Percentile cutoff for the top k sentences to use."21)22threshold_cutoff: Optional[float] = Field(23description="Threshold cutoff for similarity for each sentence to use."24)25
26_embed_model: BaseEmbedding = PrivateAttr()27_tokenizer_fn: Callable[[str], List[str]] = PrivateAttr()28
29context_before: Optional[int] = Field(30description="Number of sentences before retrieved sentence for further context"31)32
33context_after: Optional[int] = Field(34description="Number of sentences after retrieved sentence for further context"35)36
37def __init__(38self,39embed_model: Optional[BaseEmbedding] = None,40percentile_cutoff: Optional[float] = None,41threshold_cutoff: Optional[float] = None,42tokenizer_fn: Optional[Callable[[str], List[str]]] = None,43context_before: Optional[int] = None,44context_after: Optional[int] = None,45):46"""Optimizer class that is passed into BaseGPTIndexQuery.47
48Should be set like this:
49
50.. code-block:: python
51from llama_index.legacy.optimization.optimizer import Optimizer
52optimizer = SentenceEmbeddingOptimizer(
53percentile_cutoff=0.5
54this means that the top 50% of sentences will be used.
55Alternatively, you can set the cutoff using a threshold
56on the similarity score. In this case only sentences with a
57similarity score higher than the threshold will be used.
58threshold_cutoff=0.7
59these cutoffs can also be used together.
60)
61
62query_engine = index.as_query_engine(
63optimizer=optimizer
64)
65response = query_engine.query("<query_str>")
66"""
67self._embed_model = embed_model or OpenAIEmbedding()68
69if tokenizer_fn is None:70import nltk.data71
72tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")73tokenizer_fn = tokenizer.tokenize74self._tokenizer_fn = tokenizer_fn75
76super().__init__(77percentile_cutoff=percentile_cutoff,78threshold_cutoff=threshold_cutoff,79context_after=context_after,80context_before=context_before,81)82
83@classmethod84def class_name(cls) -> str:85return "SentenceEmbeddingOptimizer"86
87def _postprocess_nodes(88self,89nodes: List[NodeWithScore],90query_bundle: Optional[QueryBundle] = None,91) -> List[NodeWithScore]:92"""Optimize a node text given the query by shortening the node text."""93if query_bundle is None:94return nodes95
96for node_idx in range(len(nodes)):97text = nodes[node_idx].node.get_content(metadata_mode=MetadataMode.LLM)98
99split_text = self._tokenizer_fn(text)100
101if query_bundle.embedding is None:102query_bundle.embedding = (103self._embed_model.get_agg_embedding_from_queries(104query_bundle.embedding_strs105)106)107
108text_embeddings = self._embed_model._get_text_embeddings(split_text)109
110num_top_k = None111threshold = None112if self.percentile_cutoff is not None:113num_top_k = int(len(split_text) * self.percentile_cutoff)114if self.threshold_cutoff is not None:115threshold = self.threshold_cutoff116
117top_similarities, top_idxs = get_top_k_embeddings(118query_embedding=query_bundle.embedding,119embeddings=text_embeddings,120similarity_fn=self._embed_model.similarity,121similarity_top_k=num_top_k,122embedding_ids=list(range(len(text_embeddings))),123similarity_cutoff=threshold,124)125
126if len(top_idxs) == 0:127raise ValueError("Optimizer returned zero sentences.")128
129rangeMin, rangeMax = 0, len(split_text)130
131if self.context_before is None:132self.context_before = 1133if self.context_after is None:134self.context_after = 1135
136top_sentences = [137" ".join(138split_text[139max(idx - self.context_before, rangeMin) : min(140idx + self.context_after + 1, rangeMax141)142]143)144for idx in top_idxs145]146
147logger.debug(f"> Top {len(top_idxs)} sentences with scores:\n")148if logger.isEnabledFor(logging.DEBUG):149for idx in range(len(top_idxs)):150logger.debug(151f"{idx}. {top_sentences[idx]} ({top_similarities[idx]})"152)153
154nodes[node_idx].node.set_content(" ".join(top_sentences))155
156return nodes157