7
from datetime import datetime
8
from distutils.util import strtobool
9
from pathlib import Path
12
LEADING_V_PATTERN = re.compile("^v")
13
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
14
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
17
class NoGitTagException(Exception):
21
def get_pytorch_root() -> Path:
23
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
30
root = get_pytorch_root()
33
subprocess.check_output(["git", "describe", "--tags", "--exact"], cwd=root)
37
except subprocess.CalledProcessError:
41
tag = re.sub(LEADING_V_PATTERN, "", dirty_tag)
44
tag = re.sub(TRAILING_RC_PATTERN, "", tag)
46
if tag.startswith("ciflow/"):
51
def get_base_version() -> str:
52
root = get_pytorch_root()
53
dirty_version = open(root / "version.txt").read().strip()
56
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
63
gpu_arch_version: str,
64
no_build_suffix: bool,
66
self.gpu_arch_type = gpu_arch_type
67
self.gpu_arch_version = gpu_arch_version
68
self.no_build_suffix = no_build_suffix
70
def get_post_build_suffix(self) -> str:
71
if self.no_build_suffix:
73
if self.gpu_arch_type == "cuda":
74
return f"+cu{self.gpu_arch_version.replace('.', '')}"
75
return f"+{self.gpu_arch_type}{self.gpu_arch_version}"
77
def get_release_version(self) -> str:
79
raise NoGitTagException(
80
"Not on a git tag, are you sure you want a release version?"
82
return f"{get_tag()}{self.get_post_build_suffix()}"
84
def get_nightly_version(self) -> str:
85
date_str = datetime.today().strftime("%Y%m%d")
86
build_suffix = self.get_post_build_suffix()
87
return f"{get_base_version()}.dev{date_str}{build_suffix}"
91
parser = argparse.ArgumentParser(
92
description="Generate pytorch version for binary builds"
97
help="Whether or not to add a build suffix typically (+cpu)",
98
default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False")),
103
help="GPU arch you are building for, typically (cpu, cuda, rocm)",
104
default=os.environ.get("GPU_ARCH_TYPE", "cpu"),
107
"--gpu-arch-version",
109
help="GPU arch version, typically (10.2, 4.0), leave blank for CPU",
110
default=os.environ.get("GPU_ARCH_VERSION", ""),
112
args = parser.parse_args()
113
version_obj = PytorchVersion(
114
args.gpu_arch_type, args.gpu_arch_version, args.no_build_suffix
117
print(version_obj.get_release_version())
118
except NoGitTagException:
119
print(version_obj.get_nightly_version())
122
if __name__ == "__main__":