llama-index
50 строк · 1.6 Кб
1import abc2import json3from typing import TYPE_CHECKING, List4
5if TYPE_CHECKING:6from botocore.response import StreamingBody7
8from llama_index.legacy.bridge.pydantic import Field9
10
11class BaseIOHandler(metaclass=abc.ABCMeta):12content_type: str = Field(13description="The MIME type of the input data in the request body.",14)15accept: str = Field(16description="The desired MIME type of the inference response from the model container.",17)18
19@classmethod20def __subclasshook__(cls, subclass: type) -> bool:21return (22hasattr(subclass, "content_type")23and hasattr(subclass, "accept")24and hasattr(subclass, "serialize_input")25and callable(subclass.serialize_input)26and hasattr(subclass, "deserialize_output")27and callable(subclass.deserialize_output)28or NotImplemented29)30
31@abc.abstractmethod32def serialize_input(self, request: List[str], model_kwargs: dict) -> bytes:33raise NotImplementedError34
35@abc.abstractmethod36def deserialize_output(self, response: "StreamingBody") -> List[List[float]]:37raise NotImplementedError38
39
40class IOHandler(BaseIOHandler):41content_type: str = "application/json"42accept: str = "application/json"43
44def serialize_input(self, request: List[str], model_kwargs: dict) -> bytes:45request_str = json.dumps({"inputs": request, **model_kwargs})46return request_str.encode("utf-8")47
48def deserialize_output(self, response: "StreamingBody") -> List[List[float]]:49response_json = json.loads(response.read().decode("utf-8"))50return response_json["vectors"]51