pytorch

Форк
0
/
build_triton_wheel.py 
225 строк · 7.4 Кб
1
#!/usr/bin/env python3
2

3
import os
4
import shutil
5
import sys
6
from pathlib import Path
7
from subprocess import check_call
8
from tempfile import TemporaryDirectory
9
from typing import Optional
10

11

12
SCRIPT_DIR = Path(__file__).parent
13
REPO_DIR = SCRIPT_DIR.parent.parent
14

15

16
def read_triton_pin(device: str = "cuda") -> str:
17
    triton_file = "triton.txt"
18
    if device == "rocm":
19
        triton_file = "triton-rocm.txt"
20
    elif device == "xpu":
21
        triton_file = "triton-xpu.txt"
22
    with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / triton_file) as f:
23
        return f.read().strip()
24

25

26
def read_triton_version() -> str:
27
    with open(REPO_DIR / ".ci" / "docker" / "triton_version.txt") as f:
28
        return f.read().strip()
29

30

31
def check_and_replace(inp: str, src: str, dst: str) -> str:
32
    """Checks that `src` can be found in `input` and replaces it with `dst`"""
33
    if src not in inp:
34
        raise RuntimeError(f"Can't find ${src} in the input")
35
    return inp.replace(src, dst)
36

37

38
def patch_init_py(
39
    path: Path, *, version: str, expected_version: Optional[str] = None
40
) -> None:
41
    if not expected_version:
42
        expected_version = read_triton_version()
43
    with open(path) as f:
44
        orig = f.read()
45
    # Replace version
46
    orig = check_and_replace(
47
        orig, f"__version__ = '{expected_version}'", f'__version__ = "{version}"'
48
    )
49
    with open(path, "w") as f:
50
        f.write(orig)
51

52

53
# TODO: remove patch_setup_py() once we have a proper fix for https://github.com/triton-lang/triton/issues/4527
54
def patch_setup_py(path: Path) -> None:
55
    with open(path) as f:
56
        orig = f.read()
57
    try:
58
        orig = check_and_replace(
59
            orig,
60
            "https://tritonlang.blob.core.windows.net/llvm-builds/",
61
            "https://oaitriton.blob.core.windows.net/public/llvm-builds/",
62
        )
63
        with open(path, "w") as f:
64
            f.write(orig)
65
    except RuntimeError as e:
66
        print(
67
            f"Applying patch_setup_py() for llvm-build package failed: {e}.",
68
            "If you are trying to build a newer version of Triton, you can ignore this.",
69
        )
70

71

72
def build_triton(
73
    *,
74
    version: str,
75
    commit_hash: str,
76
    build_conda: bool = False,
77
    device: str = "cuda",
78
    py_version: Optional[str] = None,
79
    release: bool = False,
80
) -> Path:
81
    env = os.environ.copy()
82
    if "MAX_JOBS" not in env:
83
        max_jobs = os.cpu_count() or 1
84
        env["MAX_JOBS"] = str(max_jobs)
85

86
    version_suffix = ""
87
    if not release:
88
        # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
89
        # while release build should only include the version, i.e. 2.1.0
90
        version_suffix = f"+{commit_hash[:10]}"
91
        version += version_suffix
92

93
    with TemporaryDirectory() as tmpdir:
94
        triton_basedir = Path(tmpdir) / "triton"
95
        triton_pythondir = triton_basedir / "python"
96
        triton_repo = "https://github.com/openai/triton"
97
        if device == "rocm":
98
            triton_pkg_name = "pytorch-triton-rocm"
99
        elif device == "xpu":
100
            triton_pkg_name = "pytorch-triton-xpu"
101
            triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"
102
        else:
103
            triton_pkg_name = "pytorch-triton"
104
        check_call(["git", "clone", triton_repo, "triton"], cwd=tmpdir)
105
        if release:
106
            ver, rev, patch = version.split(".")
107
            check_call(
108
                ["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir
109
            )
110
        else:
111
            check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
112

113
        # TODO: remove this and patch_setup_py() once we have a proper fix for https://github.com/triton-lang/triton/issues/4527
114
        patch_setup_py(triton_pythondir / "setup.py")
115

116
        if build_conda:
117
            with open(triton_basedir / "meta.yaml", "w") as meta:
118
                print(
119
                    f"package:\n  name: torchtriton\n  version: {version}\n",
120
                    file=meta,
121
                )
122
                print("source:\n  path: .\n", file=meta)
123
                print(
124
                    "build:\n  string: py{{py}}\n  number: 1\n  script: cd python; "
125
                    "python setup.py install --record=record.txt\n",
126
                    " script_env:\n   - MAX_JOBS\n",
127
                    file=meta,
128
                )
129
                print(
130
                    "requirements:\n  host:\n    - python\n    - setuptools\n  run:\n    - python\n"
131
                    "    - filelock\n    - pytorch\n",
132
                    file=meta,
133
                )
134
                print(
135
                    "about:\n  home: https://github.com/openai/triton\n  license: MIT\n  summary:"
136
                    " 'A language and compiler for custom Deep Learning operation'",
137
                    file=meta,
138
                )
139

140
            patch_init_py(
141
                triton_pythondir / "triton" / "__init__.py",
142
                version=f"{version}",
143
            )
144
            if py_version is None:
145
                py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
146
            check_call(
147
                [
148
                    "conda",
149
                    "build",
150
                    "--python",
151
                    py_version,
152
                    "-c",
153
                    "pytorch-nightly",
154
                    "--output-folder",
155
                    tmpdir,
156
                    ".",
157
                ],
158
                cwd=triton_basedir,
159
                env=env,
160
            )
161
            conda_path = next(iter(Path(tmpdir).glob("linux-64/torchtriton*.bz2")))
162
            shutil.copy(conda_path, Path.cwd())
163
            return Path.cwd() / conda_path.name
164

165
        # change built wheel name and version
166
        env["TRITON_WHEEL_NAME"] = triton_pkg_name
167
        env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
168
        patch_init_py(
169
            triton_pythondir / "triton" / "__init__.py",
170
            version=f"{version}",
171
            expected_version=None,
172
        )
173

174
        if device == "rocm":
175
            check_call(
176
                [f"{SCRIPT_DIR}/amd/package_triton_wheel.sh"],
177
                cwd=triton_basedir,
178
                shell=True,
179
            )
180
            print("ROCm libraries setup for triton installation...")
181

182
        check_call(
183
            [sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env
184
        )
185

186
        whl_path = next(iter((triton_pythondir / "dist").glob("*.whl")))
187
        shutil.copy(whl_path, Path.cwd())
188

189
        if device == "rocm":
190
            check_call(
191
                [f"{SCRIPT_DIR}/amd/patch_triton_wheel.sh", Path.cwd()],
192
                cwd=triton_basedir,
193
            )
194

195
        return Path.cwd() / whl_path.name
196

197

198
def main() -> None:
199
    from argparse import ArgumentParser
200

201
    parser = ArgumentParser("Build Triton binaries")
202
    parser.add_argument("--release", action="store_true")
203
    parser.add_argument("--build-conda", action="store_true")
204
    parser.add_argument(
205
        "--device", type=str, default="cuda", choices=["cuda", "rocm", "xpu"]
206
    )
207
    parser.add_argument("--py-version", type=str)
208
    parser.add_argument("--commit-hash", type=str)
209
    parser.add_argument("--triton-version", type=str, default=read_triton_version())
210
    args = parser.parse_args()
211

212
    build_triton(
213
        device=args.device,
214
        commit_hash=args.commit_hash
215
        if args.commit_hash
216
        else read_triton_pin(args.device),
217
        version=args.triton_version,
218
        build_conda=args.build_conda,
219
        py_version=args.py_version,
220
        release=args.release,
221
    )
222

223

224
if __name__ == "__main__":
225
    main()
226

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.