pytorch-lightning
149 строк · 5.8 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os
16import uuid
17from contextlib import contextmanager
18from pathlib import Path
19from shutil import copytree, rmtree
20from typing import List, Optional
21
22from lightning.app.core.constants import DOT_IGNORE_FILENAME, SYS_CUSTOMIZATIONS_SYNC_PATH
23from lightning.app.source_code.copytree import _IGNORE_FUNCTION, _copytree
24from lightning.app.source_code.tar import _tar_path
25from lightning.app.source_code.uploader import FileUploader
26
27
28class LocalSourceCodeDir:
29"""Represents the source code directory and provide the utilities to manage it."""
30
31def __init__(
32self,
33path: Path,
34ignore_functions: Optional[List[_IGNORE_FUNCTION]] = None,
35default_ignore: bool = True,
36package_source: bool = True,
37sys_customizations_root: Optional[Path] = None,
38) -> None:
39if "LIGHTNING_VSCODE_WORKSPACE" in os.environ:
40# Don't use home to store the tar ball. This won't play nice with symlinks
41self.cache_location: Path = Path("/tmp", ".lightning", "cache", "repositories")
42else:
43self.cache_location: Path = Path.home() / ".lightning" / "cache" / "repositories"
44
45self.path = path
46self.ignore_functions = ignore_functions
47self.package_source = package_source
48self.sys_customizations_root = sys_customizations_root
49
50# cache version
51self._version: Optional[str] = None
52self._non_ignored_files: Optional[List[str]] = None
53
54# create global cache location if it doesn't exist
55if not self.cache_location.exists():
56self.cache_location.mkdir(parents=True, exist_ok=True)
57
58# Create a default dotignore if requested and it doesn't exist
59if default_ignore and not (path / DOT_IGNORE_FILENAME).is_file():
60with open(path / DOT_IGNORE_FILENAME, "w") as f:
61f.write("venv/\n")
62if (path / "bin" / "activate").is_file() or (path / "pyvenv.cfg").is_file():
63# the user is developing inside venv
64f.write("bin/\ninclude/\nlib/\npyvenv.cfg\n")
65
66# clean old cache entries
67self._prune_cache()
68
69@property
70def files(self) -> List[str]:
71"""Returns a set of files that are not ignored by .lightningignore."""
72if self._non_ignored_files is None:
73if self.package_source:
74self._non_ignored_files = _copytree(self.path, "", ignore_functions=self.ignore_functions, dry_run=True)
75else:
76self._non_ignored_files = []
77return self._non_ignored_files
78
79@property
80def version(self):
81"""Calculates the checksum of a local path."""
82# cache value to prevent doing this over again
83if self._version is not None:
84return self._version
85
86# create a random version ID and store it
87self._version = uuid.uuid4().hex
88return self._version
89
90@property
91def package_path(self):
92"""Location to tarball in local cache."""
93filename = f"{self.version}.tar.gz"
94return self.cache_location / filename
95
96@contextmanager
97def packaging_session(self) -> Path:
98"""Creates a local directory with source code that is used for creating a local source-code package."""
99session_path = self.cache_location / "packaging_sessions" / self.version
100try:
101rmtree(session_path, ignore_errors=True)
102if self.package_source:
103_copytree(self.path, session_path, ignore_functions=self.ignore_functions)
104if self.sys_customizations_root is not None:
105path_to_sync = Path(session_path, SYS_CUSTOMIZATIONS_SYNC_PATH)
106copytree(self.sys_customizations_root, path_to_sync, dirs_exist_ok=True)
107yield session_path
108finally:
109rmtree(session_path, ignore_errors=True)
110
111def _prune_cache(self) -> None:
112"""Prunes cache; only keeps the 10 most recent items."""
113packages = sorted(self.cache_location.iterdir(), key=os.path.getmtime)
114for package in packages[10:]:
115if package.is_file():
116package.unlink()
117
118def package(self) -> Path:
119"""Packages local path using tar."""
120if self.package_path.exists():
121return self.package_path
122# create a packaging session if not available
123with self.packaging_session() as session_path:
124_tar_path(source_path=session_path, target_file=str(self.package_path), compression=True)
125return self.package_path
126
127def upload(self, url: str) -> None:
128"""Uploads package to URL, usually pre-signed UR.
129
130Notes
131-----
132Since we do not use multipart uploads here, we cannot upload any
133packaged repository files which have a size > 2GB.
134
135This limitation should be removed during the datastore upload redesign
136
137"""
138if self.package_path.stat().st_size > 2e9:
139raise OSError(
140"cannot upload directory code whose total fize size is greater than 2GB (2e9 bytes)"
141) from None
142
143uploader = FileUploader(
144presigned_url=url,
145source_file=str(self.package_path),
146name=self.package_path.name,
147total_size=self.package_path.stat().st_size,
148)
149uploader.upload()
150