google-research

Форк
0
/
transpose_op.cc 
104 строки · 4.5 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/common.h"
16
#include "sparse/ops/cc/transpose_launcher.h"
17
#include "tensorflow/core/framework/op_kernel.h"
18

19
namespace sgk {
20

21
using ::tensorflow::Tensor;
22
using ::tensorflow::TensorShape;
23
using ::tensorflow::TensorShapeUtils;
24
using ::tensorflow::errors::InvalidArgument;
25

26
template <typename Device, typename T>
27
class CsrTransposeOp : public tensorflow::OpKernel {
28
 public:
29
  explicit CsrTransposeOp(tensorflow::OpKernelConstruction* context)
30
      : OpKernel(context) {}
31

32
  void Compute(tensorflow::OpKernelContext* context) override {
33
    // Collect the input & output tensors.
34
    const Tensor& m_tensor = context->input(0);
35
    const Tensor& n_tensor = context->input(1);
36
    const Tensor& values = context->input(2);
37
    const Tensor& row_offsets = context->input(3);
38
    const Tensor& column_indices = context->input(4);
39

40
    // Validate the input shapes.
41
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),
42
                InvalidArgument("Expected scalar for argument 'm'."));
43
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_tensor.shape()),
44
                InvalidArgument("Expected scalar for argument 'n'."));
45
    OP_REQUIRES(context, TensorShapeUtils::IsVector(values.shape()),
46
                InvalidArgument("Expected 1-dimension values tensor."));
47
    OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),
48
                InvalidArgument("Expected 1-dimension row_offsets tensor."));
49
    OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),
50
                InvalidArgument("Expected 1-dimension column_indices tensor."));
51
    OP_REQUIRES(context, values.dim_size(0) == column_indices.dim_size(0),
52
                InvalidArgument("Expected same number of values and indices"));
53

54
    // Get the problem shape.
55
    int m = m_tensor.tensor<int32, 0>().data()[0];
56
    int n = n_tensor.tensor<int32, 0>().data()[0];
57
    int nonzeros = values.dim_size(0);
58

59
    // Validate row offsets size.
60
    OP_REQUIRES(context, row_offsets.dim_size(0) == m + 1,
61
                InvalidArgument("Expected m+1 row offsets."));
62

63
    // Allocate the output tensor.
64
    Tensor* output_values = nullptr;
65
    Tensor* output_row_offsets = nullptr;
66
    Tensor* output_column_indices = nullptr;
67
    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{nonzeros},
68
                                                     &output_values));
69
    OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{n + 1},
70
                                                     &output_row_offsets));
71
    OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape{nonzeros},
72
                                                     &output_column_indices));
73

74
    // (Possibly) get a temporary buffer to work in.
75
    Tensor workspace;
76
    AllocateTransposeWorkspace(
77
        context, context->eigen_device<Device>(), m, n, nonzeros,
78
        values.tensor<float, 1>().data(), AsInt32<1>(row_offsets),
79
        AsInt32<1>(column_indices), output_values->tensor<float, 1>().data(),
80
        AsInt32<1>(output_row_offsets), AsInt32<1>(output_column_indices),
81
        &workspace);
82

83
    // Launch the kernel.
84
    LaunchTranspose(
85
        context->eigen_device<Device>(), m, n, nonzeros,
86
        values.tensor<float, 1>().data(), AsInt32<1>(row_offsets),
87
        AsInt32<1>(column_indices), output_values->tensor<float, 1>().data(),
88
        AsInt32<1>(output_row_offsets), AsInt32<1>(output_column_indices),
89
        workspace.tensor<float, 1>().data());
90
  }
91
};
92

93
REGISTER_KERNEL_BUILDER(Name("CsrTranspose").Device(tensorflow::DEVICE_CPU),
94
                        CsrTransposeOp<Eigen::ThreadPoolDevice, float>);
95

96
#ifdef GOOGLE_CUDA
97
REGISTER_KERNEL_BUILDER(Name("CsrTranspose")
98
                            .Device(tensorflow::DEVICE_GPU)
99
                            .HostMemory("m")
100
                            .HostMemory("n"),
101
                        CsrTransposeOp<Eigen::GpuDevice, float>);
102
#endif  // GOOGLE_CUDA
103

104
}  // namespace sgk
105

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.