google-research
69 строк · 2.4 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "sparse/ops/cc/bias_relu_launcher.h"
16#include "sparse/ops/cc/common.h"
17#include "tensorflow/core/framework/op_kernel.h"
18
19namespace sgk {
20
21using ::tensorflow::Tensor;
22using ::tensorflow::TensorShapeUtils;
23using ::tensorflow::errors::InvalidArgument;
24
25template <typename Device, typename T>
26class BiasReluOp : public tensorflow::OpKernel {
27public:
28explicit BiasReluOp(tensorflow::OpKernelConstruction* context)
29: OpKernel(context) {}
30
31void Compute(tensorflow::OpKernelContext* context) override {
32const Tensor& in = context->input(0);
33const Tensor& bias = context->input(1);
34
35// Validate the input shapes.
36OP_REQUIRES(context, in.dims() >= 2 && in.dims() <= 4,
37InvalidArgument("Expected 2-4 dimensional input"));
38OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
39InvalidArgument("Expected 1-dimension bias"));
40OP_REQUIRES(context, bias.dim_size(0) == in.dim_size(1),
41InvalidArgument("Expected one bias value for each channel."));
42
43// Get the problem shape.
44int n = in.dim_size(0);
45int c = in.dim_size(1);
46int d = 1;
47if (in.dims() == 3) {
48d = in.dim_size(2);
49} else if (in.dims() == 4) {
50d = in.dim_size(2) * in.dim_size(3);
51}
52
53// Allocate the output tensor.
54Tensor* out = nullptr;
55OP_REQUIRES_OK(context, context->allocate_output(0, in.shape(), &out));
56
57// Launch the kernel.
58LaunchBiasRelu(context->eigen_device<Device>(), n, c, d,
59in.flat<float>().data(), bias.flat<float>().data(),
60out->flat<float>().data());
61}
62};
63
64#ifdef GOOGLE_CUDA
65REGISTER_KERNEL_BUILDER(Name("BiasRelu").Device(tensorflow::DEVICE_GPU),
66BiasReluOp<Eigen::GpuDevice, float>);
67#endif // GOOGLE_CUDA
68
69} // namespace sgk
70