pytorch-lightning
160 строк · 6.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import multiprocessing16import os17from dataclasses import dataclass18from typing import Any, Union19
20import click21
22from lightning.app.api.http_methods import _add_tags_to_api, _validate_api23from lightning.app.core import constants24from lightning.app.core.api import start_server25from lightning.app.runners.backends import Backend26from lightning.app.runners.runtime import Runtime27from lightning.app.storage.orchestrator import StorageOrchestrator28from lightning.app.utilities.app_helpers import _is_headless, is_overridden29from lightning.app.utilities.commands.base import _commands_to_api, _prepare_commands30from lightning.app.utilities.component import _set_flow_context, _set_frontend_context31from lightning.app.utilities.load_app import extract_metadata_from_app32from lightning.app.utilities.network import find_free_network_port33from lightning.app.utilities.port import disable_port34
35
36@dataclass
37class MultiProcessRuntime(Runtime):38"""Runtime to launch the LightningApp into multiple processes.39
40The MultiProcessRuntime will generate 1 process for each :class:`~lightning.app.core.work.LightningWork` and attach
41queues to enable communication between the different processes.
42
43"""
44
45backend: Union[str, Backend] = "multiprocessing"46_has_triggered_termination: bool = False47
48def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):49"""Method to dispatch and run the LightningApp."""50try:51_set_flow_context()52
53# Note: In case the runtime is used in the cloud.54in_cloudspace = constants.LIGHTNING_CLOUDSPACE_HOST is not None55self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host # noqa: S10456
57self.app.backend = self.backend58self.backend._prepare_queues(self.app)59self.backend.resolve_url(self.app, "http://127.0.0.1")60self.app._update_index_file()61
62# set env variables63os.environ.update(self.env_vars)64
65# refresh the layout with the populated urls.66self.app._update_layout()67
68_set_frontend_context()69for frontend in self.app.frontends.values():70port = find_free_network_port()71
72server_host = "0.0.0.0" if in_cloudspace else "localhost" # noqa: S10473server_target = (74f"https://{port}-{constants.LIGHTNING_CLOUDSPACE_HOST}"75if in_cloudspace76else f"http://localhost:{port}"77)78
79frontend.start_server(host=server_host, port=port)80frontend.flow._layout["target"] = f"{server_target}/{frontend.flow.name}"81
82_set_flow_context()83
84storage_orchestrator = StorageOrchestrator(85self.app,86self.app.request_queues,87self.app.response_queues,88self.app.copy_request_queues,89self.app.copy_response_queues,90)91self.threads.append(storage_orchestrator)92storage_orchestrator.setDaemon(True)93storage_orchestrator.start()94
95if self.start_server:96self.app.should_publish_changes_to_api = True97has_started_queue = self.backend.queues.get_has_server_started_queue()98
99apis = []100if is_overridden("configure_api", self.app.root):101apis = self.app.root.configure_api()102_validate_api(apis)103_add_tags_to_api(apis, ["app_api"])104
105if is_overridden("configure_commands", self.app.root):106commands = _prepare_commands(self.app)107apis += _commands_to_api(commands, info=self.app.info)108
109kwargs = {110"apis": apis,111"host": self.host,112"port": self.port,113"api_response_queue": self.app.api_response_queue,114"api_publish_state_queue": self.app.api_publish_state_queue,115"api_delta_queue": self.app.api_delta_queue,116"has_started_queue": has_started_queue,117"spec": extract_metadata_from_app(self.app),118"root_path": self.app.root_path,119}120server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs)121self.processes["server"] = server_proc122server_proc.start()123# requires to wait for the UI to be clicked on.124
125# wait for server to be ready126has_started_queue.get()127
128if all([129open_ui,130"PYTEST_CURRENT_TEST" not in os.environ,131not _is_headless(self.app),132constants.LIGHTNING_CLOUDSPACE_HOST is None,133]):134click.launch(self._get_app_url())135
136# Connect the runtime to the application.137self.app.connect(self)138
139# Once the bootstrapping is done, running the rank 0140# app with all the components inactive141self.app._run()142except KeyboardInterrupt:143self.terminate()144self._has_triggered_termination = True145raise146finally:147if not self._has_triggered_termination:148self.terminate()149
150def terminate(self):151if constants.APP_SERVER_IN_CLOUD:152# Close all the ports open for the App within the App.153ports = [self.port] + getattr(self.backend, "ports", [])154for port in ports:155disable_port(port)156super().terminate()157
158@staticmethod159def _get_app_url() -> str:160return os.getenv("APP_SERVER_HOST", "http://127.0.0.1:7501/view")161