google-research

Форк
0
144 строки · 6.4 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/common.h"
16
#include "sparse/ops/cc/sddmm_launcher.h"
17
#include "tensorflow/core/framework/op_kernel.h"
18

19
namespace sgk {
20

21
using ::tensorflow::Tensor;
22
using ::tensorflow::TensorShapeUtils;
23
using ::tensorflow::errors::InvalidArgument;
24

25
template <typename Device, typename T>
26
class SddmmOp : public tensorflow::OpKernel {
27
 public:
28
  explicit SddmmOp(tensorflow::OpKernelConstruction* context)
29
      : OpKernel(context) {
30
    OP_REQUIRES_OK(context, context->GetAttr("transpose_lhs", &transpose_lhs_));
31
    OP_REQUIRES_OK(context, context->GetAttr("transpose_rhs", &transpose_rhs_));
32

33
    // NOTE: We currently do not support transposition for either argument.
34
    OP_REQUIRES(context, !transpose_lhs_,
35
                InvalidArgument("transpose_lhs=True not yet supported."));
36
    OP_REQUIRES(context, transpose_rhs_,
37
                InvalidArgument("transpose_rhs=False not yet supported."));
38
  }
39

40
  void Compute(tensorflow::OpKernelContext* context) override {
41
    // Collect the input & output tensors.
42
    const Tensor& m_tensor = context->input(0);
43
    const Tensor& n_tensor = context->input(1);
44
    const Tensor& row_indices = context->input(2);
45
    const Tensor& row_offsets = context->input(3);
46
    const Tensor& column_indices = context->input(4);
47
    const Tensor& lhs_matrix = context->input(5);
48
    const Tensor& rhs_matrix = context->input(6);
49

50
    // Validate the input shapes.
51
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),
52
                InvalidArgument("Expected scalar for argument 'm'."));
53
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_tensor.shape()),
54
                InvalidArgument("Expected scalar for argument 'n'."));
55
    OP_REQUIRES(context, TensorShapeUtils::IsVector(row_indices.shape()),
56
                InvalidArgument("Expected 1-dimension row_indices tensor."));
57
    OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),
58
                InvalidArgument("Expected 1-dimension row_offsets tensor."));
59
    OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),
60
                InvalidArgument("Expected 1-dimension column_indices tensor."));
61
    OP_REQUIRES(context, row_indices.dim_size(0) + 1 == row_offsets.dim_size(0),
62
                InvalidArgument("Expected 1 more row index than offset."));
63
    OP_REQUIRES(context, lhs_matrix.dim_size(1) == rhs_matrix.dim_size(1),
64
                InvalidArgument("Last dim of input matrices must match."));
65
    OP_REQUIRES(context,
66
                TensorShapeUtils::IsMatrix(lhs_matrix.shape()) ||
67
                    lhs_matrix.dims() == 3,
68
                InvalidArgument("Expected 2-dim or 3-dim lhs matrix tensor."));
69
    OP_REQUIRES(context,
70
                TensorShapeUtils::IsMatrix(rhs_matrix.shape()) ||
71
                    rhs_matrix.dims() == 3,
72
                InvalidArgument("Expected 2-dim or 3-dim rhs matrix tensor."));
73

74
    // TODO(tgale): We can lift this constraint to support arbitrary replication
75
    // of rhs/lhs matrix. For example, if lhs is a 3-tensor and rhs is a matrix
76
    // we can compute `lhs.shape[0]` sddmms with each kernel using the same rhs
77
    // matrix.
78
    OP_REQUIRES(context, rhs_matrix.dims() == lhs_matrix.dims(),
79
                InvalidArgument("rhs and lhs must match number of dims."));
80

81
    // Get the problem shape.
82
    int m = m_tensor.tensor<int32, 0>().data()[0];
83
    int n = n_tensor.tensor<int32, 0>().data()[0];
84
    int nonzeros = column_indices.dim_size(0);
85

86
    int dim_offset = lhs_matrix.dims() - 2;
87
    int k = lhs_matrix.dim_size(dim_offset + 1);
88
    int replication = dim_offset == 1 ? lhs_matrix.dim_size(0) : 1;
89

90
    // Validate the sparse matrix shape.
91
    OP_REQUIRES(context, row_indices.dim_size(0) == m,
92
                InvalidArgument("Num row indices and 'm' must match."));
93
    OP_REQUIRES(context, lhs_matrix.dim_size(dim_offset) == m,
94
                InvalidArgument("First dim of lhs must match output rows."));
95
    OP_REQUIRES(context, rhs_matrix.dim_size(dim_offset) == n,
96
                InvalidArgument("First dim of lhs must match output cols."));
97

98
    // If we're going to run multiple sddmms, the first dimension of the
99
    // matrices must match.
100
    OP_REQUIRES(context,
101
                replication == 1 || replication == rhs_matrix.dim_size(0),
102
                InvalidArgument("First dim of lhs & rhs must match"));
103

104
    // Allocate the output tensor.
105
    Tensor* output_values = nullptr;
106
    tensorflow::TensorShape output_shape = {nonzeros};
107
    if (replication > 1) {
108
      output_shape = {replication, nonzeros};
109
    }
110
    OP_REQUIRES_OK(context,
111
                   context->allocate_output(0, output_shape, &output_values));
112

113
    // Launch the kernel.
114
    //
115
    // TODO(tgale): This could be accelerated by supported replicated/batched
116
    // execution in the kernel. Running the kernel is a loop like this could
117
    // incur significant overhead from kernel launch latency if the computation
118
    // is cheap.
119
    for (int idx = 0; idx < replication; ++idx) {
120
      LaunchSddmm(context->eigen_device<Device>(), m, k, n, nonzeros,
121
                  AsInt32<1>(row_indices), AsInt32<1>(row_offsets),
122
                  AsInt32<1>(column_indices),
123
                  lhs_matrix.flat<float>().data() + m * k * idx,
124
                  rhs_matrix.flat<float>().data() + k * n * idx,
125
                  output_values->flat<float>().data() + nonzeros * idx);
126
    }
127
  }
128

129
 private:
130
  bool transpose_lhs_, transpose_rhs_;
131
};
132

133
REGISTER_KERNEL_BUILDER(Name("Sddmm").Device(tensorflow::DEVICE_CPU),
134
                        SddmmOp<Eigen::ThreadPoolDevice, float>);
135

136
#ifdef GOOGLE_CUDA
137
REGISTER_KERNEL_BUILDER(Name("Sddmm")
138
                            .Device(tensorflow::DEVICE_GPU)
139
                            .HostMemory("m")
140
                            .HostMemory("n"),
141
                        SddmmOp<Eigen::GpuDevice, float>);
142
#endif  // GOOGLE_CUDA
143

144
}  // namespace sgk
145

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.