7
# no type stub for conda command line interface
8
import conda.cli.python_api # type: ignore[import]
9
from conda.cli.python_api import Commands as conda_commands
11
# blas_compare.py will fail to import these when it's inside a conda env,
12
# but that's fine as it only wants the constants.
16
WORKING_ROOT = "/tmp/pytorch_blas_compare_environments"
17
MKL_2020_3 = "mkl_2020_3"
18
MKL_2020_0 = "mkl_2020_0"
19
OPEN_BLAS = "open_blas"
23
GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0")
35
SubEnvSpec = collections.namedtuple(
39
"environment_variables",
42
"expected_blas_symbols",
43
"expected_mkl_version",
48
MKL_2020_3: SubEnvSpec(
50
special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")),
51
environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
52
expected_blas_symbols=("mkl_blas_sgemm",),
53
expected_mkl_version="2020.0.3",
56
MKL_2020_0: SubEnvSpec(
58
special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")),
59
environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
60
expected_blas_symbols=("mkl_blas_sgemm",),
61
expected_mkl_version="2020.0.0",
64
OPEN_BLAS: SubEnvSpec(
65
generic_installs=("openblas",),
67
environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS,
68
expected_blas_symbols=("exec_blas",),
69
expected_mkl_version=None,
73
# generic_installs=(),
74
# special_installs=(),
75
# environment_variables=("BLAS=Eigen",) + GENERIC_ENV_VARS,
76
# expected_blas_symbols=(),
82
"""Convenience method."""
83
stdout, stderr, retcode = conda.cli.python_api.run_command(*args)
85
raise OSError(f"conda error: {str(args)} retcode: {retcode}\n{stderr}")
91
if os.path.exists(WORKING_ROOT):
92
print("Cleaning: removing old working root.")
93
shutil.rmtree(WORKING_ROOT)
94
os.makedirs(WORKING_ROOT)
96
git_root = subprocess.check_output(
97
"git rev-parse --show-toplevel",
99
cwd=os.path.dirname(os.path.realpath(__file__))
100
).decode("utf-8").strip()
102
for env_name, env_spec in SUB_ENVS.items():
103
env_path = os.path.join(WORKING_ROOT, env_name)
104
print(f"Creating env: {env_name}: ({env_path})")
106
conda_commands.CREATE,
107
"--no-default-packages",
108
"--prefix", env_path,
112
print("Testing that env can be activated:")
113
base_source = subprocess.run(
114
f"source activate {env_path}",
119
if base_source.returncode:
121
"Failed to source base environment:\n"
122
f" stdout: {base_source.stdout.decode('utf-8')}\n"
123
f" stderr: {base_source.stderr.decode('utf-8')}"
126
print("Installing packages:")
128
conda_commands.INSTALL,
129
"--prefix", env_path,
130
*(BASE_PKG_DEPS + env_spec.generic_installs)
133
if env_spec.special_installs:
134
channel, channel_deps = env_spec.special_installs
135
print(f"Installing packages from channel: {channel}")
137
conda_commands.INSTALL,
138
"--prefix", env_path,
139
"-c", channel, *channel_deps
142
if env_spec.environment_variables:
143
print("Setting environment variables.")
145
# This does not appear to be possible using the python API.
146
env_set = subprocess.run(
147
f"source activate {env_path} && "
148
f"conda env config vars set {' '.join(env_spec.environment_variables)}",
153
if env_set.returncode:
155
"Failed to set environment variables:\n"
156
f" stdout: {env_set.stdout.decode('utf-8')}\n"
157
f" stderr: {env_set.stderr.decode('utf-8')}"
160
# Check that they were actually set correctly.
161
actual_env_vars = subprocess.run(
162
f"source activate {env_path} && env",
166
).stdout.decode("utf-8").strip().splitlines()
167
for e in env_spec.environment_variables:
168
assert e in actual_env_vars, f"{e} not in envs"
170
print(f"Building PyTorch for env: `{env_name}`")
171
# We have to re-run during each build to pick up the new
172
# build config settings.
173
build_run = subprocess.run(
174
f"source activate {env_path} && "
176
"python setup.py install --cmake",
182
print("Checking configuration:")
183
check_run = subprocess.run(
184
# Shameless abuse of `python -c ...`
185
f"source activate {env_path} && "
188
"from torch.utils.benchmark import Timer;"
189
"print(torch.__config__.show());"
190
"setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';"
191
"counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);"
192
"stats = counts.as_standardized().stats(inclusive=True);"
193
"print(stats.filter(lambda l: 'blas' in l.lower()))\"",
198
if check_run.returncode:
200
"Failed to set environment variables:\n"
201
f" stdout: {check_run.stdout.decode('utf-8')}\n"
202
f" stderr: {check_run.stderr.decode('utf-8')}"
204
check_run_stdout = check_run.stdout.decode('utf-8')
205
print(check_run_stdout)
207
for e in env_spec.environment_variables:
209
assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}"
211
for s in env_spec.expected_blas_symbols:
212
assert s in check_run_stdout
214
if env_spec.expected_mkl_version is not None:
215
assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout
217
print(f"Build complete: {env_name}")
220
if __name__ == "__main__":