Langchain-Chatchat
470 строк · 15.9 Кб
1import operator2from abc import ABC, abstractmethod3
4import os5from pathlib import Path6import numpy as np7from langchain.embeddings.base import Embeddings8from langchain.docstore.document import Document9
10from server.db.repository.knowledge_base_repository import (11add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,12load_kb_from_db, get_kb_detail,13)
14from server.db.repository.knowledge_file_repository import (15add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,16count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,17list_docs_from_db,18)
19
20from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,21EMBEDDING_MODEL, KB_INFO)22from server.knowledge_base.utils import (23get_kb_path, get_doc_path, KnowledgeFile,24list_kbs_from_folder, list_files_from_folder,25)
26
27from typing import List, Union, Dict, Optional, Tuple28
29from server.embeddings_api import embed_texts, aembed_texts, embed_documents30from server.knowledge_base.model.kb_document_model import DocumentWithVSId31
32
33def normalize(embeddings: List[List[float]]) -> np.ndarray:34'''35sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
36'''
37norm = np.linalg.norm(embeddings, axis=1)38norm = np.reshape(norm, (norm.shape[0], 1))39norm = np.tile(norm, (1, len(embeddings[0])))40return np.divide(embeddings, norm)41
42
43class SupportedVSType:44FAISS = 'faiss'45MILVUS = 'milvus'46DEFAULT = 'default'47ZILLIZ = 'zilliz'48PG = 'pg'49ES = 'es'50CHROMADB = 'chromadb'51
52
53class KBService(ABC):54
55def __init__(self,56knowledge_base_name: str,57embed_model: str = EMBEDDING_MODEL,58):59self.kb_name = knowledge_base_name60self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")61self.embed_model = embed_model62self.kb_path = get_kb_path(self.kb_name)63self.doc_path = get_doc_path(self.kb_name)64self.do_init()65
66def __repr__(self) -> str:67return f"{self.kb_name} @ {self.embed_model}"68
69def save_vector_store(self):70'''71保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
72'''
73pass74
75def create_kb(self):76"""77创建知识库
78"""
79if not os.path.exists(self.doc_path):80os.makedirs(self.doc_path)81self.do_create_kb()82status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)83return status84
85def clear_vs(self):86"""87删除向量库中所有内容
88"""
89self.do_clear_vs()90status = delete_files_from_db(self.kb_name)91return status92
93def drop_kb(self):94"""95删除知识库
96"""
97self.do_drop_kb()98status = delete_kb_from_db(self.kb_name)99return status100
101def _docs_to_embeddings(self, docs: List[Document]) -> Dict:102'''103将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
104'''
105return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)106
107def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):108"""109向知识库添加文件
110如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
111"""
112if docs:113custom_docs = True114for doc in docs:115doc.metadata.setdefault("source", kb_file.filename)116else:117docs = kb_file.file2text()118custom_docs = False119
120if docs:121# 将 metadata["source"] 改为相对路径122for doc in docs:123try:124source = doc.metadata.get("source", "")125if os.path.isabs(source):126rel_path = Path(source).relative_to(self.doc_path)127doc.metadata["source"] = str(rel_path.as_posix().strip("/"))128except Exception as e:129print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")130self.delete_doc(kb_file)131doc_infos = self.do_add_doc(docs, **kwargs)132status = add_file_to_db(kb_file,133custom_docs=custom_docs,134docs_count=len(docs),135doc_infos=doc_infos)136else:137status = False138return status139
140def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):141"""142从知识库删除文件
143"""
144self.do_delete_doc(kb_file, **kwargs)145status = delete_file_from_db(kb_file)146if delete_content and os.path.exists(kb_file.filepath):147os.remove(kb_file.filepath)148return status149
150def update_info(self, kb_info: str):151"""152更新知识库介绍
153"""
154self.kb_info = kb_info155status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)156return status157
158def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):159"""160使用content中的文件更新向量库
161如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
162"""
163if os.path.exists(kb_file.filepath):164self.delete_doc(kb_file, **kwargs)165return self.add_doc(kb_file, docs=docs, **kwargs)166
167def exist_doc(self, file_name: str):168return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,169filename=file_name))170
171def list_files(self):172return list_files_from_db(self.kb_name)173
174def count_files(self):175return count_files_from_db(self.kb_name)176
177def search_docs(self,178query: str,179top_k: int = VECTOR_SEARCH_TOP_K,180score_threshold: float = SCORE_THRESHOLD,181) ->List[Document]:182docs = self.do_search(query, top_k, score_threshold)183return docs184
185def get_doc_by_ids(self, ids: List[str]) -> List[Document]:186return []187
188def del_doc_by_ids(self, ids: List[str]) -> bool:189raise NotImplementedError190
191def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:192'''193传入参数为: {doc_id: Document, ...}
194如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
195'''
196self.del_doc_by_ids(list(docs.keys()))197docs = []198ids = []199for k, v in docs.items():200if not v or not v.page_content.strip():201continue202ids.append(k)203docs.append(v)204self.do_add_doc(docs=docs, ids=ids)205return True206
207def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:208'''209通过file_name或metadata检索Document
210'''
211doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)212docs = []213for x in doc_infos:214doc_info = self.get_doc_by_ids([x["id"]])[0]215if doc_info is not None:216# 处理非空的情况217doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])218docs.append(doc_with_id)219else:220# 处理空的情况221# 可以选择跳过当前循环迭代或执行其他操作222pass223return docs224
225def get_relative_source_path(self,filepath: str):226'''227将文件路径转化为相对路径,保证查询时一致
228'''
229relative_path = filepath230if os.path.isabs(relative_path):231try:232relative_path = Path(filepath).relative_to(self.doc_path)233except Exception as e:234print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")235
236relative_path = str(relative_path.as_posix().strip("/"))237return relative_path238
239@abstractmethod240def do_create_kb(self):241"""242创建知识库子类实自己逻辑
243"""
244pass245
246@staticmethod247def list_kbs_type():248return list(kbs_config.keys())249
250@classmethod251def list_kbs(cls):252return list_kbs_from_db()253
254def exists(self, kb_name: str = None):255kb_name = kb_name or self.kb_name256return kb_exists(kb_name)257
258@abstractmethod259def vs_type(self) -> str:260pass261
262@abstractmethod263def do_init(self):264pass265
266@abstractmethod267def do_drop_kb(self):268"""269删除知识库子类实自己逻辑
270"""
271pass272
273@abstractmethod274def do_search(self,275query: str,276top_k: int,277score_threshold: float,278) -> List[Tuple[Document, float]]:279"""280搜索知识库子类实自己逻辑
281"""
282pass283
284@abstractmethod285def do_add_doc(self,286docs: List[Document],287**kwargs,288) -> List[Dict]:289"""290向知识库添加文档子类实自己逻辑
291"""
292pass293
294@abstractmethod295def do_delete_doc(self,296kb_file: KnowledgeFile):297"""298从知识库删除文档子类实自己逻辑
299"""
300pass301
302@abstractmethod303def do_clear_vs(self):304"""305从知识库删除全部向量子类实自己逻辑
306"""
307pass308
309
310class KBServiceFactory:311
312@staticmethod313def get_service(kb_name: str,314vector_store_type: Union[str, SupportedVSType],315embed_model: str = EMBEDDING_MODEL,316) -> KBService:317if isinstance(vector_store_type, str):318vector_store_type = getattr(SupportedVSType, vector_store_type.upper())319if SupportedVSType.FAISS == vector_store_type:320from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService321return FaissKBService(kb_name, embed_model=embed_model)322elif SupportedVSType.PG == vector_store_type:323from server.knowledge_base.kb_service.pg_kb_service import PGKBService324return PGKBService(kb_name, embed_model=embed_model)325elif SupportedVSType.MILVUS == vector_store_type:326from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService327return MilvusKBService(kb_name,embed_model=embed_model)328elif SupportedVSType.ZILLIZ == vector_store_type:329from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService330return ZillizKBService(kb_name, embed_model=embed_model)331elif SupportedVSType.DEFAULT == vector_store_type:332from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService333return MilvusKBService(kb_name,334embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config335elif SupportedVSType.ES == vector_store_type:336from server.knowledge_base.kb_service.es_kb_service import ESKBService337return ESKBService(kb_name, embed_model=embed_model)338elif SupportedVSType.CHROMADB == vector_store_type:339from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService340return ChromaKBService(kb_name, embed_model=embed_model)341elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.342from server.knowledge_base.kb_service.default_kb_service import DefaultKBService343return DefaultKBService(kb_name)344
345@staticmethod346def get_service_by_name(kb_name: str) -> KBService:347_, vs_type, embed_model = load_kb_from_db(kb_name)348if _ is None: # kb not in db, just return None349return None350return KBServiceFactory.get_service(kb_name, vs_type, embed_model)351
352@staticmethod353def get_default():354return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)355
356
357def get_kb_details() -> List[Dict]:358kbs_in_folder = list_kbs_from_folder()359kbs_in_db = KBService.list_kbs()360result = {}361
362for kb in kbs_in_folder:363result[kb] = {364"kb_name": kb,365"vs_type": "",366"kb_info": "",367"embed_model": "",368"file_count": 0,369"create_time": None,370"in_folder": True,371"in_db": False,372}373
374for kb in kbs_in_db:375kb_detail = get_kb_detail(kb)376if kb_detail:377kb_detail["in_db"] = True378if kb in result:379result[kb].update(kb_detail)380else:381kb_detail["in_folder"] = False382result[kb] = kb_detail383
384data = []385for i, v in enumerate(result.values()):386v['No'] = i + 1387data.append(v)388
389return data390
391
392def get_kb_file_details(kb_name: str) -> List[Dict]:393kb = KBServiceFactory.get_service_by_name(kb_name)394if kb is None:395return []396
397files_in_folder = list_files_from_folder(kb_name)398files_in_db = kb.list_files()399result = {}400
401for doc in files_in_folder:402result[doc] = {403"kb_name": kb_name,404"file_name": doc,405"file_ext": os.path.splitext(doc)[-1],406"file_version": 0,407"document_loader": "",408"docs_count": 0,409"text_splitter": "",410"create_time": None,411"in_folder": True,412"in_db": False,413}414lower_names = {x.lower(): x for x in result}415for doc in files_in_db:416doc_detail = get_file_detail(kb_name, doc)417if doc_detail:418doc_detail["in_db"] = True419if doc.lower() in lower_names:420result[lower_names[doc.lower()]].update(doc_detail)421else:422doc_detail["in_folder"] = False423result[doc] = doc_detail424
425data = []426for i, v in enumerate(result.values()):427v['No'] = i + 1428data.append(v)429
430return data431
432
433class EmbeddingsFunAdapter(Embeddings):434def __init__(self, embed_model: str = EMBEDDING_MODEL):435self.embed_model = embed_model436
437def embed_documents(self, texts: List[str]) -> List[List[float]]:438embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data439return normalize(embeddings).tolist()440
441def embed_query(self, text: str) -> List[float]:442embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data443query_embed = embeddings[0]444query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组445normalized_query_embed = normalize(query_embed_2d)446return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回447
448async def aembed_documents(self, texts: List[str]) -> List[List[float]]:449embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data450return normalize(embeddings).tolist()451
452async def aembed_query(self, text: str) -> List[float]:453embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data454query_embed = embeddings[0]455query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组456normalized_query_embed = normalize(query_embed_2d)457return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回458
459
460def score_threshold_process(score_threshold, k, docs):461if score_threshold is not None:462cmp = (463operator.le464)465docs = [466(doc, similarity)467for doc, similarity in docs468if cmp(similarity, score_threshold)469]470return docs[:k]471