deepspeed
34 строки · 879.0 Байт
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6from .builder import CUDAOpBuilder7
8
9class RandomLTDBuilder(CUDAOpBuilder):10BUILD_VAR = "DS_BUILD_RANDOM_LTD"11NAME = "random_ltd"12
13def __init__(self, name=None):14name = self.NAME if name is None else name15super().__init__(name=name)16
17def absolute_name(self):18return f'deepspeed.ops.{self.NAME}_op'19
20def extra_ldflags(self):21if not self.is_rocm_pytorch():22return ['-lcurand']23else:24return []25
26def sources(self):27return [28'csrc/random_ltd/pt_binding.cpp', 'csrc/random_ltd/gather_scatter.cu',29'csrc/random_ltd/slice_attn_masks.cu', 'csrc/random_ltd/token_sort.cu'30]31
32def include_paths(self):33includes = ['csrc/includes']34return includes35