pytorch-lightning
328 строк · 11.0 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc16import asyncio17import base6418import os19import platform20from typing import TYPE_CHECKING, Any, Dict, Optional21
22import requests23import uvicorn24from fastapi import FastAPI25from lightning_utilities.core.imports import compare_version, module_available26from pydantic import BaseModel27
28from lightning.app.core.work import LightningWork29from lightning.app.utilities.app_helpers import Logger30from lightning.app.utilities.imports import _is_torch_available, requires31
32if TYPE_CHECKING:33from lightning.app.frontend.frontend import Frontend34
35logger = Logger(__name__)36
37# Skip doctests if requirements aren't available
38if not module_available("lightning_api_access") or not _is_torch_available():39__doctest_skip__ = ["PythonServer", "PythonServer.*"]40
41
42def _get_device():43import operator44
45import torch46
47_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")48
49local_rank = int(os.getenv("LOCAL_RANK", "0"))50
51if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):52return torch.device("mps", local_rank)53
54return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")55
56
57class _DefaultInputData(BaseModel):58payload: str59
60
61class _DefaultOutputData(BaseModel):62prediction: str63
64
65class Image(BaseModel):66image: Optional[str] = None67
68@staticmethod69def get_sample_data() -> Dict[Any, Any]:70url = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"71img = requests.get(url).content72img = base64.b64encode(img).decode("UTF-8")73return {"image": img}74
75@staticmethod76def request_code_sample(url: str) -> str:77return f"""78import base64
79from pathlib import Path
80import requests
81
82imgurl = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
83img = requests.get(imgurl).content
84img = base64.b64encode(img).decode("UTF-8")
85response = requests.post('{url}', json={{"image": img}})86# If you are using basic authentication for your app, you should add your credentials to the request:
87# auth = requests.auth.HTTPBasicAuth('your_username', 'your_password')
88# response = requests.post('{url}', json={{"image": img}}, auth=auth)89"""
90
91@staticmethod92def response_code_sample() -> str:93return """img = response.json()["image"]94img = base64.b64decode(img.encode("utf-8"))
95Path("response.png").write_bytes(img)
96"""
97
98
99class Category(BaseModel):100category: Optional[int] = None101
102@staticmethod103def get_sample_data() -> Dict[Any, Any]:104return {"category": 463}105
106@staticmethod107def response_code_sample() -> str:108return """print("Predicted category is: ", response.json()["category"])109"""
110
111
112class Text(BaseModel):113text: Optional[str] = None114
115@staticmethod116def get_sample_data() -> Dict[Any, Any]:117return {"text": "A portrait of a person looking away from the camera"}118
119@staticmethod120def request_code_sample(url: str) -> str:121return f"""122import base64
123from pathlib import Path
124import requests
125
126response = requests.post('{url}', json={{127"text": "A portrait of a person looking away from the camera"
128}})
129# If you are using basic authentication for your app, you should add your credentials to the request:
130# response = requests.post('{url}', json={{131# "text": "A portrait of a person looking away from the camera"
132# }}, auth=requests.auth.HTTPBasicAuth('your_username', 'your_password'))
133"""
134
135
136class Number(BaseModel):137# deprecated - TODO remove this in favour of Category138prediction: Optional[int] = None139
140@staticmethod141def get_sample_data() -> Dict[Any, Any]:142return {"prediction": 463}143
144
145class PythonServer(LightningWork, abc.ABC):146_start_method = "spawn"147
148@requires(["torch"])149def __init__( # type: ignore150self,151input_type: type = _DefaultInputData,152output_type: type = _DefaultOutputData,153**kwargs: Any,154):155"""The PythonServer Class enables to easily get your machine learning server up and running.156
157Arguments:
158input_type: Optional `input_type` to be provided. This needs to be a pydantic BaseModel class.
159The default data type is good enough for the basic usecases and it expects the data
160to be a json object that has one key called `payload`
161
162.. code-block:: python
163
164input_data = {"payload": "some data"}
165
166and this can be accessed as `request.payload` in the `predict` method.
167
168.. code-block:: python
169
170def predict(self, request):
171data = request.payload
172
173output_type: Optional `output_type` to be provided. This needs to be a pydantic BaseModel class.
174The default data type is good enough for the basic usecases. It expects the return value of
175the `predict` method to be a dictionary with one key called `prediction`.
176
177.. code-block:: python
178
179def predict(self, request):
180# some code
181return {"prediction": "some data"}
182
183and this can be accessed as `response.json()["prediction"]` in the client if
184you are using requests library
185
186Example:
187
188>>> from lightning.app.components.serve.python_server import PythonServer
189>>> from lightning.app import LightningApp
190...
191>>> class SimpleServer(PythonServer):
192...
193... def setup(self):
194... self._model = lambda x: x + " " + x
195...
196... def predict(self, request):
197... return {"prediction": self._model(request.image)}
198...
199>>> app = LightningApp(SimpleServer())
200
201"""
202super().__init__(parallel=True, **kwargs)203if not issubclass(input_type, BaseModel):204raise TypeError("input_type must be a pydantic BaseModel class")205if not issubclass(output_type, BaseModel):206raise TypeError("output_type must be a pydantic BaseModel class")207self._input_type = input_type208self._output_type = output_type209
210self.ready = False211
212def setup(self, *args: Any, **kwargs: Any) -> None:213"""This method is called before the server starts. Override this if you need to download the model or214initialize the weights, setting up pipelines etc.
215
216Note that this will be called exactly once on every work machines. So if you have multiple machines for serving,
217this will be called on each of them.
218
219"""
220return221
222def configure_input_type(self) -> type:223return self._input_type224
225def configure_output_type(self) -> type:226return self._output_type227
228@abc.abstractmethod229def predict(self, request: Any) -> Any:230"""This method is called when a request is made to the server.231
232This method must be overriden by the user with the prediction logic. The pre/post processing, actual prediction
233using the model(s) etc goes here
234
235"""
236pass237
238@staticmethod239def _get_sample_dict_from_datatype(datatype: Any) -> dict:240if hasattr(datatype, "get_sample_data"):241return datatype.get_sample_data()242
243datatype_props = datatype.schema()["properties"]244out: Dict[str, Any] = {}245for k, v in datatype_props.items():246if v["type"] == "string":247out[k] = "data string"248elif v["type"] == "number":249out[k] = 0.0250elif v["type"] == "integer":251out[k] = 0252elif v["type"] == "boolean":253out[k] = False254else:255raise TypeError("Unsupported type")256return out257
258def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:259input_type: type = self.configure_input_type()260output_type: type = self.configure_output_type()261
262def predict_fn_sync(request: input_type): # type: ignore263return self.predict(request)264
265async def async_predict_fn(request: input_type): # type: ignore266return await self.predict(request)267
268if asyncio.iscoroutinefunction(self.predict):269fastapi_app.post("/predict", response_model=output_type)(async_predict_fn)270else:271fastapi_app.post("/predict", response_model=output_type)(predict_fn_sync)272
273def get_code_sample(self, url: str) -> Optional[str]:274input_type: Any = self.configure_input_type()275output_type: Any = self.configure_output_type()276
277if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):278return None279return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"280
281def configure_layout(self) -> Optional["Frontend"]:282try:283from lightning_api_access import APIAccessFrontend284except ModuleNotFoundError:285logger.warn(286"Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`"287)288return None289
290class_name = self.__class__.__name__291url = f"{self.url}/predict"292
293try:294request = self._get_sample_dict_from_datatype(self.configure_input_type())295response = self._get_sample_dict_from_datatype(self.configure_output_type())296except TypeError:297return None298
299frontend_payload = {300"name": class_name,301"url": url,302"method": "POST",303"request": request,304"response": response,305}306
307code_sample = self.get_code_sample(url)308if code_sample:309frontend_payload["code_sample"] = code_sample310
311return APIAccessFrontend(apis=[frontend_payload])312
313def run(self, *args: Any, **kwargs: Any) -> Any:314"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.315
316Normally, you don't need to override this method.
317
318"""
319self.setup(*args, **kwargs)320
321fastapi_app = FastAPI()322self._attach_predict_fn(fastapi_app)323
324self.ready = True325logger.info(326f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}"327)328uvicorn.run(app=fastapi_app, host=self.host, port=self.port, log_level="error")329