google-research
73 строки · 2.7 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "sparse/ops/cc/transpose_launcher.h"
16
17#include <limits>
18
19namespace sgk {
20
21void AllocateTransposeWorkspace(
22tensorflow::OpKernelContext *context, const Eigen::ThreadPoolDevice &d,
23int m, int n, int nonzeros, const float *values, const int *row_offsets,
24const int *column_indices, float *output_values, int *output_row_offsets,
25int *output_column_indices, tensorflow::Tensor *workspace) {
26// To transpose the matrix, we blow up the tensor into it's
27// dense, transposed representation and compress it back down.
28tensorflow::TensorShape shape = {m * n};
29OP_REQUIRES_OK(
30context, context->allocate_temp(tensorflow::DT_FLOAT, shape, workspace));
31}
32
33void LaunchTranspose(const Eigen::ThreadPoolDevice &d, int m, int n,
34int nonzeros, const float *values, const int *row_offsets,
35const int *column_indices, float *output_values,
36int *output_row_offsets, int *output_column_indices,
37float *workspace) {
38// Expand the tensor into it's tranposed dense representation.
39//
40// NOTE: We set the invalid values in the tensor to infinity. This
41// This avoids issues with the case where we have zero valued weights
42// in the sparse matrix.
43for (int i = 0; i < m * n; ++i) {
44workspace[i] = std::numeric_limits<float>::infinity();
45}
46for (int i = 0; i < m; ++i) {
47for (int l = row_offsets[i]; l < row_offsets[i + 1]; ++l) {
48int j = column_indices[l];
49workspace[j * m + i] = values[l];
50}
51}
52
53// Compress the matrix back down to it's sparse representation. Note
54// that the matrix is transposed, so 'n' is the number of rows and
55// 'm' is the number of columns.
56int offset = 0;
57output_row_offsets[0] = 0;
58for (int i = 0; i < n; ++i) { // loop over rows.
59for (int j = 0; j < m; ++j) { // loop over columns.
60int idx = i * m + j;
61if (workspace[idx] == std::numeric_limits<float>::infinity()) {
62continue;
63}
64DCHECK_LT(offset, nonzeros);
65output_values[offset] = workspace[idx];
66output_column_indices[offset] = j;
67++offset;
68}
69output_row_offsets[i + 1] = offset;
70}
71}
72
73} // namespace sgk
74