gpt-neox

Форк
0
/
scaled_masked_softmax.cpp 
83 строки · 3.1 Кб
1
/* coding=utf-8
2
 * Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
#include <cuda_fp16.h>
18
#include <torch/extension.h>
19
#include <vector>
20

21
namespace multihead_attn {
22
namespace fused_softmax {
23
namespace scaled_masked_softmax {
24

25
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor);
26

27
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
28
                       torch::Tensor const& softmax_results,
29
                       float scale_factor);
30

31
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads);
32

33
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor)
34
{
35
    AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
36
    AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
37
                   (input.scalar_type() == at::ScalarType::BFloat16),
38
               "Only fp16 and bf16 are supported");
39
    AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
40

41
    return fwd_cuda(input, mask, scale_factor);
42
}
43

44
torch::Tensor bwd(torch::Tensor const& output_grads,
45
                  torch::Tensor const& softmax_results,
46
                  float scale_factor)
47
{
48
    AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
49
    AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
50

51
    AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
52
                   (output_grads.scalar_type() == at::ScalarType::BFloat16),
53
               "Only fp16 and bf16 are supported");
54
    AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
55
                   (softmax_results.scalar_type() == at::ScalarType::BFloat16),
56
               "Only fp16 and bf16 are supported");
57

58
    return bwd_cuda(output_grads, softmax_results, scale_factor);
59
}
60

61
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads)
62
{
63
    return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
64
}
65

66
}  // end namespace scaled_masked_softmax
67
}  // end namespace fused_softmax
68
}  // end namespace multihead_attn
69

70
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
71
{
72
    m.def("forward",
73
          &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
74
          "Self Multihead Attention scaled, time masked softmax -- Forward.");
75

76
    m.def("backward",
77
          &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
78
          "Self Multihead Attention scaled, time masked softmax -- Backward.");
79

80
    m.def("get_batch_per_block",
81
          &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
82
          "Return Batch per block size.");
83
}
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.