deepspeed

Форк
0
/
cpu_adagrad.py 
43 строки · 1.1 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
import os
7
from .builder import TorchCPUOpBuilder
8

9

10
class CPUAdagradBuilder(TorchCPUOpBuilder):
11
    BUILD_VAR = "DS_BUILD_CPU_ADAGRAD"
12
    NAME = "cpu_adagrad"
13

14
    def __init__(self):
15
        super().__init__(name=self.NAME)
16

17
    def absolute_name(self):
18
        return f'deepspeed.ops.adagrad.{self.NAME}_op'
19

20
    def sources(self):
21
        if self.build_for_cpu:
22
            return ['csrc/adagrad/cpu_adagrad.cpp']
23

24
        return ['csrc/adagrad/cpu_adagrad.cpp', 'csrc/common/custom_cuda_kernel.cu']
25

26
    def libraries_args(self):
27
        args = super().libraries_args()
28
        if self.build_for_cpu:
29
            return args
30

31
        if not self.is_rocm_pytorch():
32
            args += ['curand']
33
        return args
34

35
    def include_paths(self):
36
        import torch
37
        if self.build_for_cpu:
38
            CUDA_INCLUDE = []
39
        elif not self.is_rocm_pytorch():
40
            CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
41
        else:
42
            CUDA_INCLUDE = []
43
        return ['csrc/includes'] + CUDA_INCLUDE
44

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.