pytorch-lightning

Форк
0
160 строк · 6.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import multiprocessing
16
import os
17
from dataclasses import dataclass
18
from typing import Any, Union
19

20
import click
21

22
from lightning.app.api.http_methods import _add_tags_to_api, _validate_api
23
from lightning.app.core import constants
24
from lightning.app.core.api import start_server
25
from lightning.app.runners.backends import Backend
26
from lightning.app.runners.runtime import Runtime
27
from lightning.app.storage.orchestrator import StorageOrchestrator
28
from lightning.app.utilities.app_helpers import _is_headless, is_overridden
29
from lightning.app.utilities.commands.base import _commands_to_api, _prepare_commands
30
from lightning.app.utilities.component import _set_flow_context, _set_frontend_context
31
from lightning.app.utilities.load_app import extract_metadata_from_app
32
from lightning.app.utilities.network import find_free_network_port
33
from lightning.app.utilities.port import disable_port
34

35

36
@dataclass
37
class MultiProcessRuntime(Runtime):
38
    """Runtime to launch the LightningApp into multiple processes.
39

40
    The MultiProcessRuntime will generate 1 process for each :class:`~lightning.app.core.work.LightningWork` and attach
41
    queues to enable communication between the different processes.
42

43
    """
44

45
    backend: Union[str, Backend] = "multiprocessing"
46
    _has_triggered_termination: bool = False
47

48
    def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
49
        """Method to dispatch and run the LightningApp."""
50
        try:
51
            _set_flow_context()
52

53
            # Note: In case the runtime is used in the cloud.
54
            in_cloudspace = constants.LIGHTNING_CLOUDSPACE_HOST is not None
55
            self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host  # noqa: S104
56

57
            self.app.backend = self.backend
58
            self.backend._prepare_queues(self.app)
59
            self.backend.resolve_url(self.app, "http://127.0.0.1")
60
            self.app._update_index_file()
61

62
            # set env variables
63
            os.environ.update(self.env_vars)
64

65
            # refresh the layout with the populated urls.
66
            self.app._update_layout()
67

68
            _set_frontend_context()
69
            for frontend in self.app.frontends.values():
70
                port = find_free_network_port()
71

72
                server_host = "0.0.0.0" if in_cloudspace else "localhost"  # noqa: S104
73
                server_target = (
74
                    f"https://{port}-{constants.LIGHTNING_CLOUDSPACE_HOST}"
75
                    if in_cloudspace
76
                    else f"http://localhost:{port}"
77
                )
78

79
                frontend.start_server(host=server_host, port=port)
80
                frontend.flow._layout["target"] = f"{server_target}/{frontend.flow.name}"
81

82
            _set_flow_context()
83

84
            storage_orchestrator = StorageOrchestrator(
85
                self.app,
86
                self.app.request_queues,
87
                self.app.response_queues,
88
                self.app.copy_request_queues,
89
                self.app.copy_response_queues,
90
            )
91
            self.threads.append(storage_orchestrator)
92
            storage_orchestrator.setDaemon(True)
93
            storage_orchestrator.start()
94

95
            if self.start_server:
96
                self.app.should_publish_changes_to_api = True
97
                has_started_queue = self.backend.queues.get_has_server_started_queue()
98

99
                apis = []
100
                if is_overridden("configure_api", self.app.root):
101
                    apis = self.app.root.configure_api()
102
                    _validate_api(apis)
103
                    _add_tags_to_api(apis, ["app_api"])
104

105
                if is_overridden("configure_commands", self.app.root):
106
                    commands = _prepare_commands(self.app)
107
                    apis += _commands_to_api(commands, info=self.app.info)
108

109
                kwargs = {
110
                    "apis": apis,
111
                    "host": self.host,
112
                    "port": self.port,
113
                    "api_response_queue": self.app.api_response_queue,
114
                    "api_publish_state_queue": self.app.api_publish_state_queue,
115
                    "api_delta_queue": self.app.api_delta_queue,
116
                    "has_started_queue": has_started_queue,
117
                    "spec": extract_metadata_from_app(self.app),
118
                    "root_path": self.app.root_path,
119
                }
120
                server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs)
121
                self.processes["server"] = server_proc
122
                server_proc.start()
123
                # requires to wait for the UI to be clicked on.
124

125
                # wait for server to be ready
126
                has_started_queue.get()
127

128
            if all([
129
                open_ui,
130
                "PYTEST_CURRENT_TEST" not in os.environ,
131
                not _is_headless(self.app),
132
                constants.LIGHTNING_CLOUDSPACE_HOST is None,
133
            ]):
134
                click.launch(self._get_app_url())
135

136
            # Connect the runtime to the application.
137
            self.app.connect(self)
138

139
            # Once the bootstrapping is done, running the rank 0
140
            # app with all the components inactive
141
            self.app._run()
142
        except KeyboardInterrupt:
143
            self.terminate()
144
            self._has_triggered_termination = True
145
            raise
146
        finally:
147
            if not self._has_triggered_termination:
148
                self.terminate()
149

150
    def terminate(self):
151
        if constants.APP_SERVER_IN_CLOUD:
152
            # Close all the ports open for the App within the App.
153
            ports = [self.port] + getattr(self.backend, "ports", [])
154
            for port in ports:
155
                disable_port(port)
156
        super().terminate()
157

158
    @staticmethod
159
    def _get_app_url() -> str:
160
        return os.getenv("APP_SERVER_HOST", "http://127.0.0.1:7501/view")
161

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.