Langchain-Chatchat
117 строк · 4.4 Кб
1from typing import List, Dict, Optional
2
3from langchain.schema import Document
4from langchain.vectorstores.milvus import Milvus
5import os
6
7from configs import kbs_config
8from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
9
10from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
11score_threshold_process
12from server.knowledge_base.utils import KnowledgeFile
13
14
15class MilvusKBService(KBService):
16milvus: Milvus
17
18@staticmethod
19def get_collection(milvus_name):
20from pymilvus import Collection
21return Collection(milvus_name)
22
23def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
24result = []
25if self.milvus.col:
26# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
27data_list = self.milvus.col.query(expr=f'pk in {[int(_id) for _id in ids]}', output_fields=["*"])
28for data in data_list:
29text = data.pop("text")
30result.append(Document(page_content=text, metadata=data))
31return result
32
33def del_doc_by_ids(self, ids: List[str]) -> bool:
34self.milvus.col.delete(expr=f'pk in {ids}')
35
36@staticmethod
37def search(milvus_name, content, limit=3):
38search_params = {
39"metric_type": "L2",
40"params": {"nprobe": 10},
41}
42c = MilvusKBService.get_collection(milvus_name)
43return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
44
45def do_create_kb(self):
46pass
47
48def vs_type(self) -> str:
49return SupportedVSType.MILVUS
50
51def _load_milvus(self):
52self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
53collection_name=self.kb_name,
54connection_args=kbs_config.get("milvus"),
55index_params=kbs_config.get("milvus_kwargs")["index_params"],
56search_params=kbs_config.get("milvus_kwargs")["search_params"]
57)
58
59def do_init(self):
60self._load_milvus()
61
62def do_drop_kb(self):
63if self.milvus.col:
64self.milvus.col.release()
65self.milvus.col.drop()
66
67def do_search(self, query: str, top_k: int, score_threshold: float):
68self._load_milvus()
69embed_func = EmbeddingsFunAdapter(self.embed_model)
70embeddings = embed_func.embed_query(query)
71docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
72return score_threshold_process(score_threshold, top_k, docs)
73
74def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
75for doc in docs:
76for k, v in doc.metadata.items():
77doc.metadata[k] = str(v)
78for field in self.milvus.fields:
79doc.metadata.setdefault(field, "")
80doc.metadata.pop(self.milvus._text_field, None)
81doc.metadata.pop(self.milvus._vector_field, None)
82
83ids = self.milvus.add_documents(docs)
84doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
85return doc_infos
86
87def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
88id_list = list_file_num_docs_id_by_kb_name_and_file_name(kb_file.kb_name, kb_file.filename)
89if self.milvus.col:
90self.milvus.col.delete(expr=f'pk in {id_list}')
91
92# Issue 2846, for windows
93# if self.milvus.col:
94# file_path = kb_file.filepath.replace("\\", "\\\\")
95# file_name = os.path.basename(file_path)
96# id_list = [item.get("pk") for item in
97# self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
98# self.milvus.col.delete(expr=f'pk in {id_list}')
99
100def do_clear_vs(self):
101if self.milvus.col:
102self.do_drop_kb()
103self.do_init()
104
105
106if __name__ == '__main__':
107# 测试建表使用
108from server.db.base import Base, engine
109
110Base.metadata.create_all(bind=engine)
111milvusService = MilvusKBService("test")
112# milvusService.add_doc(KnowledgeFile("README.md", "test"))
113
114print(milvusService.get_doc_by_ids(["444022434274215486"]))
115# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
116# milvusService.do_drop_kb()
117# print(milvusService.search_docs("如何启动api服务"))
118