pytorch-lightning

Форк
0
308 строк · 10.9 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import errno
16
import inspect
17
import os
18
import os.path as osp
19
import shutil
20
import sys
21
import traceback
22
from dataclasses import asdict
23
from getpass import getuser
24
from importlib.util import module_from_spec, spec_from_file_location
25
from tempfile import gettempdir
26
from typing import Any, Callable, Dict, List, Optional, Union
27

28
import requests
29
from fastapi import HTTPException
30
from pydantic import BaseModel
31

32
from lightning.app.api.http_methods import Post
33
from lightning.app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
34
from lightning.app.utilities import frontend
35
from lightning.app.utilities.app_helpers import Logger, is_overridden
36
from lightning.app.utilities.cloud import _get_project
37
from lightning.app.utilities.network import LightningClient
38
from lightning.app.utilities.state import AppState
39

40
logger = Logger(__name__)
41

42

43
def makedirs(path: str):
44
    r"""Recursive directory creation function."""
45
    try:
46
        os.makedirs(osp.expanduser(osp.normpath(path)))
47
    except OSError as ex:
48
        if ex.errno != errno.EEXIST and osp.isdir(path):
49
            raise ex
50

51

52
class ClientCommand:
53
    description: str = ""
54
    requirements: List[str] = []
55

56
    def __init__(self, method: Callable):
57
        self.method = method
58
        if not self.description:
59
            self.description = self.method.__doc__ or ""
60
        flow = getattr(self.method, "__self__", None)
61
        self.owner = flow.name if flow else None
62
        self.models: Optional[Dict[str, BaseModel]] = None
63
        self.app_url = None
64
        self._state = None
65

66
    def _setup(self, command_name: str, app_url: str) -> None:
67
        self.command_name = command_name
68
        self.app_url = app_url
69

70
    @property
71
    def state(self):
72
        if self._state is None:
73
            assert self.app_url
74
            # TODO: Resolve this hack
75
            os.environ["LIGHTNING_APP_STATE_URL"] = "1"
76
            self._state = AppState(host=self.app_url)
77
            self._state._request_state()
78
            os.environ.pop("LIGHTNING_APP_STATE_URL")
79
        return self._state
80

81
    def run(self, **cli_kwargs) -> None:
82
        """Overrides with the logic to execute on the client side."""
83

84
    def invoke_handler(self, config: Optional[BaseModel] = None) -> Dict[str, Any]:
85
        command = self.command_name.replace(" ", "_")
86
        resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None)
87
        if resp.status_code != 200:
88
            try:
89
                detail = str(resp.json())
90
            except Exception:
91
                detail = "Internal Server Error"
92
            print(f"Failed with status code {resp.status_code}. Detail: {detail}")
93
            sys.exit(0)
94

95
        return resp.json()
96

97
    def _to_dict(self):
98
        return {"owner": self.owner, "requirements": self.requirements}
99

100
    def __call__(self, **kwargs: Any):
101
        return self.method(**kwargs)
102

103

104
def _download_command(
105
    command_name: str,
106
    cls_path: str,
107
    cls_name: str,
108
    app_id: Optional[str] = None,
109
    debug_mode: bool = False,
110
    target_file: Optional[str] = None,
111
) -> ClientCommand:
112
    # TODO: This is a skateboard implementation and the final version will rely on versioned
113
    # immutable commands for security concerns
114
    command_name = command_name.replace(" ", "_")
115
    tmpdir = None
116
    if not target_file:
117
        tmpdir = osp.join(gettempdir(), f"{getuser()}_commands")
118
        makedirs(tmpdir)
119
        target_file = osp.join(tmpdir, f"{command_name}.py")
120

121
    if not debug_mode:
122
        if app_id:
123
            if not os.path.exists(target_file):
124
                client = LightningClient(retry=False)
125
                project_id = _get_project(client).project_id
126
                response = client.lightningapp_instance_service_list_lightningapp_instance_artifacts(
127
                    project_id=project_id, id=app_id
128
                )
129
                for artifact in response.artifacts:
130
                    if f"commands/{command_name}.py" == artifact.filename:
131
                        resp = requests.get(artifact.url, allow_redirects=True)
132

133
                        with open(target_file, "wb") as f:
134
                            f.write(resp.content)
135
        else:
136
            shutil.copy(cls_path, target_file)
137

138
    spec = spec_from_file_location(cls_name, target_file)
139
    mod = module_from_spec(spec)
140
    sys.modules[cls_name] = mod
141
    spec.loader.exec_module(mod)
142
    command_type = getattr(mod, cls_name)
143
    if issubclass(command_type, ClientCommand):
144
        command = command_type(method=None)
145
    else:
146
        raise ValueError(f"Expected class {cls_name} for command {command_name} to be a `ClientCommand`.")
147
    if tmpdir and os.path.exists(tmpdir):
148
        shutil.rmtree(tmpdir)
149
    return command
150

151

152
def _to_annotation(anno: str) -> str:
153
    anno = anno.split("'")[1]
154
    if "." in anno:
155
        return anno.split(".")[-1]
156
    return anno
157

158

159
def _validate_client_command(command: ClientCommand):
160
    """Extract method and its metadata from a ClientCommand."""
161
    params = inspect.signature(command.method).parameters
162
    command_metadata = {
163
        "cls_path": inspect.getfile(command.__class__),
164
        "cls_name": command.__class__.__name__,
165
        "params": {p.name: _to_annotation(str(p.annotation)) for p in params.values()},
166
        **command._to_dict(),
167
    }
168
    method = command.method
169
    command.models = {}
170
    for k, v in command_metadata["params"].items():
171
        if v == "_empty":
172
            raise Exception(
173
                f"Please, annotate your method {method} with pydantic BaseModel. Refer to the documentation."
174
            )
175
        config = getattr(sys.modules[command.__module__], v, None)
176
        if config is None:
177
            config = getattr(sys.modules[method.__module__], v, None)
178
            if config:
179
                raise Exception(
180
                    f"The provided annotation for the argument {k} should in the file "
181
                    f"{inspect.getfile(command.__class__)}, not {inspect.getfile(command.method)}."
182
                )
183
        if config is None or not issubclass(config, BaseModel):
184
            raise Exception(
185
                f"The provided annotation for the argument {k} shouldn't an instance of pydantic BaseModel."
186
            )
187

188

189
def _upload(name: str, prefix: str, obj: Any) -> Optional[str]:
190
    from lightning.app.storage.path import _filesystem, _is_s3fs_available, _shared_storage_path
191

192
    name = name.replace(" ", "_")
193
    filepath = f"{prefix}/{name}.py"
194
    fs = _filesystem()
195

196
    if _is_s3fs_available():
197
        from s3fs import S3FileSystem
198

199
        if not isinstance(fs, S3FileSystem):
200
            return None
201

202
        source_file = str(inspect.getfile(obj.__class__))
203
        remote_url = str(_shared_storage_path() / "artifacts" / filepath)
204
        fs.put(source_file, remote_url)
205
        return filepath
206
    return None
207

208

209
def _prepare_commands(app) -> List:
210
    if not is_overridden("configure_commands", app.root):
211
        return []
212

213
    # 1: Upload the command to s3.
214
    commands = app.root.configure_commands()
215
    for command_mapping in commands:
216
        for command_name, command in command_mapping.items():
217
            if isinstance(command, ClientCommand):
218
                _upload(command_name, "commands", command)
219

220
    # 2: Cache the commands on the app.
221
    app.commands = commands
222
    return commands
223

224

225
def _process_api_request(app, request: _APIRequest):
226
    flow = app.get_component_by_name(request.name)
227
    method = getattr(flow, request.method_name)
228
    try:
229
        response = _RequestResponse(content=method(*request.args, **request.kwargs), status_code=200)
230
    except HTTPException as ex:
231
        logger.error(repr(ex))
232
        response = _RequestResponse(status_code=ex.status_code, content=ex.detail)
233
    except Exception:
234
        logger.error(traceback.print_exc())
235
        response = _RequestResponse(status_code=500)
236
    return {"response": response, "id": request.id}
237

238

239
def _process_command_requests(app, request: _CommandRequest):
240
    for command in app.commands:
241
        for command_name, method in command.items():
242
            command_name = command_name.replace(" ", "_")
243
            if request.method_name == command_name:
244
                # 2.1: Evaluate the method associated to a specific command.
245
                # Validation is done on the CLI side.
246
                try:
247
                    response = _RequestResponse(content=method(*request.args, **request.kwargs), status_code=200)
248
                except HTTPException as ex:
249
                    logger.error(repr(ex))
250
                    response = _RequestResponse(status_code=ex.status_code, content=ex.detail)
251
                except Exception:
252
                    logger.error(traceback.print_exc())
253
                    response = _RequestResponse(status_code=500)
254
                return {"response": response, "id": request.id}
255
    return None
256

257

258
def _process_requests(app, requests: List[Union[_APIRequest, _CommandRequest]]) -> None:
259
    """Convert user commands to API endpoint."""
260
    responses = []
261
    for request in requests:
262
        if isinstance(request, _APIRequest):
263
            response = _process_api_request(app, request)
264
        else:
265
            response = _process_command_requests(app, request)
266

267
        if response:
268
            responses.append(response)
269

270
    app.api_response_queue.put(responses)
271

272

273
def _collect_open_api_extras(command, info) -> Dict:
274
    if not isinstance(command, ClientCommand):
275
        if command.__doc__ is not None:
276
            return {"description": command.__doc__}
277
        return {}
278

279
    extras = {
280
        "cls_path": inspect.getfile(command.__class__),
281
        "cls_name": command.__class__.__name__,
282
        "description": command.description,
283
    }
284
    if command.requirements:
285
        extras.update({"requirements": command.requirements})
286
    if info:
287
        extras.update({"app_info": asdict(info)})
288
    return extras
289

290

291
def _commands_to_api(
292
    commands: List[Dict[str, Union[Callable, ClientCommand]]], info: Optional[frontend.AppInfo] = None
293
) -> List:
294
    """Convert user commands to API endpoint."""
295
    api = []
296
    for command in commands:
297
        for k, v in command.items():
298
            k = k.replace(" ", "_")
299
            api.append(
300
                Post(
301
                    f"/command/{k}",
302
                    v.method if isinstance(v, ClientCommand) else v,
303
                    method_name=k,
304
                    tags=["app_client_command"] if isinstance(v, ClientCommand) else ["app_command"],
305
                    openapi_extra=_collect_open_api_extras(v, info),
306
                )
307
            )
308
    return api
309

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.