llama-index
116 строк · 3.9 Кб
1"""Weaviate reader."""
2
3from typing import Any, List, Optional4
5from llama_index.legacy.readers.base import BaseReader6from llama_index.legacy.schema import Document7
8
9class WeaviateReader(BaseReader):10"""Weaviate reader.11
12Retrieves documents from Weaviate through vector lookup. Allows option
13to concatenate retrieved documents into one Document, or to return
14separate Document objects per document.
15
16Args:
17host (str): host.
18auth_client_secret (Optional[weaviate.auth.AuthCredentials]):
19auth_client_secret.
20"""
21
22def __init__(23self,24host: str,25auth_client_secret: Optional[Any] = None,26) -> None:27"""Initialize with parameters."""28try:29import weaviate # noqa30from weaviate import Client31from weaviate.auth import AuthCredentials # noqa32except ImportError:33raise ImportError(34"`weaviate` package not found, please run `pip install weaviate-client`"35)36
37self.client: Client = Client(host, auth_client_secret=auth_client_secret)38
39def load_data(40self,41class_name: Optional[str] = None,42properties: Optional[List[str]] = None,43graphql_query: Optional[str] = None,44separate_documents: Optional[bool] = True,45) -> List[Document]:46"""Load data from Weaviate.47
48If `graphql_query` is not found in load_kwargs, we assume that
49`class_name` and `properties` are provided.
50
51Args:
52class_name (Optional[str]): class_name to retrieve documents from.
53properties (Optional[List[str]]): properties to retrieve from documents.
54graphql_query (Optional[str]): Raw GraphQL Query.
55We assume that the query is a Get query.
56separate_documents (Optional[bool]): Whether to return separate
57documents. Defaults to True.
58
59Returns:
60List[Document]: A list of documents.
61
62"""
63if class_name is not None and properties is not None:64props_txt = "\n".join(properties)65graphql_query = f"""66{{
67Get {{
68{class_name} {{69{props_txt}70}}
71}}
72}}
73"""
74elif graphql_query is not None:75pass76else:77raise ValueError(78"Either `class_name` and `properties` must be specified, "79"or `graphql_query` must be specified."80)81
82response = self.client.query.raw(graphql_query)83if "errors" in response:84raise ValueError("Invalid query, got errors: {}".format(response["errors"]))85
86data_response = response["data"]87if "Get" not in data_response:88raise ValueError("Invalid query response, must be a Get query.")89
90if class_name is None:91# infer class_name if only graphql_query was provided92class_name = next(iter(data_response["Get"].keys()))93entries = data_response["Get"][class_name]94documents = []95for entry in entries:96embedding: Optional[List[float]] = None97# for each entry, join properties into <property>:<value>98# separated by newlines99text_list = []100for k, v in entry.items():101if k == "_additional":102if "vector" in v:103embedding = v["vector"]104continue105text_list.append(f"{k}: {v}")106
107text = "\n".join(text_list)108documents.append(Document(text=text, embedding=embedding))109
110if not separate_documents:111# join all documents into one112text_list = [doc.get_content() for doc in documents]113text = "\n\n".join(text_list)114documents = [Document(text=text)]115
116return documents117