Langchain-Chatchat

Форк
0
123 строки · 4.4 Кб
1
import os
2
import shutil
3

4
from configs import SCORE_THRESHOLD
5
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
6
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
7
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
8
from server.utils import torch_gc
9
from langchain.docstore.document import Document
10
from typing import List, Dict, Optional, Tuple
11

12

13
class FaissKBService(KBService):
14
    vs_path: str
15
    kb_path: str
16
    vector_name: str = None
17
 
18
    def vs_type(self) -> str:
19
        return SupportedVSType.FAISS
20

21
    def get_vs_path(self):
22
        return get_vs_path(self.kb_name, self.vector_name)
23

24
    def get_kb_path(self):
25
        return get_kb_path(self.kb_name)
26

27
    def load_vector_store(self) -> ThreadSafeFaiss:
28
        return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
29
                                               vector_name=self.vector_name,
30
                                               embed_model=self.embed_model)
31

32
    def save_vector_store(self):
33
        self.load_vector_store().save(self.vs_path)
34

35
    def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
36
        with self.load_vector_store().acquire() as vs:
37
            return [vs.docstore._dict.get(id) for id in ids]
38

39
    def del_doc_by_ids(self, ids: List[str]) -> bool:
40
        with self.load_vector_store().acquire() as vs:
41
            vs.delete(ids)
42

43
    def do_init(self):
44
        self.vector_name = self.vector_name or self.embed_model
45
        self.kb_path = self.get_kb_path()
46
        self.vs_path = self.get_vs_path()
47

48
    def do_create_kb(self):
49
        if not os.path.exists(self.vs_path):
50
            os.makedirs(self.vs_path)
51
        self.load_vector_store()
52

53
    def do_drop_kb(self):
54
        self.clear_vs()
55
        try:
56
            shutil.rmtree(self.kb_path)
57
        except Exception:
58
            ...
59

60
    def do_search(self,
61
                  query: str,
62
                  top_k: int,
63
                  score_threshold: float = SCORE_THRESHOLD,
64
                  ) -> List[Tuple[Document, float]]:
65
        embed_func = EmbeddingsFunAdapter(self.embed_model)
66
        embeddings = embed_func.embed_query(query)
67
        with self.load_vector_store().acquire() as vs:
68
            docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
69
        return docs
70

71
    def do_add_doc(self,
72
                   docs: List[Document],
73
                   **kwargs,
74
                   ) -> List[Dict]:
75
        data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
76

77
        with self.load_vector_store().acquire() as vs:
78
            ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
79
                                    metadatas=data["metadatas"],
80
                                    ids=kwargs.get("ids"))
81
            if not kwargs.get("not_refresh_vs_cache"):
82
                vs.save_local(self.vs_path)
83
        doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
84
        torch_gc()
85
        return doc_infos
86

87
    def do_delete_doc(self,
88
                      kb_file: KnowledgeFile,
89
                      **kwargs):
90
        with self.load_vector_store().acquire() as vs:
91
            ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]
92
            if len(ids) > 0:
93
                vs.delete(ids)
94
            if not kwargs.get("not_refresh_vs_cache"):
95
                vs.save_local(self.vs_path)
96
        return ids
97

98
    def do_clear_vs(self):
99
        with kb_faiss_pool.atomic:
100
            kb_faiss_pool.pop((self.kb_name, self.vector_name))
101
        try:
102
            shutil.rmtree(self.vs_path)
103
        except Exception:
104
            ...
105
        os.makedirs(self.vs_path, exist_ok=True)
106

107
    def exist_doc(self, file_name: str):
108
        if super().exist_doc(file_name):
109
            return "in_db"
110

111
        content_path = os.path.join(self.kb_path, "content")
112
        if os.path.isfile(os.path.join(content_path, file_name)):
113
            return "in_folder"
114
        else:
115
            return False
116

117

118
if __name__ == '__main__':
119
    faissService = FaissKBService("test")
120
    faissService.add_doc(KnowledgeFile("README.md", "test"))
121
    faissService.delete_doc(KnowledgeFile("README.md", "test"))
122
    faissService.do_drop_kb()
123
    print(faissService.search_docs("如何启动api服务"))
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.