Langchain-Chatchat

Форк
0
470 строк · 15.9 Кб
1
import operator
2
from abc import ABC, abstractmethod
3

4
import os
5
from pathlib import Path
6
import numpy as np
7
from langchain.embeddings.base import Embeddings
8
from langchain.docstore.document import Document
9

10
from server.db.repository.knowledge_base_repository import (
11
    add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
12
    load_kb_from_db, get_kb_detail,
13
)
14
from server.db.repository.knowledge_file_repository import (
15
    add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
16
    count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
17
    list_docs_from_db,
18
)
19

20
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
21
                     EMBEDDING_MODEL, KB_INFO)
22
from server.knowledge_base.utils import (
23
    get_kb_path, get_doc_path, KnowledgeFile,
24
    list_kbs_from_folder, list_files_from_folder,
25
)
26

27
from typing import List, Union, Dict, Optional, Tuple
28

29
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
30
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
31

32

33
def normalize(embeddings: List[List[float]]) -> np.ndarray:
34
    '''
35
    sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
36
    '''
37
    norm = np.linalg.norm(embeddings, axis=1)
38
    norm = np.reshape(norm, (norm.shape[0], 1))
39
    norm = np.tile(norm, (1, len(embeddings[0])))
40
    return np.divide(embeddings, norm)
41

42

43
class SupportedVSType:
44
    FAISS = 'faiss'
45
    MILVUS = 'milvus'
46
    DEFAULT = 'default'
47
    ZILLIZ = 'zilliz'
48
    PG = 'pg'
49
    ES = 'es'
50
    CHROMADB = 'chromadb'
51

52

53
class KBService(ABC):
54

55
    def __init__(self,
56
                 knowledge_base_name: str,
57
                 embed_model: str = EMBEDDING_MODEL,
58
                 ):
59
        self.kb_name = knowledge_base_name
60
        self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
61
        self.embed_model = embed_model
62
        self.kb_path = get_kb_path(self.kb_name)
63
        self.doc_path = get_doc_path(self.kb_name)
64
        self.do_init()
65

66
    def __repr__(self) -> str:
67
        return f"{self.kb_name} @ {self.embed_model}"
68

69
    def save_vector_store(self):
70
        '''
71
        保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
72
        '''
73
        pass
74

75
    def create_kb(self):
76
        """
77
        创建知识库
78
        """
79
        if not os.path.exists(self.doc_path):
80
            os.makedirs(self.doc_path)
81
        self.do_create_kb()
82
        status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
83
        return status
84

85
    def clear_vs(self):
86
        """
87
        删除向量库中所有内容
88
        """
89
        self.do_clear_vs()
90
        status = delete_files_from_db(self.kb_name)
91
        return status
92

93
    def drop_kb(self):
94
        """
95
        删除知识库
96
        """
97
        self.do_drop_kb()
98
        status = delete_kb_from_db(self.kb_name)
99
        return status
100

101
    def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
102
        '''
103
        将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
104
        '''
105
        return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
106

107
    def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
108
        """
109
        向知识库添加文件
110
        如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
111
        """
112
        if docs:
113
            custom_docs = True
114
            for doc in docs:
115
                doc.metadata.setdefault("source", kb_file.filename)
116
        else:
117
            docs = kb_file.file2text()
118
            custom_docs = False
119

120
        if docs:
121
            # 将 metadata["source"] 改为相对路径
122
            for doc in docs:
123
                try:
124
                    source = doc.metadata.get("source", "")
125
                    if os.path.isabs(source):
126
                        rel_path = Path(source).relative_to(self.doc_path)
127
                        doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
128
                except Exception as e:
129
                    print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
130
            self.delete_doc(kb_file)
131
            doc_infos = self.do_add_doc(docs, **kwargs)
132
            status = add_file_to_db(kb_file,
133
                                    custom_docs=custom_docs,
134
                                    docs_count=len(docs),
135
                                    doc_infos=doc_infos)
136
        else:
137
            status = False
138
        return status
139

140
    def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
141
        """
142
        从知识库删除文件
143
        """
144
        self.do_delete_doc(kb_file, **kwargs)
145
        status = delete_file_from_db(kb_file)
146
        if delete_content and os.path.exists(kb_file.filepath):
147
            os.remove(kb_file.filepath)
148
        return status
149

150
    def update_info(self, kb_info: str):
151
        """
152
        更新知识库介绍
153
        """
154
        self.kb_info = kb_info
155
        status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
156
        return status
157

158
    def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
159
        """
160
        使用content中的文件更新向量库
161
        如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
162
        """
163
        if os.path.exists(kb_file.filepath):
164
            self.delete_doc(kb_file, **kwargs)
165
            return self.add_doc(kb_file, docs=docs, **kwargs)
166

167
    def exist_doc(self, file_name: str):
168
        return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
169
                                               filename=file_name))
170

171
    def list_files(self):
172
        return list_files_from_db(self.kb_name)
173

174
    def count_files(self):
175
        return count_files_from_db(self.kb_name)
176

177
    def search_docs(self,
178
                    query: str,
179
                    top_k: int = VECTOR_SEARCH_TOP_K,
180
                    score_threshold: float = SCORE_THRESHOLD,
181
                    ) ->List[Document]:
182
        docs = self.do_search(query, top_k, score_threshold)
183
        return docs
184

185
    def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
186
        return []
187

188
    def del_doc_by_ids(self, ids: List[str]) -> bool:
189
        raise NotImplementedError
190

191
    def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
192
        '''
193
        传入参数为: {doc_id: Document, ...}
194
        如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
195
        '''
196
        self.del_doc_by_ids(list(docs.keys()))
197
        docs = []
198
        ids = []
199
        for k, v in docs.items():
200
            if not v or not v.page_content.strip():
201
                continue
202
            ids.append(k)
203
            docs.append(v)
204
        self.do_add_doc(docs=docs, ids=ids)
205
        return True
206

207
    def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
208
        '''
209
        通过file_name或metadata检索Document
210
        '''
211
        doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
212
        docs = []
213
        for x in doc_infos:
214
            doc_info = self.get_doc_by_ids([x["id"]])[0]
215
            if doc_info is not None:
216
                # 处理非空的情况
217
                doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
218
                docs.append(doc_with_id)
219
            else:
220
                # 处理空的情况
221
                # 可以选择跳过当前循环迭代或执行其他操作
222
                pass
223
        return docs
224

225
    def get_relative_source_path(self,filepath: str):
226
      '''
227
      将文件路径转化为相对路径,保证查询时一致
228
      '''
229
      relative_path = filepath
230
      if os.path.isabs(relative_path):
231
        try:
232
          relative_path = Path(filepath).relative_to(self.doc_path)
233
        except Exception as e:
234
          print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
235

236
      relative_path = str(relative_path.as_posix().strip("/"))
237
      return relative_path
238

239
    @abstractmethod
240
    def do_create_kb(self):
241
        """
242
        创建知识库子类实自己逻辑
243
        """
244
        pass
245

246
    @staticmethod
247
    def list_kbs_type():
248
        return list(kbs_config.keys())
249

250
    @classmethod
251
    def list_kbs(cls):
252
        return list_kbs_from_db()
253

254
    def exists(self, kb_name: str = None):
255
        kb_name = kb_name or self.kb_name
256
        return kb_exists(kb_name)
257

258
    @abstractmethod
259
    def vs_type(self) -> str:
260
        pass
261

262
    @abstractmethod
263
    def do_init(self):
264
        pass
265

266
    @abstractmethod
267
    def do_drop_kb(self):
268
        """
269
        删除知识库子类实自己逻辑
270
        """
271
        pass
272

273
    @abstractmethod
274
    def do_search(self,
275
                  query: str,
276
                  top_k: int,
277
                  score_threshold: float,
278
                  ) -> List[Tuple[Document, float]]:
279
        """
280
        搜索知识库子类实自己逻辑
281
        """
282
        pass
283

284
    @abstractmethod
285
    def do_add_doc(self,
286
                   docs: List[Document],
287
                   **kwargs,
288
                   ) -> List[Dict]:
289
        """
290
        向知识库添加文档子类实自己逻辑
291
        """
292
        pass
293

294
    @abstractmethod
295
    def do_delete_doc(self,
296
                      kb_file: KnowledgeFile):
297
        """
298
        从知识库删除文档子类实自己逻辑
299
        """
300
        pass
301

302
    @abstractmethod
303
    def do_clear_vs(self):
304
        """
305
        从知识库删除全部向量子类实自己逻辑
306
        """
307
        pass
308

309

310
class KBServiceFactory:
311

312
    @staticmethod
313
    def get_service(kb_name: str,
314
                    vector_store_type: Union[str, SupportedVSType],
315
                    embed_model: str = EMBEDDING_MODEL,
316
                    ) -> KBService:
317
        if isinstance(vector_store_type, str):
318
            vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
319
        if SupportedVSType.FAISS == vector_store_type:
320
            from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
321
            return FaissKBService(kb_name, embed_model=embed_model)
322
        elif SupportedVSType.PG == vector_store_type:
323
            from server.knowledge_base.kb_service.pg_kb_service import PGKBService
324
            return PGKBService(kb_name, embed_model=embed_model)
325
        elif SupportedVSType.MILVUS == vector_store_type:
326
            from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
327
            return MilvusKBService(kb_name,embed_model=embed_model)
328
        elif SupportedVSType.ZILLIZ == vector_store_type:
329
            from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
330
            return ZillizKBService(kb_name, embed_model=embed_model)
331
        elif SupportedVSType.DEFAULT == vector_store_type:
332
            from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
333
            return MilvusKBService(kb_name,
334
                                   embed_model=embed_model)  # other milvus parameters are set in model_config.kbs_config
335
        elif SupportedVSType.ES == vector_store_type:
336
            from server.knowledge_base.kb_service.es_kb_service import ESKBService
337
            return ESKBService(kb_name, embed_model=embed_model)
338
        elif SupportedVSType.CHROMADB == vector_store_type:
339
            from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
340
            return ChromaKBService(kb_name, embed_model=embed_model)
341
        elif SupportedVSType.DEFAULT == vector_store_type:  # kb_exists of default kbservice is False, to make validation easier.
342
            from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
343
            return DefaultKBService(kb_name)
344

345
    @staticmethod
346
    def get_service_by_name(kb_name: str) -> KBService:
347
        _, vs_type, embed_model = load_kb_from_db(kb_name)
348
        if _ is None:  # kb not in db, just return None
349
            return None
350
        return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
351

352
    @staticmethod
353
    def get_default():
354
        return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
355

356

357
def get_kb_details() -> List[Dict]:
358
    kbs_in_folder = list_kbs_from_folder()
359
    kbs_in_db = KBService.list_kbs()
360
    result = {}
361

362
    for kb in kbs_in_folder:
363
        result[kb] = {
364
            "kb_name": kb,
365
            "vs_type": "",
366
            "kb_info": "",
367
            "embed_model": "",
368
            "file_count": 0,
369
            "create_time": None,
370
            "in_folder": True,
371
            "in_db": False,
372
        }
373

374
    for kb in kbs_in_db:
375
        kb_detail = get_kb_detail(kb)
376
        if kb_detail:
377
            kb_detail["in_db"] = True
378
            if kb in result:
379
                result[kb].update(kb_detail)
380
            else:
381
                kb_detail["in_folder"] = False
382
                result[kb] = kb_detail
383

384
    data = []
385
    for i, v in enumerate(result.values()):
386
        v['No'] = i + 1
387
        data.append(v)
388

389
    return data
390

391

392
def get_kb_file_details(kb_name: str) -> List[Dict]:
393
    kb = KBServiceFactory.get_service_by_name(kb_name)
394
    if kb is None:
395
        return []
396

397
    files_in_folder = list_files_from_folder(kb_name)
398
    files_in_db = kb.list_files()
399
    result = {}
400

401
    for doc in files_in_folder:
402
        result[doc] = {
403
            "kb_name": kb_name,
404
            "file_name": doc,
405
            "file_ext": os.path.splitext(doc)[-1],
406
            "file_version": 0,
407
            "document_loader": "",
408
            "docs_count": 0,
409
            "text_splitter": "",
410
            "create_time": None,
411
            "in_folder": True,
412
            "in_db": False,
413
        }
414
    lower_names = {x.lower(): x for x in result}
415
    for doc in files_in_db:
416
        doc_detail = get_file_detail(kb_name, doc)
417
        if doc_detail:
418
            doc_detail["in_db"] = True
419
            if doc.lower() in lower_names:
420
                result[lower_names[doc.lower()]].update(doc_detail)
421
            else:
422
                doc_detail["in_folder"] = False
423
                result[doc] = doc_detail
424

425
    data = []
426
    for i, v in enumerate(result.values()):
427
        v['No'] = i + 1
428
        data.append(v)
429

430
    return data
431

432

433
class EmbeddingsFunAdapter(Embeddings):
434
    def __init__(self, embed_model: str = EMBEDDING_MODEL):
435
        self.embed_model = embed_model
436

437
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
438
        embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
439
        return normalize(embeddings).tolist()
440

441
    def embed_query(self, text: str) -> List[float]:
442
        embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
443
        query_embed = embeddings[0]
444
        query_embed_2d = np.reshape(query_embed, (1, -1))  # 将一维数组转换为二维数组
445
        normalized_query_embed = normalize(query_embed_2d)
446
        return normalized_query_embed[0].tolist()  # 将结果转换为一维数组并返回
447

448
    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
449
        embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
450
        return normalize(embeddings).tolist()
451

452
    async def aembed_query(self, text: str) -> List[float]:
453
        embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
454
        query_embed = embeddings[0]
455
        query_embed_2d = np.reshape(query_embed, (1, -1))  # 将一维数组转换为二维数组
456
        normalized_query_embed = normalize(query_embed_2d)
457
        return normalized_query_embed[0].tolist()  # 将结果转换为一维数组并返回
458

459

460
def score_threshold_process(score_threshold, k, docs):
461
    if score_threshold is not None:
462
        cmp = (
463
            operator.le
464
        )
465
        docs = [
466
            (doc, similarity)
467
            for doc, similarity in docs
468
            if cmp(similarity, score_threshold)
469
        ]
470
    return docs[:k]
471

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.