pytorch-lightning
72 строки · 2.8 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16from typing import Any
17
18from fastapi import HTTPException
19from pydantic import BaseModel
20
21from lightning.app.utilities.imports import _is_aiohttp_available, requires
22
23if _is_aiohttp_available():
24import aiohttp
25import aiohttp.client_exceptions
26
27
28class ColdStartProxy:
29"""ColdStartProxy allows users to configure the load balancer to use a proxy service while the work is cold
30starting. This is useful with services that gets realtime requests but startup time for workers is high.
31
32If the request body is same and the method is POST for the proxy service,
33then the default implementation of `handle_request` can be used. In that case
34initialize the proxy with the proxy url. Otherwise, the user can override the `handle_request`
35
36Args:
37proxy_url (str): The url of the proxy service
38
39"""
40
41@requires(["aiohttp"])
42def __init__(self, proxy_url: str):
43self.proxy_url = proxy_url
44self.proxy_timeout = 50
45if not asyncio.iscoroutinefunction(self.handle_request):
46raise TypeError("handle_request must be an `async` function")
47
48async def handle_request(self, request: BaseModel) -> Any:
49"""This method is called when the request is received while the work is cold starting. The default
50implementation of this method is to forward the request body to the proxy service with POST method but the user
51can override this method to handle the request in any way.
52
53Args:
54request: The request body, a pydantic model that is being forwarded by load balancer which
55is a FastAPI service
56
57"""
58try:
59async with aiohttp.ClientSession() as session:
60headers = {
61"accept": "application/json",
62"Content-Type": "application/json",
63}
64async with session.post(
65self.proxy_url,
66json=request.dict(),
67timeout=self.proxy_timeout,
68headers=headers,
69) as response:
70return await response.json()
71except Exception as ex:
72raise HTTPException(status_code=500, detail=f"Error in proxy: {ex}")
73