google-research
172 строки · 7.6 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "sparse/ops/cc/common.h"
16#include "sparse/ops/cc/spmm_launcher.h"
17#include "tensorflow/core/framework/op_kernel.h"
18
19namespace sgk {
20
21using ::tensorflow::Tensor;
22using ::tensorflow::TensorShapeUtils;
23using ::tensorflow::errors::InvalidArgument;
24
25template <typename Device, typename T>
26class SpmmOp : public tensorflow::OpKernel {
27public:
28explicit SpmmOp(tensorflow::OpKernelConstruction* context)
29: OpKernel(context) {
30OP_REQUIRES_OK(context, context->GetAttr("transpose_lhs", &transpose_lhs_));
31OP_REQUIRES_OK(context, context->GetAttr("transpose_rhs", &transpose_rhs_));
32
33// NOTE: We currently do not support transposition for either argument.
34OP_REQUIRES(context, !transpose_lhs_,
35InvalidArgument("transpose_lhs=True not yet supported."));
36OP_REQUIRES(context, !transpose_rhs_,
37InvalidArgument("transpose_rhs=True not yet supported."));
38}
39
40void Compute(tensorflow::OpKernelContext* context) override {
41ComputeHelper(context, /*bias_ptr=*/nullptr);
42}
43
44void ComputeHelper(tensorflow::OpKernelContext* context,
45const float* bias_ptr) {
46// Collect the input & output tensors.
47const Tensor& m_tensor = context->input(0);
48const Tensor& k_tensor = context->input(1);
49const Tensor& values = context->input(2);
50const Tensor& row_indices = context->input(3);
51const Tensor& row_offsets = context->input(4);
52const Tensor& column_indices = context->input(5);
53const Tensor& dense_matrix = context->input(6);
54
55// Validate the input shapes.
56OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),
57InvalidArgument("Expected scalar for argument 'm'."));
58OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
59InvalidArgument("Expected scalar for argument 'k'."));
60OP_REQUIRES(
61context,
62TensorShapeUtils::IsVector(values.shape()) || values.dims() == 2,
63InvalidArgument("Expected 1-dim or 2-dim values tensor."));
64OP_REQUIRES(context, TensorShapeUtils::IsVector(row_indices.shape()),
65InvalidArgument("Expected 1-dimension row_indices tensor."));
66OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),
67InvalidArgument("Expected 1-dimension row_offsets tensor."));
68OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),
69InvalidArgument("Expected 1-dimension column_indices tensor."));
70OP_REQUIRES(context,
71TensorShapeUtils::IsMatrix(dense_matrix.shape()) ||
72dense_matrix.dims() == 3,
73InvalidArgument("Expected 2 or 3-dim dense matrix tensor."));
74OP_REQUIRES(context, row_indices.dim_size(0) + 1 == row_offsets.dim_size(0),
75InvalidArgument("Expected one more row index than offset."));
76
77// TODO(tgale): We can lift this constraint to support arbitrary replication
78// of rhs/lhs matrix. For example, if lhs is a 3-tensor and rhs is a matrix
79// we can compute `lhs.shape[0]` spmms with each kernel using the same rhs
80// matrix.
81OP_REQUIRES(context, values.dims() == dense_matrix.dims() - 1,
82InvalidArgument("Values and rhs must be replicated the same."));
83
84// Get the problem shape.
85int m = m_tensor.tensor<int32, 0>().data()[0];
86int k = k_tensor.tensor<int32, 0>().data()[0];
87int nonzeros = column_indices.dim_size(0);
88
89int dim_offset = dense_matrix.dims() - 2;
90int n = dense_matrix.dim_size(dim_offset + 1);
91int replication = dim_offset == 1 ? dense_matrix.dim_size(0) : 1;
92
93// Validate the sparse matrix and dense matrix shapes match.
94OP_REQUIRES(context, values.dim_size(dim_offset) == nonzeros,
95InvalidArgument("Num values must equal num col indices."));
96OP_REQUIRES(context, row_indices.dim_size(0) == m,
97InvalidArgument("Num row indices and 'm' must match."));
98OP_REQUIRES(context, dense_matrix.dim_size(dim_offset) == k,
99InvalidArgument("Inner matrix dimensions must match."));
100
101// If we're going to run multiple spmms, the first dimension of the
102// matrices must match.
103OP_REQUIRES(context, replication == 1 || replication == values.dim_size(0),
104InvalidArgument("First dim of values and rhs must match"));
105
106// Allocate the output tensor.
107Tensor* output_matrix = nullptr;
108tensorflow::TensorShape output_shape = {m, n};
109if (replication > 1) {
110output_shape = {replication, m, n};
111}
112OP_REQUIRES_OK(context,
113context->allocate_output(0, output_shape, &output_matrix));
114
115// TODO(tgale): Add type checks on meta-data tensors to make sure our
116// casting is safe.
117//
118// TODO(tgale): This could be accelerated by supported replicated/batched
119// execution in the kernel. Running the kernel is a loop like this could
120// incur significant overhead from kernel launch latency if the computation
121// is cheap.
122for (int idx = 0; idx < replication; ++idx) {
123LaunchSpmm(context->eigen_device<Device>(), m, k, n, nonzeros,
124values.flat<float>().data() + nonzeros * idx,
125AsInt32<1>(row_indices), AsInt32<1>(row_offsets),
126AsInt32<1>(column_indices),
127dense_matrix.flat<float>().data() + k * n * idx, bias_ptr,
128output_matrix->flat<float>().data() + m * n * idx);
129}
130}
131
132private:
133bool transpose_lhs_, transpose_rhs_;
134};
135
136template <typename Device, typename T>
137class FusedSpmmOp : public SpmmOp<Device, T> {
138public:
139explicit FusedSpmmOp(tensorflow::OpKernelConstruction* context)
140: SpmmOp<Device, T>(context) {}
141
142void Compute(tensorflow::OpKernelContext* context) override {
143const Tensor& m_tensor = context->input(0);
144const Tensor& bias = context->input(7);
145OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),
146InvalidArgument("Expected scalar for argument 'm'."));
147OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
148InvalidArgument("Expected vector for argument 'bias'."));
149int m = m_tensor.tensor<int32, 0>().data()[0];
150OP_REQUIRES(context, bias.dim_size(0) == m,
151InvalidArgument("Num biases size and 'm' must match."));
152this->ComputeHelper(context, bias.tensor<float, 1>().data());
153}
154};
155
156REGISTER_KERNEL_BUILDER(Name("Spmm").Device(tensorflow::DEVICE_CPU),
157SpmmOp<Eigen::ThreadPoolDevice, float>);
158REGISTER_KERNEL_BUILDER(Name("FusedSpmm").Device(tensorflow::DEVICE_CPU),
159FusedSpmmOp<Eigen::ThreadPoolDevice, float>);
160
161#ifdef GOOGLE_CUDA
162REGISTER_KERNEL_BUILDER(
163Name("Spmm").Device(tensorflow::DEVICE_GPU).HostMemory("m").HostMemory("k"),
164SpmmOp<Eigen::GpuDevice, float>);
165REGISTER_KERNEL_BUILDER(Name("FusedSpmm")
166.Device(tensorflow::DEVICE_GPU)
167.HostMemory("m")
168.HostMemory("k"),
169FusedSpmmOp<Eigen::GpuDevice, float>);
170#endif // GOOGLE_CUDA
171
172} // namespace sgk
173