deepspeed

Форк
0
62 строки · 2.0 Кб
1
// Copyright (c) Microsoft Corporation.
2
// SPDX-License-Identifier: Apache-2.0
3

4
// DeepSpeed Team
5

6
#include <torch/extension.h>
7

8
void attention_impl(torch::Tensor& q,
9
                    torch::Tensor& k,
10
                    torch::Tensor& v,
11
                    torch::Tensor& bias1,
12
                    torch::Tensor& bias2,
13
                    torch::Tensor& o,
14
                    torch::Tensor& lse);
15
void attention(torch::Tensor& q,
16
               torch::Tensor& k,
17
               torch::Tensor& v,
18
               torch::Tensor& bias1,
19
               torch::Tensor& bias2,
20
               torch::Tensor& o,
21
               torch::Tensor& lse)
22
{
23
    attention_impl(q, k, v, bias1, bias2, o, lse);
24
}
25

26
void attention_back_impl(torch::Tensor& go,
27
                         torch::Tensor& q,
28
                         torch::Tensor& k,
29
                         torch::Tensor& v,
30
                         torch::Tensor& o,
31
                         torch::Tensor& lse,
32
                         torch::Tensor& delta,
33
                         torch::Tensor& bias1,
34
                         torch::Tensor& bias2,
35
                         torch::Tensor& gq,
36
                         torch::Tensor& gk,
37
                         torch::Tensor& gv,
38
                         torch::Tensor& gb1,
39
                         torch::Tensor& gb2);
40
void attention_bwd(torch::Tensor& go,
41
                   torch::Tensor& q,
42
                   torch::Tensor& k,
43
                   torch::Tensor& v,
44
                   torch::Tensor& o,
45
                   torch::Tensor& lse,
46
                   torch::Tensor& delta,
47
                   torch::Tensor& bias1,
48
                   torch::Tensor& bias2,
49
                   torch::Tensor& gq,
50
                   torch::Tensor& gk,
51
                   torch::Tensor& gv,
52
                   torch::Tensor& gb1,
53
                   torch::Tensor& gb2)
54
{
55
    attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2);
56
}
57

58
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
59
{
60
    m.def("attention", &attention, "");
61
    m.def("attention_bwd", &attention_bwd, "");
62
}
63

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.