Langchain-Chatchat
123 строки · 4.4 Кб
1import os2import shutil3
4from configs import SCORE_THRESHOLD5from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter6from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss7from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path8from server.utils import torch_gc9from langchain.docstore.document import Document10from typing import List, Dict, Optional, Tuple11
12
13class FaissKBService(KBService):14vs_path: str15kb_path: str16vector_name: str = None17
18def vs_type(self) -> str:19return SupportedVSType.FAISS20
21def get_vs_path(self):22return get_vs_path(self.kb_name, self.vector_name)23
24def get_kb_path(self):25return get_kb_path(self.kb_name)26
27def load_vector_store(self) -> ThreadSafeFaiss:28return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,29vector_name=self.vector_name,30embed_model=self.embed_model)31
32def save_vector_store(self):33self.load_vector_store().save(self.vs_path)34
35def get_doc_by_ids(self, ids: List[str]) -> List[Document]:36with self.load_vector_store().acquire() as vs:37return [vs.docstore._dict.get(id) for id in ids]38
39def del_doc_by_ids(self, ids: List[str]) -> bool:40with self.load_vector_store().acquire() as vs:41vs.delete(ids)42
43def do_init(self):44self.vector_name = self.vector_name or self.embed_model45self.kb_path = self.get_kb_path()46self.vs_path = self.get_vs_path()47
48def do_create_kb(self):49if not os.path.exists(self.vs_path):50os.makedirs(self.vs_path)51self.load_vector_store()52
53def do_drop_kb(self):54self.clear_vs()55try:56shutil.rmtree(self.kb_path)57except Exception:58...59
60def do_search(self,61query: str,62top_k: int,63score_threshold: float = SCORE_THRESHOLD,64) -> List[Tuple[Document, float]]:65embed_func = EmbeddingsFunAdapter(self.embed_model)66embeddings = embed_func.embed_query(query)67with self.load_vector_store().acquire() as vs:68docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)69return docs70
71def do_add_doc(self,72docs: List[Document],73**kwargs,74) -> List[Dict]:75data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间76
77with self.load_vector_store().acquire() as vs:78ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),79metadatas=data["metadatas"],80ids=kwargs.get("ids"))81if not kwargs.get("not_refresh_vs_cache"):82vs.save_local(self.vs_path)83doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]84torch_gc()85return doc_infos86
87def do_delete_doc(self,88kb_file: KnowledgeFile,89**kwargs):90with self.load_vector_store().acquire() as vs:91ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]92if len(ids) > 0:93vs.delete(ids)94if not kwargs.get("not_refresh_vs_cache"):95vs.save_local(self.vs_path)96return ids97
98def do_clear_vs(self):99with kb_faiss_pool.atomic:100kb_faiss_pool.pop((self.kb_name, self.vector_name))101try:102shutil.rmtree(self.vs_path)103except Exception:104...105os.makedirs(self.vs_path, exist_ok=True)106
107def exist_doc(self, file_name: str):108if super().exist_doc(file_name):109return "in_db"110
111content_path = os.path.join(self.kb_path, "content")112if os.path.isfile(os.path.join(content_path, file_name)):113return "in_folder"114else:115return False116
117
118if __name__ == '__main__':119faissService = FaissKBService("test")120faissService.add_doc(KnowledgeFile("README.md", "test"))121faissService.delete_doc(KnowledgeFile("README.md", "test"))122faissService.do_drop_kb()123print(faissService.search_docs("如何启动api服务"))124