pytorch-lightning

Форк
0
328 строк · 11.0 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import abc
16
import asyncio
17
import base64
18
import os
19
import platform
20
from typing import TYPE_CHECKING, Any, Dict, Optional
21

22
import requests
23
import uvicorn
24
from fastapi import FastAPI
25
from lightning_utilities.core.imports import compare_version, module_available
26
from pydantic import BaseModel
27

28
from lightning.app.core.work import LightningWork
29
from lightning.app.utilities.app_helpers import Logger
30
from lightning.app.utilities.imports import _is_torch_available, requires
31

32
if TYPE_CHECKING:
33
    from lightning.app.frontend.frontend import Frontend
34

35
logger = Logger(__name__)
36

37
# Skip doctests if requirements aren't available
38
if not module_available("lightning_api_access") or not _is_torch_available():
39
    __doctest_skip__ = ["PythonServer", "PythonServer.*"]
40

41

42
def _get_device():
43
    import operator
44

45
    import torch
46

47
    _TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
48

49
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
50

51
    if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):
52
        return torch.device("mps", local_rank)
53

54
    return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
55

56

57
class _DefaultInputData(BaseModel):
58
    payload: str
59

60

61
class _DefaultOutputData(BaseModel):
62
    prediction: str
63

64

65
class Image(BaseModel):
66
    image: Optional[str] = None
67

68
    @staticmethod
69
    def get_sample_data() -> Dict[Any, Any]:
70
        url = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
71
        img = requests.get(url).content
72
        img = base64.b64encode(img).decode("UTF-8")
73
        return {"image": img}
74

75
    @staticmethod
76
    def request_code_sample(url: str) -> str:
77
        return f"""
78
import base64
79
from pathlib import Path
80
import requests
81

82
imgurl = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
83
img = requests.get(imgurl).content
84
img = base64.b64encode(img).decode("UTF-8")
85
response = requests.post('{url}', json={{"image": img}})
86
# If you are using basic authentication for your app, you should add your credentials to the request:
87
# auth = requests.auth.HTTPBasicAuth('your_username', 'your_password')
88
# response = requests.post('{url}', json={{"image": img}}, auth=auth)
89
"""
90

91
    @staticmethod
92
    def response_code_sample() -> str:
93
        return """img = response.json()["image"]
94
img = base64.b64decode(img.encode("utf-8"))
95
Path("response.png").write_bytes(img)
96
"""
97

98

99
class Category(BaseModel):
100
    category: Optional[int] = None
101

102
    @staticmethod
103
    def get_sample_data() -> Dict[Any, Any]:
104
        return {"category": 463}
105

106
    @staticmethod
107
    def response_code_sample() -> str:
108
        return """print("Predicted category is: ", response.json()["category"])
109
"""
110

111

112
class Text(BaseModel):
113
    text: Optional[str] = None
114

115
    @staticmethod
116
    def get_sample_data() -> Dict[Any, Any]:
117
        return {"text": "A portrait of a person looking away from the camera"}
118

119
    @staticmethod
120
    def request_code_sample(url: str) -> str:
121
        return f"""
122
import base64
123
from pathlib import Path
124
import requests
125

126
response = requests.post('{url}', json={{
127
    "text": "A portrait of a person looking away from the camera"
128
}})
129
# If you are using basic authentication for your app, you should add your credentials to the request:
130
# response = requests.post('{url}', json={{
131
#     "text": "A portrait of a person looking away from the camera"
132
# }}, auth=requests.auth.HTTPBasicAuth('your_username', 'your_password'))
133
"""
134

135

136
class Number(BaseModel):
137
    # deprecated - TODO remove this in favour of Category
138
    prediction: Optional[int] = None
139

140
    @staticmethod
141
    def get_sample_data() -> Dict[Any, Any]:
142
        return {"prediction": 463}
143

144

145
class PythonServer(LightningWork, abc.ABC):
146
    _start_method = "spawn"
147

148
    @requires(["torch"])
149
    def __init__(  # type: ignore
150
        self,
151
        input_type: type = _DefaultInputData,
152
        output_type: type = _DefaultOutputData,
153
        **kwargs: Any,
154
    ):
155
        """The PythonServer Class enables to easily get your machine learning server up and running.
156

157
        Arguments:
158
            input_type: Optional `input_type` to be provided. This needs to be a pydantic BaseModel class.
159
                The default data type is good enough for the basic usecases and it expects the data
160
                to be a json object that has one key called `payload`
161

162
                .. code-block:: python
163

164
                    input_data = {"payload": "some data"}
165

166
                and this can be accessed as `request.payload` in the `predict` method.
167

168
                .. code-block:: python
169

170
                    def predict(self, request):
171
                        data = request.payload
172

173
            output_type: Optional `output_type` to be provided. This needs to be a pydantic BaseModel class.
174
                The default data type is good enough for the basic usecases. It expects the return value of
175
                the `predict` method to be a dictionary with one key called `prediction`.
176

177
                .. code-block:: python
178

179
                    def predict(self, request):
180
                        # some code
181
                        return {"prediction": "some data"}
182

183
                and this can be accessed as `response.json()["prediction"]` in the client if
184
                you are using requests library
185

186
        Example:
187

188
            >>> from lightning.app.components.serve.python_server import PythonServer
189
            >>> from lightning.app import LightningApp
190
            ...
191
            >>> class SimpleServer(PythonServer):
192
            ...
193
            ...     def setup(self):
194
            ...         self._model = lambda x: x + " " + x
195
            ...
196
            ...     def predict(self, request):
197
            ...         return {"prediction": self._model(request.image)}
198
            ...
199
            >>> app = LightningApp(SimpleServer())
200

201
        """
202
        super().__init__(parallel=True, **kwargs)
203
        if not issubclass(input_type, BaseModel):
204
            raise TypeError("input_type must be a pydantic BaseModel class")
205
        if not issubclass(output_type, BaseModel):
206
            raise TypeError("output_type must be a pydantic BaseModel class")
207
        self._input_type = input_type
208
        self._output_type = output_type
209

210
        self.ready = False
211

212
    def setup(self, *args: Any, **kwargs: Any) -> None:
213
        """This method is called before the server starts. Override this if you need to download the model or
214
        initialize the weights, setting up pipelines etc.
215

216
        Note that this will be called exactly once on every work machines. So if you have multiple machines for serving,
217
        this will be called on each of them.
218

219
        """
220
        return
221

222
    def configure_input_type(self) -> type:
223
        return self._input_type
224

225
    def configure_output_type(self) -> type:
226
        return self._output_type
227

228
    @abc.abstractmethod
229
    def predict(self, request: Any) -> Any:
230
        """This method is called when a request is made to the server.
231

232
        This method must be overriden by the user with the prediction logic. The pre/post processing, actual prediction
233
        using the model(s) etc goes here
234

235
        """
236
        pass
237

238
    @staticmethod
239
    def _get_sample_dict_from_datatype(datatype: Any) -> dict:
240
        if hasattr(datatype, "get_sample_data"):
241
            return datatype.get_sample_data()
242

243
        datatype_props = datatype.schema()["properties"]
244
        out: Dict[str, Any] = {}
245
        for k, v in datatype_props.items():
246
            if v["type"] == "string":
247
                out[k] = "data string"
248
            elif v["type"] == "number":
249
                out[k] = 0.0
250
            elif v["type"] == "integer":
251
                out[k] = 0
252
            elif v["type"] == "boolean":
253
                out[k] = False
254
            else:
255
                raise TypeError("Unsupported type")
256
        return out
257

258
    def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
259
        input_type: type = self.configure_input_type()
260
        output_type: type = self.configure_output_type()
261

262
        def predict_fn_sync(request: input_type):  # type: ignore
263
            return self.predict(request)
264

265
        async def async_predict_fn(request: input_type):  # type: ignore
266
            return await self.predict(request)
267

268
        if asyncio.iscoroutinefunction(self.predict):
269
            fastapi_app.post("/predict", response_model=output_type)(async_predict_fn)
270
        else:
271
            fastapi_app.post("/predict", response_model=output_type)(predict_fn_sync)
272

273
    def get_code_sample(self, url: str) -> Optional[str]:
274
        input_type: Any = self.configure_input_type()
275
        output_type: Any = self.configure_output_type()
276

277
        if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
278
            return None
279
        return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"
280

281
    def configure_layout(self) -> Optional["Frontend"]:
282
        try:
283
            from lightning_api_access import APIAccessFrontend
284
        except ModuleNotFoundError:
285
            logger.warn(
286
                "Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`"
287
            )
288
            return None
289

290
        class_name = self.__class__.__name__
291
        url = f"{self.url}/predict"
292

293
        try:
294
            request = self._get_sample_dict_from_datatype(self.configure_input_type())
295
            response = self._get_sample_dict_from_datatype(self.configure_output_type())
296
        except TypeError:
297
            return None
298

299
        frontend_payload = {
300
            "name": class_name,
301
            "url": url,
302
            "method": "POST",
303
            "request": request,
304
            "response": response,
305
        }
306

307
        code_sample = self.get_code_sample(url)
308
        if code_sample:
309
            frontend_payload["code_sample"] = code_sample
310

311
        return APIAccessFrontend(apis=[frontend_payload])
312

313
    def run(self, *args: Any, **kwargs: Any) -> Any:
314
        """Run method takes care of configuring and setting up a FastAPI server behind the scenes.
315

316
        Normally, you don't need to override this method.
317

318
        """
319
        self.setup(*args, **kwargs)
320

321
        fastapi_app = FastAPI()
322
        self._attach_predict_fn(fastapi_app)
323

324
        self.ready = True
325
        logger.info(
326
            f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}"
327
        )
328
        uvicorn.run(app=fastapi_app, host=self.host, port=self.port, log_level="error")
329

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.