pytorch-lightning

Форк
0
208 строк · 7.5 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import inspect
16
import os
17
import re
18
from dataclasses import asdict, dataclass, field
19
from pathlib import Path
20
from typing import TYPE_CHECKING, Dict, List, Optional, Union
21

22
from typing_extensions import Self
23

24
from lightning.app.utilities.app_helpers import Logger
25
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
26

27
if TYPE_CHECKING:
28
    from lightning.app.core.work import LightningWork
29

30
logger = Logger(__name__)
31

32

33
def load_requirements(
34
    path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: bool = True
35
) -> List[str]:
36
    """Load requirements from a file."""
37
    path = os.path.join(path_dir, file_name)
38
    if not os.path.isfile(path):
39
        return []
40

41
    with open(path) as file:
42
        lines = [ln.strip() for ln in file.readlines()]
43
    reqs = []
44
    for ln in lines:
45
        # filer all comments
46
        comment = ""
47
        if comment_char in ln:
48
            comment = ln[ln.index(comment_char) :]
49
            ln = ln[: ln.index(comment_char)]
50
        req = ln.strip()
51
        # skip directly installed dependencies
52
        if not req or req.startswith("http") or "@http" in req:
53
            continue
54
        # remove version restrictions unless they are strict
55
        if unfreeze and "<" in req and "strict" not in comment:
56
            req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip()
57
        reqs.append(req)
58
    return reqs
59

60

61
@dataclass
62
class _Dockerfile:
63
    path: str
64
    data: List[str]
65

66

67
@dataclass
68
class BuildConfig:
69
    """The Build Configuration describes how the environment a LightningWork runs in should be set up.
70

71
    Arguments:
72
        requirements: List of requirements or list of paths to requirement files. If not passed, they will be
73
            automatically extracted from a `requirements.txt` if it exists.
74
        dockerfile: The path to a dockerfile to be used to build your container.
75
            You need to add those lines to ensure your container works in the cloud.
76

77
            .. warning:: This feature isn't supported yet, but coming soon.
78

79
            Example::
80

81
                WORKDIR /gridai/project
82
                COPY . .
83
        image: The base image that the work runs on. This should be a publicly accessible image from a registry that
84
            doesn't enforce rate limits (such as DockerHub) to pull this image, otherwise your application will not
85
            start.
86

87
    """
88

89
    requirements: List[str] = field(default_factory=list)
90
    dockerfile: Optional[Union[str, Path, _Dockerfile]] = None
91
    image: Optional[str] = None
92

93
    def __post_init__(self) -> None:
94
        current_frame = inspect.currentframe()
95
        co_filename = current_frame.f_back.f_back.f_code.co_filename  # type: ignore[union-attr]
96
        self._call_dir = os.path.dirname(co_filename)
97
        self._prepare_requirements()
98
        self._prepare_dockerfile()
99

100
    def build_commands(self) -> List[str]:
101
        """Override to run some commands before your requirements are installed.
102

103
        .. note:: If you provide your own dockerfile, this would be ignored.
104

105
        Example:
106

107
            from dataclasses import dataclass
108
            from lightning.app import BuildConfig
109

110
            @dataclass
111
            class MyOwnBuildConfig(BuildConfig):
112

113
                def build_commands(self):
114
                    return ["apt-get install libsparsehash-dev"]
115

116
            BuildConfig(requirements=["git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0"])
117

118
        """
119
        return []
120

121
    def on_work_init(self, work: "LightningWork", cloud_compute: Optional["CloudCompute"] = None) -> None:
122
        """Override with your own logic to load the requirements or dockerfile."""
123
        found_requirements = self._find_requirements(work)
124
        if self.requirements:
125
            if found_requirements and self.requirements != found_requirements:
126
                # notify the user of this silent behaviour
127
                logger.info(
128
                    f"A 'requirements.txt' exists with {found_requirements} but {self.requirements} was passed to"
129
                    f" the `{type(self).__name__}` in {work.name!r}. The `requirements.txt` file will be ignored."
130
                )
131
        else:
132
            self.requirements = found_requirements
133
        self._prepare_requirements()
134

135
        found_dockerfile = self._find_dockerfile(work)
136
        if self.dockerfile:
137
            if found_dockerfile and self.dockerfile != found_dockerfile:
138
                # notify the user of this silent behaviour
139
                logger.info(
140
                    f"A Dockerfile exists at {found_dockerfile!r} but {self.dockerfile!r} was passed to"
141
                    f" the `{type(self).__name__}` in {work.name!r}. {found_dockerfile!r}` will be ignored."
142
                )
143
        else:
144
            self.dockerfile = found_dockerfile
145
        self._prepare_dockerfile()
146

147
    def _find_requirements(self, work: "LightningWork", filename: str = "requirements.txt") -> List[str]:
148
        # 1. Get work file
149
        file = _get_work_file(work)
150
        if file is None:
151
            return []
152
        # 2. Try to find a requirement file associated the file.
153
        dirname = os.path.dirname(file)
154
        try:
155
            requirements = load_requirements(dirname, filename)
156
        except NotADirectoryError:
157
            return []
158
        return [r for r in requirements if r != "lightning"]
159

160
    def _find_dockerfile(self, work: "LightningWork", filename: str = "Dockerfile") -> Optional[str]:
161
        # 1. Get work file
162
        file = _get_work_file(work)
163
        if file is None:
164
            return None
165
        # 2. Check for Dockerfile.
166
        dirname = os.path.dirname(file)
167
        dockerfile = os.path.join(dirname, filename)
168
        if os.path.isfile(dockerfile):
169
            return dockerfile
170
        return None
171

172
    def _prepare_requirements(self) -> None:
173
        requirements = []
174
        for req in self.requirements:
175
            # 1. Check for relative path
176
            path = os.path.join(self._call_dir, req)
177
            if os.path.isfile(path):
178
                try:
179
                    new_requirements = load_requirements(self._call_dir, req)
180
                except NotADirectoryError:
181
                    continue
182
                requirements.extend(new_requirements)
183
            else:
184
                requirements.append(req)
185
        self.requirements = requirements
186

187
    def _prepare_dockerfile(self) -> None:
188
        if isinstance(self.dockerfile, (str, Path)):
189
            path = os.path.join(self._call_dir, self.dockerfile)
190
            if os.path.exists(path):
191
                with open(path) as f:
192
                    self.dockerfile = _Dockerfile(path, f.readlines())
193

194
    def to_dict(self) -> Dict:
195
        return {"__build_config__": asdict(self)}
196

197
    @classmethod
198
    def from_dict(cls, d: Dict) -> Self:
199
        return cls(**d["__build_config__"])
200

201

202
def _get_work_file(work: "LightningWork") -> Optional[str]:
203
    cls = work.__class__
204
    try:
205
        return inspect.getfile(cls)
206
    except TypeError:
207
        logger.debug(f"The {cls.__name__} file couldn't be found.")
208
        return None
209

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.