google-research

Форк
0
/
sparse_ops.cc 
311 строк · 11.2 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/common.h"
16
#include "tensorflow/core/framework/common_shape_fns.h"
17
#include "tensorflow/core/framework/op.h"
18
#include "tensorflow/core/framework/shape_inference.h"
19
#include "tensorflow/core/util/padding.h"
20

21
namespace sgk {
22

23
tensorflow::Status SpmmShapeFn(
24
    tensorflow::shape_inference::InferenceContext* c) {
25
  using tensorflow::shape_inference::ShapeHandle;
26
  ShapeHandle lhs_rows;
27
  TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &lhs_rows));
28

29
  ShapeHandle rhs_shape, output_shape;
30
  if (c->Rank(c->input(6)) == 3) {
31
    rhs_shape = c->input(6);
32
    output_shape = c->MakeShape(
33
        {c->Dim(rhs_shape, 0), c->Dim(lhs_rows, 0), c->Dim(rhs_shape, 2)});
34
  } else {
35
    TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 2, &rhs_shape));
36
    output_shape = c->MakeShape({c->Dim(lhs_rows, 0), c->Dim(rhs_shape, 1)});
37
  }
38

39
  c->set_output(0, output_shape);
40
  return tensorflow::Status();
41
}
42

43
REGISTER_OP("Spmm")
44
    .Attr("transpose_lhs: bool = false")
45
    .Attr("transpose_rhs: bool = false")
46
    .Input("m: int32")
47
    .Input("k: int32")
48
    .Input("values: float")
49
    .Input("row_indices: uint32")
50
    .Input("row_offsets: uint32")
51
    .Input("column_indices: uint32")
52
    .Input("dense_matrix: float")
53
    .Output("output_matrix: float")
54
    .SetShapeFn(SpmmShapeFn)
55
    .Doc(R"doc(
56
Compute the product of a sparse matrix and a dense matrix to produce a
57
dense output matrix. The sparse matrix is stored in compressed sparse
58
row format.
59

60
m: [1], the number of rows in the input sparse matrix.
61
k: [1], the number of columns in the input sparse matrix.
62
values: [nonzeros], the nonzero values of the sparse matrix.
63
row_indices: [m], row indices from 0-{m-1} optionally reordered for load
64
    balancing.
65
row_offsets: [m+1], offsets for the rows of the sparse matrix.
66
column_indices: [nonzeros], column indices for each nonzero in the sparse
67
    matrix.
68
dense_matrix: [k, n], dense matrix to multiply the sparse matrix by.
69
output_matrix: [m, n], output dense matrix to store the result.
70
)doc");
71

72
REGISTER_OP("FusedSpmm")
73
    .Attr("transpose_lhs: bool = false")
74
    .Attr("transpose_rhs: bool = false")
75
    .Input("m: int32")
76
    .Input("k: int32")
77
    .Input("values: float")
78
    .Input("row_indices: uint32")
79
    .Input("row_offsets: uint32")
80
    .Input("column_indices: uint32")
81
    .Input("dense_matrix: float")
82
    .Input("bias: float")
83
    .Output("output_matrix: float")
84
    .SetShapeFn(SpmmShapeFn);
85

86
REGISTER_OP("Sddmm")
87
    .Attr("transpose_lhs: bool = false")
88
    .Attr("transpose_rhs: bool = false")
89
    .Input("m: int32")
90
    .Input("n: int32")
91
    .Input("row_indices: uint32")
92
    .Input("row_offsets: uint32")
93
    .Input("column_indices: uint32")
94
    .Input("lhs_matrix: float")
95
    .Input("rhs_matrix: float")
96
    .Output("output_values: float")
97
    .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
98
      using tensorflow::shape_inference::ShapeHandle;
99
      ShapeHandle nonzeros;
100
      TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &nonzeros));
101

102
      ShapeHandle output_shape = nonzeros;
103
      if (c->Rank(c->input(5)) == 3) {
104
        ShapeHandle lhs_shape = c->input(5);
105
        output_shape =
106
            c->MakeShape({c->Dim(lhs_shape, 0), c->Dim(nonzeros, 0)});
107
      }
108
      c->set_output(0, output_shape);
109
      return tensorflow::Status();
110
    })
111
    .Doc(R"doc(
112
Computes the product of two dense matrices where a subset of the outputs
113
are requested. Which outputs are to be computed are specified by a sparse
114
matrix stored in compressed sparse row format.
115

116
Currently only supports having the right-hand matrix transposed.
117

118
m: [1], the number of rows in the input sparse matrix.
119
n: [1], the number of columns in the input sparse matrix.
120
row_indices: [m], row indices from 0-{m-1} optionally reordered for load
121
    balancing.
122
row_offsets: [m+1], offsets for the rows of the sparse matrix.
123
column_indices: [nonzeros], column indices for each nonzero in the sparse
124
    matrix.
125
lhs_matrix: [m, k], left-hand, dense matrix operand to the matrix product.
126
rhs_matrix: [n, k], right-hand, dense matrix operand to the matrix product.
127
output_values: [nonzeros], the nonzero values of the sparse matrix.
128
)doc");
129

130
// NOTE: We can't tell how many columns are in a compressed sparse row matrix
131
// from the data structures alone. The necessary information is in the host
132
// tensors `m` and `n`, but we can't access this during shape inference.
133
REGISTER_OP("CsrTranspose")
134
    .Input("m: int32")
135
    .Input("n: int32")
136
    .Input("values: float")
137
    .Input("row_offsets: uint32")
138
    .Input("column_indices: uint32")
139
    .Output("output_values : float")
140
    .Output("output_row_offsets : uint32")
141
    .Output("output_column_indices : uint32")
142
    .Doc(R"doc(
143
Transposes a compressed sparse row matrix.
144

145
m: [1], the number of rows in the input sparse matrix.
146
n: [1], the number of columns in the input sparse matrix.
147
values: [nonzeros], the nonzero values of the input sparse matrix.
148
row_offsets: [m+1], offsets for the rows of the input sparse matrix.
149
column_indices: [nonzeros], column indices for each nonzero in the
150
    input sparse matrix.
151
output_values: [nonzeros], the nonzero values of the output sparse matrix.
152
output_row_offsets: [m+1], offsets for the rows of the output sparse matrix.
153
output_column_indices: [nonzeros], column indices for each nonzero in the
154
    output sparse matrix.
155
)doc");
156

157
REGISTER_OP("Csr2idx")
158
    .Input("m: int32")
159
    .Input("n: int32")
160
    .Input("row_offsets: uint32")
161
    .Input("column_indices: uint32")
162
    .Output("linear_indices : uint32")
163
    .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
164
      using tensorflow::shape_inference::ShapeHandle;
165
      ShapeHandle nonzeros;
166
      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &nonzeros));
167
      c->set_output(0, nonzeros);
168
      return tensorflow::Status();
169
    })
170
    .Doc(R"doc(
171
Converts a compressed sparse row matrix to linear format.
172

173
Converts `column_index[i]` to `column_index[i] * row_index * n`, where
174
`row_index` is the row that this column index belongs to. We call this
175
"index format" or "1-dimensional coordinate format".
176

177
m: [1], the number of rows in the input sparse matrix.
178
n: [1], the number of columns in the input sparse matrix.
179
row_offsets: [m+1], offsets for the rows of the sparse matrix.
180
column_indices: [nonzeros], column indices for each nonzero in the
181
    sparse matrix.
182
linear_indices: [nonzeros], the linear indices for the sparse matrix.
183
)doc");
184

185
tensorflow::Status DepthwiseShapeFn(
186
    tensorflow::shape_inference::InferenceContext* c) {
187
  using tensorflow::shape_inference::DimensionHandle;
188
  using tensorflow::shape_inference::ShapeHandle;
189
  ShapeHandle input_shape;
190
  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
191
  ShapeHandle filter_shape;
192
  TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
193

194
  std::vector<int32> strides;
195
  TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
196

197
  if (strides.size() != 4) {
198
    return tensorflow::errors::InvalidArgument(
199
        "DepthwiseConv2D requires the stride attribute to contain 4 "
200
        "values, "
201
        "but got: ",
202
        strides.size());
203
  }
204

205
  // Only supports NCHW.
206
  std::string data_format = "NCHW";
207
  int32 stride_rows = strides[2];
208
  int32 stride_cols = strides[3];
209

210
  DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
211
  DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
212
  DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
213

214
  DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
215
  DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
216
  DimensionHandle input_depth = c->Dim(filter_shape, 0);
217
  DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
218

219
  // Check that the input depths are compatible.
220
  TF_RETURN_IF_ERROR(
221
      c->Merge(c->Dim(input_shape, 1), input_depth, &input_depth));
222

223
  DimensionHandle output_depth;
224
  TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
225

226
  tensorflow::Padding padding_type = tensorflow::EXPLICIT;
227
  std::vector<int64> padding;
228
  TF_RETURN_IF_ERROR(c->GetAttr("explicit_paddings", &padding));
229

230
  if (padding.size() != 4) {
231
    return tensorflow::errors::InvalidArgument(
232
        "DepthwiseConv2D requires the padding attribute to contain 4 "
233
        "values, "
234
        "but got: ",
235
        padding.size());
236
  }
237
  int64 pad_rows = padding[2];
238
  int64 pad_cols = padding[3];
239

240
  DimensionHandle output_rows, output_cols;
241

242
  TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
243
      c, in_rows_dim, filter_rows_dim, /* dilation_rate = */ 1, stride_rows,
244
      padding_type, pad_rows, pad_rows, &output_rows));
245
  TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
246
      c, in_cols_dim, filter_cols_dim, /* dilation_rate = */ 1, stride_cols,
247
      padding_type, pad_cols, pad_cols, &output_cols));
248

249
  ShapeHandle output_shape =
250
      c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
251
  c->set_output(0, output_shape);
252
  return tensorflow::Status();
253
}
254

255
REGISTER_OP("DepthwiseConv")
256
    .Input("input: T")
257
    .Input("filter: T")
258
    .Output("output: T")
259
    .Attr("T: {float}")
260
    .Attr("strides: list(int)")
261
    .Attr(tensorflow::GetExplicitPaddingsAttrString())
262
    .SetShapeFn(DepthwiseShapeFn);
263

264
REGISTER_OP("FusedDepthwiseConv")
265
    .Input("input: T")
266
    .Input("filter: T")
267
    .Input("bias: T")
268
    .Output("output: T")
269
    .Attr("T: {float}")
270
    .Attr("strides: list(int)")
271
    .Attr(tensorflow::GetExplicitPaddingsAttrString())
272
    .SetShapeFn(DepthwiseShapeFn);
273

274
REGISTER_OP("BiasRelu")
275
    .Input("in: float")
276
    .Input("bias: float")
277
    .Output("out: float")
278
    .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
279
      using tensorflow::shape_inference::ShapeHandle;
280
      ShapeHandle input_shape;
281
      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
282
      c->set_output(0, input_shape);
283
      return tensorflow::Status();
284
    });
285

286
REGISTER_OP("CsrSoftmax")
287
    .Input("input_values : float")
288
    .Input("row_indices: uint32")
289
    .Input("row_offsets: uint32")
290
    .Input("column_indices: uint32")
291
    .Output("output_values : float")
292
    .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
293
      using tensorflow::shape_inference::ShapeHandle;
294
      ShapeHandle values_shape;
295
      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &values_shape));
296
      c->set_output(0, values_shape);
297
      return tensorflow::Status();
298
    });
299

300
REGISTER_OP("FusedSoftmax")
301
    .Input("input : float")
302
    .Output("output : float")
303
    .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
304
      using tensorflow::shape_inference::ShapeHandle;
305
      ShapeHandle input_shape;
306
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input_shape));
307
      c->set_output(0, input_shape);
308
      return tensorflow::Status();
309
    });
310

311
}  // namespace sgk
312

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.