4
"""Checks out the nightly development version of PyTorch and installs pre-built
7
You can use this script to check out a new nightly branch with the following::
9
$ ./tools/nightly.py checkout -b my-nightly-branch
10
$ conda activate pytorch-deps
12
Or if you would like to re-use an existing conda environment, you can pass in
13
the regular environment parameters (--name or --prefix)::
15
$ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
16
$ conda activate my-env
18
You can also use this tool to pull the nightly commits into the current branch as
19
well. This can be done with
21
$ ./tools/nightly.py pull -n my-env
22
$ conda activate my-env
24
Pulling will reinstalle the conda dependencies as well as the nightly binaries into
41
from argparse import ArgumentParser
42
from ast import literal_eval
59
LOGGER: Optional[logging.Logger] = None
60
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
61
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
62
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
63
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
64
LOG_DIRNAME_RE = re.compile(
65
r"(\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_" r"[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12}"
67
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
70
class Formatter(logging.Formatter):
71
redactions: Dict[str, str]
73
def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None):
74
super().__init__(fmt, datefmt)
78
def _filter(self, s: str) -> str:
79
s = USERNAME_PASSWORD_RE.sub(r"://<USERNAME>:<PASSWORD>@", s)
80
for needle, replace in self.redactions.items():
81
s = s.replace(needle, replace)
84
def formatMessage(self, record: logging.LogRecord) -> str:
85
if record.levelno == logging.INFO or record.levelno == logging.DEBUG:
87
return record.getMessage()
91
return super().formatMessage(record)
93
def format(self, record: logging.LogRecord) -> str:
94
return self._filter(super().format(record))
96
def redact(self, needle: str, replace: str = "<REDACTED>") -> None:
97
"""Redact specific strings; e.g., authorization tokens. This won't
98
retroactively redact stuff you've already leaked, so make sure
99
you redact things as soon as possible.
105
self.redactions[needle] = replace
109
def logging_base_dir() -> str:
110
meta_dir = os.getcwd()
111
base_dir = os.path.join(meta_dir, "nightly", "log")
112
os.makedirs(base_dir, exist_ok=True)
117
def logging_run_dir() -> str:
118
cur_dir = os.path.join(
120
f"{datetime.datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}",
122
os.makedirs(cur_dir, exist_ok=True)
127
def logging_record_argv() -> None:
128
s = subprocess.list2cmdline(sys.argv)
129
with open(os.path.join(logging_run_dir(), "argv"), "w") as f:
133
def logging_record_exception(e: BaseException) -> None:
134
with open(os.path.join(logging_run_dir(), "exception"), "w") as f:
135
f.write(type(e).__name__)
138
def logging_rotate() -> None:
139
log_base = logging_base_dir()
140
old_logs = os.listdir(log_base)
141
old_logs.sort(reverse=True)
142
for stale_log in old_logs[1000:]:
144
if LOG_DIRNAME_RE.fullmatch(stale_log) is not None:
145
shutil.rmtree(os.path.join(log_base, stale_log))
148
@contextlib.contextmanager
149
def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
150
"""Setup logging. If a failure starts here we won't
151
be able to save the user in a reasonable way.
153
Logging structure: there is one logger (the root logger)
154
and in processes all events. There are two handlers:
155
stderr (INFO) and file handler (DEBUG).
157
formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
158
root_logger = logging.getLogger("conda-pytorch")
159
root_logger.setLevel(logging.DEBUG)
161
console_handler = logging.StreamHandler()
163
console_handler.setLevel(logging.DEBUG)
165
console_handler.setLevel(logging.INFO)
166
console_handler.setFormatter(formatter)
167
root_logger.addHandler(console_handler)
169
log_file = os.path.join(logging_run_dir(), "nightly.log")
171
file_handler = logging.FileHandler(log_file)
172
file_handler.setFormatter(formatter)
173
root_logger.addHandler(file_handler)
174
logging_record_argv()
178
print(f"log file: {log_file}")
180
except Exception as e:
181
logging.exception("Fatal exception")
182
logging_record_exception(e)
183
print(f"log file: {log_file}")
185
except BaseException as e:
189
logging.info("", exc_info=True)
190
logging_record_exception(e)
191
print(f"log file: {log_file}")
195
def check_in_repo() -> Optional[str]:
196
"""Ensures that we are in the PyTorch repo."""
197
if not os.path.isfile("setup.py"):
198
return "Not in root-level PyTorch repo, no setup.py found"
199
with open("setup.py") as f:
201
if "PyTorch" not in s:
202
return "Not in PyTorch repo, 'PyTorch' not found in setup.py"
206
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
207
"""Checks that the branch name can be checked out."""
208
if subcommand != "checkout":
212
return "Branch name to checkout must be supplied with '-b' option"
214
cmd = ["git", "status", "--untracked-files=no", "--porcelain"]
222
return "Need to have clean working tree to checkout!\n\n" + p.stdout
224
cmd = ["git", "show-ref", "--verify", "--quiet", "refs/heads/" + branch]
225
p = subprocess.run(cmd, capture_output=True, check=False)
227
return f"Branch {branch!r} already exists"
231
@contextlib.contextmanager
232
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
233
"""Timed context manager"""
234
start_time = time.time()
236
logger.info("%s took %.3f [s]", prefix, time.time() - start_time)
239
F = TypeVar("F", bound=Callable[..., Any])
242
def timed(prefix: str) -> Callable[[F], F]:
243
"""Decorator for timing functions"""
247
def wrapper(*args: Any, **kwargs: Any) -> Any:
249
logger = cast(logging.Logger, LOGGER)
251
with timer(logger, prefix):
252
return f(*args, **kwargs)
254
return cast(F, wrapper)
259
def _make_channel_args(
260
channels: Iterable[str] = ("pytorch-nightly",),
261
override_channels: bool = False,
264
for channel in channels:
265
args.append("--channel")
267
if override_channels:
268
args.append("--override-channels")
272
@timed("Solving conda environment")
274
name: Optional[str] = None,
275
prefix: Optional[str] = None,
276
channels: Iterable[str] = ("pytorch-nightly",),
277
override_channels: bool = False,
278
) -> Tuple[List[str], str, str, bool, List[str]]:
279
"""Performs the conda solve and splits the deps from the package."""
281
if prefix is not None:
283
env_opts = ["--prefix", prefix]
284
elif name is not None:
286
env_opts = ["--name", name]
290
env_opts = ["--name", "pytorch-deps"]
311
channel_args = _make_channel_args(
312
channels=channels, override_channels=override_channels
314
cmd.extend(channel_args)
315
cmd.extend(SPECS_TO_INSTALL)
316
p = subprocess.run(cmd, capture_output=True, check=True)
318
solve = json.loads(p.stdout)
319
link = solve["actions"]["LINK"]
322
url = URL_FORMAT.format(**pkg)
323
if pkg["name"] == "pytorch":
325
platform = pkg["platform"]
328
return deps, pytorch, platform, existing_env, env_opts
331
@timed("Installing dependencies")
332
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
333
"""Install dependencies to deps environment"""
336
cmd = ["conda", "env", "remove", "--yes"] + env_opts
337
p = subprocess.run(cmd, check=True)
339
inst_opt = "install" if existing_env else "create"
340
cmd = ["conda", inst_opt, "--yes", "--no-deps"] + env_opts + deps
341
p = subprocess.run(cmd, check=True)
344
@timed("Installing pytorch nightly binaries")
345
def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]":
346
"""Install pytorch into a temporary directory"""
347
pytdir = tempfile.TemporaryDirectory()
348
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
349
p = subprocess.run(cmd, check=True)
353
def _site_packages(dirname: str, platform: str) -> str:
354
if platform.startswith("win"):
355
template = os.path.join(dirname, "Lib", "site-packages")
357
template = os.path.join(dirname, "lib", "python*.*", "site-packages")
358
spdir = glob.glob(template)[0]
362
def _ensure_commit(git_sha1: str) -> None:
363
"""Make sure that we actually have the commit locally"""
364
cmd = ["git", "cat-file", "-e", git_sha1 + "^{commit}"]
365
p = subprocess.run(cmd, capture_output=True, check=False)
366
if p.returncode == 0:
370
cmd = ["git", "fetch", "https://github.com/pytorch/pytorch.git", git_sha1]
371
p = subprocess.run(cmd, check=True)
374
def _nightly_version(spdir: str) -> str:
376
version_fname = os.path.join(spdir, "torch", "version.py")
377
with open(version_fname) as f:
378
lines = f.read().splitlines()
380
if not line.startswith("git_version"):
382
git_version = literal_eval(line.partition("=")[2].strip())
385
raise RuntimeError(f"Could not find git_version in {version_fname}")
386
print(f"Found released git version {git_version}")
388
_ensure_commit(git_version)
389
cmd = ["git", "show", "--no-patch", "--format=%s", git_version]
396
m = SHA1_RE.search(p.stdout)
399
f"Could not find nightly release in git history:\n {p.stdout}"
401
nightly_version = m.group(1)
402
print(f"Found nightly release version {nightly_version}")
404
_ensure_commit(nightly_version)
405
return nightly_version
408
@timed("Checking out nightly PyTorch")
409
def checkout_nightly_version(branch: str, spdir: str) -> None:
410
"""Get's the nightly version and then checks it out."""
411
nightly_version = _nightly_version(spdir)
412
cmd = ["git", "checkout", "-b", branch, nightly_version]
413
p = subprocess.run(cmd, check=True)
416
@timed("Pulling nightly PyTorch")
417
def pull_nightly_version(spdir: str) -> None:
418
"""Fetches the nightly version and then merges it ."""
419
nightly_version = _nightly_version(spdir)
420
cmd = ["git", "merge", nightly_version]
421
p = subprocess.run(cmd, check=True)
424
def _get_listing_linux(source_dir: str) -> List[str]:
425
listing = glob.glob(os.path.join(source_dir, "*.so"))
426
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
430
def _get_listing_osx(source_dir: str) -> List[str]:
432
listing = glob.glob(os.path.join(source_dir, "*.so"))
433
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
437
def _get_listing_win(source_dir: str) -> List[str]:
438
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
439
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
440
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
444
def _glob_pyis(d: str) -> Set[str]:
445
search = os.path.join(d, "**", "*.pyi")
446
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
450
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
451
source_pyis = _glob_pyis(source_dir)
452
target_pyis = _glob_pyis(target_dir)
453
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
458
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
459
if platform.startswith("linux"):
460
listing = _get_listing_linux(source_dir)
461
elif platform.startswith("osx"):
462
listing = _get_listing_osx(source_dir)
463
elif platform.startswith("win"):
464
listing = _get_listing_win(source_dir)
466
raise RuntimeError(f"Platform {platform!r} not recognized")
467
listing.extend(_find_missing_pyi(source_dir, target_dir))
468
listing.append(os.path.join(source_dir, "version.py"))
469
listing.append(os.path.join(source_dir, "testing", "_internal", "generated"))
470
listing.append(os.path.join(source_dir, "bin"))
471
listing.append(os.path.join(source_dir, "include"))
475
def _remove_existing(trg: str, is_dir: bool) -> None:
476
if os.path.exists(trg):
487
mover: Callable[[str, str], None],
490
is_dir = os.path.isdir(src)
491
relpath = os.path.relpath(src, source_dir)
492
trg = os.path.join(target_dir, relpath)
493
_remove_existing(trg, is_dir)
496
os.makedirs(trg, exist_ok=True)
497
for root, dirs, files in os.walk(src):
498
relroot = os.path.relpath(root, src)
500
relname = os.path.join(relroot, name)
501
s = os.path.join(src, relname)
502
t = os.path.join(trg, relname)
503
print(f"{verb} {s} -> {t}")
506
relname = os.path.join(relroot, name)
507
os.makedirs(os.path.join(trg, relname), exist_ok=True)
509
print(f"{verb} {src} -> {trg}")
513
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
515
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
518
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
520
_move_single(src, source_dir, target_dir, os.link, "Linking")
523
@timed("Moving nightly files into repo")
524
def move_nightly_files(spdir: str, platform: str) -> None:
525
"""Moves PyTorch files from temporary installed location to repo."""
527
source_dir = os.path.join(spdir, "torch")
528
target_dir = os.path.abspath("torch")
529
listing = _get_listing(source_dir, target_dir, platform)
531
if platform.startswith("win"):
532
_copy_files(listing, source_dir, target_dir)
535
_link_files(listing, source_dir, target_dir)
537
_copy_files(listing, source_dir, target_dir)
540
def _available_envs() -> Dict[str, str]:
541
cmd = ["conda", "env", "list"]
548
lines = p.stdout.splitlines()
550
for line in map(str.strip, lines):
551
if not line or line.startswith("#"):
557
envs[parts[0]] = parts[-1]
561
@timed("Writing pytorch-nightly.pth")
562
def write_pth(env_opts: List[str], platform: str) -> None:
563
"""Writes Python path file for this dir."""
564
env_type, env_dir = env_opts
565
if env_type == "--name":
567
envs = _available_envs()
568
env_dir = envs[env_dir]
569
spdir = _site_packages(env_dir, platform)
570
pth = os.path.join(spdir, "pytorch-nightly.pth")
572
"# This file was autogenerated by PyTorch's tools/nightly.py\n"
573
"# Please delete this file if you no longer need the following development\n"
574
"# version of PyTorch to be importable\n"
577
with open(pth, "w") as f:
583
logger: logging.Logger,
584
subcommand: str = "checkout",
585
branch: Optional[str] = None,
586
name: Optional[str] = None,
587
prefix: Optional[str] = None,
588
channels: Iterable[str] = ("pytorch-nightly",),
589
override_channels: bool = False,
591
"""Development install of PyTorch"""
592
deps, pytorch, platform, existing_env, env_opts = conda_solve(
593
name=name, prefix=prefix, channels=channels, override_channels=override_channels
596
deps_install(deps, existing_env, env_opts)
597
pytdir = pytorch_install(pytorch)
598
spdir = _site_packages(pytdir.name, platform)
599
if subcommand == "checkout":
600
checkout_nightly_version(cast(str, branch), spdir)
601
elif subcommand == "pull":
602
pull_nightly_version(spdir)
604
raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
605
move_nightly_files(spdir, platform)
606
write_pth(env_opts, platform)
609
"-------\nPyTorch Development Environment set up!\nPlease activate to "
610
"enable this environment:\n $ conda activate %s",
615
def make_parser() -> ArgumentParser:
616
p = ArgumentParser("nightly")
618
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
619
co = subcmd.add_parser("checkout", help="checkout a new branch")
623
help="Branch name to checkout",
628
pull = subcmd.add_parser(
629
"pull", help="pulls the nightly commits into the current branch"
637
help="Name of environment",
640
metavar="ENVIRONMENT",
645
help="Full path to environment location (i.e. prefix)",
653
help="Provide debugging info",
659
"--override-channels",
660
help="Do not search default or .condarc channels.",
661
dest="override_channels",
668
help="Additional channel to search for packages. 'pytorch-nightly' will always be prepended to this list.",
676
def main(args: Optional[Sequence[str]] = None) -> None:
677
"""Main entry point"""
680
ns = p.parse_args(args)
681
ns.branch = getattr(ns, "branch", None)
682
status = check_in_repo()
683
status = status or check_branch(ns.subcmd, ns.branch)
686
channels = ["pytorch-nightly"]
688
channels.extend(ns.channels)
689
with logging_manager(debug=ns.verbose) as logger:
692
subcommand=ns.subcmd,
698
override_channels=ns.override_channels,
702
if __name__ == "__main__":