pytorch-lightning

Форк
0
149 строк · 5.8 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import uuid
17
from contextlib import contextmanager
18
from pathlib import Path
19
from shutil import copytree, rmtree
20
from typing import List, Optional
21

22
from lightning.app.core.constants import DOT_IGNORE_FILENAME, SYS_CUSTOMIZATIONS_SYNC_PATH
23
from lightning.app.source_code.copytree import _IGNORE_FUNCTION, _copytree
24
from lightning.app.source_code.tar import _tar_path
25
from lightning.app.source_code.uploader import FileUploader
26

27

28
class LocalSourceCodeDir:
29
    """Represents the source code directory and provide the utilities to manage it."""
30

31
    def __init__(
32
        self,
33
        path: Path,
34
        ignore_functions: Optional[List[_IGNORE_FUNCTION]] = None,
35
        default_ignore: bool = True,
36
        package_source: bool = True,
37
        sys_customizations_root: Optional[Path] = None,
38
    ) -> None:
39
        if "LIGHTNING_VSCODE_WORKSPACE" in os.environ:
40
            # Don't use home to store the tar ball. This won't play nice with symlinks
41
            self.cache_location: Path = Path("/tmp", ".lightning", "cache", "repositories")
42
        else:
43
            self.cache_location: Path = Path.home() / ".lightning" / "cache" / "repositories"
44

45
        self.path = path
46
        self.ignore_functions = ignore_functions
47
        self.package_source = package_source
48
        self.sys_customizations_root = sys_customizations_root
49

50
        # cache version
51
        self._version: Optional[str] = None
52
        self._non_ignored_files: Optional[List[str]] = None
53

54
        # create global cache location if it doesn't exist
55
        if not self.cache_location.exists():
56
            self.cache_location.mkdir(parents=True, exist_ok=True)
57

58
        # Create a default dotignore if requested and it doesn't exist
59
        if default_ignore and not (path / DOT_IGNORE_FILENAME).is_file():
60
            with open(path / DOT_IGNORE_FILENAME, "w") as f:
61
                f.write("venv/\n")
62
                if (path / "bin" / "activate").is_file() or (path / "pyvenv.cfg").is_file():
63
                    # the user is developing inside venv
64
                    f.write("bin/\ninclude/\nlib/\npyvenv.cfg\n")
65

66
        # clean old cache entries
67
        self._prune_cache()
68

69
    @property
70
    def files(self) -> List[str]:
71
        """Returns a set of files that are not ignored by .lightningignore."""
72
        if self._non_ignored_files is None:
73
            if self.package_source:
74
                self._non_ignored_files = _copytree(self.path, "", ignore_functions=self.ignore_functions, dry_run=True)
75
            else:
76
                self._non_ignored_files = []
77
        return self._non_ignored_files
78

79
    @property
80
    def version(self):
81
        """Calculates the checksum of a local path."""
82
        # cache value to prevent doing this over again
83
        if self._version is not None:
84
            return self._version
85

86
        # create a random version ID and store it
87
        self._version = uuid.uuid4().hex
88
        return self._version
89

90
    @property
91
    def package_path(self):
92
        """Location to tarball in local cache."""
93
        filename = f"{self.version}.tar.gz"
94
        return self.cache_location / filename
95

96
    @contextmanager
97
    def packaging_session(self) -> Path:
98
        """Creates a local directory with source code that is used for creating a local source-code package."""
99
        session_path = self.cache_location / "packaging_sessions" / self.version
100
        try:
101
            rmtree(session_path, ignore_errors=True)
102
            if self.package_source:
103
                _copytree(self.path, session_path, ignore_functions=self.ignore_functions)
104
            if self.sys_customizations_root is not None:
105
                path_to_sync = Path(session_path, SYS_CUSTOMIZATIONS_SYNC_PATH)
106
                copytree(self.sys_customizations_root, path_to_sync, dirs_exist_ok=True)
107
            yield session_path
108
        finally:
109
            rmtree(session_path, ignore_errors=True)
110

111
    def _prune_cache(self) -> None:
112
        """Prunes cache; only keeps the 10 most recent items."""
113
        packages = sorted(self.cache_location.iterdir(), key=os.path.getmtime)
114
        for package in packages[10:]:
115
            if package.is_file():
116
                package.unlink()
117

118
    def package(self) -> Path:
119
        """Packages local path using tar."""
120
        if self.package_path.exists():
121
            return self.package_path
122
        # create a packaging session if not available
123
        with self.packaging_session() as session_path:
124
            _tar_path(source_path=session_path, target_file=str(self.package_path), compression=True)
125
        return self.package_path
126

127
    def upload(self, url: str) -> None:
128
        """Uploads package to URL, usually pre-signed UR.
129

130
        Notes
131
        -----
132
        Since we do not use multipart uploads here, we cannot upload any
133
        packaged repository files which have a size > 2GB.
134

135
        This limitation should be removed during the datastore upload redesign
136

137
        """
138
        if self.package_path.stat().st_size > 2e9:
139
            raise OSError(
140
                "cannot upload directory code whose total fize size is greater than 2GB (2e9 bytes)"
141
            ) from None
142

143
        uploader = FileUploader(
144
            presigned_url=url,
145
            source_file=str(self.package_path),
146
            name=self.package_path.name,
147
            total_size=self.package_path.stat().st_size,
148
        )
149
        uploader.upload()
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.