onnxruntime
35 строк · 1.4 Кб
1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4find_package(Python3 COMPONENTS Interpreter REQUIRED)
5
6# set all triton kernel ops that need to be compiled
7if(onnxruntime_USE_ROCM)
8set(triton_kernel_scripts
9"onnxruntime/core/providers/rocm/math/softmax_triton.py"
10"onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py"
11)
12endif()
13
14function(compile_triton_kernel out_triton_kernel_obj_file out_triton_kernel_header_dir)
15# compile triton kernel, generate .a and .h files
16set(triton_kernel_compiler "${REPO_ROOT}/tools/ci_build/compile_triton.py")
17set(out_dir "${CMAKE_CURRENT_BINARY_DIR}/triton_kernels")
18set(out_obj_file "${out_dir}/triton_kernel_infos.a")
19set(header_file "${out_dir}/triton_kernel_infos.h")
20
21list(TRANSFORM triton_kernel_scripts PREPEND "${REPO_ROOT}/")
22
23add_custom_command(
24OUTPUT ${out_obj_file} ${header_file}
25COMMAND Python3::Interpreter ${triton_kernel_compiler}
26--header ${header_file}
27--script_files ${triton_kernel_scripts}
28--obj_file ${out_obj_file}
29DEPENDS ${triton_kernel_scripts} ${triton_kernel_compiler}
30COMMENT "Triton compile generates: ${out_obj_file}"
31)
32add_custom_target(onnxruntime_triton_kernel DEPENDS ${out_obj_file} ${header_file})
33set(${out_triton_kernel_obj_file} ${out_obj_file} PARENT_SCOPE)
34set(${out_triton_kernel_header_dir} ${out_dir} PARENT_SCOPE)
35endfunction()
36