llama-index

Форк
0
/
sagemaker_llm_endpoint_utils.py 
73 строки · 2.5 Кб
1
import abc
2
import codecs
3
import json
4
from typing import TYPE_CHECKING
5

6
if TYPE_CHECKING:
7
    from botocore.response import StreamingBody
8

9
from llama_index.legacy.bridge.pydantic import BaseModel, Field
10

11

12
class BaseIOHandler(BaseModel, metaclass=abc.ABCMeta):
13
    content_type: str = Field(
14
        description="The MIME type of the input data in the request body.",
15
    )
16
    accept: str = Field(
17
        description="The desired MIME type of the inference response from the model container.",
18
    )
19

20
    @classmethod
21
    def __subclasshook__(cls, subclass: type) -> bool:
22
        return (
23
            hasattr(subclass, "content_type")
24
            and hasattr(subclass, "accept")
25
            and hasattr(subclass, "serialize_input")
26
            and callable(subclass.serialize_input)
27
            and hasattr(subclass, "deserialize_output")
28
            and callable(subclass.deserialize_output)
29
            and hasattr(subclass, "deserialize_streaming_output")
30
            and callable(subclass.deserialize_streaming_output)
31
            and hasattr(subclass, "remove_prefix")
32
            and callable(subclass.remove_prefix)
33
            or NotImplemented
34
        )
35

36
    @abc.abstractmethod
37
    def serialize_input(self, request: str, model_kwargs: dict) -> bytes:
38
        raise NotImplementedError
39

40
    @abc.abstractmethod
41
    def deserialize_output(self, response: "StreamingBody") -> str:
42
        raise NotImplementedError
43

44
    @abc.abstractmethod
45
    def deserialize_streaming_output(self, response: bytes) -> str:
46
        raise NotImplementedError
47

48
    @abc.abstractmethod
49
    def remove_prefix(self, response: str, prompt: str) -> str:
50
        raise NotImplementedError
51

52

53
class IOHandler(BaseIOHandler):
54
    content_type: str = "application/json"
55
    accept: str = "application/json"
56

57
    def serialize_input(self, request: str, model_kwargs: dict) -> bytes:
58
        request_str = json.dumps({"inputs": request, "parameters": model_kwargs})
59
        return request_str.encode("utf-8")
60

61
    def deserialize_output(self, response: "StreamingBody") -> str:
62
        return json.load(codecs.getreader("utf-8")(response))[0]["generated_text"]
63

64
    def deserialize_streaming_output(self, response: bytes) -> str:
65
        response_str = (
66
            response.decode("utf-8").lstrip('[{"generated_text":"').rstrip('"}]')
67
        )
68
        clean_response = '{"response":"' + response_str + '"}'
69

70
        return json.loads(clean_response)["response"]
71

72
    def remove_prefix(self, raw_text: str, prompt: str) -> str:
73
        return raw_text[len(prompt) :]
74

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.