pytorch-lightning

Форк
0
219 строк · 7.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import functools
16
import logging
17
import os
18
import pathlib
19
import shutil
20
import subprocess
21
import sys
22
import tarfile
23
import tempfile
24
import urllib.request
25
from pathlib import Path
26
from typing import Any, Callable, Optional
27

28
from packaging.version import Version
29

30
from lightning.app import _PROJECT_ROOT, _logger, _root_logger
31
from lightning.app import __version__ as version
32
from lightning.app.core.constants import FRONTEND_DIR, PACKAGE_LIGHTNING
33
from lightning.app.utilities.app_helpers import Logger
34
from lightning.app.utilities.git import check_github_repository, get_dir_name
35

36
logger = Logger(__name__)
37

38

39
# FIXME(alecmerdler): Use GitHub release artifacts once the `lightning-ui` repo is public
40
LIGHTNING_FRONTEND_RELEASE_URL = "https://storage.googleapis.com/grid-packages/lightning-ui/v0.0.0/build.tar.gz"
41

42

43
def download_frontend(root: str = _PROJECT_ROOT):
44
    """Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
45
    directory."""
46
    build_dir = "build"
47
    frontend_dir = pathlib.Path(FRONTEND_DIR)
48
    download_dir = tempfile.mkdtemp()
49

50
    shutil.rmtree(frontend_dir, ignore_errors=True)
51

52
    response = urllib.request.urlopen(LIGHTNING_FRONTEND_RELEASE_URL)  # noqa: S310
53

54
    file = tarfile.open(fileobj=response, mode="r|gz")
55
    file.extractall(path=download_dir)  # noqa: S202
56

57
    shutil.move(os.path.join(download_dir, build_dir), frontend_dir)
58
    print("The Lightning UI has successfully been downloaded!")
59

60

61
def _cleanup(*tar_files: str):
62
    for tar_file in tar_files:
63
        shutil.rmtree(os.path.join(_PROJECT_ROOT, "dist"), ignore_errors=True)
64
        os.remove(tar_file)
65

66

67
def _prepare_wheel(path):
68
    with open("log.txt", "w") as logfile:
69
        with subprocess.Popen(
70
            ["rm", "-r", "dist"], stdout=logfile, stderr=logfile, bufsize=0, close_fds=True, cwd=path
71
        ) as proc:
72
            proc.wait()
73

74
        with subprocess.Popen(
75
            ["python", "setup.py", "sdist"],
76
            stdout=logfile,
77
            stderr=logfile,
78
            bufsize=0,
79
            close_fds=True,
80
            cwd=path,
81
        ) as proc:
82
            proc.wait()
83

84
    os.remove("log.txt")
85

86

87
def _copy_tar(project_root, dest: Path) -> str:
88
    dist_dir = os.path.join(project_root, "dist")
89
    tar_files = os.listdir(dist_dir)
90
    assert len(tar_files) == 1
91
    tar_name = tar_files[0]
92
    tar_path = os.path.join(dist_dir, tar_name)
93
    shutil.copy(tar_path, dest)
94
    return tar_name
95

96

97
def get_dist_path_if_editable_install(project_name) -> str:
98
    """Is distribution an editable install - modified version from pip that
99
    fetches egg-info instead of egg-link"""
100
    for path_item in sys.path:
101
        if not os.path.isdir(path_item):
102
            continue
103

104
        egg_info = os.path.join(path_item, project_name + ".egg-info")
105
        if os.path.isdir(egg_info):
106
            return path_item
107
    return ""
108

109

110
def _prepare_lightning_wheels_and_requirements(root: Path, package_name: str = "lightning") -> Optional[Callable]:
111
    """This function determines if lightning is installed in editable mode (for developers) and packages the current
112
    lightning source along with the app.
113

114
    For normal users who install via PyPi or Conda, then this function does not do anything.
115

116
    """
117
    if not get_dist_path_if_editable_install(package_name):
118
        return None
119

120
    os.environ["PACKAGE_NAME"] = "app" if package_name == "lightning" + "_app" else "lightning"
121

122
    # Packaging the Lightning codebase happens only inside the `lightning` repo.
123
    git_dir_name = get_dir_name() if check_github_repository() else None
124

125
    is_lightning = git_dir_name and git_dir_name == package_name
126

127
    if (PACKAGE_LIGHTNING is None and not is_lightning) or PACKAGE_LIGHTNING == "0":
128
        return None
129

130
    download_frontend(_PROJECT_ROOT)
131
    _prepare_wheel(_PROJECT_ROOT)
132

133
    # todo: check why logging.info is missing in outputs
134
    print(f"Packaged Lightning with your application. Version: {version}")
135

136
    tar_name = _copy_tar(_PROJECT_ROOT, root)
137

138
    tar_files = [os.path.join(root, tar_name)]
139

140
    # Don't skip by default
141
    if (PACKAGE_LIGHTNING or is_lightning) and not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "0"))):
142
        # building and copying lightning-cloud wheel if installed in editable mode
143
        lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud")
144
        if lightning_cloud_project_path:
145
            from lightning_cloud.__version__ import __version__ as cloud_version
146

147
            # todo: check why logging.info is missing in outputs
148
            print(f"Packaged Lightning Cloud with your application. Version: {cloud_version}")
149
            _prepare_wheel(lightning_cloud_project_path)
150
            tar_name = _copy_tar(lightning_cloud_project_path, root)
151
            tar_files.append(os.path.join(root, tar_name))
152

153
        lightning_launcher_project_path = get_dist_path_if_editable_install("lightning_launcher")
154
        if lightning_launcher_project_path:
155
            from lightning_launcher.__version__ import __version__ as cloud_version
156

157
            # todo: check why logging.info is missing in outputs
158
            print(f"Packaged Lightning Launcher with your application. Version: {cloud_version}")
159
            _prepare_wheel(lightning_launcher_project_path)
160
            tar_name = _copy_tar(lightning_launcher_project_path, root)
161
            tar_files.append(os.path.join(root, tar_name))
162

163
    return functools.partial(_cleanup, *tar_files)
164

165

166
def _enable_debugging():
167
    tar_file = os.path.join(os.getcwd(), f"lightning-{version}.tar.gz")
168

169
    if not os.path.exists(tar_file):
170
        return
171

172
    _root_logger.propagate = True
173
    _logger.propagate = True
174
    _root_logger.setLevel(logging.DEBUG)
175
    _root_logger.debug("Setting debugging mode.")
176

177

178
def enable_debugging(func: Callable) -> Callable:
179
    """This function is used to transform any print into logger.info calls, so it gets tracked in the cloud."""
180

181
    @functools.wraps(func)
182
    def wrapper(*args: Any, **kwargs: Any) -> Any:
183
        _enable_debugging()
184
        res = func(*args, **kwargs)
185
        _logger.setLevel(logging.INFO)
186
        return res
187

188
    return wrapper
189

190

191
def _fetch_latest_version(package_name: str) -> str:
192
    args = [
193
        sys.executable,
194
        "-m",
195
        "pip",
196
        "install",
197
        f"{package_name}==1000",
198
    ]
199

200
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=0, close_fds=True)
201
    if proc.stdout:
202
        logs = " ".join([line.decode("utf-8") for line in iter(proc.stdout.readline, b"")])
203
        return logs.split(")\n")[0].split(",")[-1].replace(" ", "")
204
    return version
205

206

207
def _verify_lightning_version():
208
    """This function verifies that users are running the latest lightning version for the cloud."""
209
    # TODO (tchaton) Add support for windows
210
    if sys.platform == "win32":
211
        return
212

213
    lightning_latest_version = _fetch_latest_version("lightning")
214

215
    if Version(lightning_latest_version) > Version(version):
216
        raise Exception(
217
            f"You need to use the latest version of Lightning ({lightning_latest_version}) to run in the cloud. "
218
            "Please, run `pip install -U lightning`"
219
        )
220

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.