Langchain-Chatchat

Форк
0
417 строк · 16.4 Кб
1
import os
2
from configs import (
3
    KB_ROOT_PATH,
4
    CHUNK_SIZE,
5
    OVERLAP_SIZE,
6
    ZH_TITLE_ENHANCE,
7
    logger,
8
    log_verbose,
9
    text_splitter_dict,
10
    LLM_MODELS,
11
    TEXT_SPLITTER_NAME,
12
)
13
import importlib
14
from text_splitter import zh_title_enhance as func_zh_title_enhance
15
import langchain.document_loaders
16
from langchain.docstore.document import Document
17
from langchain.text_splitter import TextSplitter
18
from pathlib import Path
19
from server.utils import run_in_thread_pool, get_model_worker_config
20
import json
21
from typing import List, Union,Dict, Tuple, Generator
22
import chardet
23

24

25
def validate_kb_name(knowledge_base_id: str) -> bool:
26
    # 检查是否包含预期外的字符或路径攻击关键字
27
    if "../" in knowledge_base_id:
28
        return False
29
    return True
30

31

32
def get_kb_path(knowledge_base_name: str):
33
    return os.path.join(KB_ROOT_PATH, knowledge_base_name)
34

35

36
def get_doc_path(knowledge_base_name: str):
37
    return os.path.join(get_kb_path(knowledge_base_name), "content")
38

39

40
def get_vs_path(knowledge_base_name: str, vector_name: str):
41
    return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name)
42

43

44
def get_file_path(knowledge_base_name: str, doc_name: str):
45
    doc_path = Path(get_doc_path(knowledge_base_name))
46
    file_path = doc_path / doc_name
47
    if file_path.is_relative_to(doc_path):
48
        return str(file_path)
49

50

51
def list_kbs_from_folder():
52
    return [f for f in os.listdir(KB_ROOT_PATH)
53
            if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
54

55

56
def list_files_from_folder(kb_name: str):
57
    doc_path = get_doc_path(kb_name)
58
    result = []
59

60
    def is_skiped_path(path: str):
61
        tail = os.path.basename(path).lower()
62
        for x in ["temp", "tmp", ".", "~$"]:
63
            if tail.startswith(x):
64
                return True
65
        return False
66

67
    def process_entry(entry):
68
        if is_skiped_path(entry.path):
69
            return
70

71
        if entry.is_symlink():
72
            target_path = os.path.realpath(entry.path)
73
            with os.scandir(target_path) as target_it:
74
                for target_entry in target_it:
75
                    process_entry(target_entry)
76
        elif entry.is_file():
77
            file_path = (Path(os.path.relpath(entry.path, doc_path)).as_posix()) # 路径统一为 posix 格式
78
            result.append(file_path)
79
        elif entry.is_dir():
80
            with os.scandir(entry.path) as it:
81
                for sub_entry in it:
82
                    process_entry(sub_entry)
83

84
    with os.scandir(doc_path) as it:
85
        for entry in it:
86
            process_entry(entry)
87

88
    return result
89

90

91
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html', '.htm'],
92
               "MHTMLLoader": ['.mhtml'],
93
               "UnstructuredMarkdownLoader": ['.md'],
94
               "JSONLoader": [".json"],
95
               "JSONLinesLoader": [".jsonl"],
96
               "CSVLoader": [".csv"],
97
               # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
98
               "RapidOCRPDFLoader": [".pdf"],
99
               "RapidOCRDocLoader": ['.docx', '.doc'],
100
               "RapidOCRPPTLoader": ['.ppt', '.pptx', ],
101
               "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
102
               "UnstructuredFileLoader": ['.eml', '.msg', '.rst',
103
                                          '.rtf', '.txt', '.xml',
104
                                          '.epub', '.odt','.tsv'],
105
               "UnstructuredEmailLoader": ['.eml', '.msg'],
106
               "UnstructuredEPubLoader": ['.epub'],
107
               "UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
108
               "NotebookLoader": ['.ipynb'],
109
               "UnstructuredODTLoader": ['.odt'],
110
               "PythonLoader": ['.py'],
111
               "UnstructuredRSTLoader": ['.rst'],
112
               "UnstructuredRTFLoader": ['.rtf'],
113
               "SRTLoader": ['.srt'],
114
               "TomlLoader": ['.toml'],
115
               "UnstructuredTSVLoader": ['.tsv'],
116
               "UnstructuredWordDocumentLoader": ['.docx', '.doc'],
117
               "UnstructuredXMLLoader": ['.xml'],
118
               "UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
119
               "EverNoteLoader": ['.enex'],
120
               }
121
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
122

123

124
# patch json.dumps to disable ensure_ascii
125
def _new_json_dumps(obj, **kwargs):
126
    kwargs["ensure_ascii"] = False
127
    return _origin_json_dumps(obj, **kwargs)
128

129
if json.dumps is not _new_json_dumps:
130
    _origin_json_dumps = json.dumps
131
    json.dumps = _new_json_dumps
132

133

134
class JSONLinesLoader(langchain.document_loaders.JSONLoader):
135
    '''
136
    行式 Json 加载器,要求文件扩展名为 .jsonl
137
    '''
138
    def __init__(self, *args, **kwargs):
139
        super().__init__(*args, **kwargs)
140
        self._json_lines = True
141

142

143
langchain.document_loaders.JSONLinesLoader = JSONLinesLoader
144

145

146
def get_LoaderClass(file_extension):
147
    for LoaderClass, extensions in LOADER_DICT.items():
148
        if file_extension in extensions:
149
            return LoaderClass
150

151
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
152
    '''
153
    根据loader_name和文件路径或内容返回文档加载器。
154
    '''
155
    loader_kwargs = loader_kwargs or {}
156
    try:
157
        if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader",
158
                           "RapidOCRDocLoader", "RapidOCRPPTLoader"]:
159
            document_loaders_module = importlib.import_module('document_loaders')
160
        else:
161
            document_loaders_module = importlib.import_module('langchain.document_loaders')
162
        DocumentLoader = getattr(document_loaders_module, loader_name)
163
    except Exception as e:
164
        msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
165
        logger.error(f'{e.__class__.__name__}: {msg}',
166
                     exc_info=e if log_verbose else None)
167
        document_loaders_module = importlib.import_module('langchain.document_loaders')
168
        DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
169

170
    if loader_name == "UnstructuredFileLoader":
171
        loader_kwargs.setdefault("autodetect_encoding", True)
172
    elif loader_name == "CSVLoader":
173
        if not loader_kwargs.get("encoding"):
174
            # 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误
175
            with open(file_path, 'rb') as struct_file:
176
                encode_detect = chardet.detect(struct_file.read())
177
            if encode_detect is None:
178
                encode_detect = {"encoding": "utf-8"}
179
            loader_kwargs["encoding"] = encode_detect["encoding"]
180

181
    elif loader_name == "JSONLoader":
182
        loader_kwargs.setdefault("jq_schema", ".")
183
        loader_kwargs.setdefault("text_content", False)
184
    elif loader_name == "JSONLinesLoader":
185
        loader_kwargs.setdefault("jq_schema", ".")
186
        loader_kwargs.setdefault("text_content", False)
187

188
    loader = DocumentLoader(file_path, **loader_kwargs)
189
    return loader
190

191

192
def make_text_splitter(
193
        splitter_name: str = TEXT_SPLITTER_NAME,
194
        chunk_size: int = CHUNK_SIZE,
195
        chunk_overlap: int = OVERLAP_SIZE,
196
        llm_model: str = LLM_MODELS[0],
197
):
198
    """
199
    根据参数获取特定的分词器
200
    """
201
    splitter_name = splitter_name or "SpacyTextSplitter"
202
    try:
203
        if splitter_name == "MarkdownHeaderTextSplitter":  # MarkdownHeaderTextSplitter特殊判定
204
            headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
205
            text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
206
                headers_to_split_on=headers_to_split_on)
207
        else:
208

209
            try:  ## 优先使用用户自定义的text_splitter
210
                text_splitter_module = importlib.import_module('text_splitter')
211
                TextSplitter = getattr(text_splitter_module, splitter_name)
212
            except:  ## 否则使用langchain的text_splitter
213
                text_splitter_module = importlib.import_module('langchain.text_splitter')
214
                TextSplitter = getattr(text_splitter_module, splitter_name)
215

216
            if text_splitter_dict[splitter_name]["source"] == "tiktoken":  ## 从tiktoken加载
217
                try:
218
                    text_splitter = TextSplitter.from_tiktoken_encoder(
219
                        encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
220
                        pipeline="zh_core_web_sm",
221
                        chunk_size=chunk_size,
222
                        chunk_overlap=chunk_overlap
223
                    )
224
                except:
225
                    text_splitter = TextSplitter.from_tiktoken_encoder(
226
                        encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
227
                        chunk_size=chunk_size,
228
                        chunk_overlap=chunk_overlap
229
                    )
230
            elif text_splitter_dict[splitter_name]["source"] == "huggingface":  ## 从huggingface加载
231
                if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
232
                    config = get_model_worker_config(llm_model)
233
                    text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
234
                        config.get("model_path")
235

236
                if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
237
                    from transformers import GPT2TokenizerFast
238
                    from langchain.text_splitter import CharacterTextSplitter
239
                    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
240
                else:  ## 字符长度加载
241
                    from transformers import AutoTokenizer
242
                    tokenizer = AutoTokenizer.from_pretrained(
243
                        text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
244
                        trust_remote_code=True)
245
                text_splitter = TextSplitter.from_huggingface_tokenizer(
246
                    tokenizer=tokenizer,
247
                    chunk_size=chunk_size,
248
                    chunk_overlap=chunk_overlap
249
                )
250
            else:
251
                try:
252
                    text_splitter = TextSplitter(
253
                        pipeline="zh_core_web_sm",
254
                        chunk_size=chunk_size,
255
                        chunk_overlap=chunk_overlap
256
                    )
257
                except:
258
                    text_splitter = TextSplitter(
259
                        chunk_size=chunk_size,
260
                        chunk_overlap=chunk_overlap
261
                    )
262
    except Exception as e:
263
        print(e)
264
        text_splitter_module = importlib.import_module('langchain.text_splitter')
265
        TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
266
        text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
267
        
268
    # If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
269
    # text_splitter._tokenizer.max_length = 37016792
270
    # text_splitter._tokenizer.prefer_gpu()
271
    return text_splitter
272

273

274
class KnowledgeFile:
275
    def __init__(
276
            self,
277
            filename: str,
278
            knowledge_base_name: str,
279
            loader_kwargs: Dict = {},
280
    ):
281
        '''
282
        对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
283
        '''
284
        self.kb_name = knowledge_base_name
285
        self.filename = str(Path(filename).as_posix())
286
        self.ext = os.path.splitext(filename)[-1].lower()
287
        if self.ext not in SUPPORTED_EXTS:
288
            raise ValueError(f"暂未支持的文件格式 {self.filename}")
289
        self.loader_kwargs = loader_kwargs
290
        self.filepath = get_file_path(knowledge_base_name, filename)
291
        self.docs = None
292
        self.splited_docs = None
293
        self.document_loader_name = get_LoaderClass(self.ext)
294
        self.text_splitter_name = TEXT_SPLITTER_NAME
295

296
    def file2docs(self, refresh: bool = False):
297
        if self.docs is None or refresh:
298
            logger.info(f"{self.document_loader_name} used for {self.filepath}")
299
            loader = get_loader(loader_name=self.document_loader_name,
300
                                file_path=self.filepath,
301
                                loader_kwargs=self.loader_kwargs)
302
            self.docs = loader.load()
303
        return self.docs
304

305
    def docs2texts(
306
            self,
307
            docs: List[Document] = None,
308
            zh_title_enhance: bool = ZH_TITLE_ENHANCE,
309
            refresh: bool = False,
310
            chunk_size: int = CHUNK_SIZE,
311
            chunk_overlap: int = OVERLAP_SIZE,
312
            text_splitter: TextSplitter = None,
313
    ):
314
        docs = docs or self.file2docs(refresh=refresh)
315
        if not docs:
316
            return []
317
        if self.ext not in [".csv"]:
318
            if text_splitter is None:
319
                text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size,
320
                                                   chunk_overlap=chunk_overlap)
321
            if self.text_splitter_name == "MarkdownHeaderTextSplitter":
322
                docs = text_splitter.split_text(docs[0].page_content)
323
            else:
324
                docs = text_splitter.split_documents(docs)
325

326
        if not docs:
327
            return []
328

329
        print(f"文档切分示例:{docs[0]}")
330
        if zh_title_enhance:
331
            docs = func_zh_title_enhance(docs)
332
        self.splited_docs = docs
333
        return self.splited_docs
334

335
    def file2text(
336
            self,
337
            zh_title_enhance: bool = ZH_TITLE_ENHANCE,
338
            refresh: bool = False,
339
            chunk_size: int = CHUNK_SIZE,
340
            chunk_overlap: int = OVERLAP_SIZE,
341
            text_splitter: TextSplitter = None,
342
    ):
343
        if self.splited_docs is None or refresh:
344
            docs = self.file2docs()
345
            self.splited_docs = self.docs2texts(docs=docs,
346
                                                zh_title_enhance=zh_title_enhance,
347
                                                refresh=refresh,
348
                                                chunk_size=chunk_size,
349
                                                chunk_overlap=chunk_overlap,
350
                                                text_splitter=text_splitter)
351
        return self.splited_docs
352

353
    def file_exist(self):
354
        return os.path.isfile(self.filepath)
355

356
    def get_mtime(self):
357
        return os.path.getmtime(self.filepath)
358

359
    def get_size(self):
360
        return os.path.getsize(self.filepath)
361

362

363
def files2docs_in_thread(
364
        files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
365
        chunk_size: int = CHUNK_SIZE,
366
        chunk_overlap: int = OVERLAP_SIZE,
367
        zh_title_enhance: bool = ZH_TITLE_ENHANCE,
368
) -> Generator:
369
    '''
370
    利用多线程批量将磁盘文件转化成langchain Document.
371
    如果传入参数是Tuple,形式为(filename, kb_name)
372
    生成器返回值为 status, (kb_name, file_name, docs | error)
373
    '''
374

375
    def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
376
        try:
377
            return True, (file.kb_name, file.filename, file.file2text(**kwargs))
378
        except Exception as e:
379
            msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
380
            logger.error(f'{e.__class__.__name__}: {msg}',
381
                         exc_info=e if log_verbose else None)
382
            return False, (file.kb_name, file.filename, msg)
383

384
    kwargs_list = []
385
    for i, file in enumerate(files):
386
        kwargs = {}
387
        try:
388
            if isinstance(file, tuple) and len(file) >= 2:
389
                filename = file[0]
390
                kb_name = file[1]
391
                file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
392
            elif isinstance(file, dict):
393
                filename = file.pop("filename")
394
                kb_name = file.pop("kb_name")
395
                kwargs.update(file)
396
                file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
397
            kwargs["file"] = file
398
            kwargs["chunk_size"] = chunk_size
399
            kwargs["chunk_overlap"] = chunk_overlap
400
            kwargs["zh_title_enhance"] = zh_title_enhance
401
            kwargs_list.append(kwargs)
402
        except Exception as e:
403
            yield False, (kb_name, filename, str(e))
404

405
    for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
406
        yield result
407

408

409
if __name__ == "__main__":
410
    from pprint import pprint
411

412
    kb_file = KnowledgeFile(
413
        filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
414
        knowledge_base_name="samples")
415
    # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
416
    docs = kb_file.file2docs()
417
    # pprint(docs[-1])
418

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.