pytorch-lightning

Форк
0
170 строк · 5.8 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import abc
16
import inspect
17
import os
18
import pydoc
19
import subprocess
20
import sys
21
from typing import Any, Callable, Optional
22

23
import fastapi  # noqa E511
24
import uvicorn
25
from fastapi import FastAPI
26
from fastapi.responses import JSONResponse
27

28
from lightning.app.components.serve.types import _DESERIALIZER, _SERIALIZER
29
from lightning.app.core.work import LightningWork
30
from lightning.app.utilities.app_helpers import Logger
31

32
logger = Logger(__name__)
33

34

35
fastapi_service = FastAPI()
36

37

38
class _InferenceCallable:
39
    def __init__(
40
        self,
41
        deserialize: Callable,
42
        predict: Callable,
43
        serialize: Callable,
44
    ):
45
        self.deserialize = deserialize
46
        self.predict = predict
47
        self.serialize = serialize
48

49
    async def run(self, data) -> Any:
50
        return self.serialize(self.predict(self.deserialize(data)))
51

52

53
class ModelInferenceAPI(LightningWork, abc.ABC):
54
    def __init__(
55
        self,
56
        input: Optional[str] = None,
57
        output: Optional[str] = None,
58
        host: str = "127.0.0.1",
59
        port: int = 7777,
60
        workers: int = 0,
61
    ):
62
        """The ModelInferenceAPI Class enables to easily get your model served.
63

64
        Arguments:
65
            input: Optional `input` to be provided. This would make provide a built-in deserializer.
66
            output: Optional `output` to be provided. This would make provide a built-in serializer.
67
            host: Address to be used to serve the model.
68
            port: Port to be used to serve the model.
69
            workers: Number of workers for the uvicorn. Warning, this won't work if your subclass takes more arguments.
70

71
        """
72
        super().__init__(parallel=True, host=host, port=port)
73
        if input and input not in _DESERIALIZER:
74
            raise Exception(f"Only input in {_DESERIALIZER.keys()} are supported.")
75
        if output and output not in _SERIALIZER:
76
            raise Exception(f"Only output in {_SERIALIZER.keys()} are supported.")
77
        self.input = input
78
        self.output = output
79
        self.workers = workers
80
        self._model = None
81

82
        self.ready = False
83

84
    @property
85
    def model(self):
86
        return self._model
87

88
    @abc.abstractmethod
89
    def build_model(self) -> Any:
90
        """Override to define your model."""
91

92
    def deserialize(self, data) -> Any:
93
        return data
94

95
    @abc.abstractmethod
96
    def predict(self, data) -> Any:
97
        """Override to add your predict logic."""
98

99
    def serialize(self, data) -> Any:
100
        return data
101

102
    def run(self):
103
        global fastapi_service
104
        if self.workers > 1:
105
            # TODO: This is quite limitated
106
            # Find a more reliable solution to enable multi workers serving.
107
            env = os.environ.copy()
108
            module = inspect.getmodule(self).__file__
109
            env["LIGHTNING_MODEL_INFERENCE_API_FILE"] = module
110
            env["LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME"] = self.__class__.__name__
111
            if self.input:
112
                env["LIGHTNING_MODEL_INFERENCE_API_INPUT"] = self.input
113
            if self.output:
114
                env["LIGHTNING_MODEL_INFERENCE_API_OUTPUT"] = self.output
115
            command = [
116
                sys.executable,
117
                "-m",
118
                "uvicorn",
119
                "--workers",
120
                str(self.workers),
121
                "--host",
122
                str(self.host),
123
                "--port",
124
                str(self.port),
125
                "serve:fastapi_service",
126
            ]
127
            process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__))
128
            self.ready = True
129
            process.wait()
130
        else:
131
            self._populate_app(fastapi_service)
132
            self.ready = True
133
            self._launch_server(fastapi_service)
134

135
    def _populate_app(self, fastapi_service: FastAPI):
136
        self._model = self.build_model()
137

138
        fastapi_service.post("/predict", response_class=JSONResponse)(
139
            _InferenceCallable(
140
                deserialize=_DESERIALIZER[self.input] if self.input else self.deserialize,
141
                predict=self.predict,
142
                serialize=_SERIALIZER[self.output] if self.output else self.serialize,
143
            ).run
144
        )
145

146
    def _launch_server(self, fastapi_service: FastAPI):
147
        logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}")
148
        uvicorn.run(app=fastapi_service, host=self.host, port=self.port, log_level="error")
149

150
    def configure_layout(self) -> str:
151
        return f"{self.url}/docs"
152

153

154
def _maybe_create_instance() -> Optional[ModelInferenceAPI]:
155
    """This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi workers
156
    are present."""
157
    render_fn_name = os.getenv("LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME", None)
158
    render_fn_module_file = os.getenv("LIGHTNING_MODEL_INFERENCE_API_FILE", None)
159
    if render_fn_name is None or render_fn_module_file is None:
160
        return None
161
    module = pydoc.importfile(render_fn_module_file)
162
    cls = getattr(module, render_fn_name)
163
    input = os.getenv("LIGHTNING_MODEL_INFERENCE_API_INPUT", None)
164
    output = os.getenv("LIGHTNING_MODEL_INFERENCE_API_OUTPUT", None)
165
    return cls(input=input, output=output)
166

167

168
instance = _maybe_create_instance()
169
if instance:
170
    instance._populate_app(fastapi_service)
171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.