pytorch-lightning

Форк
0
508 строк · 20.5 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import glob
15
import logging
16
import os
17
import pathlib
18
import re
19
import shutil
20
import tarfile
21
import tempfile
22
import urllib.request
23
from distutils.version import LooseVersion
24
from itertools import chain
25
from os.path import dirname, isfile
26
from pathlib import Path
27
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
28

29
from pkg_resources import Requirement, parse_requirements, yield_lines
30

31
REQUIREMENT_FILES = {
32
    "pytorch": (
33
        "requirements/pytorch/base.txt",
34
        "requirements/pytorch/extra.txt",
35
        "requirements/pytorch/strategies.txt",
36
        "requirements/pytorch/examples.txt",
37
    ),
38
    "app": (
39
        "requirements/app/app.txt",
40
        "requirements/app/cloud.txt",
41
        "requirements/app/ui.txt",
42
    ),
43
    "fabric": (
44
        "requirements/fabric/base.txt",
45
        "requirements/fabric/strategies.txt",
46
    ),
47
    "data": ("requirements/data/data.txt",),
48
}
49
REQUIREMENT_FILES_ALL = list(chain(*REQUIREMENT_FILES.values()))
50

51
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
52

53

54
class _RequirementWithComment(Requirement):
55
    strict_string = "# strict"
56

57
    def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
58
        super().__init__(*args, **kwargs)
59
        self.comment = comment
60
        assert pip_argument is None or pip_argument  # sanity check that it's not an empty str
61
        self.pip_argument = pip_argument
62
        self.strict = self.strict_string in comment.lower()
63

64
    def adjust(self, unfreeze: str) -> str:
65
        """Remove version restrictions unless they are strict.
66

67
        >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# anything").adjust("none")
68
        'arrow<=1.2.2,>=1.2.0'
69
        >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# strict").adjust("none")
70
        'arrow<=1.2.2,>=1.2.0  # strict'
71
        >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# my name").adjust("all")
72
        'arrow>=1.2.0'
73
        >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("all")
74
        'arrow<=1.2.2,>=1.2.0  # strict'
75
        >>> _RequirementWithComment("arrow").adjust("all")
76
        'arrow'
77
        >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# cool").adjust("major")
78
        'arrow<2.0,>=1.2.0'
79
        >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("major")
80
        'arrow<=1.2.2,>=1.2.0  # strict'
81
        >>> _RequirementWithComment("arrow>=1.2.0").adjust("major")
82
        'arrow>=1.2.0'
83
        >>> _RequirementWithComment("arrow").adjust("major")
84
        'arrow'
85

86
        """
87
        out = str(self)
88
        if self.strict:
89
            return f"{out}  {self.strict_string}"
90
        if unfreeze == "major":
91
            for operator, version in self.specs:
92
                if operator in ("<", "<="):
93
                    major = LooseVersion(version).version[0]
94
                    # replace upper bound with major version increased by one
95
                    return out.replace(f"{operator}{version}", f"<{major + 1}.0")
96
        elif unfreeze == "all":
97
            for operator, version in self.specs:
98
                if operator in ("<", "<="):
99
                    # drop upper bound
100
                    return out.replace(f"{operator}{version},", "")
101
        elif unfreeze != "none":
102
            raise ValueError(f"Unexpected unfreeze: {unfreeze!r} value.")
103
        return out
104

105

106
def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]:
107
    """Adapted from `pkg_resources.parse_requirements` to include comments.
108

109
    >>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt']
110
    >>> [r.adjust('none') for r in _parse_requirements(txt)]
111
    ['this', 'example', 'foo  # strict', 'thing']
112
    >>> txt = '\\n'.join(txt)
113
    >>> [r.adjust('none') for r in _parse_requirements(txt)]
114
    ['this', 'example', 'foo  # strict', 'thing']
115

116
    """
117
    lines = yield_lines(strs)
118
    pip_argument = None
119
    for line in lines:
120
        # Drop comments -- a hash without a space may be in a URL.
121
        if " #" in line:
122
            comment_pos = line.find(" #")
123
            line, comment = line[:comment_pos], line[comment_pos:]
124
        else:
125
            comment = ""
126
        # If there is a line continuation, drop it, and append the next line.
127
        if line.endswith("\\"):
128
            line = line[:-2].strip()
129
            try:
130
                line += next(lines)
131
            except StopIteration:
132
                return
133
        # If there's a pip argument, save it
134
        if line.startswith("--"):
135
            pip_argument = line
136
            continue
137
        if line.startswith("-r "):
138
            # linked requirement files are unsupported
139
            continue
140
        yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument)
141
        pip_argument = None
142

143

144
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]:
145
    """Loading requirements from a file.
146

147
    >>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
148
    >>> load_requirements(path_req, "docs.txt", unfreeze="major")  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
149
    ['sphinx<...]
150

151
    """
152
    assert unfreeze in {"none", "major", "all"}
153
    path = Path(path_dir) / file_name
154
    if not path.exists():
155
        logging.warning(f"Folder {path_dir} does not have any base requirements.")
156
        return []
157
    assert path.exists(), (path_dir, file_name, path)
158
    text = path.read_text()
159
    return [req.adjust(unfreeze) for req in _parse_requirements(text)]
160

161

162
def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
163
    """Load readme as decribtion.
164

165
    >>> load_readme_description(_PROJECT_ROOT, "", "")  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
166
    '...PyTorch Lightning is just organized PyTorch...'
167

168
    """
169
    path_readme = os.path.join(path_dir, "README.md")
170
    with open(path_readme, encoding="utf-8") as fo:
171
        text = fo.read()
172

173
    # drop images from readme
174
    text = text.replace(
175
        "![PT to PL](docs/source-pytorch/_static/images/general/pl_quick_start_full_compressed.gif)", ""
176
    )
177

178
    # https://github.com/Lightning-AI/lightning/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png
179
    github_source_url = os.path.join(homepage, "raw", version)
180
    # replace relative repository path to absolute link to the release
181
    #  do not replace all "docs" as in the readme we reger some other sources with particular path to docs
182
    text = text.replace(
183
        "docs/source-pytorch/_static/", f"{os.path.join(github_source_url, 'docs/source-app/_static/')}"
184
    )
185

186
    # readthedocs badge
187
    text = text.replace("badge/?version=stable", f"badge/?version={version}")
188
    text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{version}")
189
    # codecov badge
190
    text = text.replace("/branch/master/graph/badge.svg", f"/release/{version}/graph/badge.svg")
191
    # github actions badge
192
    text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={version}")
193
    # azure pipelines badge
194
    text = text.replace("?branchName=master", f"?branchName=refs%2Ftags%2F{version}")
195

196
    skip_begin = r"<!-- following section will be skipped from PyPI description -->"
197
    skip_end = r"<!-- end skipping PyPI description -->"
198
    # todo: wrap content as commented description
199
    return re.sub(rf"{skip_begin}.+?{skip_end}", "<!--  -->", text, flags=re.IGNORECASE + re.DOTALL)
200

201
    # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png
202
    # github_release_url = os.path.join(homepage, "releases", "download", version)
203
    # # download badge and replace url with local file
204
    # text = _parse_for_badge(text, github_release_url)
205

206

207
def distribute_version(src_folder: str, ver_file: str = "version.info") -> None:
208
    """Copy the global version to all packages."""
209
    ls_ver = glob.glob(os.path.join(src_folder, "*", "__version__.py"))
210
    ver_template = os.path.join(src_folder, ver_file)
211
    for fpath in ls_ver:
212
        fpath = os.path.join(os.path.dirname(fpath), ver_file)
213
        print("Distributing the version to", fpath)
214
        if os.path.isfile(fpath):
215
            os.remove(fpath)
216
        shutil.copy2(ver_template, fpath)
217

218

219
def _download_frontend(pkg_path: str, version: str = "v0.0.0"):
220
    """Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
221
    directory."""
222

223
    try:
224
        frontend_dir = pathlib.Path(pkg_path, "ui")
225
        download_dir = tempfile.mkdtemp()
226

227
        shutil.rmtree(frontend_dir, ignore_errors=True)
228
        # TODO: remove this once lightning-ui package is ready as a dependency
229
        frontend_release_url = f"https://lightning-packages.s3.amazonaws.com/ui/{version}.tar.gz"
230
        response = urllib.request.urlopen(frontend_release_url)
231

232
        file = tarfile.open(fileobj=response, mode="r|gz")
233
        file.extractall(path=download_dir)  # noqa: S202
234

235
        shutil.move(download_dir, frontend_dir)
236
        print("The Lightning UI has successfully been downloaded!")
237

238
    # If installing from source without internet connection, we don't want to break the installation
239
    except Exception:
240
        print("The Lightning UI downloading has failed!")
241

242

243
def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requirements: bool = False) -> None:
244
    """Load all base requirements from all particular packages and prune duplicates.
245

246
    >>> _load_aggregate_requirements(os.path.join(_PROJECT_ROOT, "requirements"))
247

248
    """
249
    requires = [
250
        load_requirements(d, unfreeze="none" if freeze_requirements else "major")
251
        for d in glob.glob(os.path.join(req_dir, "*"))
252
        # skip empty folder (git artifacts), and resolving Will's special issue
253
        if os.path.isdir(d) and len(glob.glob(os.path.join(d, "*"))) > 0 and not os.path.basename(d).startswith("_")
254
    ]
255
    if not requires:
256
        return
257
    # TODO: add some smarter version aggregation per each package
258
    requires = sorted(set(chain(*requires)))
259
    with open(os.path.join(req_dir, "base.txt"), "w") as fp:
260
        fp.writelines([ln + os.linesep for ln in requires] + [os.linesep])
261

262

263
def _retrieve_files(directory: str, *ext: str) -> List[str]:
264
    all_files = []
265
    for root, _, files in os.walk(directory):
266
        for fname in files:
267
            if not ext or any(os.path.split(fname)[1].lower().endswith(e) for e in ext):
268
                all_files.append(os.path.join(root, fname))
269

270
    return all_files
271

272

273
def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]:
274
    """Replace imports of standalone package to lightning.
275

276
    >>> lns = [
277
    ...     '"lightning_app"',
278
    ...     "lightning_app",
279
    ...     "lightning_app/",
280
    ...     "delete_cloud_lightning_apps",
281
    ...     "from lightning_app import",
282
    ...     "lightning_apps = []",
283
    ...     "lightning_app and pytorch_lightning are ours",
284
    ...     "def _lightning_app():",
285
    ...     ":class:`~lightning_app.core.flow.LightningFlow`",
286
    ...     "http://pytorch_lightning.ai",
287
    ...     "from lightning import __version__",
288
    ...     "@lightning.ai"
289
    ... ]
290
    >>> mapping = [("lightning_app", "lightning.app"), ("pytorch_lightning", "lightning.pytorch")]
291
    >>> _replace_imports(lns, mapping, lightning_by="lightning_fabric")  # doctest: +NORMALIZE_WHITESPACE
292
    ['"lightning.app"', \
293
     'lightning.app', \
294
     'lightning_app/', \
295
     'delete_cloud_lightning_apps', \
296
     'from lightning.app import', \
297
     'lightning_apps = []', \
298
     'lightning.app and lightning.pytorch are ours', \
299
     'def _lightning_app():', \
300
     ':class:`~lightning.app.core.flow.LightningFlow`', \
301
     'http://pytorch_lightning.ai', \
302
     'from lightning_fabric import __version__', \
303
     '@lightning.ai']
304

305
    """
306
    out = lines[:]
307
    for source_import, target_import in mapping:
308
        for i, ln in enumerate(out):
309
            out[i] = re.sub(
310
                rf"([^_/@]|^){source_import}([^_\w/]|$)",
311
                rf"\1{target_import}\2",
312
                ln,
313
            )
314
            if lightning_by:  # in addition, replace base package
315
                out[i] = out[i].replace("from lightning import ", f"from {lightning_by} import ")
316
                out[i] = out[i].replace("import lightning ", f"import {lightning_by} ")
317
    return out
318

319

320
def copy_replace_imports(
321
    source_dir: str,
322
    source_imports: Sequence[str],
323
    target_imports: Sequence[str],
324
    target_dir: Optional[str] = None,
325
    lightning_by: str = "",
326
) -> None:
327
    """Copy package content with import adjustments."""
328
    print(f"Replacing imports: {locals()}")
329
    assert len(source_imports) == len(target_imports), (
330
        "source and target imports must have the same length, "
331
        f"source: {len(source_imports)}, target: {len(target_imports)}"
332
    )
333
    if target_dir is None:
334
        target_dir = source_dir
335

336
    ls = _retrieve_files(source_dir)
337
    for fp in ls:
338
        fp_new = fp.replace(source_dir, target_dir)
339
        _, ext = os.path.splitext(fp)
340
        if ext in (".png", ".jpg", ".ico"):
341
            os.makedirs(dirname(fp_new), exist_ok=True)
342
            if not isfile(fp_new):
343
                shutil.copy(fp, fp_new)
344
            continue
345
        if ext in (".pyc",):
346
            continue
347
        # Try to parse everything else
348
        with open(fp, encoding="utf-8") as fo:
349
            try:
350
                lines = fo.readlines()
351
            except UnicodeDecodeError:
352
                # a binary file, skip
353
                print(f"Skipped replacing imports for {fp}")
354
                continue
355
        lines = _replace_imports(lines, list(zip(source_imports, target_imports)), lightning_by=lightning_by)
356
        os.makedirs(os.path.dirname(fp_new), exist_ok=True)
357
        with open(fp_new, "w", encoding="utf-8") as fo:
358
            fo.writelines(lines)
359

360

361
def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None:
362
    # replace imports and copy the code
363
    mapping = package_mapping.copy()
364
    mapping.pop("lightning", None)  # pop this key to avoid replacing `lightning` to `lightning.lightning`
365

366
    mapping = {f"lightning.{sp}": sl for sp, sl in mapping.items()}
367
    for pkg_from, pkg_to in mapping.items():
368
        copy_replace_imports(
369
            source_dir=os.path.join(source_dir, pkg_from.replace(".", os.sep)),
370
            # pytorch_lightning uses lightning_fabric, so we need to replace all imports for all directories
371
            source_imports=mapping.keys(),
372
            target_imports=mapping.values(),
373
            target_dir=os.path.join(source_dir, pkg_to.replace(".", os.sep)),
374
            lightning_by=pkg_from,
375
        )
376

377

378
class AssistantCLI:
379
    @staticmethod
380
    def requirements_prune_pkgs(packages: Sequence[str], req_files: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
381
        """Remove some packages from given requirement files."""
382
        if isinstance(req_files, str):
383
            req_files = [req_files]
384
        for req in req_files:
385
            AssistantCLI._prune_packages(req, packages)
386

387
    @staticmethod
388
    def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
389
        """Remove some packages from given requirement files."""
390
        path = Path(req_file)
391
        assert path.exists()
392
        text = path.read_text()
393
        lines = text.splitlines()
394
        final = []
395
        for line in lines:
396
            ln_ = line.strip()
397
            if not ln_ or ln_.startswith("#"):
398
                final.append(line)
399
                continue
400
            req = list(parse_requirements(ln_))[0]
401
            if req.name not in packages:
402
                final.append(line)
403
        print(final)
404
        path.write_text("\n".join(final) + "\n")
405

406
    @staticmethod
407
    def _replace_min(fname: str) -> None:
408
        with open(fname, encoding="utf-8") as fo:
409
            req = fo.read().replace(">=", "==")
410
        with open(fname, "w", encoding="utf-8") as fw:
411
            fw.write(req)
412

413
    @staticmethod
414
    def replace_oldest_ver(requirement_fnames: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
415
        """Replace the min package version by fixed one."""
416
        for fname in requirement_fnames:
417
            print(fname)
418
            AssistantCLI._replace_min(fname)
419

420
    @staticmethod
421
    def copy_replace_imports(
422
        source_dir: str,
423
        source_import: str,
424
        target_import: str,
425
        target_dir: Optional[str] = None,
426
        lightning_by: str = "",
427
    ) -> None:
428
        """Copy package content with import adjustments."""
429
        source_imports = source_import.strip().split(",")
430
        target_imports = target_import.strip().split(",")
431
        copy_replace_imports(
432
            source_dir, source_imports, target_imports, target_dir=target_dir, lightning_by=lightning_by
433
        )
434

435
    @staticmethod
436
    def pull_docs_files(
437
        gh_user_repo: str,
438
        target_dir: str = "docs/source-pytorch/XXX",
439
        checkout: str = "refs/tags/1.0.0",
440
        source_dir: str = "docs/source",
441
        single_page: Optional[str] = None,
442
        as_orphan: bool = False,
443
    ) -> None:
444
        """Pull docs pages from external source and append to local docs.
445

446
        Args:
447
            gh_user_repo: standard GitHub user/repo string
448
            target_dir: relative location inside the docs folder
449
            checkout: specific tag or branch to checkout
450
            source_dir: relative location inside the remote / external repo
451
            single_page: copy only single page from the remote repo and name it as the repo name
452
            as_orphan: append orphan statement to the page
453

454
        """
455
        import zipfile
456

457
        zip_url = f"https://github.com/{gh_user_repo}/archive/{checkout}.zip"
458

459
        with tempfile.TemporaryDirectory() as tmp:
460
            zip_file = os.path.join(tmp, "repo.zip")
461
            try:
462
                urllib.request.urlretrieve(zip_url, zip_file)
463
            except urllib.error.HTTPError:
464
                raise RuntimeError(f"Requesting file '{zip_url}' does not exist or it is just unavailable.")
465

466
            with zipfile.ZipFile(zip_file, "r") as zip_ref:
467
                zip_ref.extractall(tmp)  # noqa: S202
468

469
            zip_dirs = [d for d in glob.glob(os.path.join(tmp, "*")) if os.path.isdir(d)]
470
            # check that the extracted archive has only repo folder
471
            assert len(zip_dirs) == 1
472
            repo_dir = zip_dirs[0]
473

474
            if single_page:  # special case for copying single page
475
                single_page = os.path.join(repo_dir, source_dir, single_page)
476
                assert os.path.isfile(single_page), f"File '{single_page}' does not exist."
477
                name = re.sub(r"lightning[-_]?", "", gh_user_repo.split("/")[-1])
478
                new_rst = os.path.join(_PROJECT_ROOT, target_dir, f"{name}.rst")
479
                AssistantCLI._copy_rst(single_page, new_rst, as_orphan=as_orphan)
480
                return
481
            # continue with copying all pages
482
            ls_pages = glob.glob(os.path.join(repo_dir, source_dir, "*.rst"))
483
            ls_pages += glob.glob(os.path.join(repo_dir, source_dir, "**", "*.rst"))
484
            for rst in ls_pages:
485
                rel_rst = rst.replace(os.path.join(repo_dir, source_dir) + os.path.sep, "")
486
                rel_dir = os.path.dirname(rel_rst)
487
                os.makedirs(os.path.join(_PROJECT_ROOT, target_dir, rel_dir), exist_ok=True)
488
                new_rst = os.path.join(_PROJECT_ROOT, target_dir, rel_rst)
489
                if os.path.isfile(new_rst):
490
                    logging.warning(f"Page {new_rst} already exists in the local tree so it will be skipped.")
491
                    continue
492
                AssistantCLI._copy_rst(rst, new_rst, as_orphan=as_orphan)
493

494
    @staticmethod
495
    def _copy_rst(rst_in, rst_out, as_orphan: bool = False):
496
        """Copy RST page with optional inserting orphan statement."""
497
        with open(rst_in, encoding="utf-8") as fopen:
498
            page = fopen.read()
499
        if as_orphan and ":orphan:" not in page:
500
            page = ":orphan:\n\n" + page
501
        with open(rst_out, "w", encoding="utf-8") as fopen:
502
            fopen.write(page)
503

504

505
if __name__ == "__main__":
506
    import jsonargparse
507

508
    jsonargparse.CLI(AssistantCLI, as_positional=False)
509

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.