llama-index

Форк
0
164 строки · 4.8 Кб
1
"""Weaviate-specific serializers for LlamaIndex data structures.
2

3
Contain conversion to and from dataclasses that LlamaIndex uses.
4

5
"""
6

7
import logging
8
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
9

10
if TYPE_CHECKING:
11
    from weaviate import Client
12

13
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
14
from llama_index.legacy.vector_stores.utils import (
15
    DEFAULT_TEXT_KEY,
16
    legacy_metadata_dict_to_node,
17
    metadata_dict_to_node,
18
    node_to_metadata_dict,
19
)
20

21
_logger = logging.getLogger(__name__)
22

23
NODE_SCHEMA: List[Dict] = [
24
    {
25
        "dataType": ["text"],
26
        "description": "Text property",
27
        "name": "text",
28
    },
29
    {
30
        "dataType": ["text"],
31
        "description": "The ref_doc_id of the Node",
32
        "name": "ref_doc_id",
33
    },
34
    {
35
        "dataType": ["text"],
36
        "description": "node_info (in JSON)",
37
        "name": "node_info",
38
    },
39
    {
40
        "dataType": ["text"],
41
        "description": "The relationships of the node (in JSON)",
42
        "name": "relationships",
43
    },
44
]
45

46

47
def validate_client(client: Any) -> None:
48
    """Validate client and import weaviate library."""
49
    try:
50
        import weaviate  # noqa
51
        from weaviate import Client
52

53
        client = cast(Client, client)
54
    except ImportError:
55
        raise ImportError(
56
            "Weaviate is not installed. "
57
            "Please install it with `pip install weaviate-client`."
58
        )
59
    cast(Client, client)
60

61

62
def parse_get_response(response: Dict) -> Dict:
63
    """Parse get response from Weaviate."""
64
    if "errors" in response:
65
        raise ValueError("Invalid query, got errors: {}".format(response["errors"]))
66
    data_response = response["data"]
67
    if "Get" not in data_response:
68
        raise ValueError("Invalid query response, must be a Get query.")
69

70
    return data_response["Get"]
71

72

73
def class_schema_exists(client: Any, class_name: str) -> bool:
74
    """Check if class schema exists."""
75
    validate_client(client)
76
    schema = client.schema.get()
77
    classes = schema["classes"]
78
    existing_class_names = {c["class"] for c in classes}
79
    return class_name in existing_class_names
80

81

82
def create_default_schema(client: Any, class_name: str) -> None:
83
    """Create default schema."""
84
    validate_client(client)
85
    class_schema = {
86
        "class": class_name,
87
        "description": f"Class for {class_name}",
88
        "properties": NODE_SCHEMA,
89
    }
90
    client.schema.create_class(class_schema)
91

92

93
def get_all_properties(client: Any, class_name: str) -> List[str]:
94
    """Get all properties of a class."""
95
    validate_client(client)
96
    schema = client.schema.get()
97
    classes = schema["classes"]
98
    classes_by_name = {c["class"]: c for c in classes}
99
    if class_name not in classes_by_name:
100
        raise ValueError(f"{class_name} schema does not exist.")
101
    schema = classes_by_name[class_name]
102
    return [p["name"] for p in schema["properties"]]
103

104

105
def get_node_similarity(entry: Dict, similarity_key: str = "distance") -> float:
106
    """Get converted node similarity from distance."""
107
    distance = entry["_additional"].get(similarity_key, 0.0)
108

109
    if distance is None:
110
        return 1.0
111

112
    # convert distance https://forum.weaviate.io/t/distance-vs-certainty-scores/258
113
    return 1.0 - float(distance)
114

115

116
def to_node(entry: Dict, text_key: str = DEFAULT_TEXT_KEY) -> TextNode:
117
    """Convert to Node."""
118
    additional = entry.pop("_additional")
119
    text = entry.pop(text_key, "")
120
    embedding = additional.pop("vector", None)
121
    try:
122
        node = metadata_dict_to_node(entry)
123
        node.text = text
124
        node.embedding = embedding
125
    except Exception as e:
126
        _logger.debug("Failed to parse Node metadata, fallback to legacy logic.", e)
127
        metadata, node_info, relationships = legacy_metadata_dict_to_node(entry)
128

129
        node = TextNode(
130
            text=text,
131
            id_=additional["id"],
132
            metadata=metadata,
133
            start_char_idx=node_info.get("start", None),
134
            end_char_idx=node_info.get("end", None),
135
            relationships=relationships,
136
            embedding=embedding,
137
        )
138
    return node
139

140

141
def add_node(
142
    client: "Client",
143
    node: BaseNode,
144
    class_name: str,
145
    batch: Optional[Any] = None,
146
    text_key: str = DEFAULT_TEXT_KEY,
147
) -> None:
148
    """Add node."""
149
    metadata = {}
150
    metadata[text_key] = node.get_content(metadata_mode=MetadataMode.NONE) or ""
151

152
    additional_metadata = node_to_metadata_dict(
153
        node, remove_text=True, flat_metadata=False
154
    )
155
    metadata.update(additional_metadata)
156

157
    vector = node.get_embedding()
158
    id = node.node_id
159

160
    # if batch object is provided (via a context manager), use that instead
161
    if batch is not None:
162
        batch.add_data_object(metadata, class_name, id, vector)
163
    else:
164
        client.batch.add_data_object(metadata, class_name, id, vector)
165

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.