llama-index
50 строк · 1.9 Кб
1from inspect import signature2from typing import Any, Callable, List, Optional, Tuple, Type, Union, cast3
4from llama_index.legacy.bridge.pydantic import BaseModel, FieldInfo, create_model5
6
7def create_schema_from_function(8name: str,9func: Callable[..., Any],10additional_fields: Optional[11List[Union[Tuple[str, Type, Any], Tuple[str, Type]]]12] = None,13) -> Type[BaseModel]:14"""Create schema from function."""15fields = {}16params = signature(func).parameters17for param_name in params:18param_type = params[param_name].annotation19param_default = params[param_name].default20
21if param_type is params[param_name].empty:22param_type = Any23
24if param_default is params[param_name].empty:25# Required field26fields[param_name] = (param_type, FieldInfo())27elif isinstance(param_default, FieldInfo):28# Field with pydantic.Field as default value29fields[param_name] = (param_type, param_default)30else:31fields[param_name] = (param_type, FieldInfo(default=param_default))32
33additional_fields = additional_fields or []34for field_info in additional_fields:35if len(field_info) == 3:36field_info = cast(Tuple[str, Type, Any], field_info)37field_name, field_type, field_default = field_info38fields[field_name] = (field_type, FieldInfo(default=field_default))39elif len(field_info) == 2:40# Required field has no default value41field_info = cast(Tuple[str, Type], field_info)42field_name, field_type = field_info43fields[field_name] = (field_type, FieldInfo())44else:45raise ValueError(46f"Invalid additional field info: {field_info}. "47"Must be a tuple of length 2 or 3."48)49
50return create_model(name, **fields) # type: ignore51