Langchain-Chatchat
406 строк · 18.6 Кб
1import os2import urllib3from fastapi import File, Form, Body, Query, UploadFile4from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,5VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,6CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,7logger, log_verbose, )8from server.utils import BaseResponse, ListResponse, run_in_thread_pool9from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,10files2docs_in_thread, KnowledgeFile)11from fastapi.responses import FileResponse12from sse_starlette import EventSourceResponse13from pydantic import Json14import json15from server.knowledge_base.kb_service.base import KBServiceFactory16from server.db.repository.knowledge_file_repository import get_file_detail17from langchain.docstore.document import Document18from server.knowledge_base.model.kb_document_model import DocumentWithVSId19from typing import List, Dict20
21
22def search_docs(23query: str = Body("", description="用户输入", examples=["你好"]),24knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),25top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),26score_threshold: float = Body(SCORE_THRESHOLD,27description="知识库匹配相关度阈值,取值范围在0-1之间,"28"SCORE越小,相关度越高,"29"取到1相当于不筛选,建议设置在0.5左右",30ge=0, le=1),31file_name: str = Body("", description="文件名称,支持 sql 通配符"),32metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),33) -> List[DocumentWithVSId]:34kb = KBServiceFactory.get_service_by_name(knowledge_base_name)35data = []36if kb is not None:37if query:38docs = kb.search_docs(query, top_k, score_threshold)39data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]40elif file_name or metadata:41data = kb.list_docs(file_name=file_name, metadata=metadata)42for d in data:43if "vector" in d.metadata:44del d.metadata["vector"]45return data46
47
48def update_docs_by_id(49knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),50docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")51) -> BaseResponse:52'''53按照文档 ID 更新文档内容
54'''
55kb = KBServiceFactory.get_service_by_name(knowledge_base_name)56if kb is None:57return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")58if kb.update_doc_by_ids(docs=docs):59return BaseResponse(msg=f"文档更新成功")60else:61return BaseResponse(msg=f"文档更新失败")62
63
64def list_files(65knowledge_base_name: str66) -> ListResponse:67if not validate_kb_name(knowledge_base_name):68return ListResponse(code=403, msg="Don't attack me", data=[])69
70knowledge_base_name = urllib.parse.unquote(knowledge_base_name)71kb = KBServiceFactory.get_service_by_name(knowledge_base_name)72if kb is None:73return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])74else:75all_doc_names = kb.list_files()76return ListResponse(data=all_doc_names)77
78
79def _save_files_in_thread(files: List[UploadFile],80knowledge_base_name: str,81override: bool):82"""83通过多线程将上传的文件保存到对应知识库目录内。
84生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
85"""
86
87def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:88'''89保存单个文件。
90'''
91try:92filename = file.filename93file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)94data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}95
96file_content = file.file.read() # 读取上传文件的内容97if (os.path.isfile(file_path)98and not override99and os.path.getsize(file_path) == len(file_content)100):101file_status = f"文件 {filename} 已存在。"102logger.warn(file_status)103return dict(code=404, msg=file_status, data=data)104
105if not os.path.isdir(os.path.dirname(file_path)):106os.makedirs(os.path.dirname(file_path))107with open(file_path, "wb") as f:108f.write(file_content)109return dict(code=200, msg=f"成功上传文件 {filename}", data=data)110except Exception as e:111msg = f"{filename} 文件上传失败,报错信息为: {e}"112logger.error(f'{e.__class__.__name__}: {msg}',113exc_info=e if log_verbose else None)114return dict(code=500, msg=msg, data=data)115
116params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]117for result in run_in_thread_pool(save_file, params=params):118yield result119
120
121# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
122# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
123# override: bool = Form(False, description="覆盖已有文件"),
124# save: bool = Form(True, description="是否将文件保存到知识库目录")):
125# def save_files(files, knowledge_base_name, override):
126# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
127# yield json.dumps(result, ensure_ascii=False)
128
129# def files_to_docs(files):
130# for result in files2docs_in_thread(files):
131# yield json.dumps(result, ensure_ascii=False)
132
133
134def upload_docs(135files: List[UploadFile] = File(..., description="上传文件,支持多文件"),136knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),137override: bool = Form(False, description="覆盖已有文件"),138to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),139chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),140chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),141zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),142docs: Json = Form({}, description="自定义的docs,需要转为json字符串",143examples=[{"test.txt": [Document(page_content="custom doc")]}]),144not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),145) -> BaseResponse:146"""147API接口:上传文件,并/或向量化
148"""
149if not validate_kb_name(knowledge_base_name):150return BaseResponse(code=403, msg="Don't attack me")151
152kb = KBServiceFactory.get_service_by_name(knowledge_base_name)153if kb is None:154return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")155
156failed_files = {}157file_names = list(docs.keys())158
159# 先将上传的文件保存到磁盘160for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):161filename = result["data"]["file_name"]162if result["code"] != 200:163failed_files[filename] = result["msg"]164
165if filename not in file_names:166file_names.append(filename)167
168# 对保存的文件进行向量化169if to_vector_store:170result = update_docs(171knowledge_base_name=knowledge_base_name,172file_names=file_names,173override_custom_docs=True,174chunk_size=chunk_size,175chunk_overlap=chunk_overlap,176zh_title_enhance=zh_title_enhance,177docs=docs,178not_refresh_vs_cache=True,179)180failed_files.update(result.data["failed_files"])181if not not_refresh_vs_cache:182kb.save_vector_store()183
184return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})185
186
187def delete_docs(188knowledge_base_name: str = Body(..., examples=["samples"]),189file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),190delete_content: bool = Body(False),191not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),192) -> BaseResponse:193if not validate_kb_name(knowledge_base_name):194return BaseResponse(code=403, msg="Don't attack me")195
196knowledge_base_name = urllib.parse.unquote(knowledge_base_name)197kb = KBServiceFactory.get_service_by_name(knowledge_base_name)198if kb is None:199return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")200
201failed_files = {}202for file_name in file_names:203if not kb.exist_doc(file_name):204failed_files[file_name] = f"未找到文件 {file_name}"205
206try:207kb_file = KnowledgeFile(filename=file_name,208knowledge_base_name=knowledge_base_name)209kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)210except Exception as e:211msg = f"{file_name} 文件删除失败,错误信息:{e}"212logger.error(f'{e.__class__.__name__}: {msg}',213exc_info=e if log_verbose else None)214failed_files[file_name] = msg215
216if not not_refresh_vs_cache:217kb.save_vector_store()218
219return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})220
221
222def update_info(223knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),224kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),225):226if not validate_kb_name(knowledge_base_name):227return BaseResponse(code=403, msg="Don't attack me")228
229kb = KBServiceFactory.get_service_by_name(knowledge_base_name)230if kb is None:231return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")232kb.update_info(kb_info)233
234return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})235
236
237def update_docs(238knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),239file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),240chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),241chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),242zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),243override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),244docs: Json = Body({}, description="自定义的docs,需要转为json字符串",245examples=[{"test.txt": [Document(page_content="custom doc")]}]),246not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),247) -> BaseResponse:248"""249更新知识库文档
250"""
251if not validate_kb_name(knowledge_base_name):252return BaseResponse(code=403, msg="Don't attack me")253
254kb = KBServiceFactory.get_service_by_name(knowledge_base_name)255if kb is None:256return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")257
258failed_files = {}259kb_files = []260
261# 生成需要加载docs的文件列表262for file_name in file_names:263file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name)264# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖265if file_detail.get("custom_docs") and not override_custom_docs:266continue267if file_name not in docs:268try:269kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))270except Exception as e:271msg = f"加载文档 {file_name} 时出错:{e}"272logger.error(f'{e.__class__.__name__}: {msg}',273exc_info=e if log_verbose else None)274failed_files[file_name] = msg275
276# 从文件生成docs,并进行向量化。277# 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile278for status, result in files2docs_in_thread(kb_files,279chunk_size=chunk_size,280chunk_overlap=chunk_overlap,281zh_title_enhance=zh_title_enhance):282if status:283kb_name, file_name, new_docs = result284kb_file = KnowledgeFile(filename=file_name,285knowledge_base_name=knowledge_base_name)286kb_file.splited_docs = new_docs287kb.update_doc(kb_file, not_refresh_vs_cache=True)288else:289kb_name, file_name, error = result290failed_files[file_name] = error291
292# 将自定义的docs进行向量化293for file_name, v in docs.items():294try:295v = [x if isinstance(x, Document) else Document(**x) for x in v]296kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)297kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)298except Exception as e:299msg = f"为 {file_name} 添加自定义docs时出错:{e}"300logger.error(f'{e.__class__.__name__}: {msg}',301exc_info=e if log_verbose else None)302failed_files[file_name] = msg303
304if not not_refresh_vs_cache:305kb.save_vector_store()306
307return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})308
309
310def download_doc(311knowledge_base_name: str = Query(..., description="知识库名称", examples=["samples"]),312file_name: str = Query(..., description="文件名称", examples=["test.txt"]),313preview: bool = Query(False, description="是:浏览器内预览;否:下载"),314):315"""316下载知识库文档
317"""
318if not validate_kb_name(knowledge_base_name):319return BaseResponse(code=403, msg="Don't attack me")320
321kb = KBServiceFactory.get_service_by_name(knowledge_base_name)322if kb is None:323return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")324
325if preview:326content_disposition_type = "inline"327else:328content_disposition_type = None329
330try:331kb_file = KnowledgeFile(filename=file_name,332knowledge_base_name=knowledge_base_name)333
334if os.path.exists(kb_file.filepath):335return FileResponse(336path=kb_file.filepath,337filename=kb_file.filename,338media_type="multipart/form-data",339content_disposition_type=content_disposition_type,340)341except Exception as e:342msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"343logger.error(f'{e.__class__.__name__}: {msg}',344exc_info=e if log_verbose else None)345return BaseResponse(code=500, msg=msg)346
347return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")348
349
350def recreate_vector_store(351knowledge_base_name: str = Body(..., examples=["samples"]),352allow_empty_kb: bool = Body(True),353vs_type: str = Body(DEFAULT_VS_TYPE),354embed_model: str = Body(EMBEDDING_MODEL),355chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),356chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),357zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),358not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),359):360"""361recreate vector store from the content.
362this is usefull when user can copy files to content folder directly instead of upload through network.
363by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
364set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
365"""
366
367def output():368kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)369if not kb.exists() and not allow_empty_kb:370yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}371else:372if kb.exists():373kb.clear_vs()374kb.create_kb()375files = list_files_from_folder(knowledge_base_name)376kb_files = [(file, knowledge_base_name) for file in files]377i = 0378for status, result in files2docs_in_thread(kb_files,379chunk_size=chunk_size,380chunk_overlap=chunk_overlap,381zh_title_enhance=zh_title_enhance):382if status:383kb_name, file_name, docs = result384kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)385kb_file.splited_docs = docs386yield json.dumps({387"code": 200,388"msg": f"({i + 1} / {len(files)}): {file_name}",389"total": len(files),390"finished": i + 1,391"doc": file_name,392}, ensure_ascii=False)393kb.add_doc(kb_file, not_refresh_vs_cache=True)394else:395kb_name, file_name, error = result396msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"397logger.error(msg)398yield json.dumps({399"code": 500,400"msg": msg,401})402i += 1403if not not_refresh_vs_cache:404kb.save_vector_store()405
406return EventSourceResponse(output())407