pytorch-lightning

Форк
0
143 строки · 5.1 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from abc import ABC, abstractmethod
16
from functools import partial
17
from typing import TYPE_CHECKING, Any, Callable, List, Optional
18

19
from lightning.app.core.queues import QueuingSystem
20
from lightning.app.utilities.proxies import ProxyWorkRun, unwrap
21

22
if TYPE_CHECKING:
23
    import lightning.app
24

25

26
class Backend(ABC):
27
    """The Backend provides and interface for the framework to communicate with resources in the cloud."""
28

29
    def __init__(self, entrypoint_file: str, queues: QueuingSystem, queue_id: str) -> None:
30
        self.queues: QueuingSystem = queues
31
        self.queue_id = queue_id
32
        self.entrypoint_file = entrypoint_file
33

34
    @abstractmethod
35
    def create_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
36
        pass
37

38
    @abstractmethod
39
    def update_work_statuses(self, works: List["lightning.app.LightningWork"]) -> None:
40
        pass
41

42
    @abstractmethod
43
    def stop_all_works(self, works: List["lightning.app.LightningWork"]) -> None:
44
        pass
45

46
    @abstractmethod
47
    def resolve_url(self, app, base_url: Optional[str] = None) -> None:
48
        pass
49

50
    @abstractmethod
51
    def stop_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
52
        pass
53

54
    def _dynamic_run_wrapper(
55
        self,
56
        *args: Any,
57
        app: "lightning.app.LightningApp",
58
        work: "lightning.app.LightningWork",
59
        work_run: Callable,
60
        **kwargs: Any,
61
    ) -> None:
62
        if not work.name:
63
            # the name is empty, which means this work was never assigned to a parent flow
64
            raise AttributeError(
65
                f"Failed to create process for {work.__class__.__name__}."
66
                f" Make sure to set this work as an attribute of a `LightningFlow` before calling the run method."
67
            )
68

69
        # 1. Create and register the queues associated the work
70
        self._register_queues(app, work)
71

72
        work.run = work_run
73

74
        # 2. Create the work
75
        self.create_work(app, work)
76

77
        # 3. Attach backend
78
        work._backend = self
79

80
        # 4. Create the work proxy to manipulate the work
81
        work.run = ProxyWorkRun(
82
            work_run=work_run,
83
            work_name=work.name,
84
            work=work,
85
            caller_queue=app.caller_queues[work.name],
86
        )
87

88
        # 5. Run the work proxy
89
        return work.run(*args, **kwargs)
90

91
    def _wrap_run_method(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork"):
92
        if work.run.__name__ == "_dynamic_run_wrapper":
93
            return
94

95
        work.run = partial(self._dynamic_run_wrapper, app=app, work=work, work_run=unwrap(work.run))
96

97
    def _prepare_queues(self, app: "lightning.app.LightningApp"):
98
        kw = {"queue_id": self.queue_id}
99
        app.delta_queue = self.queues.get_delta_queue(**kw)
100
        app.readiness_queue = self.queues.get_readiness_queue(**kw)
101
        app.api_response_queue = self.queues.get_api_response_queue(**kw)
102
        app.error_queue = self.queues.get_error_queue(**kw)
103
        app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw)
104
        app.api_delta_queue = app.delta_queue
105
        app.request_queues = {}
106
        app.response_queues = {}
107
        app.copy_request_queues = {}
108
        app.copy_response_queues = {}
109
        app.caller_queues = {}
110
        app.work_queues = {}
111
        app.flow_to_work_delta_queues = {}
112

113
    def _register_queues(self, app, work):
114
        kw = {"queue_id": self.queue_id, "work_name": work.name}
115
        app.request_queues.update({work.name: self.queues.get_orchestrator_request_queue(**kw)})
116
        app.response_queues.update({work.name: self.queues.get_orchestrator_response_queue(**kw)})
117
        app.copy_request_queues.update({work.name: self.queues.get_orchestrator_copy_request_queue(**kw)})
118
        app.copy_response_queues.update({work.name: self.queues.get_orchestrator_copy_response_queue(**kw)})
119
        app.caller_queues.update({work.name: self.queues.get_caller_queue(**kw)})
120
        app.flow_to_work_delta_queues.update({work.name: self.queues.get_flow_to_work_delta_queue(**kw)})
121

122

123
class WorkManager(ABC):
124
    """The work manager is an interface for the backend, runtime to control the LightningWork."""
125

126
    def __init__(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork"):
127
        pass
128

129
    @abstractmethod
130
    def start(self) -> None:
131
        pass
132

133
    @abstractmethod
134
    def kill(self) -> None:
135
        pass
136

137
    @abstractmethod
138
    def restart(self) -> None:
139
        pass
140

141
    @abstractmethod
142
    def is_alive(self) -> bool:
143
        pass
144

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.