google-research

Форк
0
172 строки · 7.6 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/common.h"
16
#include "sparse/ops/cc/spmm_launcher.h"
17
#include "tensorflow/core/framework/op_kernel.h"
18

19
namespace sgk {
20

21
using ::tensorflow::Tensor;
22
using ::tensorflow::TensorShapeUtils;
23
using ::tensorflow::errors::InvalidArgument;
24

25
template <typename Device, typename T>
26
class SpmmOp : public tensorflow::OpKernel {
27
 public:
28
  explicit SpmmOp(tensorflow::OpKernelConstruction* context)
29
      : OpKernel(context) {
30
    OP_REQUIRES_OK(context, context->GetAttr("transpose_lhs", &transpose_lhs_));
31
    OP_REQUIRES_OK(context, context->GetAttr("transpose_rhs", &transpose_rhs_));
32

33
    // NOTE: We currently do not support transposition for either argument.
34
    OP_REQUIRES(context, !transpose_lhs_,
35
                InvalidArgument("transpose_lhs=True not yet supported."));
36
    OP_REQUIRES(context, !transpose_rhs_,
37
                InvalidArgument("transpose_rhs=True not yet supported."));
38
  }
39

40
  void Compute(tensorflow::OpKernelContext* context) override {
41
    ComputeHelper(context, /*bias_ptr=*/nullptr);
42
  }
43

44
  void ComputeHelper(tensorflow::OpKernelContext* context,
45
                     const float* bias_ptr) {
46
    // Collect the input & output tensors.
47
    const Tensor& m_tensor = context->input(0);
48
    const Tensor& k_tensor = context->input(1);
49
    const Tensor& values = context->input(2);
50
    const Tensor& row_indices = context->input(3);
51
    const Tensor& row_offsets = context->input(4);
52
    const Tensor& column_indices = context->input(5);
53
    const Tensor& dense_matrix = context->input(6);
54

55
    // Validate the input shapes.
56
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),
57
                InvalidArgument("Expected scalar for argument 'm'."));
58
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
59
                InvalidArgument("Expected scalar for argument 'k'."));
60
    OP_REQUIRES(
61
        context,
62
        TensorShapeUtils::IsVector(values.shape()) || values.dims() == 2,
63
        InvalidArgument("Expected 1-dim or 2-dim values tensor."));
64
    OP_REQUIRES(context, TensorShapeUtils::IsVector(row_indices.shape()),
65
                InvalidArgument("Expected 1-dimension row_indices tensor."));
66
    OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),
67
                InvalidArgument("Expected 1-dimension row_offsets tensor."));
68
    OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),
69
                InvalidArgument("Expected 1-dimension column_indices tensor."));
70
    OP_REQUIRES(context,
71
                TensorShapeUtils::IsMatrix(dense_matrix.shape()) ||
72
                    dense_matrix.dims() == 3,
73
                InvalidArgument("Expected 2 or 3-dim dense matrix tensor."));
74
    OP_REQUIRES(context, row_indices.dim_size(0) + 1 == row_offsets.dim_size(0),
75
                InvalidArgument("Expected one more row index than offset."));
76

77
    // TODO(tgale): We can lift this constraint to support arbitrary replication
78
    // of rhs/lhs matrix. For example, if lhs is a 3-tensor and rhs is a matrix
79
    // we can compute `lhs.shape[0]` spmms with each kernel using the same rhs
80
    // matrix.
81
    OP_REQUIRES(context, values.dims() == dense_matrix.dims() - 1,
82
                InvalidArgument("Values and rhs must be replicated the same."));
83

84
    // Get the problem shape.
85
    int m = m_tensor.tensor<int32, 0>().data()[0];
86
    int k = k_tensor.tensor<int32, 0>().data()[0];
87
    int nonzeros = column_indices.dim_size(0);
88

89
    int dim_offset = dense_matrix.dims() - 2;
90
    int n = dense_matrix.dim_size(dim_offset + 1);
91
    int replication = dim_offset == 1 ? dense_matrix.dim_size(0) : 1;
92

93
    // Validate the sparse matrix and dense matrix shapes match.
94
    OP_REQUIRES(context, values.dim_size(dim_offset) == nonzeros,
95
                InvalidArgument("Num values must equal num col indices."));
96
    OP_REQUIRES(context, row_indices.dim_size(0) == m,
97
                InvalidArgument("Num row indices and 'm' must match."));
98
    OP_REQUIRES(context, dense_matrix.dim_size(dim_offset) == k,
99
                InvalidArgument("Inner matrix dimensions must match."));
100

101
    // If we're going to run multiple spmms, the first dimension of the
102
    // matrices must match.
103
    OP_REQUIRES(context, replication == 1 || replication == values.dim_size(0),
104
                InvalidArgument("First dim of values and rhs must match"));
105

106
    // Allocate the output tensor.
107
    Tensor* output_matrix = nullptr;
108
    tensorflow::TensorShape output_shape = {m, n};
109
    if (replication > 1) {
110
      output_shape = {replication, m, n};
111
    }
112
    OP_REQUIRES_OK(context,
113
                   context->allocate_output(0, output_shape, &output_matrix));
114

115
    // TODO(tgale): Add type checks on meta-data tensors to make sure our
116
    // casting is safe.
117
    //
118
    // TODO(tgale): This could be accelerated by supported replicated/batched
119
    // execution in the kernel. Running the kernel is a loop like this could
120
    // incur significant overhead from kernel launch latency if the computation
121
    // is cheap.
122
    for (int idx = 0; idx < replication; ++idx) {
123
      LaunchSpmm(context->eigen_device<Device>(), m, k, n, nonzeros,
124
                 values.flat<float>().data() + nonzeros * idx,
125
                 AsInt32<1>(row_indices), AsInt32<1>(row_offsets),
126
                 AsInt32<1>(column_indices),
127
                 dense_matrix.flat<float>().data() + k * n * idx, bias_ptr,
128
                 output_matrix->flat<float>().data() + m * n * idx);
129
    }
130
  }
131

132
 private:
133
  bool transpose_lhs_, transpose_rhs_;
134
};
135

136
template <typename Device, typename T>
137
class FusedSpmmOp : public SpmmOp<Device, T> {
138
 public:
139
  explicit FusedSpmmOp(tensorflow::OpKernelConstruction* context)
140
      : SpmmOp<Device, T>(context) {}
141

142
  void Compute(tensorflow::OpKernelContext* context) override {
143
    const Tensor& m_tensor = context->input(0);
144
    const Tensor& bias = context->input(7);
145
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),
146
                InvalidArgument("Expected scalar for argument 'm'."));
147
    OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
148
                InvalidArgument("Expected vector for argument 'bias'."));
149
    int m = m_tensor.tensor<int32, 0>().data()[0];
150
    OP_REQUIRES(context, bias.dim_size(0) == m,
151
                InvalidArgument("Num biases size and 'm' must match."));
152
    this->ComputeHelper(context, bias.tensor<float, 1>().data());
153
  }
154
};
155

156
REGISTER_KERNEL_BUILDER(Name("Spmm").Device(tensorflow::DEVICE_CPU),
157
                        SpmmOp<Eigen::ThreadPoolDevice, float>);
158
REGISTER_KERNEL_BUILDER(Name("FusedSpmm").Device(tensorflow::DEVICE_CPU),
159
                        FusedSpmmOp<Eigen::ThreadPoolDevice, float>);
160

161
#ifdef GOOGLE_CUDA
162
REGISTER_KERNEL_BUILDER(
163
    Name("Spmm").Device(tensorflow::DEVICE_GPU).HostMemory("m").HostMemory("k"),
164
    SpmmOp<Eigen::GpuDevice, float>);
165
REGISTER_KERNEL_BUILDER(Name("FusedSpmm")
166
                            .Device(tensorflow::DEVICE_GPU)
167
                            .HostMemory("m")
168
                            .HostMemory("k"),
169
                        FusedSpmmOp<Eigen::GpuDevice, float>);
170
#endif  // GOOGLE_CUDA
171

172
}  // namespace sgk
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.