google-research
92 строки · 3.8 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "sparse/ops/cc/common.h"16#include "sparse/ops/cc/sparse_softmax_launcher.h"17#include "tensorflow/core/framework/op_kernel.h"18
19namespace sgk {20
21using ::tensorflow::Tensor;22using ::tensorflow::TensorShapeUtils;23using ::tensorflow::errors::InvalidArgument;24
25template <typename Device, typename T>26class SparseSoftmaxOp : public tensorflow::OpKernel {27public:28explicit SparseSoftmaxOp(tensorflow::OpKernelConstruction* context)29: OpKernel(context) {}30
31void Compute(tensorflow::OpKernelContext* context) override {32// Collect the input & output tensors.33const Tensor& values = context->input(0);34const Tensor& row_indices = context->input(1);35const Tensor& row_offsets = context->input(2);36const Tensor& column_indices = context->input(3);37
38// Validate the input shapes.39OP_REQUIRES(context, TensorShapeUtils::IsVector(row_indices.shape()),40InvalidArgument("Expected 1-dimension row_indices tensor."));41OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),42InvalidArgument("Expected 1-dimension row_offsets tensor."));43OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),44InvalidArgument("Expected 1-dimension column_indices tensor."));45OP_REQUIRES(context, row_indices.dim_size(0) + 1 == row_offsets.dim_size(0),46InvalidArgument("Expected 1 more row index than offset."));47OP_REQUIRES(48context,49TensorShapeUtils::IsVector(values.shape()) || values.dims() == 2,50InvalidArgument("Expected 1-dim or 2-dim values tensor."));51
52// Get the problem shape.53//54// NOTE: The kernel doesn't actually need the n argument. Pass garbage,55// since we can't pull it off the sparse matrix representation.56int m = row_indices.dim_size(0);57int n = -1;58int nonzeros = column_indices.dim_size(0);59int dim_offset = values.dims() - 1;60int replication = dim_offset == 1 ? values.dim_size(0) : 1;61
62// Validate the sparse matrix shape.63OP_REQUIRES(context, values.dim_size(dim_offset) == nonzeros,64InvalidArgument("Num values must equal num col indices."));65
66// Allocate the output tensor.67Tensor* output_values = nullptr;68OP_REQUIRES_OK(context,69context->allocate_output(0, values.shape(), &output_values));70
71// Launch the kernel for each step of computation.72//73// TODO(tgale): This could be accelerated by supported replicated/batched74// execution in the kernel. Running the kernel is a loop like this could75// incur significant overhead from kernel launch latency if the computation76// is cheap.77for (int idx = 0; idx < replication; ++idx) {78LaunchSparseSoftmax(context->eigen_device<Device>(), m, n, nonzeros,79values.flat<float>().data() + nonzeros * idx,80AsInt32<1>(row_indices), AsInt32<1>(row_offsets),81AsInt32<1>(column_indices),82output_values->flat<float>().data() + nonzeros * idx);83}84}85};86
87#ifdef GOOGLE_CUDA88REGISTER_KERNEL_BUILDER(Name("CsrSoftmax").Device(tensorflow::DEVICE_GPU),89SparseSoftmaxOp<Eigen::GpuDevice, float>);90#endif // GOOGLE_CUDA91
92} // namespace sgk93