google-research

Форк
0
/
fused_depthwise_op.cc 
194 строки · 8.1 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/common.h"
16
#include "sparse/ops/cc/fused_depthwise_launcher.h"
17
#include "tensorflow/core/framework/kernel_shape_util.h"
18
#include "tensorflow/core/framework/op_kernel.h"
19
#include "tensorflow/core/util/padding.h"
20

21
namespace sgk {
22

23
using ::tensorflow::Tensor;
24
using ::tensorflow::TensorFormat;
25
using ::tensorflow::TensorShape;
26
using ::tensorflow::TensorShapeUtils;
27
using ::tensorflow::errors::InvalidArgument;
28

29
template <typename Device, typename T>
30
class DepthwiseConvOp : public tensorflow::OpKernel {
31
 public:
32
  explicit DepthwiseConvOp(tensorflow::OpKernelConstruction* context)
33
      : OpKernel(context) {
34
    OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
35

36
    // NOTE: This op only supports NCHW format.
37
    std::string data_format = "NCHW";
38
    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
39
                InvalidArgument("Invalid data format"));
40

41
    // NOTE: This kernel only supports matching strides for the H & W
42
    // dimensions, and does not support stride in the channel and batch
43
    // dimensions.
44
    OP_REQUIRES(context, strides_.size() == 4,
45
                InvalidArgument("Sliding window strides field must "
46
                                "specify 4 dimensions"));
47
    stride_ = GetTensorDim(strides_, data_format_, 'H');
48
    const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
49
    const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
50
    const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
51

52
    OP_REQUIRES(
53
        context, stride_ == stride_w,
54
        InvalidArgument("Current implementation only supports equal length "
55
                        "strides in the row and column dimensions."));
56
    OP_REQUIRES(context, (stride_n == 1 && stride_c == 1),
57
                InvalidArgument("Current implementation does not yet support "
58
                                "strides in the batch and depth dimensions."));
59
    OP_REQUIRES_OK(context, context->GetAttr("explicit_paddings", &padding_));
60
  }
61

62
  void Compute(tensorflow::OpKernelContext* context) override {
63
    ComputeHelper(context, /*bias_ptr=*/nullptr);
64
  }
65

66
  void ComputeHelper(tensorflow::OpKernelContext* context,
67
                     const float* bias_ptr) {
68
    // Input tensor is of the following dimensions:
69
    // [ batch, in_rows, in_cols, in_depth ]
70
    const Tensor& input = context->input(0);
71

72
    // Input filter is of the following dimensions:
73
    // [ filter_rows, filter_cols, in_depth, depth_multiplier]
74
    const Tensor& filter = context->input(1);
75

76
    // For 2D convolution, there should be 4 dimensions.
77
    OP_REQUIRES(context, input.dims() == 4,
78
                InvalidArgument("input must be 4-dimensional",
79
                                input.shape().DebugString()));
80
    OP_REQUIRES(context, filter.dims() == 4,
81
                InvalidArgument("filter must be 4-dimensional: ",
82
                                filter.shape().DebugString()));
83

84
    // in_depth for input and filter must match.
85
    const int64 in_depth = GetTensorDim(input, data_format_, 'C');
86
    OP_REQUIRES(context, in_depth == filter.dim_size(0),
87
                InvalidArgument("input and filter must have the same depth: ",
88
                                in_depth, " vs ", filter.dim_size(0)));
89

90
    // The last dimension for filter is depth multiplier.
91
    //
92
    // NOTE: We only support depth_multiplier == 1.
93
    const int32 depth_multiplier = filter.dim_size(3);
94
    OP_REQUIRES(context, depth_multiplier == 1,
95
                InvalidArgument("Depth multiplier must be 1."));
96

97
    // The output depth is input depth x depth multiplier
98
    const int32 out_depth = in_depth * depth_multiplier;
99

100
    // NOTE: We only support 3x3 kernels.
101
    const int32 input_rows = GetTensorDim(input, data_format_, 'H');
102
    const int32 filter_rows = filter.dim_size(1);
103
    const int32 input_cols = GetTensorDim(input, data_format_, 'W');
104
    const int32 filter_cols = filter.dim_size(2);
105
    OP_REQUIRES(context, input_rows == input_cols,
106
                InvalidArgument("Only supports square images."));
107
    OP_REQUIRES(context, filter_rows == 3,
108
                InvalidArgument("Only supports 3x3 kernels."));
109
    OP_REQUIRES(context, filter_cols == 3,
110
                InvalidArgument("Only supports 3x3 kernels."));
111

112
    // The first dimension for input is batch.
113
    const int32 batch = input.dim_size(0);
114

115
    // Get and validate the padding arguments.
116
    int64 pad_rows = GetTensorDim(padding_, data_format_, 'H');
117
    int64 pad_cols = GetTensorDim(padding_, data_format_, 'W');
118
    OP_REQUIRES(context, GetTensorDim(padding_, data_format_, 'C') == 0,
119
                InvalidArgument("Channel padding not supported."));
120
    OP_REQUIRES(context, GetTensorDim(padding_, data_format_, 'N') == 0,
121
                InvalidArgument("Batch padding not supported."));
122
    OP_REQUIRES(context, pad_rows == pad_cols,
123
                InvalidArgument("Height and width padding must match."));
124

125
    int64 out_rows = 0, out_cols = 0;
126
    OP_REQUIRES_OK(
127
        context, tensorflow::GetWindowedOutputSizeVerbose(
128
                     input_rows, filter_rows, /* dilation_rate = */ 1, stride_,
129
                     /* padding_type = */ tensorflow::EXPLICIT, &out_rows,
130
                     &pad_rows, &pad_rows));
131
    OP_REQUIRES_OK(
132
        context, tensorflow::GetWindowedOutputSizeVerbose(
133
                     input_cols, filter_cols, /* dilation_rate = */ 1, stride_,
134
                     /* padding_type = */ tensorflow::EXPLICIT, &out_cols,
135
                     &pad_cols, &pad_cols));
136

137
    // Setup and allocate the output tensor.
138
    TensorShape out_shape = {batch, out_depth, out_rows, out_cols};
139
    Tensor* output = nullptr;
140
    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
141

142
    // If there is nothing to compute, return.
143
    if (out_shape.num_elements() == 0) {
144
      return;
145
    }
146

147
    LaunchFusedDepthwiseConv(
148
        context->eigen_device<Device>(), batch, in_depth, input_rows,
149
        input_cols, input.template flat<T>().data(), filter_rows, pad_rows,
150
        stride_, filter.template flat<T>().data(), bias_ptr,
151
        output->template flat<T>().data());
152
  }
153

154
 protected:
155
  std::vector<int32> strides_;
156
  std::vector<int64> padding_;
157
  TensorFormat data_format_;
158
  int64 stride_;  // in height/width dimension.
159
};
160

161
template <typename Device, typename T>
162
class FusedDepthwiseConvOp : public DepthwiseConvOp<Device, T> {
163
 public:
164
  explicit FusedDepthwiseConvOp(tensorflow::OpKernelConstruction* context)
165
      : DepthwiseConvOp<Device, T>(context) {}
166

167
  void Compute(tensorflow::OpKernelContext* context) override {
168
    // Input bias is of the following dimensions:
169
    // [in_depth * depth_multiplier]
170
    const Tensor& bias = context->input(2);
171

172
    OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
173
                InvalidArgument("Bias must be 1-dimensional."));
174

175
    const int64 out_depth =
176
        GetTensorDim(context->input(0), this->data_format_, 'C');
177
    OP_REQUIRES(context, out_depth == bias.dim_size(0),
178
                InvalidArgument("Bias must match output depth."));
179
    this->ComputeHelper(context, bias.template flat<T>().data());
180
  }
181
};
182

183
#ifdef GOOGLE_CUDA
184
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv")
185
                            .Device(tensorflow::DEVICE_GPU)
186
                            .TypeConstraint<float>("T"),
187
                        DepthwiseConvOp<Eigen::GpuDevice, float>);
188
REGISTER_KERNEL_BUILDER(Name("FusedDepthwiseConv")
189
                            .Device(tensorflow::DEVICE_GPU)
190
                            .TypeConstraint<float>("T"),
191
                        FusedDepthwiseConvOp<Eigen::GpuDevice, float>);
192
#endif  // GOOGLE_CUDA
193

194
}  // namespace sgk
195

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.