deepspeed
62 строки · 2.0 Кб
1// Copyright (c) Microsoft Corporation.
2// SPDX-License-Identifier: Apache-2.0
3
4// DeepSpeed Team
5
6#include <torch/extension.h>
7
8void attention_impl(torch::Tensor& q,
9torch::Tensor& k,
10torch::Tensor& v,
11torch::Tensor& bias1,
12torch::Tensor& bias2,
13torch::Tensor& o,
14torch::Tensor& lse);
15void attention(torch::Tensor& q,
16torch::Tensor& k,
17torch::Tensor& v,
18torch::Tensor& bias1,
19torch::Tensor& bias2,
20torch::Tensor& o,
21torch::Tensor& lse)
22{
23attention_impl(q, k, v, bias1, bias2, o, lse);
24}
25
26void attention_back_impl(torch::Tensor& go,
27torch::Tensor& q,
28torch::Tensor& k,
29torch::Tensor& v,
30torch::Tensor& o,
31torch::Tensor& lse,
32torch::Tensor& delta,
33torch::Tensor& bias1,
34torch::Tensor& bias2,
35torch::Tensor& gq,
36torch::Tensor& gk,
37torch::Tensor& gv,
38torch::Tensor& gb1,
39torch::Tensor& gb2);
40void attention_bwd(torch::Tensor& go,
41torch::Tensor& q,
42torch::Tensor& k,
43torch::Tensor& v,
44torch::Tensor& o,
45torch::Tensor& lse,
46torch::Tensor& delta,
47torch::Tensor& bias1,
48torch::Tensor& bias2,
49torch::Tensor& gq,
50torch::Tensor& gk,
51torch::Tensor& gv,
52torch::Tensor& gb1,
53torch::Tensor& gb2)
54{
55attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2);
56}
57
58PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
59{
60m.def("attention", &attention, "");
61m.def("attention_bwd", &attention_bwd, "");
62}
63