Langchain-Chatchat
108 строк · 4.0 Кб
1import uuid
2from typing import Any, Dict, List, Tuple
3
4import chromadb
5from chromadb.api.types import (GetResult, QueryResult)
6from langchain.docstore.document import Document
7
8from configs import SCORE_THRESHOLD
9from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
10KBService, SupportedVSType)
11from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
12
13
14def _get_result_to_documents(get_result: GetResult) -> List[Document]:
15if not get_result['documents']:
16return []
17
18_metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents'])
19
20document_list = []
21for page_content, metadata in zip(get_result['documents'], _metadatas):
22document_list.append(Document(**{'page_content': page_content, 'metadata': metadata}))
23
24return document_list
25
26
27def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
28"""
29from langchain_community.vectorstores.chroma import Chroma
30"""
31return [
32# TODO: Chroma can do batch querying,
33(Document(page_content=result[0], metadata=result[1] or {}), result[2])
34for result in zip(
35results["documents"][0],
36results["metadatas"][0],
37results["distances"][0],
38)
39]
40
41
42class ChromaKBService(KBService):
43vs_path: str
44kb_path: str
45
46client = None
47collection = None
48
49def vs_type(self) -> str:
50return SupportedVSType.CHROMADB
51
52def get_vs_path(self) -> str:
53return get_vs_path(self.kb_name, self.embed_model)
54
55def get_kb_path(self) -> str:
56return get_kb_path(self.kb_name)
57
58def do_init(self) -> None:
59self.kb_path = self.get_kb_path()
60self.vs_path = self.get_vs_path()
61self.client = chromadb.PersistentClient(path=self.vs_path)
62self.collection = self.client.get_or_create_collection(self.kb_name)
63
64def do_create_kb(self) -> None:
65# In ChromaDB, creating a KB is equivalent to creating a collection
66self.collection = self.client.get_or_create_collection(self.kb_name)
67
68def do_drop_kb(self):
69# Dropping a KB is equivalent to deleting a collection in ChromaDB
70try:
71self.client.delete_collection(self.kb_name)
72except ValueError as e:
73if not str(e) == f"Collection {self.kb_name} does not exist.":
74raise e
75
76def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
77Tuple[Document, float]]:
78embed_func = EmbeddingsFunAdapter(self.embed_model)
79embeddings = embed_func.embed_query(query)
80query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
81return _results_to_docs_and_scores(query_result)
82
83def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
84doc_infos = []
85embed_func = EmbeddingsFunAdapter(self.embed_model)
86texts = [doc.page_content for doc in docs]
87metadatas = [doc.metadata for doc in docs]
88embeddings = embed_func.embed_documents(texts=texts)
89ids = [str(uuid.uuid1()) for _ in range(len(texts))]
90for _id, text, embedding, metadata in zip(ids, texts, embeddings, metadatas):
91self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text)
92doc_infos.append({"id": _id, "metadata": metadata})
93return doc_infos
94
95def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
96get_result: GetResult = self.collection.get(ids=ids)
97return _get_result_to_documents(get_result)
98
99def del_doc_by_ids(self, ids: List[str]) -> bool:
100self.collection.delete(ids=ids)
101return True
102
103def do_clear_vs(self):
104# Clearing the vector store might be equivalent to dropping and recreating the collection
105self.do_drop_kb()
106
107def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
108return self.collection.delete(where={"source": kb_file.filepath})
109