6
from pathlib import Path
7
from subprocess import check_call
8
from tempfile import TemporaryDirectory
9
from typing import Optional
12
SCRIPT_DIR = Path(__file__).parent
13
REPO_DIR = SCRIPT_DIR.parent.parent
16
def read_triton_pin(device: str = "cuda") -> str:
17
triton_file = "triton.txt"
19
triton_file = "triton-rocm.txt"
21
triton_file = "triton-xpu.txt"
22
with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / triton_file) as f:
23
return f.read().strip()
26
def read_triton_version() -> str:
27
with open(REPO_DIR / ".ci" / "docker" / "triton_version.txt") as f:
28
return f.read().strip()
31
def check_and_replace(inp: str, src: str, dst: str) -> str:
32
"""Checks that `src` can be found in `input` and replaces it with `dst`"""
34
raise RuntimeError(f"Can't find ${src} in the input")
35
return inp.replace(src, dst)
39
path: Path, *, version: str, expected_version: Optional[str] = None
41
if not expected_version:
42
expected_version = read_triton_version()
46
orig = check_and_replace(
47
orig, f"__version__ = '{expected_version}'", f'__version__ = "{version}"'
49
with open(path, "w") as f:
54
def patch_setup_py(path: Path) -> None:
58
orig = check_and_replace(
60
"https://tritonlang.blob.core.windows.net/llvm-builds/",
61
"https://oaitriton.blob.core.windows.net/public/llvm-builds/",
63
with open(path, "w") as f:
65
except RuntimeError as e:
67
f"Applying patch_setup_py() for llvm-build package failed: {e}.",
68
"If you are trying to build a newer version of Triton, you can ignore this.",
76
build_conda: bool = False,
78
py_version: Optional[str] = None,
79
release: bool = False,
81
env = os.environ.copy()
82
if "MAX_JOBS" not in env:
83
max_jobs = os.cpu_count() or 1
84
env["MAX_JOBS"] = str(max_jobs)
90
version_suffix = f"+{commit_hash[:10]}"
91
version += version_suffix
93
with TemporaryDirectory() as tmpdir:
94
triton_basedir = Path(tmpdir) / "triton"
95
triton_pythondir = triton_basedir / "python"
96
triton_repo = "https://github.com/openai/triton"
98
triton_pkg_name = "pytorch-triton-rocm"
100
triton_pkg_name = "pytorch-triton-xpu"
101
triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"
103
triton_pkg_name = "pytorch-triton"
104
check_call(["git", "clone", triton_repo, "triton"], cwd=tmpdir)
106
ver, rev, patch = version.split(".")
108
["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir
111
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
114
patch_setup_py(triton_pythondir / "setup.py")
117
with open(triton_basedir / "meta.yaml", "w") as meta:
119
f"package:\n name: torchtriton\n version: {version}\n",
122
print("source:\n path: .\n", file=meta)
124
"build:\n string: py{{py}}\n number: 1\n script: cd python; "
125
"python setup.py install --record=record.txt\n",
126
" script_env:\n - MAX_JOBS\n",
130
"requirements:\n host:\n - python\n - setuptools\n run:\n - python\n"
131
" - filelock\n - pytorch\n",
135
"about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
136
" 'A language and compiler for custom Deep Learning operation'",
141
triton_pythondir / "triton" / "__init__.py",
142
version=f"{version}",
144
if py_version is None:
145
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
161
conda_path = next(iter(Path(tmpdir).glob("linux-64/torchtriton*.bz2")))
162
shutil.copy(conda_path, Path.cwd())
163
return Path.cwd() / conda_path.name
166
env["TRITON_WHEEL_NAME"] = triton_pkg_name
167
env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
169
triton_pythondir / "triton" / "__init__.py",
170
version=f"{version}",
171
expected_version=None,
176
[f"{SCRIPT_DIR}/amd/package_triton_wheel.sh"],
180
print("ROCm libraries setup for triton installation...")
183
[sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env
186
whl_path = next(iter((triton_pythondir / "dist").glob("*.whl")))
187
shutil.copy(whl_path, Path.cwd())
191
[f"{SCRIPT_DIR}/amd/patch_triton_wheel.sh", Path.cwd()],
195
return Path.cwd() / whl_path.name
199
from argparse import ArgumentParser
201
parser = ArgumentParser("Build Triton binaries")
202
parser.add_argument("--release", action="store_true")
203
parser.add_argument("--build-conda", action="store_true")
205
"--device", type=str, default="cuda", choices=["cuda", "rocm", "xpu"]
207
parser.add_argument("--py-version", type=str)
208
parser.add_argument("--commit-hash", type=str)
209
parser.add_argument("--triton-version", type=str, default=read_triton_version())
210
args = parser.parse_args()
214
commit_hash=args.commit_hash
216
else read_triton_pin(args.device),
217
version=args.triton_version,
218
build_conda=args.build_conda,
219
py_version=args.py_version,
220
release=args.release,
224
if __name__ == "__main__":