Langchain-Chatchat
122 строки · 4.6 Кб
1import os
2import sys
3
4sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
5from typing import Any, List, Optional
6from sentence_transformers import CrossEncoder
7from typing import Optional, Sequence
8from langchain_core.documents import Document
9from langchain.callbacks.manager import Callbacks
10from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
11from llama_index.bridge.pydantic import Field, PrivateAttr
12
13
14class LangchainReranker(BaseDocumentCompressor):
15"""Document compressor that uses `Cohere Rerank API`."""
16model_name_or_path: str = Field()
17_model: Any = PrivateAttr()
18top_n: int = Field()
19device: str = Field()
20max_length: int = Field()
21batch_size: int = Field()
22# show_progress_bar: bool = None
23num_workers: int = Field()
24
25# activation_fct = None
26# apply_softmax = False
27
28def __init__(self,
29model_name_or_path: str,
30top_n: int = 3,
31device: str = "cuda",
32max_length: int = 1024,
33batch_size: int = 32,
34# show_progress_bar: bool = None,
35num_workers: int = 0,
36# activation_fct = None,
37# apply_softmax = False,
38):
39# self.top_n=top_n
40# self.model_name_or_path=model_name_or_path
41# self.device=device
42# self.max_length=max_length
43# self.batch_size=batch_size
44# self.show_progress_bar=show_progress_bar
45# self.num_workers=num_workers
46# self.activation_fct=activation_fct
47# self.apply_softmax=apply_softmax
48
49self._model = CrossEncoder(model_name=model_name_or_path, max_length=1024, device=device)
50super().__init__(
51top_n=top_n,
52model_name_or_path=model_name_or_path,
53device=device,
54max_length=max_length,
55batch_size=batch_size,
56# show_progress_bar=show_progress_bar,
57num_workers=num_workers,
58# activation_fct=activation_fct,
59# apply_softmax=apply_softmax
60)
61
62def compress_documents(
63self,
64documents: Sequence[Document],
65query: str,
66callbacks: Optional[Callbacks] = None,
67) -> Sequence[Document]:
68"""
69Compress documents using Cohere's rerank API.
70
71Args:
72documents: A sequence of documents to compress.
73query: The query to use for compressing the documents.
74callbacks: Callbacks to run during the compression process.
75
76Returns:
77A sequence of compressed documents.
78"""
79if len(documents) == 0: # to avoid empty api call
80return []
81doc_list = list(documents)
82_docs = [d.page_content for d in doc_list]
83sentence_pairs = [[query, _doc] for _doc in _docs]
84results = self._model.predict(sentences=sentence_pairs,
85batch_size=self.batch_size,
86# show_progress_bar=self.show_progress_bar,
87num_workers=self.num_workers,
88# activation_fct=self.activation_fct,
89# apply_softmax=self.apply_softmax,
90convert_to_tensor=True
91)
92top_k = self.top_n if self.top_n < len(results) else len(results)
93
94values, indices = results.topk(top_k)
95final_results = []
96for value, index in zip(values, indices):
97doc = doc_list[index]
98doc.metadata["relevance_score"] = value
99final_results.append(doc)
100return final_results
101
102
103if __name__ == "__main__":
104from configs import (LLM_MODELS,
105VECTOR_SEARCH_TOP_K,
106SCORE_THRESHOLD,
107TEMPERATURE,
108USE_RERANKER,
109RERANKER_MODEL,
110RERANKER_MAX_LENGTH,
111MODEL_PATH)
112from server.utils import embedding_device
113
114if USE_RERANKER:
115reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
116print("-----------------model path------------------")
117print(reranker_model_path)
118reranker_model = LangchainReranker(top_n=3,
119device=embedding_device(),
120max_length=RERANKER_MAX_LENGTH,
121model_name_or_path=reranker_model_path
122)
123