pytorch-lightning
170 строк · 5.8 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc16import inspect17import os18import pydoc19import subprocess20import sys21from typing import Any, Callable, Optional22
23import fastapi # noqa E51124import uvicorn25from fastapi import FastAPI26from fastapi.responses import JSONResponse27
28from lightning.app.components.serve.types import _DESERIALIZER, _SERIALIZER29from lightning.app.core.work import LightningWork30from lightning.app.utilities.app_helpers import Logger31
32logger = Logger(__name__)33
34
35fastapi_service = FastAPI()36
37
38class _InferenceCallable:39def __init__(40self,41deserialize: Callable,42predict: Callable,43serialize: Callable,44):45self.deserialize = deserialize46self.predict = predict47self.serialize = serialize48
49async def run(self, data) -> Any:50return self.serialize(self.predict(self.deserialize(data)))51
52
53class ModelInferenceAPI(LightningWork, abc.ABC):54def __init__(55self,56input: Optional[str] = None,57output: Optional[str] = None,58host: str = "127.0.0.1",59port: int = 7777,60workers: int = 0,61):62"""The ModelInferenceAPI Class enables to easily get your model served.63
64Arguments:
65input: Optional `input` to be provided. This would make provide a built-in deserializer.
66output: Optional `output` to be provided. This would make provide a built-in serializer.
67host: Address to be used to serve the model.
68port: Port to be used to serve the model.
69workers: Number of workers for the uvicorn. Warning, this won't work if your subclass takes more arguments.
70
71"""
72super().__init__(parallel=True, host=host, port=port)73if input and input not in _DESERIALIZER:74raise Exception(f"Only input in {_DESERIALIZER.keys()} are supported.")75if output and output not in _SERIALIZER:76raise Exception(f"Only output in {_SERIALIZER.keys()} are supported.")77self.input = input78self.output = output79self.workers = workers80self._model = None81
82self.ready = False83
84@property85def model(self):86return self._model87
88@abc.abstractmethod89def build_model(self) -> Any:90"""Override to define your model."""91
92def deserialize(self, data) -> Any:93return data94
95@abc.abstractmethod96def predict(self, data) -> Any:97"""Override to add your predict logic."""98
99def serialize(self, data) -> Any:100return data101
102def run(self):103global fastapi_service104if self.workers > 1:105# TODO: This is quite limitated106# Find a more reliable solution to enable multi workers serving.107env = os.environ.copy()108module = inspect.getmodule(self).__file__109env["LIGHTNING_MODEL_INFERENCE_API_FILE"] = module110env["LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME"] = self.__class__.__name__111if self.input:112env["LIGHTNING_MODEL_INFERENCE_API_INPUT"] = self.input113if self.output:114env["LIGHTNING_MODEL_INFERENCE_API_OUTPUT"] = self.output115command = [116sys.executable,117"-m",118"uvicorn",119"--workers",120str(self.workers),121"--host",122str(self.host),123"--port",124str(self.port),125"serve:fastapi_service",126]127process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__))128self.ready = True129process.wait()130else:131self._populate_app(fastapi_service)132self.ready = True133self._launch_server(fastapi_service)134
135def _populate_app(self, fastapi_service: FastAPI):136self._model = self.build_model()137
138fastapi_service.post("/predict", response_class=JSONResponse)(139_InferenceCallable(140deserialize=_DESERIALIZER[self.input] if self.input else self.deserialize,141predict=self.predict,142serialize=_SERIALIZER[self.output] if self.output else self.serialize,143).run144)145
146def _launch_server(self, fastapi_service: FastAPI):147logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}")148uvicorn.run(app=fastapi_service, host=self.host, port=self.port, log_level="error")149
150def configure_layout(self) -> str:151return f"{self.url}/docs"152
153
154def _maybe_create_instance() -> Optional[ModelInferenceAPI]:155"""This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi workers156are present."""
157render_fn_name = os.getenv("LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME", None)158render_fn_module_file = os.getenv("LIGHTNING_MODEL_INFERENCE_API_FILE", None)159if render_fn_name is None or render_fn_module_file is None:160return None161module = pydoc.importfile(render_fn_module_file)162cls = getattr(module, render_fn_name)163input = os.getenv("LIGHTNING_MODEL_INFERENCE_API_INPUT", None)164output = os.getenv("LIGHTNING_MODEL_INFERENCE_API_OUTPUT", None)165return cls(input=input, output=output)166
167
168instance = _maybe_create_instance()169if instance:170instance._populate_app(fastapi_service)171