onnxruntime
92 строки · 3.7 Кб
1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4include(CheckLanguage)
5
6if(NOT onnxruntime_ENABLE_PYTHON)
7message(FATAL_ERROR "python is required but is not enabled")
8endif()
9
10set(KERNEL_EXPLORER_ROOT ${ONNXRUNTIME_ROOT}/python/tools/kernel_explorer)
11
12if (onnxruntime_USE_CUDA)
13check_language(CUDA)
14set(LANGUAGE CUDA)
15set(BERT_DIR ${ONNXRUNTIME_ROOT}/contrib_ops/cuda/bert)
16elseif(onnxruntime_USE_ROCM)
17check_language(HIP)
18set(LANGUAGE HIP)
19if (onnxruntime_USE_COMPOSABLE_KERNEL)
20include(composable_kernel)
21endif()
22if (onnxruntime_USE_HIPBLASLT)
23find_package(hipblaslt REQUIRED)
24endif()
25set(BERT_DIR ${ONNXRUNTIME_ROOT}/contrib_ops/rocm/bert)
26endif()
27
28file(GLOB kernel_explorer_srcs CONFIGURE_DEPENDS
29"${KERNEL_EXPLORER_ROOT}/*.cc"
30"${KERNEL_EXPLORER_ROOT}/*.h"
31)
32
33file(GLOB kernel_explorer_kernel_srcs CONFIGURE_DEPENDS
34"${KERNEL_EXPLORER_ROOT}/kernels/*.cc"
35"${KERNEL_EXPLORER_ROOT}/kernels/*.h"
36"${KERNEL_EXPLORER_ROOT}/kernels/*.cu"
37"${KERNEL_EXPLORER_ROOT}/kernels/*.cuh"
38)
39
40onnxruntime_add_shared_library_module(kernel_explorer ${kernel_explorer_srcs} ${kernel_explorer_kernel_srcs})
41set_target_properties(kernel_explorer PROPERTIES PREFIX "_")
42target_include_directories(kernel_explorer PUBLIC
43$<TARGET_PROPERTY:onnxruntime_pybind11_state,INCLUDE_DIRECTORIES>
44${KERNEL_EXPLORER_ROOT})
45target_link_libraries(kernel_explorer PRIVATE $<TARGET_PROPERTY:onnxruntime_pybind11_state,LINK_LIBRARIES>)
46target_compile_definitions(kernel_explorer PRIVATE $<TARGET_PROPERTY:onnxruntime_pybind11_state,COMPILE_DEFINITIONS>)
47target_compile_options(kernel_explorer PRIVATE -Wno-sign-compare)
48
49if (onnxruntime_USE_CUDA)
50file(GLOB kernel_explorer_cuda_kernel_srcs CONFIGURE_DEPENDS
51"${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cc"
52"${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.h"
53"${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cu"
54"${KERNEL_EXPLORER_ROOT}/kernels/cuda/*.cuh"
55)
56target_sources(kernel_explorer PRIVATE ${kernel_explorer_cuda_kernel_srcs})
57target_include_directories(kernel_explorer PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
58elseif (onnxruntime_USE_ROCM)
59file(GLOB kernel_explorer_rocm_kernel_srcs CONFIGURE_DEPENDS
60"${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cc"
61"${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.h"
62"${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cu"
63"${KERNEL_EXPLORER_ROOT}/kernels/rocm/*.cuh"
64)
65auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs})
66target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs})
67target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1)
68if (onnxruntime_USE_COMPOSABLE_KERNEL)
69target_compile_definitions(kernel_explorer PRIVATE USE_COMPOSABLE_KERNEL)
70if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE)
71target_compile_definitions(kernel_explorer PRIVATE USE_COMPOSABLE_KERNEL_CK_TILE)
72endif()
73target_link_libraries(kernel_explorer PRIVATE onnxruntime_composable_kernel_includes)
74endif()
75if (onnxruntime_USE_TRITON_KERNEL)
76target_compile_definitions(kernel_explorer PRIVATE USE_TRITON_KERNEL)
77endif()
78if (onnxruntime_USE_HIPBLASLT)
79target_compile_definitions(kernel_explorer PRIVATE USE_HIPBLASLT)
80endif()
81if (onnxruntime_USE_ROCBLAS_EXTENSION_API)
82target_compile_definitions(kernel_explorer PRIVATE USE_ROCBLAS_EXTENSION_API)
83target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS)
84target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_BETA_FEATURES_API)
85endif()
86endif()
87
88add_dependencies(kernel_explorer onnxruntime_pybind11_state)
89
90enable_testing()
91find_package(Python COMPONENTS Interpreter REQUIRED)
92# add_test(NAME test_kernels COMMAND ${Python_EXECUTABLE} -m pytest ..)
93