Langchain-Chatchat

Форк
0
220 строк · 8.6 Кб
1
from fastapi import Body
2
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
3
                     OVERLAP_SIZE,
4
                     logger, log_verbose, )
5
from server.knowledge_base.utils import (list_files_from_folder)
6
from sse_starlette import EventSourceResponse
7
import json
8
from server.knowledge_base.kb_service.base import KBServiceFactory
9
from typing import List, Optional
10
from server.knowledge_base.kb_summary.base import KBSummaryService
11
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
12
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
13
from configs import LLM_MODELS, TEMPERATURE
14
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
15

16
def recreate_summary_vector_store(
17
        knowledge_base_name: str = Body(..., examples=["samples"]),
18
        allow_empty_kb: bool = Body(True),
19
        vs_type: str = Body(DEFAULT_VS_TYPE),
20
        embed_model: str = Body(EMBEDDING_MODEL),
21
        file_description: str = Body(''),
22
        model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
23
        temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
24
        max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
25
):
26
    """
27
    重建单个知识库文件摘要
28
    :param max_tokens:
29
    :param model_name:
30
    :param temperature:
31
    :param file_description:
32
    :param knowledge_base_name:
33
    :param allow_empty_kb:
34
    :param vs_type:
35
    :param embed_model:
36
    :return:
37
    """
38

39
    def output():
40

41
        kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
42
        if not kb.exists() and not allow_empty_kb:
43
            yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
44
        else:
45
            # 重新创建知识库
46
            kb_summary = KBSummaryService(knowledge_base_name, embed_model)
47
            kb_summary.drop_kb_summary()
48
            kb_summary.create_kb_summary()
49

50
            llm = get_ChatOpenAI(
51
                model_name=model_name,
52
                temperature=temperature,
53
                max_tokens=max_tokens,
54
            )
55
            reduce_llm = get_ChatOpenAI(
56
                model_name=model_name,
57
                temperature=temperature,
58
                max_tokens=max_tokens,
59
            )
60
            # 文本摘要适配器
61
            summary = SummaryAdapter.form_summary(llm=llm,
62
                                                  reduce_llm=reduce_llm,
63
                                                  overlap_size=OVERLAP_SIZE)
64
            files = list_files_from_folder(knowledge_base_name)
65

66
            i = 0
67
            for i, file_name in enumerate(files):
68

69
                doc_infos = kb.list_docs(file_name=file_name)
70
                docs = summary.summarize(file_description=file_description,
71
                                         docs=doc_infos)
72

73
                status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
74
                if status_kb_summary:
75
                    logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
76
                    yield json.dumps({
77
                        "code": 200,
78
                        "msg": f"({i + 1} / {len(files)}): {file_name}",
79
                        "total": len(files),
80
                        "finished": i + 1,
81
                        "doc": file_name,
82
                    }, ensure_ascii=False)
83
                else:
84

85
                    msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
86
                    logger.error(msg)
87
                    yield json.dumps({
88
                        "code": 500,
89
                        "msg": msg,
90
                    })
91
                i += 1
92

93
    return EventSourceResponse(output())
94

95

96
def summary_file_to_vector_store(
97
        knowledge_base_name: str = Body(..., examples=["samples"]),
98
        file_name: str = Body(..., examples=["test.pdf"]),
99
        allow_empty_kb: bool = Body(True),
100
        vs_type: str = Body(DEFAULT_VS_TYPE),
101
        embed_model: str = Body(EMBEDDING_MODEL),
102
        file_description: str = Body(''),
103
        model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
104
        temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
105
        max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
106
):
107
    """
108
    单个知识库根据文件名称摘要
109
    :param model_name:
110
    :param max_tokens:
111
    :param temperature:
112
    :param file_description:
113
    :param file_name:
114
    :param knowledge_base_name:
115
    :param allow_empty_kb:
116
    :param vs_type:
117
    :param embed_model:
118
    :return:
119
    """
120

121
    def output():
122
        kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
123
        if not kb.exists() and not allow_empty_kb:
124
            yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
125
        else:
126
            # 重新创建知识库
127
            kb_summary = KBSummaryService(knowledge_base_name, embed_model)
128
            kb_summary.create_kb_summary()
129

130
            llm = get_ChatOpenAI(
131
                model_name=model_name,
132
                temperature=temperature,
133
                max_tokens=max_tokens,
134
            )
135
            reduce_llm = get_ChatOpenAI(
136
                model_name=model_name,
137
                temperature=temperature,
138
                max_tokens=max_tokens,
139
            )
140
            # 文本摘要适配器
141
            summary = SummaryAdapter.form_summary(llm=llm,
142
                                                  reduce_llm=reduce_llm,
143
                                                  overlap_size=OVERLAP_SIZE)
144

145
            doc_infos = kb.list_docs(file_name=file_name)
146
            docs = summary.summarize(file_description=file_description,
147
                                     docs=doc_infos)
148

149
            status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
150
            if status_kb_summary:
151
                logger.info(f" {file_name} 总结完成")
152
                yield json.dumps({
153
                    "code": 200,
154
                    "msg": f"{file_name} 总结完成",
155
                    "doc": file_name,
156
                }, ensure_ascii=False)
157
            else:
158

159
                msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
160
                logger.error(msg)
161
                yield json.dumps({
162
                    "code": 500,
163
                    "msg": msg,
164
                })
165

166
    return EventSourceResponse(output())
167

168

169
def summary_doc_ids_to_vector_store(
170
        knowledge_base_name: str = Body(..., examples=["samples"]),
171
        doc_ids: List = Body([], examples=[["uuid"]]),
172
        vs_type: str = Body(DEFAULT_VS_TYPE),
173
        embed_model: str = Body(EMBEDDING_MODEL),
174
        file_description: str = Body(''),
175
        model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
176
        temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
177
        max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
178
) -> BaseResponse:
179
    """
180
    单个知识库根据doc_ids摘要
181
    :param knowledge_base_name:
182
    :param doc_ids:
183
    :param model_name:
184
    :param max_tokens:
185
    :param temperature:
186
    :param file_description:
187
    :param vs_type:
188
    :param embed_model:
189
    :return:
190
    """
191
    kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
192
    if not kb.exists():
193
        return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
194
    else:
195
        llm = get_ChatOpenAI(
196
            model_name=model_name,
197
            temperature=temperature,
198
            max_tokens=max_tokens,
199
        )
200
        reduce_llm = get_ChatOpenAI(
201
            model_name=model_name,
202
            temperature=temperature,
203
            max_tokens=max_tokens,
204
        )
205
        # 文本摘要适配器
206
        summary = SummaryAdapter.form_summary(llm=llm,
207
                                              reduce_llm=reduce_llm,
208
                                              overlap_size=OVERLAP_SIZE)
209

210
        doc_infos = kb.get_doc_by_ids(ids=doc_ids)
211
        # doc_infos转换成DocumentWithVSId包装的对象
212
        doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)]
213

214
        docs = summary.summarize(file_description=file_description,
215
                                 docs=doc_info_with_ids)
216

217
        # 将docs转换成dict
218
        resp_summarize = [{**doc.dict()} for doc in docs]
219

220
        return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize})
221

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.