llama-index
108 строк · 3.4 Кб
1import logging2import re3from typing import TYPE_CHECKING, Any, List, Optional, Pattern4
5import numpy as np6
7_logger = logging.getLogger(__name__)8
9if TYPE_CHECKING:10from redis.client import Redis as RedisType11from redis.commands.search.query import Query12
13
14class TokenEscaper:15"""16Escape punctuation within an input string. Taken from RedisOM Python.
17"""
18
19# Characters that RediSearch requires us to escape during queries.20# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization21DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"22
23def __init__(self, escape_chars_re: Optional[Pattern] = None):24if escape_chars_re:25self.escaped_chars_re = escape_chars_re26else:27self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)28
29def escape(self, value: str) -> str:30def escape_symbol(match: re.Match) -> str:31value = match.group(0)32return f"\\{value}"33
34return self.escaped_chars_re.sub(escape_symbol, value)35
36
37# required modules
38REDIS_REQUIRED_MODULES = [39{"name": "search", "ver": 20400},40{"name": "searchlight", "ver": 20400},41]
42
43
44def check_redis_modules_exist(client: "RedisType") -> None:45"""Check if the correct Redis modules are installed."""46installed_modules = client.module_list()47installed_modules = {48module[b"name"].decode("utf-8"): module for module in installed_modules49}50for module in REDIS_REQUIRED_MODULES:51if module["name"] in installed_modules and int(52installed_modules[module["name"]][b"ver"]53) >= int(54module["ver"]55): # type: ignore[call-overload]56return57# otherwise raise error58error_message = (59"You must add the RediSearch (>= 2.4) module from Redis Stack. "60"Please refer to Redis Stack docs: https://redis.io/docs/stack/"61)62_logger.error(error_message)63raise ValueError(error_message)64
65
66def get_redis_query(67return_fields: List[str],68top_k: int = 20,69vector_field: str = "vector",70sort: bool = True,71filters: str = "*",72) -> "Query":73"""Create a vector query for use with a SearchIndex.74
75Args:
76return_fields (t.List[str]): A list of fields to return in the query results
77top_k (int, optional): The number of results to return. Defaults to 20.
78vector_field (str, optional): The name of the vector field in the index.
79Defaults to "vector".
80sort (bool, optional): Whether to sort the results by score. Defaults to True.
81filters (str, optional): string to filter the results by. Defaults to "*".
82
83"""
84from redis.commands.search.query import Query85
86base_query = f"{filters}=>[KNN {top_k} @{vector_field} $vector AS vector_score]"87
88query = Query(base_query).return_fields(*return_fields).dialect(2).paging(0, top_k)89
90if sort:91query.sort_by("vector_score")92return query93
94
95def convert_bytes(data: Any) -> Any:96if isinstance(data, bytes):97return data.decode("ascii")98if isinstance(data, dict):99return dict(map(convert_bytes, data.items()))100if isinstance(data, list):101return list(map(convert_bytes, data))102if isinstance(data, tuple):103return map(convert_bytes, data)104return data105
106
107def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:108return np.array(array).astype(dtype).tobytes()109