Langchain-Chatchat

Форк
0
117 строк · 4.4 Кб
1
from typing import List, Dict, Optional
2

3
from langchain.schema import Document
4
from langchain.vectorstores.milvus import Milvus
5
import os
6

7
from configs import kbs_config
8
from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
9

10
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
11
    score_threshold_process
12
from server.knowledge_base.utils import KnowledgeFile
13

14

15
class MilvusKBService(KBService):
16
    milvus: Milvus
17

18
    @staticmethod
19
    def get_collection(milvus_name):
20
        from pymilvus import Collection
21
        return Collection(milvus_name)
22

23
    def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
24
        result = []
25
        if self.milvus.col:
26
            # ids = [int(id) for id in ids]  # for milvus if needed #pr 2725
27
            data_list = self.milvus.col.query(expr=f'pk in {[int(_id) for _id in ids]}', output_fields=["*"])
28
            for data in data_list:
29
                text = data.pop("text")
30
                result.append(Document(page_content=text, metadata=data))
31
        return result
32

33
    def del_doc_by_ids(self, ids: List[str]) -> bool:
34
        self.milvus.col.delete(expr=f'pk in {ids}')
35

36
    @staticmethod
37
    def search(milvus_name, content, limit=3):
38
        search_params = {
39
            "metric_type": "L2",
40
            "params": {"nprobe": 10},
41
        }
42
        c = MilvusKBService.get_collection(milvus_name)
43
        return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
44

45
    def do_create_kb(self):
46
        pass
47

48
    def vs_type(self) -> str:
49
        return SupportedVSType.MILVUS
50

51
    def _load_milvus(self):
52
        self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
53
                             collection_name=self.kb_name,
54
                             connection_args=kbs_config.get("milvus"),
55
                             index_params=kbs_config.get("milvus_kwargs")["index_params"],
56
                             search_params=kbs_config.get("milvus_kwargs")["search_params"]
57
                             )
58

59
    def do_init(self):
60
        self._load_milvus()
61

62
    def do_drop_kb(self):
63
        if self.milvus.col:
64
            self.milvus.col.release()
65
            self.milvus.col.drop()
66

67
    def do_search(self, query: str, top_k: int, score_threshold: float):
68
        self._load_milvus()
69
        embed_func = EmbeddingsFunAdapter(self.embed_model)
70
        embeddings = embed_func.embed_query(query)
71
        docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
72
        return score_threshold_process(score_threshold, top_k, docs)
73

74
    def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
75
        for doc in docs:
76
            for k, v in doc.metadata.items():
77
                doc.metadata[k] = str(v)
78
            for field in self.milvus.fields:
79
                doc.metadata.setdefault(field, "")
80
            doc.metadata.pop(self.milvus._text_field, None)
81
            doc.metadata.pop(self.milvus._vector_field, None)
82

83
        ids = self.milvus.add_documents(docs)
84
        doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
85
        return doc_infos
86

87
    def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
88
        id_list = list_file_num_docs_id_by_kb_name_and_file_name(kb_file.kb_name, kb_file.filename)
89
        if self.milvus.col:
90
            self.milvus.col.delete(expr=f'pk in {id_list}')
91

92
        # Issue 2846, for windows
93
        # if self.milvus.col:
94
        #     file_path = kb_file.filepath.replace("\\", "\\\\")
95
        #     file_name = os.path.basename(file_path)
96
        #     id_list = [item.get("pk") for item in
97
        #                self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
98
        #     self.milvus.col.delete(expr=f'pk in {id_list}')
99

100
    def do_clear_vs(self):
101
        if self.milvus.col:
102
            self.do_drop_kb()
103
            self.do_init()
104

105

106
if __name__ == '__main__':
107
    # 测试建表使用
108
    from server.db.base import Base, engine
109

110
    Base.metadata.create_all(bind=engine)
111
    milvusService = MilvusKBService("test")
112
    # milvusService.add_doc(KnowledgeFile("README.md", "test"))
113

114
    print(milvusService.get_doc_by_ids(["444022434274215486"]))
115
    # milvusService.delete_doc(KnowledgeFile("README.md", "test"))
116
    # milvusService.do_drop_kb()
117
    # print(milvusService.search_docs("如何启动api服务"))
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.