pytorch

Форк
0
/
blas_compare_setup.py 
221 строка · 7.0 Кб
1
import collections
2
import os
3
import shutil
4
import subprocess
5

6
try:
7
    # no type stub for conda command line interface
8
    import conda.cli.python_api  # type: ignore[import]
9
    from conda.cli.python_api import Commands as conda_commands
10
except ImportError:
11
    # blas_compare.py will fail to import these when it's inside a conda env,
12
    # but that's fine as it only wants the constants.
13
    pass
14

15

16
WORKING_ROOT = "/tmp/pytorch_blas_compare_environments"
17
MKL_2020_3 = "mkl_2020_3"
18
MKL_2020_0 = "mkl_2020_0"
19
OPEN_BLAS = "open_blas"
20
EIGEN = "eigen"
21

22

23
GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0")
24
BASE_PKG_DEPS = (
25
    "cmake",
26
    "hypothesis",
27
    "ninja",
28
    "numpy",
29
    "pyyaml",
30
    "setuptools",
31
    "typing_extensions",
32
)
33

34

35
SubEnvSpec = collections.namedtuple(
36
    "SubEnvSpec", (
37
        "generic_installs",
38
        "special_installs",
39
        "environment_variables",
40

41
        # Validate install.
42
        "expected_blas_symbols",
43
        "expected_mkl_version",
44
    ))
45

46

47
SUB_ENVS = {
48
    MKL_2020_3: SubEnvSpec(
49
        generic_installs=(),
50
        special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")),
51
        environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
52
        expected_blas_symbols=("mkl_blas_sgemm",),
53
        expected_mkl_version="2020.0.3",
54
    ),
55

56
    MKL_2020_0: SubEnvSpec(
57
        generic_installs=(),
58
        special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")),
59
        environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
60
        expected_blas_symbols=("mkl_blas_sgemm",),
61
        expected_mkl_version="2020.0.0",
62
    ),
63

64
    OPEN_BLAS: SubEnvSpec(
65
        generic_installs=("openblas",),
66
        special_installs=(),
67
        environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS,
68
        expected_blas_symbols=("exec_blas",),
69
        expected_mkl_version=None,
70
    ),
71

72
    # EIGEN: SubEnvSpec(
73
    #     generic_installs=(),
74
    #     special_installs=(),
75
    #     environment_variables=("BLAS=Eigen",) + GENERIC_ENV_VARS,
76
    #     expected_blas_symbols=(),
77
    # ),
78
}
79

80

81
def conda_run(*args):
82
    """Convenience method."""
83
    stdout, stderr, retcode = conda.cli.python_api.run_command(*args)
84
    if retcode:
85
        raise OSError(f"conda error: {str(args)}  retcode: {retcode}\n{stderr}")
86

87
    return stdout
88

89

90
def main():
91
    if os.path.exists(WORKING_ROOT):
92
        print("Cleaning: removing old working root.")
93
        shutil.rmtree(WORKING_ROOT)
94
    os.makedirs(WORKING_ROOT)
95

96
    git_root = subprocess.check_output(
97
        "git rev-parse --show-toplevel",
98
        shell=True,
99
        cwd=os.path.dirname(os.path.realpath(__file__))
100
    ).decode("utf-8").strip()
101

102
    for env_name, env_spec in SUB_ENVS.items():
103
        env_path = os.path.join(WORKING_ROOT, env_name)
104
        print(f"Creating env: {env_name}: ({env_path})")
105
        conda_run(
106
            conda_commands.CREATE,
107
            "--no-default-packages",
108
            "--prefix", env_path,
109
            "python=3",
110
        )
111

112
        print("Testing that env can be activated:")
113
        base_source = subprocess.run(
114
            f"source activate {env_path}",
115
            shell=True,
116
            capture_output=True,
117
            check=False,
118
        )
119
        if base_source.returncode:
120
            raise OSError(
121
                "Failed to source base environment:\n"
122
                f"  stdout: {base_source.stdout.decode('utf-8')}\n"
123
                f"  stderr: {base_source.stderr.decode('utf-8')}"
124
            )
125

126
        print("Installing packages:")
127
        conda_run(
128
            conda_commands.INSTALL,
129
            "--prefix", env_path,
130
            *(BASE_PKG_DEPS + env_spec.generic_installs)
131
        )
132

133
        if env_spec.special_installs:
134
            channel, channel_deps = env_spec.special_installs
135
            print(f"Installing packages from channel: {channel}")
136
            conda_run(
137
                conda_commands.INSTALL,
138
                "--prefix", env_path,
139
                "-c", channel, *channel_deps
140
            )
141

142
        if env_spec.environment_variables:
143
            print("Setting environment variables.")
144

145
            # This does not appear to be possible using the python API.
146
            env_set = subprocess.run(
147
                f"source activate {env_path} && "
148
                f"conda env config vars set {' '.join(env_spec.environment_variables)}",
149
                shell=True,
150
                capture_output=True,
151
                check=False,
152
            )
153
            if env_set.returncode:
154
                raise OSError(
155
                    "Failed to set environment variables:\n"
156
                    f"  stdout: {env_set.stdout.decode('utf-8')}\n"
157
                    f"  stderr: {env_set.stderr.decode('utf-8')}"
158
                )
159

160
            # Check that they were actually set correctly.
161
            actual_env_vars = subprocess.run(
162
                f"source activate {env_path} && env",
163
                shell=True,
164
                capture_output=True,
165
                check=True,
166
            ).stdout.decode("utf-8").strip().splitlines()
167
            for e in env_spec.environment_variables:
168
                assert e in actual_env_vars, f"{e} not in envs"
169

170
        print(f"Building PyTorch for env: `{env_name}`")
171
        # We have to re-run during each build to pick up the new
172
        # build config settings.
173
        build_run = subprocess.run(
174
            f"source activate {env_path} && "
175
            f"cd {git_root} && "
176
            "python setup.py install --cmake",
177
            shell=True,
178
            capture_output=True,
179
            check=True,
180
        )
181

182
        print("Checking configuration:")
183
        check_run = subprocess.run(
184
            # Shameless abuse of `python -c ...`
185
            f"source activate {env_path} && "
186
            "python -c \""
187
            "import torch;"
188
            "from torch.utils.benchmark import Timer;"
189
            "print(torch.__config__.show());"
190
            "setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';"
191
            "counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);"
192
            "stats = counts.as_standardized().stats(inclusive=True);"
193
            "print(stats.filter(lambda l: 'blas' in l.lower()))\"",
194
            shell=True,
195
            capture_output=True,
196
            check=False,
197
        )
198
        if check_run.returncode:
199
            raise OSError(
200
                "Failed to set environment variables:\n"
201
                f"  stdout: {check_run.stdout.decode('utf-8')}\n"
202
                f"  stderr: {check_run.stderr.decode('utf-8')}"
203
            )
204
        check_run_stdout = check_run.stdout.decode('utf-8')
205
        print(check_run_stdout)
206

207
        for e in env_spec.environment_variables:
208
            if "BLAS" in e:
209
                assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}"
210

211
        for s in env_spec.expected_blas_symbols:
212
            assert s in check_run_stdout
213

214
        if env_spec.expected_mkl_version is not None:
215
            assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout
216

217
        print(f"Build complete: {env_name}")
218

219

220
if __name__ == "__main__":
221
    main()
222

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.