pytorch-lightning

Форк
0
188 строк · 7.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from dataclasses import asdict, dataclass
16
from typing import Dict, List, Optional, Tuple, Union
17
from uuid import uuid4
18

19
from lightning.app.core.constants import ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER, enable_interruptible_works
20
from lightning.app.storage.mount import Mount
21

22
__CLOUD_COMPUTE_IDENTIFIER__ = "__cloud_compute__"
23

24

25
@dataclass
26
class _CloudComputeStore:
27
    id: str
28
    component_names: List[str]
29

30
    def add_component_name(self, new_component_name: str) -> None:
31
        found_index = None
32
        # When the work is being named by the flow, pop its previous names
33
        for index, component_name in enumerate(self.component_names):
34
            if new_component_name.endswith(component_name.replace("root.", "")):
35
                found_index = index
36

37
        if found_index is not None:
38
            self.component_names[found_index] = new_component_name
39
        else:
40
            if (
41
                len(self.component_names) == 1
42
                and not ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER
43
                and self.id != "default"
44
            ):
45
                raise Exception(
46
                    f"A Cloud Compute can be assigned only to a single Work. Attached to {self.component_names[0]}"
47
                )
48
            self.component_names.append(new_component_name)
49

50
    def remove(self, new_component_name: str) -> None:
51
        found_index = None
52
        for index, component_name in enumerate(self.component_names):
53
            if new_component_name == component_name:
54
                found_index = index
55

56
        if found_index is not None:
57
            del self.component_names[found_index]
58

59

60
_CLOUD_COMPUTE_STORE = {}
61

62

63
@dataclass
64
class CloudCompute:
65
    """Configure the cloud runtime for a lightning work or flow.
66

67
    Arguments:
68
        name: The name of the hardware to use. A full list of supported options can be found in
69
            :doc:`/core_api/lightning_work/compute`. If you have a request for more hardware options, please contact
70
            `onprem@lightning.ai <mailto:onprem@lightning.ai>`_.
71

72
        disk_size: The disk size in Gigabytes.
73
            The value you set here will be allocated to the /home folder.
74

75
        idle_timeout: The number of seconds to wait before pausing the compute when the work is running and idle.
76
            This timeout starts whenever your run() method succeeds (or fails).
77
            If the timeout is reached, the instance pauses until the next run() call happens.
78

79
        shm_size: Shared memory size in MiB, backed by RAM. min 512, max 8192, it will auto update in steps of 512.
80
            For example 1100 will become 1024. If set to zero (the default) will get the default 64MiB inside docker.
81

82
        mounts: External data sources which should be mounted into a work as a filesystem at runtime.
83

84
        colocation_group_id: Identifier for groups of works to be colocated in the same datacenter.
85
            Set this to a string of max. 64 characters and all works with this group id will run in the same datacenter.
86
            If not set, the works are not guaranteed to be colocated.
87

88
        interruptible: Whether to run on a interruptible machine e.g the machine can be stopped
89
            at any time by the providers. This is also known as spot or preemptible machines.
90
            Compared to on-demand machines, they tend to be cheaper.
91

92
    """
93

94
    name: str = "default"
95
    disk_size: int = 0
96
    idle_timeout: Optional[int] = None
97
    shm_size: Optional[int] = None
98
    mounts: Optional[Union[Mount, List[Mount]]] = None
99
    colocation_group_id: Optional[str] = None
100
    interruptible: bool = False
101
    _internal_id: Optional[str] = None
102

103
    def __post_init__(self) -> None:
104
        _verify_mount_root_dirs_are_unique(self.mounts)
105

106
        self.name = self.name.lower()
107

108
        if self.shm_size is None:
109
            if "gpu" in self.name:
110
                self.shm_size = 1024
111
            else:
112
                self.shm_size = 0
113

114
        if self.interruptible:
115
            if not enable_interruptible_works():
116
                raise ValueError("CloudCompute with `interruptible=True` isn't supported yet.")
117
            if "gpu" not in self.name:
118
                raise ValueError("CloudCompute `interruptible=True` is supported only with GPU.")
119

120
        # FIXME: Clean the mess on the platform side
121
        if self.name == "default" or self.name == "cpu":
122
            self.name = "cpu-small"
123
            self._internal_id = "default"
124

125
        # TODO: Remove from the platform first.
126
        self.preemptible = self.interruptible
127

128
        # All `default` CloudCompute are identified in the same way.
129
        if self._internal_id is None:
130
            self._internal_id = self._generate_id()
131

132
        if self.colocation_group_id is not None and (
133
            not isinstance(self.colocation_group_id, str)
134
            or (isinstance(self.colocation_group_id, str) and len(self.colocation_group_id) > 64)
135
        ):
136
            raise ValueError("colocation_group_id can only be a string of maximum 64 characters.")
137

138
    def to_dict(self) -> dict:
139
        _verify_mount_root_dirs_are_unique(self.mounts)
140
        return {"type": __CLOUD_COMPUTE_IDENTIFIER__, **asdict(self)}
141

142
    @classmethod
143
    def from_dict(cls, d: dict) -> "CloudCompute":
144
        assert d.pop("type") == __CLOUD_COMPUTE_IDENTIFIER__
145
        mounts = d.pop("mounts", None)
146
        if mounts is None:
147
            pass
148
        elif isinstance(mounts, dict):
149
            d["mounts"] = Mount(**mounts)
150
        elif isinstance(mounts, (list)):
151
            d["mounts"] = []
152
            for mount in mounts:
153
                d["mounts"].append(Mount(**mount))
154
        else:
155
            raise TypeError(
156
                f"mounts argument must be one of [None, Mount, List[Mount]], "
157
                f"received {mounts} of type {type(mounts)}"
158
            )
159
        _verify_mount_root_dirs_are_unique(d.get("mounts"))
160
        return cls(**d)
161

162
    @property
163
    def id(self) -> Optional[str]:
164
        return self._internal_id
165

166
    def is_default(self) -> bool:
167
        return self.name in ("default", "cpu-small")
168

169
    def _generate_id(self):
170
        return "default" if self.name == "default" else uuid4().hex[:7]
171

172
    def clone(self):
173
        new_dict = self.to_dict()
174
        new_dict["_internal_id"] = self._generate_id()
175
        return self.from_dict(new_dict)
176

177

178
def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None:
179
    if isinstance(mounts, (list, tuple, set)):
180
        mount_paths = [mount.mount_path for mount in mounts]
181
        if len(set(mount_paths)) != len(mount_paths):
182
            raise ValueError("Every Mount attached to a work must have a unique 'mount_path' argument.")
183

184

185
def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:
186
    if state and state.get("type") == __CLOUD_COMPUTE_IDENTIFIER__:
187
        return CloudCompute.from_dict(state)
188
    return state
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.