llama-index

Форк
0
85 строк · 2.8 Кб
1
"""DashVector reader."""
2

3
from typing import Dict, List, Optional
4

5
from llama_index.legacy.readers.base import BaseReader
6
from llama_index.legacy.schema import Document
7

8

9
class DashVectorReader(BaseReader):
10
    """DashVector reader.
11

12
    Args:
13
        api_key (str): DashVector API key.
14
        endpoint (str): DashVector cluster endpoint.
15
    """
16

17
    def __init__(self, api_key: str, endpoint: str):
18
        """Initialize with parameters."""
19
        try:
20
            import dashvector
21
        except ImportError:
22
            raise ImportError(
23
                "`dashvector` package not found, please run `pip install dashvector`"
24
            )
25

26
        self._client = dashvector.Client(api_key=api_key, endpoint=endpoint)
27

28
    def load_data(
29
        self,
30
        collection_name: str,
31
        id_to_text_map: Dict[str, str],
32
        vector: Optional[List[float]],
33
        top_k: int,
34
        separate_documents: bool = True,
35
        filter: Optional[str] = None,
36
        include_vector: bool = True,
37
    ) -> List[Document]:
38
        """Load data from DashVector.
39

40
        Args:
41
            collection_name (str): Name of the collection.
42
            id_to_text_map (Dict[str, str]): A map from ID's to text.
43
            separate_documents (Optional[bool]): Whether to return separate
44
                documents per retrieved entry. Defaults to True.
45
            vector (List[float]): Query vector.
46
            top_k (int): Number of results to return.
47
            filter (Optional[str]): doc fields filter conditions that meet the SQL
48
                where clause specification.
49
            include_vector (bool): Whether to include the embedding in the response.
50
                Defaults to True.
51

52
        Returns:
53
            List[Document]: A list of documents.
54
        """
55
        collection = self._client.get(collection_name)
56
        if not collection:
57
            raise ValueError(
58
                f"Failed to get collection: {collection_name}," f"Error: {collection}"
59
            )
60

61
        resp = collection.query(
62
            vector=vector,
63
            topk=top_k,
64
            filter=filter,
65
            include_vector=include_vector,
66
        )
67
        if not resp:
68
            raise Exception(f"Failed to query document," f"Error: {resp}")
69

70
        documents = []
71
        for doc in resp:
72
            if doc.id not in id_to_text_map:
73
                raise ValueError("ID not found in id_to_text_map.")
74
            text = id_to_text_map[doc.id]
75
            embedding = doc.vector
76
            if len(embedding) == 0:
77
                embedding = None
78
            documents.append(Document(text=text, embedding=embedding))
79

80
        if not separate_documents:
81
            text_list = [doc.get_content() for doc in documents]
82
            text = "\n\n".join(text_list)
83
            documents = [Document(text=text)]
84

85
        return documents
86

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.