Langchain-Chatchat

Форк
0
406 строк · 18.6 Кб
1
import os
2
import urllib
3
from fastapi import File, Form, Body, Query, UploadFile
4
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
5
                     VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
6
                     CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
7
                     logger, log_verbose, )
8
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
9
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
10
                                         files2docs_in_thread, KnowledgeFile)
11
from fastapi.responses import FileResponse
12
from sse_starlette import EventSourceResponse
13
from pydantic import Json
14
import json
15
from server.knowledge_base.kb_service.base import KBServiceFactory
16
from server.db.repository.knowledge_file_repository import get_file_detail
17
from langchain.docstore.document import Document
18
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
19
from typing import List, Dict
20

21

22
def search_docs(
23
        query: str = Body("", description="用户输入", examples=["你好"]),
24
        knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
25
        top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
26
        score_threshold: float = Body(SCORE_THRESHOLD,
27
                                      description="知识库匹配相关度阈值,取值范围在0-1之间,"
28
                                                  "SCORE越小,相关度越高,"
29
                                                  "取到1相当于不筛选,建议设置在0.5左右",
30
                                      ge=0, le=1),
31
        file_name: str = Body("", description="文件名称,支持 sql 通配符"),
32
        metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
33
) -> List[DocumentWithVSId]:
34
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
35
    data = []
36
    if kb is not None:
37
        if query:
38
            docs = kb.search_docs(query, top_k, score_threshold)
39
            data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
40
        elif file_name or metadata:
41
            data = kb.list_docs(file_name=file_name, metadata=metadata)
42
            for d in data:
43
                if "vector" in d.metadata:
44
                    del d.metadata["vector"]
45
    return data
46

47

48
def update_docs_by_id(
49
        knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
50
        docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")
51
) -> BaseResponse:
52
    '''
53
    按照文档 ID 更新文档内容
54
    '''
55
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
56
    if kb is None:
57
        return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")
58
    if kb.update_doc_by_ids(docs=docs):
59
        return BaseResponse(msg=f"文档更新成功")
60
    else:
61
        return BaseResponse(msg=f"文档更新失败")
62

63

64
def list_files(
65
        knowledge_base_name: str
66
) -> ListResponse:
67
    if not validate_kb_name(knowledge_base_name):
68
        return ListResponse(code=403, msg="Don't attack me", data=[])
69

70
    knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
71
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
72
    if kb is None:
73
        return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
74
    else:
75
        all_doc_names = kb.list_files()
76
        return ListResponse(data=all_doc_names)
77

78

79
def _save_files_in_thread(files: List[UploadFile],
80
                          knowledge_base_name: str,
81
                          override: bool):
82
    """
83
    通过多线程将上传的文件保存到对应知识库目录内。
84
    生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
85
    """
86

87
    def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
88
        '''
89
        保存单个文件。
90
        '''
91
        try:
92
            filename = file.filename
93
            file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
94
            data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
95

96
            file_content = file.file.read()  # 读取上传文件的内容
97
            if (os.path.isfile(file_path)
98
                    and not override
99
                    and os.path.getsize(file_path) == len(file_content)
100
            ):
101
                file_status = f"文件 {filename} 已存在。"
102
                logger.warn(file_status)
103
                return dict(code=404, msg=file_status, data=data)
104

105
            if not os.path.isdir(os.path.dirname(file_path)):
106
                os.makedirs(os.path.dirname(file_path))
107
            with open(file_path, "wb") as f:
108
                f.write(file_content)
109
            return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
110
        except Exception as e:
111
            msg = f"{filename} 文件上传失败,报错信息为: {e}"
112
            logger.error(f'{e.__class__.__name__}: {msg}',
113
                         exc_info=e if log_verbose else None)
114
            return dict(code=500, msg=msg, data=data)
115

116
    params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
117
    for result in run_in_thread_pool(save_file, params=params):
118
        yield result
119

120

121
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
122
#                 knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
123
#                 override: bool = Form(False, description="覆盖已有文件"),
124
#                 save: bool = Form(True, description="是否将文件保存到知识库目录")):
125
#     def save_files(files, knowledge_base_name, override):
126
#         for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
127
#             yield json.dumps(result, ensure_ascii=False)
128

129
#     def files_to_docs(files):
130
#         for result in files2docs_in_thread(files):
131
#             yield json.dumps(result, ensure_ascii=False)
132

133

134
def upload_docs(
135
        files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
136
        knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
137
        override: bool = Form(False, description="覆盖已有文件"),
138
        to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
139
        chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
140
        chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
141
        zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
142
        docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
143
                          examples=[{"test.txt": [Document(page_content="custom doc")]}]),
144
        not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
145
) -> BaseResponse:
146
    """
147
    API接口:上传文件,并/或向量化
148
    """
149
    if not validate_kb_name(knowledge_base_name):
150
        return BaseResponse(code=403, msg="Don't attack me")
151

152
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
153
    if kb is None:
154
        return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
155

156
    failed_files = {}
157
    file_names = list(docs.keys())
158

159
    # 先将上传的文件保存到磁盘
160
    for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
161
        filename = result["data"]["file_name"]
162
        if result["code"] != 200:
163
            failed_files[filename] = result["msg"]
164

165
        if filename not in file_names:
166
            file_names.append(filename)
167

168
    # 对保存的文件进行向量化
169
    if to_vector_store:
170
        result = update_docs(
171
            knowledge_base_name=knowledge_base_name,
172
            file_names=file_names,
173
            override_custom_docs=True,
174
            chunk_size=chunk_size,
175
            chunk_overlap=chunk_overlap,
176
            zh_title_enhance=zh_title_enhance,
177
            docs=docs,
178
            not_refresh_vs_cache=True,
179
        )
180
        failed_files.update(result.data["failed_files"])
181
        if not not_refresh_vs_cache:
182
            kb.save_vector_store()
183

184
    return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
185

186

187
def delete_docs(
188
        knowledge_base_name: str = Body(..., examples=["samples"]),
189
        file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
190
        delete_content: bool = Body(False),
191
        not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
192
) -> BaseResponse:
193
    if not validate_kb_name(knowledge_base_name):
194
        return BaseResponse(code=403, msg="Don't attack me")
195

196
    knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
197
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
198
    if kb is None:
199
        return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
200

201
    failed_files = {}
202
    for file_name in file_names:
203
        if not kb.exist_doc(file_name):
204
            failed_files[file_name] = f"未找到文件 {file_name}"
205

206
        try:
207
            kb_file = KnowledgeFile(filename=file_name,
208
                                    knowledge_base_name=knowledge_base_name)
209
            kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
210
        except Exception as e:
211
            msg = f"{file_name} 文件删除失败,错误信息:{e}"
212
            logger.error(f'{e.__class__.__name__}: {msg}',
213
                         exc_info=e if log_verbose else None)
214
            failed_files[file_name] = msg
215

216
    if not not_refresh_vs_cache:
217
        kb.save_vector_store()
218

219
    return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
220

221

222
def update_info(
223
        knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
224
        kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
225
):
226
    if not validate_kb_name(knowledge_base_name):
227
        return BaseResponse(code=403, msg="Don't attack me")
228

229
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
230
    if kb is None:
231
        return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
232
    kb.update_info(kb_info)
233

234
    return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
235

236

237
def update_docs(
238
        knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
239
        file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
240
        chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
241
        chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
242
        zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
243
        override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
244
        docs: Json = Body({}, description="自定义的docs,需要转为json字符串",
245
                          examples=[{"test.txt": [Document(page_content="custom doc")]}]),
246
        not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
247
) -> BaseResponse:
248
    """
249
    更新知识库文档
250
    """
251
    if not validate_kb_name(knowledge_base_name):
252
        return BaseResponse(code=403, msg="Don't attack me")
253

254
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
255
    if kb is None:
256
        return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
257

258
    failed_files = {}
259
    kb_files = []
260

261
    # 生成需要加载docs的文件列表
262
    for file_name in file_names:
263
        file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name)
264
        # 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
265
        if file_detail.get("custom_docs") and not override_custom_docs:
266
            continue
267
        if file_name not in docs:
268
            try:
269
                kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
270
            except Exception as e:
271
                msg = f"加载文档 {file_name} 时出错:{e}"
272
                logger.error(f'{e.__class__.__name__}: {msg}',
273
                             exc_info=e if log_verbose else None)
274
                failed_files[file_name] = msg
275

276
    # 从文件生成docs,并进行向量化。
277
    # 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile
278
    for status, result in files2docs_in_thread(kb_files,
279
                                               chunk_size=chunk_size,
280
                                               chunk_overlap=chunk_overlap,
281
                                               zh_title_enhance=zh_title_enhance):
282
        if status:
283
            kb_name, file_name, new_docs = result
284
            kb_file = KnowledgeFile(filename=file_name,
285
                                    knowledge_base_name=knowledge_base_name)
286
            kb_file.splited_docs = new_docs
287
            kb.update_doc(kb_file, not_refresh_vs_cache=True)
288
        else:
289
            kb_name, file_name, error = result
290
            failed_files[file_name] = error
291

292
    # 将自定义的docs进行向量化
293
    for file_name, v in docs.items():
294
        try:
295
            v = [x if isinstance(x, Document) else Document(**x) for x in v]
296
            kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
297
            kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
298
        except Exception as e:
299
            msg = f"为 {file_name} 添加自定义docs时出错:{e}"
300
            logger.error(f'{e.__class__.__name__}: {msg}',
301
                         exc_info=e if log_verbose else None)
302
            failed_files[file_name] = msg
303

304
    if not not_refresh_vs_cache:
305
        kb.save_vector_store()
306

307
    return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
308

309

310
def download_doc(
311
        knowledge_base_name: str = Query(..., description="知识库名称", examples=["samples"]),
312
        file_name: str = Query(..., description="文件名称", examples=["test.txt"]),
313
        preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
314
):
315
    """
316
    下载知识库文档
317
    """
318
    if not validate_kb_name(knowledge_base_name):
319
        return BaseResponse(code=403, msg="Don't attack me")
320

321
    kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
322
    if kb is None:
323
        return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
324

325
    if preview:
326
        content_disposition_type = "inline"
327
    else:
328
        content_disposition_type = None
329

330
    try:
331
        kb_file = KnowledgeFile(filename=file_name,
332
                                knowledge_base_name=knowledge_base_name)
333

334
        if os.path.exists(kb_file.filepath):
335
            return FileResponse(
336
                path=kb_file.filepath,
337
                filename=kb_file.filename,
338
                media_type="multipart/form-data",
339
                content_disposition_type=content_disposition_type,
340
            )
341
    except Exception as e:
342
        msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
343
        logger.error(f'{e.__class__.__name__}: {msg}',
344
                     exc_info=e if log_verbose else None)
345
        return BaseResponse(code=500, msg=msg)
346

347
    return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
348

349

350
def recreate_vector_store(
351
        knowledge_base_name: str = Body(..., examples=["samples"]),
352
        allow_empty_kb: bool = Body(True),
353
        vs_type: str = Body(DEFAULT_VS_TYPE),
354
        embed_model: str = Body(EMBEDDING_MODEL),
355
        chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
356
        chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
357
        zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
358
        not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
359
):
360
    """
361
    recreate vector store from the content.
362
    this is usefull when user can copy files to content folder directly instead of upload through network.
363
    by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
364
    set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
365
    """
366

367
    def output():
368
        kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
369
        if not kb.exists() and not allow_empty_kb:
370
            yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
371
        else:
372
            if kb.exists():
373
                kb.clear_vs()
374
            kb.create_kb()
375
            files = list_files_from_folder(knowledge_base_name)
376
            kb_files = [(file, knowledge_base_name) for file in files]
377
            i = 0
378
            for status, result in files2docs_in_thread(kb_files,
379
                                                       chunk_size=chunk_size,
380
                                                       chunk_overlap=chunk_overlap,
381
                                                       zh_title_enhance=zh_title_enhance):
382
                if status:
383
                    kb_name, file_name, docs = result
384
                    kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
385
                    kb_file.splited_docs = docs
386
                    yield json.dumps({
387
                        "code": 200,
388
                        "msg": f"({i + 1} / {len(files)}): {file_name}",
389
                        "total": len(files),
390
                        "finished": i + 1,
391
                        "doc": file_name,
392
                    }, ensure_ascii=False)
393
                    kb.add_doc(kb_file, not_refresh_vs_cache=True)
394
                else:
395
                    kb_name, file_name, error = result
396
                    msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
397
                    logger.error(msg)
398
                    yield json.dumps({
399
                        "code": 500,
400
                        "msg": msg,
401
                    })
402
                i += 1
403
            if not not_refresh_vs_cache:
404
                kb.save_vector_store()
405

406
    return EventSourceResponse(output())
407

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.