pytorch-lightning
508 строк · 20.5 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import glob15import logging16import os17import pathlib18import re19import shutil20import tarfile21import tempfile22import urllib.request23from distutils.version import LooseVersion24from itertools import chain25from os.path import dirname, isfile26from pathlib import Path27from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union28
29from pkg_resources import Requirement, parse_requirements, yield_lines30
31REQUIREMENT_FILES = {32"pytorch": (33"requirements/pytorch/base.txt",34"requirements/pytorch/extra.txt",35"requirements/pytorch/strategies.txt",36"requirements/pytorch/examples.txt",37),38"app": (39"requirements/app/app.txt",40"requirements/app/cloud.txt",41"requirements/app/ui.txt",42),43"fabric": (44"requirements/fabric/base.txt",45"requirements/fabric/strategies.txt",46),47"data": ("requirements/data/data.txt",),48}
49REQUIREMENT_FILES_ALL = list(chain(*REQUIREMENT_FILES.values()))50
51_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))52
53
54class _RequirementWithComment(Requirement):55strict_string = "# strict"56
57def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:58super().__init__(*args, **kwargs)59self.comment = comment60assert pip_argument is None or pip_argument # sanity check that it's not an empty str61self.pip_argument = pip_argument62self.strict = self.strict_string in comment.lower()63
64def adjust(self, unfreeze: str) -> str:65"""Remove version restrictions unless they are strict.66
67>>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# anything").adjust("none")
68'arrow<=1.2.2,>=1.2.0'
69>>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# strict").adjust("none")
70'arrow<=1.2.2,>=1.2.0 # strict'
71>>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# my name").adjust("all")
72'arrow>=1.2.0'
73>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("all")
74'arrow<=1.2.2,>=1.2.0 # strict'
75>>> _RequirementWithComment("arrow").adjust("all")
76'arrow'
77>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# cool").adjust("major")
78'arrow<2.0,>=1.2.0'
79>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("major")
80'arrow<=1.2.2,>=1.2.0 # strict'
81>>> _RequirementWithComment("arrow>=1.2.0").adjust("major")
82'arrow>=1.2.0'
83>>> _RequirementWithComment("arrow").adjust("major")
84'arrow'
85
86"""
87out = str(self)88if self.strict:89return f"{out} {self.strict_string}"90if unfreeze == "major":91for operator, version in self.specs:92if operator in ("<", "<="):93major = LooseVersion(version).version[0]94# replace upper bound with major version increased by one95return out.replace(f"{operator}{version}", f"<{major + 1}.0")96elif unfreeze == "all":97for operator, version in self.specs:98if operator in ("<", "<="):99# drop upper bound100return out.replace(f"{operator}{version},", "")101elif unfreeze != "none":102raise ValueError(f"Unexpected unfreeze: {unfreeze!r} value.")103return out104
105
106def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]:107"""Adapted from `pkg_resources.parse_requirements` to include comments.108
109>>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt']
110>>> [r.adjust('none') for r in _parse_requirements(txt)]
111['this', 'example', 'foo # strict', 'thing']
112>>> txt = '\\n'.join(txt)
113>>> [r.adjust('none') for r in _parse_requirements(txt)]
114['this', 'example', 'foo # strict', 'thing']
115
116"""
117lines = yield_lines(strs)118pip_argument = None119for line in lines:120# Drop comments -- a hash without a space may be in a URL.121if " #" in line:122comment_pos = line.find(" #")123line, comment = line[:comment_pos], line[comment_pos:]124else:125comment = ""126# If there is a line continuation, drop it, and append the next line.127if line.endswith("\\"):128line = line[:-2].strip()129try:130line += next(lines)131except StopIteration:132return133# If there's a pip argument, save it134if line.startswith("--"):135pip_argument = line136continue137if line.startswith("-r "):138# linked requirement files are unsupported139continue140yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument)141pip_argument = None142
143
144def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]:145"""Loading requirements from a file.146
147>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
148>>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
149['sphinx<...]
150
151"""
152assert unfreeze in {"none", "major", "all"}153path = Path(path_dir) / file_name154if not path.exists():155logging.warning(f"Folder {path_dir} does not have any base requirements.")156return []157assert path.exists(), (path_dir, file_name, path)158text = path.read_text()159return [req.adjust(unfreeze) for req in _parse_requirements(text)]160
161
162def load_readme_description(path_dir: str, homepage: str, version: str) -> str:163"""Load readme as decribtion.164
165>>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
166'...PyTorch Lightning is just organized PyTorch...'
167
168"""
169path_readme = os.path.join(path_dir, "README.md")170with open(path_readme, encoding="utf-8") as fo:171text = fo.read()172
173# drop images from readme174text = text.replace(175"![PT to PL](docs/source-pytorch/_static/images/general/pl_quick_start_full_compressed.gif)", ""176)177
178# https://github.com/Lightning-AI/lightning/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png179github_source_url = os.path.join(homepage, "raw", version)180# replace relative repository path to absolute link to the release181# do not replace all "docs" as in the readme we reger some other sources with particular path to docs182text = text.replace(183"docs/source-pytorch/_static/", f"{os.path.join(github_source_url, 'docs/source-app/_static/')}"184)185
186# readthedocs badge187text = text.replace("badge/?version=stable", f"badge/?version={version}")188text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{version}")189# codecov badge190text = text.replace("/branch/master/graph/badge.svg", f"/release/{version}/graph/badge.svg")191# github actions badge192text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={version}")193# azure pipelines badge194text = text.replace("?branchName=master", f"?branchName=refs%2Ftags%2F{version}")195
196skip_begin = r"<!-- following section will be skipped from PyPI description -->"197skip_end = r"<!-- end skipping PyPI description -->"198# todo: wrap content as commented description199return re.sub(rf"{skip_begin}.+?{skip_end}", "<!-- -->", text, flags=re.IGNORECASE + re.DOTALL)200
201# # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png202# github_release_url = os.path.join(homepage, "releases", "download", version)203# # download badge and replace url with local file204# text = _parse_for_badge(text, github_release_url)205
206
207def distribute_version(src_folder: str, ver_file: str = "version.info") -> None:208"""Copy the global version to all packages."""209ls_ver = glob.glob(os.path.join(src_folder, "*", "__version__.py"))210ver_template = os.path.join(src_folder, ver_file)211for fpath in ls_ver:212fpath = os.path.join(os.path.dirname(fpath), ver_file)213print("Distributing the version to", fpath)214if os.path.isfile(fpath):215os.remove(fpath)216shutil.copy2(ver_template, fpath)217
218
219def _download_frontend(pkg_path: str, version: str = "v0.0.0"):220"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct221directory."""
222
223try:224frontend_dir = pathlib.Path(pkg_path, "ui")225download_dir = tempfile.mkdtemp()226
227shutil.rmtree(frontend_dir, ignore_errors=True)228# TODO: remove this once lightning-ui package is ready as a dependency229frontend_release_url = f"https://lightning-packages.s3.amazonaws.com/ui/{version}.tar.gz"230response = urllib.request.urlopen(frontend_release_url)231
232file = tarfile.open(fileobj=response, mode="r|gz")233file.extractall(path=download_dir) # noqa: S202234
235shutil.move(download_dir, frontend_dir)236print("The Lightning UI has successfully been downloaded!")237
238# If installing from source without internet connection, we don't want to break the installation239except Exception:240print("The Lightning UI downloading has failed!")241
242
243def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requirements: bool = False) -> None:244"""Load all base requirements from all particular packages and prune duplicates.245
246>>> _load_aggregate_requirements(os.path.join(_PROJECT_ROOT, "requirements"))
247
248"""
249requires = [250load_requirements(d, unfreeze="none" if freeze_requirements else "major")251for d in glob.glob(os.path.join(req_dir, "*"))252# skip empty folder (git artifacts), and resolving Will's special issue253if os.path.isdir(d) and len(glob.glob(os.path.join(d, "*"))) > 0 and not os.path.basename(d).startswith("_")254]255if not requires:256return257# TODO: add some smarter version aggregation per each package258requires = sorted(set(chain(*requires)))259with open(os.path.join(req_dir, "base.txt"), "w") as fp:260fp.writelines([ln + os.linesep for ln in requires] + [os.linesep])261
262
263def _retrieve_files(directory: str, *ext: str) -> List[str]:264all_files = []265for root, _, files in os.walk(directory):266for fname in files:267if not ext or any(os.path.split(fname)[1].lower().endswith(e) for e in ext):268all_files.append(os.path.join(root, fname))269
270return all_files271
272
273def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]:274"""Replace imports of standalone package to lightning.275
276>>> lns = [
277... '"lightning_app"',
278... "lightning_app",
279... "lightning_app/",
280... "delete_cloud_lightning_apps",
281... "from lightning_app import",
282... "lightning_apps = []",
283... "lightning_app and pytorch_lightning are ours",
284... "def _lightning_app():",
285... ":class:`~lightning_app.core.flow.LightningFlow`",
286... "http://pytorch_lightning.ai",
287... "from lightning import __version__",
288... "@lightning.ai"
289... ]
290>>> mapping = [("lightning_app", "lightning.app"), ("pytorch_lightning", "lightning.pytorch")]
291>>> _replace_imports(lns, mapping, lightning_by="lightning_fabric") # doctest: +NORMALIZE_WHITESPACE
292['"lightning.app"', \
293'lightning.app', \
294'lightning_app/', \
295'delete_cloud_lightning_apps', \
296'from lightning.app import', \
297'lightning_apps = []', \
298'lightning.app and lightning.pytorch are ours', \
299'def _lightning_app():', \
300':class:`~lightning.app.core.flow.LightningFlow`', \
301'http://pytorch_lightning.ai', \
302'from lightning_fabric import __version__', \
303'@lightning.ai']
304
305"""
306out = lines[:]307for source_import, target_import in mapping:308for i, ln in enumerate(out):309out[i] = re.sub(310rf"([^_/@]|^){source_import}([^_\w/]|$)",311rf"\1{target_import}\2",312ln,313)314if lightning_by: # in addition, replace base package315out[i] = out[i].replace("from lightning import ", f"from {lightning_by} import ")316out[i] = out[i].replace("import lightning ", f"import {lightning_by} ")317return out318
319
320def copy_replace_imports(321source_dir: str,322source_imports: Sequence[str],323target_imports: Sequence[str],324target_dir: Optional[str] = None,325lightning_by: str = "",326) -> None:327"""Copy package content with import adjustments."""328print(f"Replacing imports: {locals()}")329assert len(source_imports) == len(target_imports), (330"source and target imports must have the same length, "331f"source: {len(source_imports)}, target: {len(target_imports)}"332)333if target_dir is None:334target_dir = source_dir335
336ls = _retrieve_files(source_dir)337for fp in ls:338fp_new = fp.replace(source_dir, target_dir)339_, ext = os.path.splitext(fp)340if ext in (".png", ".jpg", ".ico"):341os.makedirs(dirname(fp_new), exist_ok=True)342if not isfile(fp_new):343shutil.copy(fp, fp_new)344continue345if ext in (".pyc",):346continue347# Try to parse everything else348with open(fp, encoding="utf-8") as fo:349try:350lines = fo.readlines()351except UnicodeDecodeError:352# a binary file, skip353print(f"Skipped replacing imports for {fp}")354continue355lines = _replace_imports(lines, list(zip(source_imports, target_imports)), lightning_by=lightning_by)356os.makedirs(os.path.dirname(fp_new), exist_ok=True)357with open(fp_new, "w", encoding="utf-8") as fo:358fo.writelines(lines)359
360
361def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None:362# replace imports and copy the code363mapping = package_mapping.copy()364mapping.pop("lightning", None) # pop this key to avoid replacing `lightning` to `lightning.lightning`365
366mapping = {f"lightning.{sp}": sl for sp, sl in mapping.items()}367for pkg_from, pkg_to in mapping.items():368copy_replace_imports(369source_dir=os.path.join(source_dir, pkg_from.replace(".", os.sep)),370# pytorch_lightning uses lightning_fabric, so we need to replace all imports for all directories371source_imports=mapping.keys(),372target_imports=mapping.values(),373target_dir=os.path.join(source_dir, pkg_to.replace(".", os.sep)),374lightning_by=pkg_from,375)376
377
378class AssistantCLI:379@staticmethod380def requirements_prune_pkgs(packages: Sequence[str], req_files: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:381"""Remove some packages from given requirement files."""382if isinstance(req_files, str):383req_files = [req_files]384for req in req_files:385AssistantCLI._prune_packages(req, packages)386
387@staticmethod388def _prune_packages(req_file: str, packages: Sequence[str]) -> None:389"""Remove some packages from given requirement files."""390path = Path(req_file)391assert path.exists()392text = path.read_text()393lines = text.splitlines()394final = []395for line in lines:396ln_ = line.strip()397if not ln_ or ln_.startswith("#"):398final.append(line)399continue400req = list(parse_requirements(ln_))[0]401if req.name not in packages:402final.append(line)403print(final)404path.write_text("\n".join(final) + "\n")405
406@staticmethod407def _replace_min(fname: str) -> None:408with open(fname, encoding="utf-8") as fo:409req = fo.read().replace(">=", "==")410with open(fname, "w", encoding="utf-8") as fw:411fw.write(req)412
413@staticmethod414def replace_oldest_ver(requirement_fnames: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:415"""Replace the min package version by fixed one."""416for fname in requirement_fnames:417print(fname)418AssistantCLI._replace_min(fname)419
420@staticmethod421def copy_replace_imports(422source_dir: str,423source_import: str,424target_import: str,425target_dir: Optional[str] = None,426lightning_by: str = "",427) -> None:428"""Copy package content with import adjustments."""429source_imports = source_import.strip().split(",")430target_imports = target_import.strip().split(",")431copy_replace_imports(432source_dir, source_imports, target_imports, target_dir=target_dir, lightning_by=lightning_by433)434
435@staticmethod436def pull_docs_files(437gh_user_repo: str,438target_dir: str = "docs/source-pytorch/XXX",439checkout: str = "refs/tags/1.0.0",440source_dir: str = "docs/source",441single_page: Optional[str] = None,442as_orphan: bool = False,443) -> None:444"""Pull docs pages from external source and append to local docs.445
446Args:
447gh_user_repo: standard GitHub user/repo string
448target_dir: relative location inside the docs folder
449checkout: specific tag or branch to checkout
450source_dir: relative location inside the remote / external repo
451single_page: copy only single page from the remote repo and name it as the repo name
452as_orphan: append orphan statement to the page
453
454"""
455import zipfile456
457zip_url = f"https://github.com/{gh_user_repo}/archive/{checkout}.zip"458
459with tempfile.TemporaryDirectory() as tmp:460zip_file = os.path.join(tmp, "repo.zip")461try:462urllib.request.urlretrieve(zip_url, zip_file)463except urllib.error.HTTPError:464raise RuntimeError(f"Requesting file '{zip_url}' does not exist or it is just unavailable.")465
466with zipfile.ZipFile(zip_file, "r") as zip_ref:467zip_ref.extractall(tmp) # noqa: S202468
469zip_dirs = [d for d in glob.glob(os.path.join(tmp, "*")) if os.path.isdir(d)]470# check that the extracted archive has only repo folder471assert len(zip_dirs) == 1472repo_dir = zip_dirs[0]473
474if single_page: # special case for copying single page475single_page = os.path.join(repo_dir, source_dir, single_page)476assert os.path.isfile(single_page), f"File '{single_page}' does not exist."477name = re.sub(r"lightning[-_]?", "", gh_user_repo.split("/")[-1])478new_rst = os.path.join(_PROJECT_ROOT, target_dir, f"{name}.rst")479AssistantCLI._copy_rst(single_page, new_rst, as_orphan=as_orphan)480return481# continue with copying all pages482ls_pages = glob.glob(os.path.join(repo_dir, source_dir, "*.rst"))483ls_pages += glob.glob(os.path.join(repo_dir, source_dir, "**", "*.rst"))484for rst in ls_pages:485rel_rst = rst.replace(os.path.join(repo_dir, source_dir) + os.path.sep, "")486rel_dir = os.path.dirname(rel_rst)487os.makedirs(os.path.join(_PROJECT_ROOT, target_dir, rel_dir), exist_ok=True)488new_rst = os.path.join(_PROJECT_ROOT, target_dir, rel_rst)489if os.path.isfile(new_rst):490logging.warning(f"Page {new_rst} already exists in the local tree so it will be skipped.")491continue492AssistantCLI._copy_rst(rst, new_rst, as_orphan=as_orphan)493
494@staticmethod495def _copy_rst(rst_in, rst_out, as_orphan: bool = False):496"""Copy RST page with optional inserting orphan statement."""497with open(rst_in, encoding="utf-8") as fopen:498page = fopen.read()499if as_orphan and ":orphan:" not in page:500page = ":orphan:\n\n" + page501with open(rst_out, "w", encoding="utf-8") as fopen:502fopen.write(page)503
504
505if __name__ == "__main__":506import jsonargparse507
508jsonargparse.CLI(AssistantCLI, as_positional=False)509