google-research
81 строка · 3.2 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "sparse/ops/cc/common.h"16#include "sparse/ops/cc/csr2idx_launcher.h"17#include "tensorflow/core/framework/op_kernel.h"18
19namespace sgk {20
21using ::tensorflow::Tensor;22using ::tensorflow::TensorShape;23using ::tensorflow::TensorShapeUtils;24using ::tensorflow::errors::InvalidArgument;25
26template <typename Device, typename T>27class Csr2idxOp : public tensorflow::OpKernel {28public:29explicit Csr2idxOp(tensorflow::OpKernelConstruction* context)30: OpKernel(context) {}31
32void Compute(tensorflow::OpKernelContext* context) override {33// Collect the input & output tensors.34const Tensor& m_tensor = context->input(0);35const Tensor& n_tensor = context->input(1);36const Tensor& row_offsets = context->input(2);37const Tensor& column_indices = context->input(3);38
39// Validate the input shapes.40OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),41InvalidArgument("Expected scalar for argument 'm'."));42OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_tensor.shape()),43InvalidArgument("Expected scalar for argument 'n'."));44OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),45InvalidArgument("Expected 1-dimension row_offsets tensor."));46OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),47InvalidArgument("Expected 1-dimension column_indices tensor."));48
49// Get the problem shape.50int m = m_tensor.tensor<int32, 0>().data()[0];51int n = n_tensor.tensor<int32, 0>().data()[0];52int nonzeros = column_indices.dim_size(0);53
54// Validate row offsets size.55OP_REQUIRES(context, row_offsets.dim_size(0) == m + 1,56InvalidArgument("Expected m+1 row offsets."));57
58// Allocate the output tensor.59Tensor* linear_indices = nullptr;60OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{nonzeros},61&linear_indices));62
63// Launch the kernel.64LaunchCsr2idx(context->eigen_device<Device>(), m, n, nonzeros,65AsInt32<1>(row_offsets), AsInt32<1>(column_indices),66AsInt32<1>(linear_indices));67}68};69
70REGISTER_KERNEL_BUILDER(Name("Csr2idx").Device(tensorflow::DEVICE_CPU),71Csr2idxOp<Eigen::ThreadPoolDevice, float>);72
73#ifdef GOOGLE_CUDA74REGISTER_KERNEL_BUILDER(Name("Csr2idx")75.Device(tensorflow::DEVICE_GPU)76.HostMemory("m")77.HostMemory("n"),78Csr2idxOp<Eigen::GpuDevice, float>);79#endif // GOOGLE_CUDA80
81} // namespace sgk82