Langchain-Chatchat

Форк
0
108 строк · 4.0 Кб
1
import uuid
2
from typing import Any, Dict, List, Tuple
3

4
import chromadb
5
from chromadb.api.types import (GetResult, QueryResult)
6
from langchain.docstore.document import Document
7

8
from configs import SCORE_THRESHOLD
9
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
10
                                                   KBService, SupportedVSType)
11
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
12

13

14
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
15
    if not get_result['documents']:
16
        return []
17

18
    _metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents'])
19

20
    document_list = []
21
    for page_content, metadata in zip(get_result['documents'], _metadatas):
22
        document_list.append(Document(**{'page_content': page_content, 'metadata': metadata}))
23

24
    return document_list
25

26

27
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
28
    """
29
    from langchain_community.vectorstores.chroma import Chroma
30
    """
31
    return [
32
        # TODO: Chroma can do batch querying,
33
        (Document(page_content=result[0], metadata=result[1] or {}), result[2])
34
        for result in zip(
35
            results["documents"][0],
36
            results["metadatas"][0],
37
            results["distances"][0],
38
        )
39
    ]
40

41

42
class ChromaKBService(KBService):
43
    vs_path: str
44
    kb_path: str
45

46
    client = None
47
    collection = None
48

49
    def vs_type(self) -> str:
50
        return SupportedVSType.CHROMADB
51

52
    def get_vs_path(self) -> str:
53
        return get_vs_path(self.kb_name, self.embed_model)
54

55
    def get_kb_path(self) -> str:
56
        return get_kb_path(self.kb_name)
57

58
    def do_init(self) -> None:
59
        self.kb_path = self.get_kb_path()
60
        self.vs_path = self.get_vs_path()
61
        self.client = chromadb.PersistentClient(path=self.vs_path)
62
        self.collection = self.client.get_or_create_collection(self.kb_name)
63

64
    def do_create_kb(self) -> None:
65
        # In ChromaDB, creating a KB is equivalent to creating a collection
66
        self.collection = self.client.get_or_create_collection(self.kb_name)
67

68
    def do_drop_kb(self):
69
        # Dropping a KB is equivalent to deleting a collection in ChromaDB
70
        try:
71
            self.client.delete_collection(self.kb_name)
72
        except ValueError as e:
73
            if not str(e) == f"Collection {self.kb_name} does not exist.":
74
                raise e
75

76
    def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
77
        Tuple[Document, float]]:
78
        embed_func = EmbeddingsFunAdapter(self.embed_model)
79
        embeddings = embed_func.embed_query(query)
80
        query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
81
        return _results_to_docs_and_scores(query_result)
82

83
    def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
84
        doc_infos = []
85
        embed_func = EmbeddingsFunAdapter(self.embed_model)
86
        texts = [doc.page_content for doc in docs]
87
        metadatas = [doc.metadata for doc in docs]
88
        embeddings = embed_func.embed_documents(texts=texts)
89
        ids = [str(uuid.uuid1()) for _ in range(len(texts))]
90
        for _id, text, embedding, metadata in zip(ids, texts, embeddings, metadatas):
91
            self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text)
92
            doc_infos.append({"id": _id, "metadata": metadata})
93
        return doc_infos
94

95
    def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
96
        get_result: GetResult = self.collection.get(ids=ids)
97
        return _get_result_to_documents(get_result)
98

99
    def del_doc_by_ids(self, ids: List[str]) -> bool:
100
        self.collection.delete(ids=ids)
101
        return True
102

103
    def do_clear_vs(self):
104
        # Clearing the vector store might be equivalent to dropping and recreating the collection
105
        self.do_drop_kb()
106

107
    def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
108
        return self.collection.delete(where={"source": kb_file.filepath})
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.