deepspeed

Форк
0
/
random_ltd.py 
34 строки · 879.0 Байт
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
from .builder import CUDAOpBuilder
7

8

9
class RandomLTDBuilder(CUDAOpBuilder):
10
    BUILD_VAR = "DS_BUILD_RANDOM_LTD"
11
    NAME = "random_ltd"
12

13
    def __init__(self, name=None):
14
        name = self.NAME if name is None else name
15
        super().__init__(name=name)
16

17
    def absolute_name(self):
18
        return f'deepspeed.ops.{self.NAME}_op'
19

20
    def extra_ldflags(self):
21
        if not self.is_rocm_pytorch():
22
            return ['-lcurand']
23
        else:
24
            return []
25

26
    def sources(self):
27
        return [
28
            'csrc/random_ltd/pt_binding.cpp', 'csrc/random_ltd/gather_scatter.cu',
29
            'csrc/random_ltd/slice_attn_masks.cu', 'csrc/random_ltd/token_sort.cu'
30
        ]
31

32
    def include_paths(self):
33
        includes = ['csrc/includes']
34
        return includes
35

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.