llama-index
390 строк · 15.2 Кб
1import logging2from dataclasses import dataclass3from typing import Any, List, Optional, cast4
5import llama_index.legacy6from llama_index.legacy.bridge.pydantic import BaseModel7from llama_index.legacy.callbacks.base import CallbackManager8from llama_index.legacy.core.embeddings.base import BaseEmbedding9from llama_index.legacy.indices.prompt_helper import PromptHelper10from llama_index.legacy.llm_predictor import LLMPredictor11from llama_index.legacy.llm_predictor.base import BaseLLMPredictor, LLMMetadata12from llama_index.legacy.llms.llm import LLM13from llama_index.legacy.llms.utils import LLMType, resolve_llm14from llama_index.legacy.logger import LlamaLogger15from llama_index.legacy.node_parser.interface import NodeParser, TextSplitter16from llama_index.legacy.node_parser.text.sentence import (17DEFAULT_CHUNK_SIZE,18SENTENCE_CHUNK_OVERLAP,19SentenceSplitter,20)
21from llama_index.legacy.prompts.base import BasePromptTemplate22from llama_index.legacy.schema import TransformComponent23from llama_index.legacy.types import PydanticProgramMode24
25logger = logging.getLogger(__name__)26
27
28def _get_default_node_parser(29chunk_size: int = DEFAULT_CHUNK_SIZE,30chunk_overlap: int = SENTENCE_CHUNK_OVERLAP,31callback_manager: Optional[CallbackManager] = None,32) -> NodeParser:33"""Get default node parser."""34return SentenceSplitter(35chunk_size=chunk_size,36chunk_overlap=chunk_overlap,37callback_manager=callback_manager or CallbackManager(),38)39
40
41def _get_default_prompt_helper(42llm_metadata: LLMMetadata,43context_window: Optional[int] = None,44num_output: Optional[int] = None,45) -> PromptHelper:46"""Get default prompt helper."""47if context_window is not None:48llm_metadata.context_window = context_window49if num_output is not None:50llm_metadata.num_output = num_output51return PromptHelper.from_llm_metadata(llm_metadata=llm_metadata)52
53
54class ServiceContextData(BaseModel):55llm: dict56llm_predictor: dict57prompt_helper: dict58embed_model: dict59transformations: List[dict]60
61
62@dataclass
63class ServiceContext:64"""Service Context container.65
66The service context container is a utility container for LlamaIndex
67index and query classes. It contains the following:
68- llm_predictor: BaseLLMPredictor
69- prompt_helper: PromptHelper
70- embed_model: BaseEmbedding
71- node_parser: NodeParser
72- llama_logger: LlamaLogger (deprecated)
73- callback_manager: CallbackManager
74
75"""
76
77llm_predictor: BaseLLMPredictor78prompt_helper: PromptHelper79embed_model: BaseEmbedding80transformations: List[TransformComponent]81llama_logger: LlamaLogger82callback_manager: CallbackManager83
84@classmethod85def from_defaults(86cls,87llm_predictor: Optional[BaseLLMPredictor] = None,88llm: Optional[LLMType] = "default",89prompt_helper: Optional[PromptHelper] = None,90embed_model: Optional[Any] = "default",91node_parser: Optional[NodeParser] = None,92text_splitter: Optional[TextSplitter] = None,93transformations: Optional[List[TransformComponent]] = None,94llama_logger: Optional[LlamaLogger] = None,95callback_manager: Optional[CallbackManager] = None,96system_prompt: Optional[str] = None,97query_wrapper_prompt: Optional[BasePromptTemplate] = None,98# pydantic program mode (used if output_cls is specified)99pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,100# node parser kwargs101chunk_size: Optional[int] = None,102chunk_overlap: Optional[int] = None,103# prompt helper kwargs104context_window: Optional[int] = None,105num_output: Optional[int] = None,106# deprecated kwargs107chunk_size_limit: Optional[int] = None,108) -> "ServiceContext":109"""Create a ServiceContext from defaults.110If an argument is specified, then use the argument value provided for that
111parameter. If an argument is not specified, then use the default value.
112
113You can change the base defaults by setting llama_index.legacy.global_service_context
114to a ServiceContext object with your desired settings.
115
116Args:
117llm_predictor (Optional[BaseLLMPredictor]): LLMPredictor
118prompt_helper (Optional[PromptHelper]): PromptHelper
119embed_model (Optional[BaseEmbedding]): BaseEmbedding
120or "local" (use local model)
121node_parser (Optional[NodeParser]): NodeParser
122llama_logger (Optional[LlamaLogger]): LlamaLogger (deprecated)
123chunk_size (Optional[int]): chunk_size
124callback_manager (Optional[CallbackManager]): CallbackManager
125system_prompt (Optional[str]): System-wide prompt to be prepended
126to all input prompts, used to guide system "decision making"
127query_wrapper_prompt (Optional[BasePromptTemplate]): A format to wrap
128passed-in input queries.
129
130Deprecated Args:
131chunk_size_limit (Optional[int]): renamed to chunk_size
132
133"""
134from llama_index.legacy.embeddings.utils import EmbedType, resolve_embed_model135
136embed_model = cast(EmbedType, embed_model)137
138if chunk_size_limit is not None and chunk_size is None:139logger.warning(140"chunk_size_limit is deprecated, please specify chunk_size instead"141)142chunk_size = chunk_size_limit143
144if llama_index.legacy.global_service_context is not None:145return cls.from_service_context(146llama_index.legacy.global_service_context,147llm=llm,148llm_predictor=llm_predictor,149prompt_helper=prompt_helper,150embed_model=embed_model,151node_parser=node_parser,152text_splitter=text_splitter,153llama_logger=llama_logger,154callback_manager=callback_manager,155context_window=context_window,156chunk_size=chunk_size,157chunk_size_limit=chunk_size_limit,158chunk_overlap=chunk_overlap,159num_output=num_output,160system_prompt=system_prompt,161query_wrapper_prompt=query_wrapper_prompt,162transformations=transformations,163)164
165callback_manager = callback_manager or CallbackManager([])166if llm != "default":167if llm_predictor is not None:168raise ValueError("Cannot specify both llm and llm_predictor")169llm = resolve_llm(llm)170llm.system_prompt = llm.system_prompt or system_prompt171llm.query_wrapper_prompt = llm.query_wrapper_prompt or query_wrapper_prompt172llm.pydantic_program_mode = (173llm.pydantic_program_mode or pydantic_program_mode174)175
176if llm_predictor is not None:177print("LLMPredictor is deprecated, please use LLM instead.")178llm_predictor = llm_predictor or LLMPredictor(179llm=llm, pydantic_program_mode=pydantic_program_mode180)181if isinstance(llm_predictor, LLMPredictor):182llm_predictor.llm.callback_manager = callback_manager183if system_prompt:184llm_predictor.system_prompt = system_prompt185if query_wrapper_prompt:186llm_predictor.query_wrapper_prompt = query_wrapper_prompt187
188# NOTE: the embed_model isn't used in all indices189# NOTE: embed model should be a transformation, but the way the service190# context works, we can't put in there yet.191embed_model = resolve_embed_model(embed_model)192embed_model.callback_manager = callback_manager193
194prompt_helper = prompt_helper or _get_default_prompt_helper(195llm_metadata=llm_predictor.metadata,196context_window=context_window,197num_output=num_output,198)199
200if text_splitter is not None and node_parser is not None:201raise ValueError("Cannot specify both text_splitter and node_parser")202
203node_parser = (204text_splitter # text splitter extends node parser205or node_parser206or _get_default_node_parser(207chunk_size=chunk_size or DEFAULT_CHUNK_SIZE,208chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP,209callback_manager=callback_manager,210)211)212
213transformations = transformations or [node_parser]214
215llama_logger = llama_logger or LlamaLogger()216
217return cls(218llm_predictor=llm_predictor,219embed_model=embed_model,220prompt_helper=prompt_helper,221transformations=transformations,222llama_logger=llama_logger, # deprecated223callback_manager=callback_manager,224)225
226@classmethod227def from_service_context(228cls,229service_context: "ServiceContext",230llm_predictor: Optional[BaseLLMPredictor] = None,231llm: Optional[LLMType] = "default",232prompt_helper: Optional[PromptHelper] = None,233embed_model: Optional[Any] = "default",234node_parser: Optional[NodeParser] = None,235text_splitter: Optional[TextSplitter] = None,236transformations: Optional[List[TransformComponent]] = None,237llama_logger: Optional[LlamaLogger] = None,238callback_manager: Optional[CallbackManager] = None,239system_prompt: Optional[str] = None,240query_wrapper_prompt: Optional[BasePromptTemplate] = None,241# node parser kwargs242chunk_size: Optional[int] = None,243chunk_overlap: Optional[int] = None,244# prompt helper kwargs245context_window: Optional[int] = None,246num_output: Optional[int] = None,247# deprecated kwargs248chunk_size_limit: Optional[int] = None,249) -> "ServiceContext":250"""Instantiate a new service context using a previous as the defaults."""251from llama_index.legacy.embeddings.utils import EmbedType, resolve_embed_model252
253embed_model = cast(EmbedType, embed_model)254
255if chunk_size_limit is not None and chunk_size is None:256logger.warning(257"chunk_size_limit is deprecated, please specify chunk_size",258DeprecationWarning,259)260chunk_size = chunk_size_limit261
262callback_manager = callback_manager or service_context.callback_manager263if llm != "default":264if llm_predictor is not None:265raise ValueError("Cannot specify both llm and llm_predictor")266llm = resolve_llm(llm)267llm_predictor = LLMPredictor(llm=llm)268
269llm_predictor = llm_predictor or service_context.llm_predictor270if isinstance(llm_predictor, LLMPredictor):271llm_predictor.llm.callback_manager = callback_manager272if system_prompt:273llm_predictor.system_prompt = system_prompt274if query_wrapper_prompt:275llm_predictor.query_wrapper_prompt = query_wrapper_prompt276
277# NOTE: the embed_model isn't used in all indices278# default to using the embed model passed from the service context279if embed_model == "default":280embed_model = service_context.embed_model281embed_model = resolve_embed_model(embed_model)282embed_model.callback_manager = callback_manager283
284prompt_helper = prompt_helper or service_context.prompt_helper285if context_window is not None or num_output is not None:286prompt_helper = _get_default_prompt_helper(287llm_metadata=llm_predictor.metadata,288context_window=context_window,289num_output=num_output,290)291
292transformations = transformations or []293node_parser_found = False294for transform in service_context.transformations:295if isinstance(transform, NodeParser):296node_parser_found = True297node_parser = transform298break299
300if text_splitter is not None and node_parser is not None:301raise ValueError("Cannot specify both text_splitter and node_parser")302
303if not node_parser_found:304node_parser = (305text_splitter # text splitter extends node parser306or node_parser307or _get_default_node_parser(308chunk_size=chunk_size or DEFAULT_CHUNK_SIZE,309chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP,310callback_manager=callback_manager,311)312)313
314transformations = transformations or service_context.transformations315
316llama_logger = llama_logger or service_context.llama_logger317
318return cls(319llm_predictor=llm_predictor,320embed_model=embed_model,321prompt_helper=prompt_helper,322transformations=transformations,323llama_logger=llama_logger, # deprecated324callback_manager=callback_manager,325)326
327@property328def llm(self) -> LLM:329return self.llm_predictor.llm330
331@property332def node_parser(self) -> NodeParser:333"""Get the node parser."""334for transform in self.transformations:335if isinstance(transform, NodeParser):336return transform337raise ValueError("No node parser found.")338
339def to_dict(self) -> dict:340"""Convert service context to dict."""341llm_dict = self.llm_predictor.llm.to_dict()342llm_predictor_dict = self.llm_predictor.to_dict()343
344embed_model_dict = self.embed_model.to_dict()345
346prompt_helper_dict = self.prompt_helper.to_dict()347
348tranform_list_dict = [x.to_dict() for x in self.transformations]349
350return ServiceContextData(351llm=llm_dict,352llm_predictor=llm_predictor_dict,353prompt_helper=prompt_helper_dict,354embed_model=embed_model_dict,355transformations=tranform_list_dict,356).dict()357
358@classmethod359def from_dict(cls, data: dict) -> "ServiceContext":360from llama_index.legacy.embeddings.loading import load_embed_model361from llama_index.legacy.extractors.loading import load_extractor362from llama_index.legacy.llm_predictor.loading import load_predictor363from llama_index.legacy.node_parser.loading import load_parser364
365service_context_data = ServiceContextData.parse_obj(data)366
367llm_predictor = load_predictor(service_context_data.llm_predictor)368
369embed_model = load_embed_model(service_context_data.embed_model)370
371prompt_helper = PromptHelper.from_dict(service_context_data.prompt_helper)372
373transformations: List[TransformComponent] = []374for transform in service_context_data.transformations:375try:376transformations.append(load_parser(transform))377except ValueError:378transformations.append(load_extractor(transform))379
380return cls.from_defaults(381llm_predictor=llm_predictor,382prompt_helper=prompt_helper,383embed_model=embed_model,384transformations=transformations,385)386
387
388def set_global_service_context(service_context: Optional[ServiceContext]) -> None:389"""Helper function to set the global service context."""390llama_index.legacy.global_service_context = service_context391