onnxruntime

Форк
0
/
setup.py 
773 строки · 31.1 Кб
1
# ------------------------------------------------------------------------
2
# Copyright (c) Microsoft Corporation. All rights reserved.
3
# Licensed under the MIT License.
4
# ------------------------------------------------------------------------
5
# pylint: disable=C0103
6

7
import datetime
8
import logging
9
import platform
10
import shlex
11
import subprocess
12
import sys
13
from glob import glob, iglob
14
from os import environ, getcwd, path, popen, remove
15
from pathlib import Path
16
from shutil import copyfile
17

18
from packaging.tags import sys_tags
19
from setuptools import Extension, setup
20
from setuptools.command.build_ext import build_ext as _build_ext
21
from setuptools.command.install import install as InstallCommandBase
22

23
nightly_build = False
24
package_name = "onnxruntime"
25
wheel_name_suffix = None
26
logger = logging.getLogger()
27

28

29
def parse_arg_remove_boolean(argv, arg_name):
30
    arg_value = False
31
    if arg_name in sys.argv:
32
        arg_value = True
33
        argv.remove(arg_name)
34

35
    return arg_value
36

37

38
def parse_arg_remove_string(argv, arg_name_equal):
39
    arg_value = None
40
    for arg in sys.argv[1:]:
41
        if arg.startswith(arg_name_equal):
42
            arg_value = arg[len(arg_name_equal) :]
43
            sys.argv.remove(arg)
44
            break
45

46
    return arg_value
47

48

49
# Any combination of the following arguments can be applied
50

51
if parse_arg_remove_boolean(sys.argv, "--nightly_build"):
52
    package_name = "ort-nightly"
53
    nightly_build = True
54

55
wheel_name_suffix = parse_arg_remove_string(sys.argv, "--wheel_name_suffix=")
56

57
cuda_version = None
58
rocm_version = None
59
is_migraphx = False
60
is_rocm = False
61
is_openvino = False
62
# The following arguments are mutually exclusive
63
if wheel_name_suffix == "gpu":
64
    # TODO: how to support multiple CUDA versions?
65
    cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=")
66
elif parse_arg_remove_boolean(sys.argv, "--use_rocm"):
67
    is_rocm = True
68
    rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=")
69
elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"):
70
    is_migraphx = True
71
elif parse_arg_remove_boolean(sys.argv, "--use_openvino"):
72
    is_openvino = True
73
    package_name = "onnxruntime-openvino"
74
elif parse_arg_remove_boolean(sys.argv, "--use_dnnl"):
75
    package_name = "onnxruntime-dnnl"
76
elif parse_arg_remove_boolean(sys.argv, "--use_tvm"):
77
    package_name = "onnxruntime-tvm"
78
elif parse_arg_remove_boolean(sys.argv, "--use_vitisai"):
79
    package_name = "onnxruntime-vitisai"
80
elif parse_arg_remove_boolean(sys.argv, "--use_acl"):
81
    package_name = "onnxruntime-acl"
82
elif parse_arg_remove_boolean(sys.argv, "--use_armnn"):
83
    package_name = "onnxruntime-armnn"
84
elif parse_arg_remove_boolean(sys.argv, "--use_cann"):
85
    package_name = "onnxruntime-cann"
86
elif parse_arg_remove_boolean(sys.argv, "--use_azure"):
87
    # keep the same name since AzureEP will release with CpuEP by default.
88
    pass
89
elif parse_arg_remove_boolean(sys.argv, "--use_qnn"):
90
    package_name = "onnxruntime-qnn"
91

92
if is_rocm or is_migraphx:
93
    package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly"
94

95
# PEP 513 defined manylinux1_x86_64 and manylinux1_i686
96
# PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686
97
# PEP 599 defines the following platform tags:
98
# manylinux2014_x86_64
99
# manylinux2014_i686
100
# manylinux2014_aarch64
101
# manylinux2014_armv7l
102
# manylinux2014_ppc64
103
# manylinux2014_ppc64le
104
# manylinux2014_s390x
105
manylinux_tags = [
106
    "manylinux1_x86_64",
107
    "manylinux1_i686",
108
    "manylinux2010_x86_64",
109
    "manylinux2010_i686",
110
    "manylinux2014_x86_64",
111
    "manylinux2014_i686",
112
    "manylinux2014_aarch64",
113
    "manylinux2014_armv7l",
114
    "manylinux2014_ppc64",
115
    "manylinux2014_ppc64le",
116
    "manylinux2014_s390x",
117
    "manylinux_2_28_x86_64",
118
    "manylinux_2_28_aarch64",
119
]
120
is_manylinux = environ.get("AUDITWHEEL_PLAT", None) in manylinux_tags
121

122

123
class build_ext(_build_ext):  # noqa: N801
124
    def build_extension(self, ext):
125
        dest_file = self.get_ext_fullpath(ext.name)
126
        logger.info("copying %s -> %s", ext.sources[0], dest_file)
127
        copyfile(ext.sources[0], dest_file)
128

129

130
try:
131
    from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
132

133
    class bdist_wheel(_bdist_wheel):  # noqa: N801
134
        """Helper functions to create wheel package"""
135

136
        if is_openvino and is_manylinux:
137

138
            def get_tag(self):
139
                _, _, plat = _bdist_wheel.get_tag(self)
140
                if platform.system() == "Linux":
141
                    # Get the right platform tag by querying the linker version
142
                    glibc_major, glibc_minor = popen("ldd --version | head -1").read().split()[-1].split(".")
143
                    """# See https://github.com/mayeut/pep600_compliance/blob/master/
144
                    pep600_compliance/tools/manylinux-policy.json"""
145
                    if glibc_major == "2" and glibc_minor == "17":
146
                        plat = "manylinux_2_17_x86_64.manylinux2014_x86_64"
147
                    else:  # For manylinux2014 and above, no alias is required
148
                        plat = f"manylinux_{glibc_major}_{glibc_minor}_x86_64"
149
                tags = next(sys_tags())
150
                return (tags.interpreter, tags.abi, plat)
151

152
        def finalize_options(self):
153
            _bdist_wheel.finalize_options(self)
154
            if not is_manylinux:
155
                self.root_is_pure = False
156

157
        def _rewrite_ld_preload(self, to_preload):
158
            with open("onnxruntime/capi/_ld_preload.py", "a") as f:
159
                if len(to_preload) > 0:
160
                    f.write("from ctypes import CDLL, RTLD_GLOBAL\n")
161
                    for library in to_preload:
162
                        f.write('_{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split(".")[0], library))
163

164
        def _rewrite_ld_preload_cuda(self, to_preload):
165
            with open("onnxruntime/capi/_ld_preload.py", "a") as f:
166
                if len(to_preload) > 0:
167
                    f.write("from ctypes import CDLL, RTLD_GLOBAL\n")
168
                    f.write("try:\n")
169
                    for library in to_preload:
170
                        f.write('    _{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split(".")[0], library))
171
                    f.write("except OSError:\n")
172
                    f.write("    import os\n")
173
                    f.write('    os.environ["ORT_CUDA_UNAVAILABLE"] = "1"\n')
174

175
        def _rewrite_ld_preload_tensorrt(self, to_preload):
176
            with open("onnxruntime/capi/_ld_preload.py", "a", encoding="ascii") as f:
177
                if len(to_preload) > 0:
178
                    f.write("from ctypes import CDLL, RTLD_GLOBAL\n")
179
                    f.write("try:\n")
180
                    for library in to_preload:
181
                        f.write('    _{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split(".")[0], library))
182
                    f.write("except OSError:\n")
183
                    f.write("    import os\n")
184
                    f.write('    os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n')
185

186
        def run(self):
187
            if is_manylinux:
188
                source = "onnxruntime/capi/onnxruntime_pybind11_state.so"
189
                dest = "onnxruntime/capi/onnxruntime_pybind11_state_manylinux1.so"
190
                logger.info("copying %s -> %s", source, dest)
191
                copyfile(source, dest)
192

193
                to_preload = []
194
                to_preload_cuda = []
195
                to_preload_tensorrt = []
196
                to_preload_cann = []
197

198
                cuda_dependencies = [
199
                    "libcuda.so.1",
200
                    "libcublas.so.11",
201
                    "libcublas.so.12",
202
                    "libcublasLt.so.11",
203
                    "libcublasLt.so.12",
204
                    "libcudart.so.11.0",
205
                    "libcudart.so.12",
206
                    "libcudnn.so.8",
207
                    "libcudnn.so.9",
208
                    "libcufft.so.10",
209
                    "libcufft.so.11",
210
                    "libcurand.so.10",
211
                    "libcudnn_adv_infer.so.8",
212
                    "libcudnn_adv_train.so.8",
213
                    "libcudnn_cnn_infer.so.8",
214
                    "libcudnn_cnn_train.so.8",
215
                    "libcudnn_ops_infer.so.8",
216
                    "libcudnn_ops_train.so.8",
217
                    "libcudnn_adv.so.9",
218
                    "libcudnn_cnn.so.9",
219
                    "libcudnn_engines_precompiled.so.9",
220
                    "libcudnn_engines_runtime_compiled.so.9",
221
                    "libcudnn_graph.so.9",
222
                    "libcudnn_heuristic.so.9",
223
                    "libcudnn_ops.so.9",
224
                    "libnvJitLink.so.12",
225
                    "libnvrtc.so.11",
226
                    "libnvrtc.so.12",
227
                    "libnvrtc-builtins.so.11",
228
                    "libnvrtc-builtins.so.12",
229
                ]
230

231
                rocm_dependencies = [
232
                    "libamd_comgr.so.2",
233
                    "libamdhip64.so.5",
234
                    "libamdhip64.so.6",
235
                    "libdrm.so.2",
236
                    "libdrm_amdgpu.so.1",
237
                    "libelf.so.1",
238
                    "libhipfft.so.0",
239
                    "libhiprtc.so.5",
240
                    "libhiprtc.so.6",
241
                    "libhsa-runtime64.so.1",
242
                    "libMIOpen.so.1",
243
                    "libnuma.so.1",
244
                    "librccl.so.1",
245
                    "librocblas.so.3",
246
                    "librocblas.so.4",
247
                    "librocfft.so.0",
248
                    "libroctx64.so.4",
249
                    "librocm_smi64.so.5",
250
                    "librocm_smi64.so.6",
251
                    "libroctracer64.so.4",
252
                    "libtinfo.so.6",
253
                    "libmigraphx_c.so.3",
254
                    "libmigraphx.so.2",
255
                    "libmigraphx_onnx.so.2",
256
                    "libmigraphx_tf.so.2",
257
                ]
258

259
                tensorrt_dependencies = ["libnvinfer.so.10", "libnvinfer_plugin.so.10", "libnvonnxparser.so.10"]
260

261
                cann_dependencies = ["libascendcl.so", "libacl_op_compiler.so", "libfmk_onnx_parser.so"]
262

263
                dest = "onnxruntime/capi/libonnxruntime_providers_openvino.so"
264
                if path.isfile(dest):
265
                    subprocess.run(
266
                        ["patchelf", "--set-rpath", "$ORIGIN", dest, "--force-rpath"],
267
                        check=True,
268
                        stdout=subprocess.PIPE,
269
                        text=True,
270
                    )
271

272
                self._rewrite_ld_preload(to_preload)
273
                self._rewrite_ld_preload_cuda(to_preload_cuda)
274
                self._rewrite_ld_preload_tensorrt(to_preload_tensorrt)
275
                self._rewrite_ld_preload(to_preload_cann)
276

277
            else:
278
                pass
279

280
            _bdist_wheel.run(self)
281
            if is_manylinux and not disable_auditwheel_repair and not is_openvino:
282
                assert self.dist_dir is not None
283
                file = glob(path.join(self.dist_dir, "*linux*.whl"))[0]
284
                logger.info("repairing %s for manylinux1", file)
285
                auditwheel_cmd = ["auditwheel", "-v", "repair", "-w", self.dist_dir, file]
286
                for i in cuda_dependencies + rocm_dependencies + tensorrt_dependencies + cann_dependencies:
287
                    auditwheel_cmd += ["--exclude", i]
288
                logger.info("Running %s", " ".join([shlex.quote(arg) for arg in auditwheel_cmd]))
289
                try:
290
                    subprocess.run(auditwheel_cmd, check=True, stdout=subprocess.PIPE)
291
                finally:
292
                    logger.info("removing %s", file)
293
                    remove(file)
294

295
except ImportError as error:
296
    print("Error importing dependencies:")
297
    print(error)
298
    bdist_wheel = None
299

300

301
class InstallCommand(InstallCommandBase):
302
    def finalize_options(self):
303
        ret = InstallCommandBase.finalize_options(self)
304
        self.install_lib = self.install_platlib
305
        return ret
306

307

308
providers_cuda_or_rocm = "onnxruntime_providers_" + ("rocm" if is_rocm else "cuda")
309
providers_tensorrt_or_migraphx = "onnxruntime_providers_" + ("migraphx" if is_migraphx else "tensorrt")
310
providers_openvino = "onnxruntime_providers_openvino"
311
providers_cann = "onnxruntime_providers_cann"
312

313
if platform.system() == "Linux":
314
    providers_cuda_or_rocm = "lib" + providers_cuda_or_rocm + ".so"
315
    providers_tensorrt_or_migraphx = "lib" + providers_tensorrt_or_migraphx + ".so"
316
    providers_openvino = "lib" + providers_openvino + ".so"
317
    providers_cann = "lib" + providers_cann + ".so"
318
elif platform.system() == "Windows":
319
    providers_cuda_or_rocm = providers_cuda_or_rocm + ".dll"
320
    providers_tensorrt_or_migraphx = providers_tensorrt_or_migraphx + ".dll"
321
    providers_openvino = providers_openvino + ".dll"
322
    providers_cann = providers_cann + ".dll"
323

324
# Additional binaries
325
dl_libs = []
326
libs = []
327

328
if platform.system() == "Linux" or platform.system() == "AIX":
329
    libs = [
330
        "onnxruntime_pybind11_state.so",
331
        "libdnnl.so.2",
332
        "libmklml_intel.so",
333
        "libmklml_gnu.so",
334
        "libiomp5.so",
335
        "mimalloc.so",
336
        "libonnxruntime.so*",
337
    ]
338
    dl_libs = ["libonnxruntime_providers_shared.so"]
339
    dl_libs.append(providers_cuda_or_rocm)
340
    dl_libs.append(providers_tensorrt_or_migraphx)
341
    dl_libs.append(providers_cann)
342
    dl_libs.append("libonnxruntime.so*")
343
    # DNNL, TensorRT & OpenVINO EPs are built as shared libs
344
    libs.extend(["libonnxruntime_providers_shared.so"])
345
    libs.extend(["libonnxruntime_providers_dnnl.so"])
346
    libs.extend(["libonnxruntime_providers_openvino.so"])
347
    libs.extend(["libonnxruntime_providers_vitisai.so"])
348
    libs.append(providers_cuda_or_rocm)
349
    libs.append(providers_tensorrt_or_migraphx)
350
    libs.append(providers_cann)
351
    if nightly_build:
352
        libs.extend(["libonnxruntime_pywrapper.so"])
353
elif platform.system() == "Darwin":
354
    libs = [
355
        "onnxruntime_pybind11_state.so",
356
        "libdnnl.2.dylib",
357
        "mimalloc.so",
358
        "libonnxruntime*.dylib",
359
    ]  # TODO add libmklml and libiomp5 later.
360
    # DNNL & TensorRT EPs are built as shared libs
361
    libs.extend(["libonnxruntime_providers_shared.dylib"])
362
    libs.extend(["libonnxruntime_providers_dnnl.dylib"])
363
    libs.extend(["libonnxruntime_providers_tensorrt.dylib"])
364
    libs.extend(["libonnxruntime_providers_cuda.dylib"])
365
    libs.extend(["libonnxruntime_providers_vitisai.dylib"])
366
    if nightly_build:
367
        libs.extend(["libonnxruntime_pywrapper.dylib"])
368
else:
369
    libs = [
370
        "onnxruntime_pybind11_state.pyd",
371
        "dnnl.dll",
372
        "mklml.dll",
373
        "libiomp5md.dll",
374
        providers_cuda_or_rocm,
375
        providers_tensorrt_or_migraphx,
376
        providers_cann,
377
        "onnxruntime.dll",
378
    ]
379
    # DNNL, TensorRT & OpenVINO EPs are built as shared libs
380
    libs.extend(["onnxruntime_providers_shared.dll"])
381
    libs.extend(["onnxruntime_providers_dnnl.dll"])
382
    libs.extend(["onnxruntime_providers_tensorrt.dll"])
383
    libs.extend(["onnxruntime_providers_openvino.dll"])
384
    libs.extend(["onnxruntime_providers_cuda.dll"])
385
    libs.extend(["onnxruntime_providers_vitisai.dll"])
386
    # DirectML Libs
387
    libs.extend(["DirectML.dll"])
388
    # QNN V68/V73 dependencies
389
    qnn_deps = [
390
        "QnnCpu.dll",
391
        "QnnHtp.dll",
392
        "QnnSaver.dll",
393
        "QnnSystem.dll",
394
        "QnnHtpPrepare.dll",
395
        "QnnHtpV73Stub.dll",
396
        "libQnnHtpV73Skel.so",
397
        "libqnnhtpv73.cat",
398
        "QnnHtpV68Stub.dll",
399
        "libQnnHtpV68Skel.so",
400
    ]
401
    libs.extend(qnn_deps)
402
    if nightly_build:
403
        libs.extend(["onnxruntime_pywrapper.dll"])
404

405
if is_manylinux:
406
    if is_openvino:
407
        ov_libs = [
408
            "libopenvino_intel_cpu_plugin.so",
409
            "libopenvino_intel_gpu_plugin.so",
410
            "libopenvino_auto_plugin.so",
411
            "libopenvino_hetero_plugin.so",
412
            "libtbb.so.2",
413
            "libtbbmalloc.so.2",
414
            "libopenvino.so",
415
            "libopenvino_c.so",
416
            "libopenvino_onnx_frontend.so",
417
        ]
418
        for x in ov_libs:
419
            y = "onnxruntime/capi/" + x
420
            subprocess.run(
421
                ["patchelf", "--set-rpath", "$ORIGIN", y, "--force-rpath"],
422
                check=True,
423
                stdout=subprocess.PIPE,
424
                text=True,
425
            )
426
            dl_libs.append(x)
427
        dl_libs.append(providers_openvino)
428
        dl_libs.append("plugins.xml")
429
        dl_libs.append("usb-ma2x8x.mvcmd")
430
    data = ["capi/libonnxruntime_pywrapper.so"] if nightly_build else []
431
    data += [path.join("capi", x) for x in dl_libs if glob(path.join("onnxruntime", "capi", x))]
432
    ext_modules = [
433
        Extension(
434
            "onnxruntime.capi.onnxruntime_pybind11_state",
435
            ["onnxruntime/capi/onnxruntime_pybind11_state_manylinux1.so"],
436
        ),
437
    ]
438
else:
439
    data = [path.join("capi", x) for x in libs if glob(path.join("onnxruntime", "capi", x))]
440
    ext_modules = []
441

442
# Additional examples
443
examples_names = ["mul_1.onnx", "logreg_iris.onnx", "sigmoid.onnx"]
444
examples = [path.join("datasets", x) for x in examples_names]
445

446
# Extra files such as EULA and ThirdPartyNotices (and Qualcomm License, only for QNN release packages)
447
extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md", "Qualcomm AI Hub Proprietary License.pdf"]
448

449
# Description
450
readme_file = "docs/python/ReadMeOV.rst" if is_openvino else "docs/python/README.rst"
451
README = path.join(getcwd(), readme_file)
452
if not path.exists(README):
453
    this = path.dirname(__file__)
454
    README = path.join(this, readme_file)
455

456
if not path.exists(README):
457
    raise FileNotFoundError("Unable to find 'README.rst'")
458
with open(README, encoding="utf-8") as fdesc:
459
    long_description = fdesc.read()
460

461
# Include files in onnxruntime/external if --enable_external_custom_op_schemas build.sh command
462
# line option is specified.
463
# If the options is not specified this following condition fails as onnxruntime/external folder is not created in the
464
# build flow under the build binary directory.
465
if path.isdir(path.join("onnxruntime", "external")):
466
    # Gather all files under onnxruntime/external directory.
467
    extra.extend(
468
        list(
469
            str(Path(*Path(x).parts[1:]))
470
            for x in list(iglob(path.join(path.join("onnxruntime", "external"), "**/*.*"), recursive=True))
471
        )
472
    )
473

474
packages = [
475
    "onnxruntime",
476
    "onnxruntime.backend",
477
    "onnxruntime.capi",
478
    "onnxruntime.datasets",
479
    "onnxruntime.tools",
480
    "onnxruntime.tools.mobile_helpers",
481
    "onnxruntime.tools.ort_format_model",
482
    "onnxruntime.tools.ort_format_model.ort_flatbuffers_py",
483
    "onnxruntime.tools.ort_format_model.ort_flatbuffers_py.fbs",
484
    "onnxruntime.tools.qdq_helpers",
485
    "onnxruntime.quantization",
486
    "onnxruntime.quantization.operators",
487
    "onnxruntime.quantization.CalTableFlatBuffers",
488
    "onnxruntime.quantization.fusions",
489
    "onnxruntime.quantization.execution_providers.qnn",
490
    "onnxruntime.transformers",
491
    "onnxruntime.transformers.models.bart",
492
    "onnxruntime.transformers.models.bert",
493
    "onnxruntime.transformers.models.gpt2",
494
    "onnxruntime.transformers.models.llama",
495
    "onnxruntime.transformers.models.longformer",
496
    "onnxruntime.transformers.models.phi2",
497
    "onnxruntime.transformers.models.t5",
498
    "onnxruntime.transformers.models.stable_diffusion",
499
    "onnxruntime.transformers.models.whisper",
500
]
501

502
package_data = {"onnxruntime.tools.mobile_helpers": ["*.md", "*.config"]}
503
data_files = []
504

505
requirements_file = "requirements.txt"
506

507
local_version = None
508
enable_training = parse_arg_remove_boolean(sys.argv, "--enable_training")
509
enable_training_apis = parse_arg_remove_boolean(sys.argv, "--enable_training_apis")
510
enable_rocm_profiling = parse_arg_remove_boolean(sys.argv, "--enable_rocm_profiling")
511
disable_auditwheel_repair = parse_arg_remove_boolean(sys.argv, "--disable_auditwheel_repair")
512
default_training_package_device = parse_arg_remove_boolean(sys.argv, "--default_training_package_device")
513

514
classifiers = [
515
    "Development Status :: 5 - Production/Stable",
516
    "Intended Audience :: Developers",
517
    "License :: OSI Approved :: MIT License",
518
    "Operating System :: POSIX :: Linux",
519
    "Topic :: Scientific/Engineering",
520
    "Topic :: Scientific/Engineering :: Mathematics",
521
    "Topic :: Scientific/Engineering :: Artificial Intelligence",
522
    "Topic :: Software Development",
523
    "Topic :: Software Development :: Libraries",
524
    "Topic :: Software Development :: Libraries :: Python Modules",
525
    "Programming Language :: Python",
526
    "Programming Language :: Python :: 3 :: Only",
527
    "Programming Language :: Python :: 3.7",
528
    "Programming Language :: Python :: 3.8",
529
    "Programming Language :: Python :: 3.9",
530
    "Programming Language :: Python :: 3.10",
531
    "Programming Language :: Python :: 3.11",
532
    "Programming Language :: Python :: 3.12",
533
    "Operating System :: Microsoft :: Windows",
534
    "Operating System :: MacOS",
535
]
536

537
if enable_training or enable_training_apis:
538
    packages.append("onnxruntime.training")
539
    if enable_training:
540
        packages.extend(
541
            [
542
                "onnxruntime.training.amp",
543
                "onnxruntime.training.experimental",
544
                "onnxruntime.training.experimental.gradient_graph",
545
                "onnxruntime.training.optim",
546
                "onnxruntime.training.ortmodule",
547
                "onnxruntime.training.ortmodule.experimental",
548
                "onnxruntime.training.ortmodule.experimental.json_config",
549
                "onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule",
550
                "onnxruntime.training.ortmodule.torch_cpp_extensions",
551
                "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor",
552
                "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
553
                "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
554
                "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
555
                "onnxruntime.training.ortmodule.graph_optimizers",
556
                "onnxruntime.training.ortmodule.experimental.pipe",
557
                "onnxruntime.training.ort_triton",
558
                "onnxruntime.training.ort_triton.kernel",
559
                "onnxruntime.training.utils",
560
                "onnxruntime.training.utils.data",
561
                "onnxruntime.training.utils.hooks",
562
                "onnxruntime.training.api",
563
                "onnxruntime.training.onnxblock",
564
                "onnxruntime.training.onnxblock.loss",
565
                "onnxruntime.training.onnxblock.optim",
566
            ]
567
        )
568

569
        package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
570
        package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc", "*.h"]
571
        package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]
572
        package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [
573
            "*.cpp",
574
            "*.cu",
575
            "*.cuh",
576
            "*.h",
577
        ]
578

579
    requirements_file = "requirements-training.txt"
580
    # with training, we want to follow this naming convention:
581
    # stable:
582
    # onnxruntime-training-1.7.0+cu111-cp36-cp36m-linux_x86_64.whl
583
    # nightly:
584
    # onnxruntime-training-1.7.0.dev20210408+cu111-cp36-cp36m-linux_x86_64.whl
585
    # this is needed immediately by pytorch/ort so that the user is able to
586
    # install an onnxruntime training package with matching torch cuda version.
587
    if not is_openvino:
588
        # To support the package consisting of both openvino and training modules part of it
589
        package_name = "onnxruntime-training"
590

591
        disable_local_version = environ.get("ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION", "0")
592
        disable_local_version = (
593
            disable_local_version == "1"
594
            or disable_local_version.lower() == "true"
595
            or disable_local_version.lower() == "yes"
596
        )
597
        # local version should be disabled for internal feeds.
598
        if not disable_local_version:
599
            # we want put default training packages to pypi. pypi does not accept package with a local version.
600
            if not default_training_package_device or nightly_build:
601
                if cuda_version:
602
                    # removing '.' to make Cuda version number in the same form as Pytorch.
603
                    local_version = "+cu" + cuda_version.replace(".", "")
604
                elif rocm_version:
605
                    # removing '.' to make Rocm version number in the same form as Pytorch.
606
                    local_version = "+rocm" + rocm_version.replace(".", "")
607
                else:
608
                    # cpu version for documentation
609
                    local_version = "+cpu"
610
        else:
611
            if not (cuda_version or rocm_version):
612
                # Training CPU package for ADO feeds is called onnxruntime-training-cpu
613
                package_name = "onnxruntime-training-cpu"
614

615
            if rocm_version:
616
                # Training ROCM package for ADO feeds is called onnxruntime-training-rocm
617
                package_name = "onnxruntime-training-rocm"
618

619
if package_name == "onnxruntime-tvm":
620
    packages += ["onnxruntime.providers.tvm"]
621

622
package_data["onnxruntime"] = data + examples + extra
623

624
version_number = ""
625
with open("VERSION_NUMBER") as f:
626
    version_number = f.readline().strip()
627
if nightly_build:
628
    # https://docs.microsoft.com/en-us/azure/devops/pipelines/build/variables
629
    build_suffix = environ.get("BUILD_BUILDNUMBER")
630
    if build_suffix is None:
631
        # The following line is only for local testing
632
        build_suffix = str(datetime.datetime.now().date().strftime("%Y%m%d"))
633
    else:
634
        build_suffix = build_suffix.replace(".", "")
635

636
    if len(build_suffix) > 8 and len(build_suffix) < 12:
637
        # we want to format the build_suffix to avoid (the 12th run on 20210630 vs the first run on 20210701):
638
        # 2021063012 > 202107011
639
        # in above 2021063012 is treated as the latest which is incorrect.
640
        # we want to convert the format to:
641
        # 20210630012 < 20210701001
642
        # where the first 8 digits are date. the last 3 digits are run count.
643
        # as long as there are less than 1000 runs per day, we will not have the problem.
644
        # to test this code locally, run:
645
        # NIGHTLY_BUILD=1 BUILD_BUILDNUMBER=202107011 python tools/ci_build/build.py --config RelWithDebInfo \
646
        #   --enable_training --use_cuda --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ \
647
        #   --nccl_home /usr/lib/x86_64-linux-gnu/ --build_dir build/Linux --build --build_wheel --skip_tests \
648
        #   --cuda_version 11.1
649
        def check_date_format(date_str):
650
            try:
651
                datetime.datetime.strptime(date_str, "%Y%m%d")
652
                return True
653
            except Exception:
654
                return False
655

656
        def reformat_run_count(count_str):
657
            try:
658
                count = int(count_str)
659
                if count >= 0 and count < 1000:
660
                    return f"{count:03}"
661
                elif count >= 1000:
662
                    raise RuntimeError(f"Too many builds for the same day: {count}")
663
                return ""
664
            except Exception:
665
                return ""
666

667
        build_suffix_is_date_format = check_date_format(build_suffix[:8])
668
        build_suffix_run_count = reformat_run_count(build_suffix[8:])
669
        if build_suffix_is_date_format and build_suffix_run_count:
670
            build_suffix = build_suffix[:8] + build_suffix_run_count
671
    elif len(build_suffix) >= 12:
672
        raise RuntimeError(f'Incorrect build suffix: "{build_suffix}"')
673

674
    if enable_training:
675
        from packaging import version
676
        from packaging.version import Version
677

678
        # with training package, we need to bump up version minor number so that
679
        # nightly releases take precedence over the latest release when --pre is used during pip install.
680
        # eventually this shall be the behavior of all onnxruntime releases.
681
        # alternatively we may bump up version number right after every release.
682
        ort_version = version.parse(version_number)
683
        if isinstance(ort_version, Version):
684
            # TODO: this is the last time we have to do this!!!
685
            # We shall bump up release number right after release cut.
686
            if ort_version.major == 1 and ort_version.minor == 8 and ort_version.micro == 0:
687
                version_number = f"{ort_version.major}.{ort_version.minor + 1}.{ort_version.micro}"
688

689
    version_number = version_number + ".dev" + build_suffix
690

691
if local_version:
692
    version_number = version_number + local_version
693
    if is_rocm and enable_rocm_profiling:
694
        version_number = version_number + ".profiling"
695

696
if wheel_name_suffix:
697
    if not (enable_training and wheel_name_suffix == "gpu"):
698
        # for training packages, local version is used to indicate device types
699
        package_name = f"{package_name}-{wheel_name_suffix}"
700

701
cmd_classes = {}
702
if bdist_wheel is not None:
703
    cmd_classes["bdist_wheel"] = bdist_wheel
704
cmd_classes["install"] = InstallCommand
705
cmd_classes["build_ext"] = build_ext
706

707
requirements_path = path.join(getcwd(), requirements_file)
708
if not path.exists(requirements_path):
709
    this = path.dirname(__file__)
710
    requirements_path = path.join(this, requirements_file)
711
if not path.exists(requirements_path):
712
    raise FileNotFoundError("Unable to find " + requirements_file)
713
with open(requirements_path) as f:
714
    install_requires = f.read().splitlines()
715

716

717
if enable_training:
718

719
    def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
720
        sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
721
        from onnxruntime_collect_build_info import find_cudart_versions
722

723
        version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
724
        with open(version_path, "w") as f:
725
            f.write(f"package_name = '{package_name}'\n")
726
            f.write(f"__version__ = '{version_number}'\n")
727

728
            if cuda_version:
729
                f.write(f"cuda_version = '{cuda_version}'\n")
730

731
                # cudart_versions are integers
732
                cudart_versions = find_cudart_versions(build_env=True)
733
                if cudart_versions and len(cudart_versions) == 1:
734
                    f.write(f"cudart_version = {cudart_versions[0]}\n")
735
                else:
736
                    print(
737
                        "Error getting cudart version. ",
738
                        (
739
                            "did not find any cudart library"
740
                            if not cudart_versions or len(cudart_versions) == 0
741
                            else "found multiple cudart libraries"
742
                        ),
743
                    )
744
            elif rocm_version:
745
                f.write(f"rocm_version = '{rocm_version}'\n")
746

747
    save_build_and_package_info(package_name, version_number, cuda_version, rocm_version)
748

749
# Setup
750
setup(
751
    name=package_name,
752
    version=version_number,
753
    description="ONNX Runtime is a runtime accelerator for Machine Learning models",
754
    long_description=long_description,
755
    author="Microsoft Corporation",
756
    author_email="onnxruntime@microsoft.com",
757
    cmdclass=cmd_classes,
758
    license="MIT License",
759
    packages=packages,
760
    ext_modules=ext_modules,
761
    package_data=package_data,
762
    url="https://onnxruntime.ai",
763
    download_url="https://github.com/microsoft/onnxruntime/tags",
764
    data_files=data_files,
765
    install_requires=install_requires,
766
    keywords="onnx machine learning",
767
    entry_points={
768
        "console_scripts": [
769
            "onnxruntime_test = onnxruntime.tools.onnxruntime_test:main",
770
        ]
771
    },
772
    classifiers=classifiers,
773
)
774

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.