llama-index

Форк
0
390 строк · 15.2 Кб
1
import logging
2
from dataclasses import dataclass
3
from typing import Any, List, Optional, cast
4

5
import llama_index.legacy
6
from llama_index.legacy.bridge.pydantic import BaseModel
7
from llama_index.legacy.callbacks.base import CallbackManager
8
from llama_index.legacy.core.embeddings.base import BaseEmbedding
9
from llama_index.legacy.indices.prompt_helper import PromptHelper
10
from llama_index.legacy.llm_predictor import LLMPredictor
11
from llama_index.legacy.llm_predictor.base import BaseLLMPredictor, LLMMetadata
12
from llama_index.legacy.llms.llm import LLM
13
from llama_index.legacy.llms.utils import LLMType, resolve_llm
14
from llama_index.legacy.logger import LlamaLogger
15
from llama_index.legacy.node_parser.interface import NodeParser, TextSplitter
16
from llama_index.legacy.node_parser.text.sentence import (
17
    DEFAULT_CHUNK_SIZE,
18
    SENTENCE_CHUNK_OVERLAP,
19
    SentenceSplitter,
20
)
21
from llama_index.legacy.prompts.base import BasePromptTemplate
22
from llama_index.legacy.schema import TransformComponent
23
from llama_index.legacy.types import PydanticProgramMode
24

25
logger = logging.getLogger(__name__)
26

27

28
def _get_default_node_parser(
29
    chunk_size: int = DEFAULT_CHUNK_SIZE,
30
    chunk_overlap: int = SENTENCE_CHUNK_OVERLAP,
31
    callback_manager: Optional[CallbackManager] = None,
32
) -> NodeParser:
33
    """Get default node parser."""
34
    return SentenceSplitter(
35
        chunk_size=chunk_size,
36
        chunk_overlap=chunk_overlap,
37
        callback_manager=callback_manager or CallbackManager(),
38
    )
39

40

41
def _get_default_prompt_helper(
42
    llm_metadata: LLMMetadata,
43
    context_window: Optional[int] = None,
44
    num_output: Optional[int] = None,
45
) -> PromptHelper:
46
    """Get default prompt helper."""
47
    if context_window is not None:
48
        llm_metadata.context_window = context_window
49
    if num_output is not None:
50
        llm_metadata.num_output = num_output
51
    return PromptHelper.from_llm_metadata(llm_metadata=llm_metadata)
52

53

54
class ServiceContextData(BaseModel):
55
    llm: dict
56
    llm_predictor: dict
57
    prompt_helper: dict
58
    embed_model: dict
59
    transformations: List[dict]
60

61

62
@dataclass
63
class ServiceContext:
64
    """Service Context container.
65

66
    The service context container is a utility container for LlamaIndex
67
    index and query classes. It contains the following:
68
    - llm_predictor: BaseLLMPredictor
69
    - prompt_helper: PromptHelper
70
    - embed_model: BaseEmbedding
71
    - node_parser: NodeParser
72
    - llama_logger: LlamaLogger (deprecated)
73
    - callback_manager: CallbackManager
74

75
    """
76

77
    llm_predictor: BaseLLMPredictor
78
    prompt_helper: PromptHelper
79
    embed_model: BaseEmbedding
80
    transformations: List[TransformComponent]
81
    llama_logger: LlamaLogger
82
    callback_manager: CallbackManager
83

84
    @classmethod
85
    def from_defaults(
86
        cls,
87
        llm_predictor: Optional[BaseLLMPredictor] = None,
88
        llm: Optional[LLMType] = "default",
89
        prompt_helper: Optional[PromptHelper] = None,
90
        embed_model: Optional[Any] = "default",
91
        node_parser: Optional[NodeParser] = None,
92
        text_splitter: Optional[TextSplitter] = None,
93
        transformations: Optional[List[TransformComponent]] = None,
94
        llama_logger: Optional[LlamaLogger] = None,
95
        callback_manager: Optional[CallbackManager] = None,
96
        system_prompt: Optional[str] = None,
97
        query_wrapper_prompt: Optional[BasePromptTemplate] = None,
98
        # pydantic program mode (used if output_cls is specified)
99
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
100
        # node parser kwargs
101
        chunk_size: Optional[int] = None,
102
        chunk_overlap: Optional[int] = None,
103
        # prompt helper kwargs
104
        context_window: Optional[int] = None,
105
        num_output: Optional[int] = None,
106
        # deprecated kwargs
107
        chunk_size_limit: Optional[int] = None,
108
    ) -> "ServiceContext":
109
        """Create a ServiceContext from defaults.
110
        If an argument is specified, then use the argument value provided for that
111
        parameter. If an argument is not specified, then use the default value.
112

113
        You can change the base defaults by setting llama_index.legacy.global_service_context
114
        to a ServiceContext object with your desired settings.
115

116
        Args:
117
            llm_predictor (Optional[BaseLLMPredictor]): LLMPredictor
118
            prompt_helper (Optional[PromptHelper]): PromptHelper
119
            embed_model (Optional[BaseEmbedding]): BaseEmbedding
120
                or "local" (use local model)
121
            node_parser (Optional[NodeParser]): NodeParser
122
            llama_logger (Optional[LlamaLogger]): LlamaLogger (deprecated)
123
            chunk_size (Optional[int]): chunk_size
124
            callback_manager (Optional[CallbackManager]): CallbackManager
125
            system_prompt (Optional[str]): System-wide prompt to be prepended
126
                to all input prompts, used to guide system "decision making"
127
            query_wrapper_prompt (Optional[BasePromptTemplate]): A format to wrap
128
                passed-in input queries.
129

130
        Deprecated Args:
131
            chunk_size_limit (Optional[int]): renamed to chunk_size
132

133
        """
134
        from llama_index.legacy.embeddings.utils import EmbedType, resolve_embed_model
135

136
        embed_model = cast(EmbedType, embed_model)
137

138
        if chunk_size_limit is not None and chunk_size is None:
139
            logger.warning(
140
                "chunk_size_limit is deprecated, please specify chunk_size instead"
141
            )
142
            chunk_size = chunk_size_limit
143

144
        if llama_index.legacy.global_service_context is not None:
145
            return cls.from_service_context(
146
                llama_index.legacy.global_service_context,
147
                llm=llm,
148
                llm_predictor=llm_predictor,
149
                prompt_helper=prompt_helper,
150
                embed_model=embed_model,
151
                node_parser=node_parser,
152
                text_splitter=text_splitter,
153
                llama_logger=llama_logger,
154
                callback_manager=callback_manager,
155
                context_window=context_window,
156
                chunk_size=chunk_size,
157
                chunk_size_limit=chunk_size_limit,
158
                chunk_overlap=chunk_overlap,
159
                num_output=num_output,
160
                system_prompt=system_prompt,
161
                query_wrapper_prompt=query_wrapper_prompt,
162
                transformations=transformations,
163
            )
164

165
        callback_manager = callback_manager or CallbackManager([])
166
        if llm != "default":
167
            if llm_predictor is not None:
168
                raise ValueError("Cannot specify both llm and llm_predictor")
169
            llm = resolve_llm(llm)
170
            llm.system_prompt = llm.system_prompt or system_prompt
171
            llm.query_wrapper_prompt = llm.query_wrapper_prompt or query_wrapper_prompt
172
            llm.pydantic_program_mode = (
173
                llm.pydantic_program_mode or pydantic_program_mode
174
            )
175

176
        if llm_predictor is not None:
177
            print("LLMPredictor is deprecated, please use LLM instead.")
178
        llm_predictor = llm_predictor or LLMPredictor(
179
            llm=llm, pydantic_program_mode=pydantic_program_mode
180
        )
181
        if isinstance(llm_predictor, LLMPredictor):
182
            llm_predictor.llm.callback_manager = callback_manager
183
            if system_prompt:
184
                llm_predictor.system_prompt = system_prompt
185
            if query_wrapper_prompt:
186
                llm_predictor.query_wrapper_prompt = query_wrapper_prompt
187

188
        # NOTE: the embed_model isn't used in all indices
189
        # NOTE: embed model should be a transformation, but the way the service
190
        # context works, we can't put in there yet.
191
        embed_model = resolve_embed_model(embed_model)
192
        embed_model.callback_manager = callback_manager
193

194
        prompt_helper = prompt_helper or _get_default_prompt_helper(
195
            llm_metadata=llm_predictor.metadata,
196
            context_window=context_window,
197
            num_output=num_output,
198
        )
199

200
        if text_splitter is not None and node_parser is not None:
201
            raise ValueError("Cannot specify both text_splitter and node_parser")
202

203
        node_parser = (
204
            text_splitter  # text splitter extends node parser
205
            or node_parser
206
            or _get_default_node_parser(
207
                chunk_size=chunk_size or DEFAULT_CHUNK_SIZE,
208
                chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP,
209
                callback_manager=callback_manager,
210
            )
211
        )
212

213
        transformations = transformations or [node_parser]
214

215
        llama_logger = llama_logger or LlamaLogger()
216

217
        return cls(
218
            llm_predictor=llm_predictor,
219
            embed_model=embed_model,
220
            prompt_helper=prompt_helper,
221
            transformations=transformations,
222
            llama_logger=llama_logger,  # deprecated
223
            callback_manager=callback_manager,
224
        )
225

226
    @classmethod
227
    def from_service_context(
228
        cls,
229
        service_context: "ServiceContext",
230
        llm_predictor: Optional[BaseLLMPredictor] = None,
231
        llm: Optional[LLMType] = "default",
232
        prompt_helper: Optional[PromptHelper] = None,
233
        embed_model: Optional[Any] = "default",
234
        node_parser: Optional[NodeParser] = None,
235
        text_splitter: Optional[TextSplitter] = None,
236
        transformations: Optional[List[TransformComponent]] = None,
237
        llama_logger: Optional[LlamaLogger] = None,
238
        callback_manager: Optional[CallbackManager] = None,
239
        system_prompt: Optional[str] = None,
240
        query_wrapper_prompt: Optional[BasePromptTemplate] = None,
241
        # node parser kwargs
242
        chunk_size: Optional[int] = None,
243
        chunk_overlap: Optional[int] = None,
244
        # prompt helper kwargs
245
        context_window: Optional[int] = None,
246
        num_output: Optional[int] = None,
247
        # deprecated kwargs
248
        chunk_size_limit: Optional[int] = None,
249
    ) -> "ServiceContext":
250
        """Instantiate a new service context using a previous as the defaults."""
251
        from llama_index.legacy.embeddings.utils import EmbedType, resolve_embed_model
252

253
        embed_model = cast(EmbedType, embed_model)
254

255
        if chunk_size_limit is not None and chunk_size is None:
256
            logger.warning(
257
                "chunk_size_limit is deprecated, please specify chunk_size",
258
                DeprecationWarning,
259
            )
260
            chunk_size = chunk_size_limit
261

262
        callback_manager = callback_manager or service_context.callback_manager
263
        if llm != "default":
264
            if llm_predictor is not None:
265
                raise ValueError("Cannot specify both llm and llm_predictor")
266
            llm = resolve_llm(llm)
267
            llm_predictor = LLMPredictor(llm=llm)
268

269
        llm_predictor = llm_predictor or service_context.llm_predictor
270
        if isinstance(llm_predictor, LLMPredictor):
271
            llm_predictor.llm.callback_manager = callback_manager
272
            if system_prompt:
273
                llm_predictor.system_prompt = system_prompt
274
            if query_wrapper_prompt:
275
                llm_predictor.query_wrapper_prompt = query_wrapper_prompt
276

277
        # NOTE: the embed_model isn't used in all indices
278
        # default to using the embed model passed from the service context
279
        if embed_model == "default":
280
            embed_model = service_context.embed_model
281
        embed_model = resolve_embed_model(embed_model)
282
        embed_model.callback_manager = callback_manager
283

284
        prompt_helper = prompt_helper or service_context.prompt_helper
285
        if context_window is not None or num_output is not None:
286
            prompt_helper = _get_default_prompt_helper(
287
                llm_metadata=llm_predictor.metadata,
288
                context_window=context_window,
289
                num_output=num_output,
290
            )
291

292
        transformations = transformations or []
293
        node_parser_found = False
294
        for transform in service_context.transformations:
295
            if isinstance(transform, NodeParser):
296
                node_parser_found = True
297
                node_parser = transform
298
                break
299

300
        if text_splitter is not None and node_parser is not None:
301
            raise ValueError("Cannot specify both text_splitter and node_parser")
302

303
        if not node_parser_found:
304
            node_parser = (
305
                text_splitter  # text splitter extends node parser
306
                or node_parser
307
                or _get_default_node_parser(
308
                    chunk_size=chunk_size or DEFAULT_CHUNK_SIZE,
309
                    chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP,
310
                    callback_manager=callback_manager,
311
                )
312
            )
313

314
        transformations = transformations or service_context.transformations
315

316
        llama_logger = llama_logger or service_context.llama_logger
317

318
        return cls(
319
            llm_predictor=llm_predictor,
320
            embed_model=embed_model,
321
            prompt_helper=prompt_helper,
322
            transformations=transformations,
323
            llama_logger=llama_logger,  # deprecated
324
            callback_manager=callback_manager,
325
        )
326

327
    @property
328
    def llm(self) -> LLM:
329
        return self.llm_predictor.llm
330

331
    @property
332
    def node_parser(self) -> NodeParser:
333
        """Get the node parser."""
334
        for transform in self.transformations:
335
            if isinstance(transform, NodeParser):
336
                return transform
337
        raise ValueError("No node parser found.")
338

339
    def to_dict(self) -> dict:
340
        """Convert service context to dict."""
341
        llm_dict = self.llm_predictor.llm.to_dict()
342
        llm_predictor_dict = self.llm_predictor.to_dict()
343

344
        embed_model_dict = self.embed_model.to_dict()
345

346
        prompt_helper_dict = self.prompt_helper.to_dict()
347

348
        tranform_list_dict = [x.to_dict() for x in self.transformations]
349

350
        return ServiceContextData(
351
            llm=llm_dict,
352
            llm_predictor=llm_predictor_dict,
353
            prompt_helper=prompt_helper_dict,
354
            embed_model=embed_model_dict,
355
            transformations=tranform_list_dict,
356
        ).dict()
357

358
    @classmethod
359
    def from_dict(cls, data: dict) -> "ServiceContext":
360
        from llama_index.legacy.embeddings.loading import load_embed_model
361
        from llama_index.legacy.extractors.loading import load_extractor
362
        from llama_index.legacy.llm_predictor.loading import load_predictor
363
        from llama_index.legacy.node_parser.loading import load_parser
364

365
        service_context_data = ServiceContextData.parse_obj(data)
366

367
        llm_predictor = load_predictor(service_context_data.llm_predictor)
368

369
        embed_model = load_embed_model(service_context_data.embed_model)
370

371
        prompt_helper = PromptHelper.from_dict(service_context_data.prompt_helper)
372

373
        transformations: List[TransformComponent] = []
374
        for transform in service_context_data.transformations:
375
            try:
376
                transformations.append(load_parser(transform))
377
            except ValueError:
378
                transformations.append(load_extractor(transform))
379

380
        return cls.from_defaults(
381
            llm_predictor=llm_predictor,
382
            prompt_helper=prompt_helper,
383
            embed_model=embed_model,
384
            transformations=transformations,
385
        )
386

387

388
def set_global_service_context(service_context: Optional[ServiceContext]) -> None:
389
    """Helper function to set the global service context."""
390
    llama_index.legacy.global_service_context = service_context
391

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.