llama-index
73 строки · 2.5 Кб
1import abc
2import codecs
3import json
4from typing import TYPE_CHECKING
5
6if TYPE_CHECKING:
7from botocore.response import StreamingBody
8
9from llama_index.legacy.bridge.pydantic import BaseModel, Field
10
11
12class BaseIOHandler(BaseModel, metaclass=abc.ABCMeta):
13content_type: str = Field(
14description="The MIME type of the input data in the request body.",
15)
16accept: str = Field(
17description="The desired MIME type of the inference response from the model container.",
18)
19
20@classmethod
21def __subclasshook__(cls, subclass: type) -> bool:
22return (
23hasattr(subclass, "content_type")
24and hasattr(subclass, "accept")
25and hasattr(subclass, "serialize_input")
26and callable(subclass.serialize_input)
27and hasattr(subclass, "deserialize_output")
28and callable(subclass.deserialize_output)
29and hasattr(subclass, "deserialize_streaming_output")
30and callable(subclass.deserialize_streaming_output)
31and hasattr(subclass, "remove_prefix")
32and callable(subclass.remove_prefix)
33or NotImplemented
34)
35
36@abc.abstractmethod
37def serialize_input(self, request: str, model_kwargs: dict) -> bytes:
38raise NotImplementedError
39
40@abc.abstractmethod
41def deserialize_output(self, response: "StreamingBody") -> str:
42raise NotImplementedError
43
44@abc.abstractmethod
45def deserialize_streaming_output(self, response: bytes) -> str:
46raise NotImplementedError
47
48@abc.abstractmethod
49def remove_prefix(self, response: str, prompt: str) -> str:
50raise NotImplementedError
51
52
53class IOHandler(BaseIOHandler):
54content_type: str = "application/json"
55accept: str = "application/json"
56
57def serialize_input(self, request: str, model_kwargs: dict) -> bytes:
58request_str = json.dumps({"inputs": request, "parameters": model_kwargs})
59return request_str.encode("utf-8")
60
61def deserialize_output(self, response: "StreamingBody") -> str:
62return json.load(codecs.getreader("utf-8")(response))[0]["generated_text"]
63
64def deserialize_streaming_output(self, response: bytes) -> str:
65response_str = (
66response.decode("utf-8").lstrip('[{"generated_text":"').rstrip('"}]')
67)
68clean_response = '{"response":"' + response_str + '"}'
69
70return json.loads(clean_response)["response"]
71
72def remove_prefix(self, raw_text: str, prompt: str) -> str:
73return raw_text[len(prompt) :]
74