colossalai
41 строка · 1.0 Кб
1import platform2
3from ..cpp_extension import _CppExtension4
5
6class CpuAdamArmExtension(_CppExtension):7def __init__(self):8super().__init__(name="cpu_adam_arm")9
10def is_hardware_available(self) -> bool:11# only arm allowed12return platform.machine() == "aarch64"13
14def assert_hardware_compatible(self) -> None:15arch = platform.machine()16assert (17arch == "aarch64"18), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}"19
20# necessary 4 functions21def sources_files(self):22ret = [23self.csrc_abs_path("arm/cpu_adam_arm.cpp"),24]25return ret26
27def include_dirs(self):28return []29
30def cxx_flags(self):31extra_cxx_flags = [32"-std=c++14",33"-std=c++17",34"-g",35"-Wno-reorder",36"-fopenmp",37]38return ["-O3"] + self.version_dependent_macros + extra_cxx_flags39
40def nvcc_flags(self):41return []42