pytorch

Форк
0
/
generate_pytorch_version.py 
123 строки · 3.5 Кб
1
#!/usr/bin/env python3
2

3
import argparse
4
import os
5
import re
6
import subprocess
7
from datetime import datetime
8
from distutils.util import strtobool
9
from pathlib import Path
10

11

12
LEADING_V_PATTERN = re.compile("^v")
13
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
14
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
15

16

17
class NoGitTagException(Exception):
18
    pass
19

20

21
def get_pytorch_root() -> Path:
22
    return Path(
23
        subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
24
        .decode("ascii")
25
        .strip()
26
    )
27

28

29
def get_tag() -> str:
30
    root = get_pytorch_root()
31
    try:
32
        dirty_tag = (
33
            subprocess.check_output(["git", "describe", "--tags", "--exact"], cwd=root)
34
            .decode("ascii")
35
            .strip()
36
        )
37
    except subprocess.CalledProcessError:
38
        return ""
39
    # Strip leading v that we typically do when we tag branches
40
    # ie: v1.7.1 -> 1.7.1
41
    tag = re.sub(LEADING_V_PATTERN, "", dirty_tag)
42
    # Strip trailing rc pattern
43
    # ie: 1.7.1-rc1 -> 1.7.1
44
    tag = re.sub(TRAILING_RC_PATTERN, "", tag)
45
    # Ignore ciflow tags
46
    if tag.startswith("ciflow/"):
47
        return ""
48
    return tag
49

50

51
def get_base_version() -> str:
52
    root = get_pytorch_root()
53
    dirty_version = open(root / "version.txt").read().strip()
54
    # Strips trailing a0 from version.txt, not too sure why it's there in the
55
    # first place
56
    return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
57

58

59
class PytorchVersion:
60
    def __init__(
61
        self,
62
        gpu_arch_type: str,
63
        gpu_arch_version: str,
64
        no_build_suffix: bool,
65
    ) -> None:
66
        self.gpu_arch_type = gpu_arch_type
67
        self.gpu_arch_version = gpu_arch_version
68
        self.no_build_suffix = no_build_suffix
69

70
    def get_post_build_suffix(self) -> str:
71
        if self.no_build_suffix:
72
            return ""
73
        if self.gpu_arch_type == "cuda":
74
            return f"+cu{self.gpu_arch_version.replace('.', '')}"
75
        return f"+{self.gpu_arch_type}{self.gpu_arch_version}"
76

77
    def get_release_version(self) -> str:
78
        if not get_tag():
79
            raise NoGitTagException(
80
                "Not on a git tag, are you sure you want a release version?"
81
            )
82
        return f"{get_tag()}{self.get_post_build_suffix()}"
83

84
    def get_nightly_version(self) -> str:
85
        date_str = datetime.today().strftime("%Y%m%d")
86
        build_suffix = self.get_post_build_suffix()
87
        return f"{get_base_version()}.dev{date_str}{build_suffix}"
88

89

90
def main() -> None:
91
    parser = argparse.ArgumentParser(
92
        description="Generate pytorch version for binary builds"
93
    )
94
    parser.add_argument(
95
        "--no-build-suffix",
96
        action="store_true",
97
        help="Whether or not to add a build suffix typically (+cpu)",
98
        default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False")),
99
    )
100
    parser.add_argument(
101
        "--gpu-arch-type",
102
        type=str,
103
        help="GPU arch you are building for, typically (cpu, cuda, rocm)",
104
        default=os.environ.get("GPU_ARCH_TYPE", "cpu"),
105
    )
106
    parser.add_argument(
107
        "--gpu-arch-version",
108
        type=str,
109
        help="GPU arch version, typically (10.2, 4.0), leave blank for CPU",
110
        default=os.environ.get("GPU_ARCH_VERSION", ""),
111
    )
112
    args = parser.parse_args()
113
    version_obj = PytorchVersion(
114
        args.gpu_arch_type, args.gpu_arch_version, args.no_build_suffix
115
    )
116
    try:
117
        print(version_obj.get_release_version())
118
    except NoGitTagException:
119
        print(version_obj.get_nightly_version())
120

121

122
if __name__ == "__main__":
123
    main()
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.