google-research
144 строки · 6.4 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "sparse/ops/cc/common.h"
16#include "sparse/ops/cc/sddmm_launcher.h"
17#include "tensorflow/core/framework/op_kernel.h"
18
19namespace sgk {
20
21using ::tensorflow::Tensor;
22using ::tensorflow::TensorShapeUtils;
23using ::tensorflow::errors::InvalidArgument;
24
25template <typename Device, typename T>
26class SddmmOp : public tensorflow::OpKernel {
27public:
28explicit SddmmOp(tensorflow::OpKernelConstruction* context)
29: OpKernel(context) {
30OP_REQUIRES_OK(context, context->GetAttr("transpose_lhs", &transpose_lhs_));
31OP_REQUIRES_OK(context, context->GetAttr("transpose_rhs", &transpose_rhs_));
32
33// NOTE: We currently do not support transposition for either argument.
34OP_REQUIRES(context, !transpose_lhs_,
35InvalidArgument("transpose_lhs=True not yet supported."));
36OP_REQUIRES(context, transpose_rhs_,
37InvalidArgument("transpose_rhs=False not yet supported."));
38}
39
40void Compute(tensorflow::OpKernelContext* context) override {
41// Collect the input & output tensors.
42const Tensor& m_tensor = context->input(0);
43const Tensor& n_tensor = context->input(1);
44const Tensor& row_indices = context->input(2);
45const Tensor& row_offsets = context->input(3);
46const Tensor& column_indices = context->input(4);
47const Tensor& lhs_matrix = context->input(5);
48const Tensor& rhs_matrix = context->input(6);
49
50// Validate the input shapes.
51OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),
52InvalidArgument("Expected scalar for argument 'm'."));
53OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_tensor.shape()),
54InvalidArgument("Expected scalar for argument 'n'."));
55OP_REQUIRES(context, TensorShapeUtils::IsVector(row_indices.shape()),
56InvalidArgument("Expected 1-dimension row_indices tensor."));
57OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),
58InvalidArgument("Expected 1-dimension row_offsets tensor."));
59OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),
60InvalidArgument("Expected 1-dimension column_indices tensor."));
61OP_REQUIRES(context, row_indices.dim_size(0) + 1 == row_offsets.dim_size(0),
62InvalidArgument("Expected 1 more row index than offset."));
63OP_REQUIRES(context, lhs_matrix.dim_size(1) == rhs_matrix.dim_size(1),
64InvalidArgument("Last dim of input matrices must match."));
65OP_REQUIRES(context,
66TensorShapeUtils::IsMatrix(lhs_matrix.shape()) ||
67lhs_matrix.dims() == 3,
68InvalidArgument("Expected 2-dim or 3-dim lhs matrix tensor."));
69OP_REQUIRES(context,
70TensorShapeUtils::IsMatrix(rhs_matrix.shape()) ||
71rhs_matrix.dims() == 3,
72InvalidArgument("Expected 2-dim or 3-dim rhs matrix tensor."));
73
74// TODO(tgale): We can lift this constraint to support arbitrary replication
75// of rhs/lhs matrix. For example, if lhs is a 3-tensor and rhs is a matrix
76// we can compute `lhs.shape[0]` sddmms with each kernel using the same rhs
77// matrix.
78OP_REQUIRES(context, rhs_matrix.dims() == lhs_matrix.dims(),
79InvalidArgument("rhs and lhs must match number of dims."));
80
81// Get the problem shape.
82int m = m_tensor.tensor<int32, 0>().data()[0];
83int n = n_tensor.tensor<int32, 0>().data()[0];
84int nonzeros = column_indices.dim_size(0);
85
86int dim_offset = lhs_matrix.dims() - 2;
87int k = lhs_matrix.dim_size(dim_offset + 1);
88int replication = dim_offset == 1 ? lhs_matrix.dim_size(0) : 1;
89
90// Validate the sparse matrix shape.
91OP_REQUIRES(context, row_indices.dim_size(0) == m,
92InvalidArgument("Num row indices and 'm' must match."));
93OP_REQUIRES(context, lhs_matrix.dim_size(dim_offset) == m,
94InvalidArgument("First dim of lhs must match output rows."));
95OP_REQUIRES(context, rhs_matrix.dim_size(dim_offset) == n,
96InvalidArgument("First dim of lhs must match output cols."));
97
98// If we're going to run multiple sddmms, the first dimension of the
99// matrices must match.
100OP_REQUIRES(context,
101replication == 1 || replication == rhs_matrix.dim_size(0),
102InvalidArgument("First dim of lhs & rhs must match"));
103
104// Allocate the output tensor.
105Tensor* output_values = nullptr;
106tensorflow::TensorShape output_shape = {nonzeros};
107if (replication > 1) {
108output_shape = {replication, nonzeros};
109}
110OP_REQUIRES_OK(context,
111context->allocate_output(0, output_shape, &output_values));
112
113// Launch the kernel.
114//
115// TODO(tgale): This could be accelerated by supported replicated/batched
116// execution in the kernel. Running the kernel is a loop like this could
117// incur significant overhead from kernel launch latency if the computation
118// is cheap.
119for (int idx = 0; idx < replication; ++idx) {
120LaunchSddmm(context->eigen_device<Device>(), m, k, n, nonzeros,
121AsInt32<1>(row_indices), AsInt32<1>(row_offsets),
122AsInt32<1>(column_indices),
123lhs_matrix.flat<float>().data() + m * k * idx,
124rhs_matrix.flat<float>().data() + k * n * idx,
125output_values->flat<float>().data() + nonzeros * idx);
126}
127}
128
129private:
130bool transpose_lhs_, transpose_rhs_;
131};
132
133REGISTER_KERNEL_BUILDER(Name("Sddmm").Device(tensorflow::DEVICE_CPU),
134SddmmOp<Eigen::ThreadPoolDevice, float>);
135
136#ifdef GOOGLE_CUDA
137REGISTER_KERNEL_BUILDER(Name("Sddmm")
138.Device(tensorflow::DEVICE_GPU)
139.HostMemory("m")
140.HostMemory("n"),
141SddmmOp<Eigen::GpuDevice, float>);
142#endif // GOOGLE_CUDA
143
144} // namespace sgk
145