google-research
55 строк · 1.9 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "sparse/ops/cc/fused_softmax_launcher.h"
16#include "tensorflow/core/framework/op_kernel.h"
17
18namespace sgk {
19
20using ::tensorflow::Tensor;
21using ::tensorflow::TensorShapeUtils;
22using ::tensorflow::errors::InvalidArgument;
23
24template <typename Device, typename T>
25class FusedSoftmaxOp : public tensorflow::OpKernel {
26public:
27explicit FusedSoftmaxOp(tensorflow::OpKernelConstruction* context)
28: OpKernel(context) {}
29
30void Compute(tensorflow::OpKernelContext* context) override {
31// Collect the input tensor.
32const Tensor& input = context->input(0);
33
34// Validate the input shapes.
35OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input.shape()),
36InvalidArgument("Expected 2-dimension input."));
37
38// Allocate the output tensor.
39Tensor* output = nullptr;
40OP_REQUIRES_OK(context,
41context->allocate_output(0, input.shape(), &output));
42
43// Launch the kernel.
44LaunchFusedSoftmax(context->eigen_device<Device>(), input.dim_size(0),
45input.dim_size(1), input.flat<float>().data(),
46output->flat<float>().data());
47}
48};
49
50#ifdef GOOGLE_CUDA
51REGISTER_KERNEL_BUILDER(Name("FusedSoftmax").Device(tensorflow::DEVICE_GPU),
52FusedSoftmaxOp<Eigen::GpuDevice, float>);
53#endif // GOOGLE_CUDA
54
55} // namespace sgk
56