google-research

Форк
0
/
fused_softmax_op.cc 
55 строк · 1.9 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/fused_softmax_launcher.h"
16
#include "tensorflow/core/framework/op_kernel.h"
17

18
namespace sgk {
19

20
using ::tensorflow::Tensor;
21
using ::tensorflow::TensorShapeUtils;
22
using ::tensorflow::errors::InvalidArgument;
23

24
template <typename Device, typename T>
25
class FusedSoftmaxOp : public tensorflow::OpKernel {
26
 public:
27
  explicit FusedSoftmaxOp(tensorflow::OpKernelConstruction* context)
28
      : OpKernel(context) {}
29

30
  void Compute(tensorflow::OpKernelContext* context) override {
31
    // Collect the input tensor.
32
    const Tensor& input = context->input(0);
33

34
    // Validate the input shapes.
35
    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input.shape()),
36
                InvalidArgument("Expected 2-dimension input."));
37

38
    // Allocate the output tensor.
39
    Tensor* output = nullptr;
40
    OP_REQUIRES_OK(context,
41
                   context->allocate_output(0, input.shape(), &output));
42

43
    // Launch the kernel.
44
    LaunchFusedSoftmax(context->eigen_device<Device>(), input.dim_size(0),
45
                       input.dim_size(1), input.flat<float>().data(),
46
                       output->flat<float>().data());
47
  }
48
};
49

50
#ifdef GOOGLE_CUDA
51
REGISTER_KERNEL_BUILDER(Name("FusedSoftmax").Device(tensorflow::DEVICE_GPU),
52
                        FusedSoftmaxOp<Eigen::GpuDevice, float>);
53
#endif  // GOOGLE_CUDA
54

55
}  // namespace sgk
56

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.