google-research

Форк
0
/
csr2idx_op.cc 
81 строка · 3.2 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/common.h"
16
#include "sparse/ops/cc/csr2idx_launcher.h"
17
#include "tensorflow/core/framework/op_kernel.h"
18

19
namespace sgk {
20

21
using ::tensorflow::Tensor;
22
using ::tensorflow::TensorShape;
23
using ::tensorflow::TensorShapeUtils;
24
using ::tensorflow::errors::InvalidArgument;
25

26
template <typename Device, typename T>
27
class Csr2idxOp : public tensorflow::OpKernel {
28
 public:
29
  explicit Csr2idxOp(tensorflow::OpKernelConstruction* context)
30
      : OpKernel(context) {}
31

32
  void Compute(tensorflow::OpKernelContext* context) override {
33
    // Collect the input & output tensors.
34
    const Tensor& m_tensor = context->input(0);
35
    const Tensor& n_tensor = context->input(1);
36
    const Tensor& row_offsets = context->input(2);
37
    const Tensor& column_indices = context->input(3);
38

39
    // Validate the input shapes.
40
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(m_tensor.shape()),
41
                InvalidArgument("Expected scalar for argument 'm'."));
42
    OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_tensor.shape()),
43
                InvalidArgument("Expected scalar for argument 'n'."));
44
    OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),
45
                InvalidArgument("Expected 1-dimension row_offsets tensor."));
46
    OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),
47
                InvalidArgument("Expected 1-dimension column_indices tensor."));
48

49
    // Get the problem shape.
50
    int m = m_tensor.tensor<int32, 0>().data()[0];
51
    int n = n_tensor.tensor<int32, 0>().data()[0];
52
    int nonzeros = column_indices.dim_size(0);
53

54
    // Validate row offsets size.
55
    OP_REQUIRES(context, row_offsets.dim_size(0) == m + 1,
56
                InvalidArgument("Expected m+1 row offsets."));
57

58
    // Allocate the output tensor.
59
    Tensor* linear_indices = nullptr;
60
    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{nonzeros},
61
                                                     &linear_indices));
62

63
    // Launch the kernel.
64
    LaunchCsr2idx(context->eigen_device<Device>(), m, n, nonzeros,
65
                  AsInt32<1>(row_offsets), AsInt32<1>(column_indices),
66
                  AsInt32<1>(linear_indices));
67
  }
68
};
69

70
REGISTER_KERNEL_BUILDER(Name("Csr2idx").Device(tensorflow::DEVICE_CPU),
71
                        Csr2idxOp<Eigen::ThreadPoolDevice, float>);
72

73
#ifdef GOOGLE_CUDA
74
REGISTER_KERNEL_BUILDER(Name("Csr2idx")
75
                            .Device(tensorflow::DEVICE_GPU)
76
                            .HostMemory("m")
77
                            .HostMemory("n"),
78
                        Csr2idxOp<Eigen::GpuDevice, float>);
79
#endif  // GOOGLE_CUDA
80

81
}  // namespace sgk
82

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.