google-research

Форк
0
/
bias_relu_op.cc 
69 строк · 2.4 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/bias_relu_launcher.h"
16
#include "sparse/ops/cc/common.h"
17
#include "tensorflow/core/framework/op_kernel.h"
18

19
namespace sgk {
20

21
using ::tensorflow::Tensor;
22
using ::tensorflow::TensorShapeUtils;
23
using ::tensorflow::errors::InvalidArgument;
24

25
template <typename Device, typename T>
26
class BiasReluOp : public tensorflow::OpKernel {
27
 public:
28
  explicit BiasReluOp(tensorflow::OpKernelConstruction* context)
29
      : OpKernel(context) {}
30

31
  void Compute(tensorflow::OpKernelContext* context) override {
32
    const Tensor& in = context->input(0);
33
    const Tensor& bias = context->input(1);
34

35
    // Validate the input shapes.
36
    OP_REQUIRES(context, in.dims() >= 2 && in.dims() <= 4,
37
                InvalidArgument("Expected 2-4 dimensional input"));
38
    OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
39
                InvalidArgument("Expected 1-dimension bias"));
40
    OP_REQUIRES(context, bias.dim_size(0) == in.dim_size(1),
41
                InvalidArgument("Expected one bias value for each channel."));
42

43
    // Get the problem shape.
44
    int n = in.dim_size(0);
45
    int c = in.dim_size(1);
46
    int d = 1;
47
    if (in.dims() == 3) {
48
      d = in.dim_size(2);
49
    } else if (in.dims() == 4) {
50
      d = in.dim_size(2) * in.dim_size(3);
51
    }
52

53
    // Allocate the output tensor.
54
    Tensor* out = nullptr;
55
    OP_REQUIRES_OK(context, context->allocate_output(0, in.shape(), &out));
56

57
    // Launch the kernel.
58
    LaunchBiasRelu(context->eigen_device<Device>(), n, c, d,
59
                   in.flat<float>().data(), bias.flat<float>().data(),
60
                   out->flat<float>().data());
61
  }
62
};
63

64
#ifdef GOOGLE_CUDA
65
REGISTER_KERNEL_BUILDER(Name("BiasRelu").Device(tensorflow::DEVICE_GPU),
66
                        BiasReluOp<Eigen::GpuDevice, float>);
67
#endif  // GOOGLE_CUDA
68

69
}  // namespace sgk
70

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.