google-research

Форк
0
/
sparse_softmax_op.cc 
92 строки · 3.8 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "sparse/ops/cc/common.h"
16
#include "sparse/ops/cc/sparse_softmax_launcher.h"
17
#include "tensorflow/core/framework/op_kernel.h"
18

19
namespace sgk {
20

21
using ::tensorflow::Tensor;
22
using ::tensorflow::TensorShapeUtils;
23
using ::tensorflow::errors::InvalidArgument;
24

25
template <typename Device, typename T>
26
class SparseSoftmaxOp : public tensorflow::OpKernel {
27
 public:
28
  explicit SparseSoftmaxOp(tensorflow::OpKernelConstruction* context)
29
      : OpKernel(context) {}
30

31
  void Compute(tensorflow::OpKernelContext* context) override {
32
    // Collect the input & output tensors.
33
    const Tensor& values = context->input(0);
34
    const Tensor& row_indices = context->input(1);
35
    const Tensor& row_offsets = context->input(2);
36
    const Tensor& column_indices = context->input(3);
37

38
    // Validate the input shapes.
39
    OP_REQUIRES(context, TensorShapeUtils::IsVector(row_indices.shape()),
40
                InvalidArgument("Expected 1-dimension row_indices tensor."));
41
    OP_REQUIRES(context, TensorShapeUtils::IsVector(row_offsets.shape()),
42
                InvalidArgument("Expected 1-dimension row_offsets tensor."));
43
    OP_REQUIRES(context, TensorShapeUtils::IsVector(column_indices.shape()),
44
                InvalidArgument("Expected 1-dimension column_indices tensor."));
45
    OP_REQUIRES(context, row_indices.dim_size(0) + 1 == row_offsets.dim_size(0),
46
                InvalidArgument("Expected 1 more row index than offset."));
47
    OP_REQUIRES(
48
        context,
49
        TensorShapeUtils::IsVector(values.shape()) || values.dims() == 2,
50
        InvalidArgument("Expected 1-dim or 2-dim values tensor."));
51

52
    // Get the problem shape.
53
    //
54
    // NOTE: The kernel doesn't actually need the n argument. Pass garbage,
55
    // since we can't pull it off the sparse matrix representation.
56
    int m = row_indices.dim_size(0);
57
    int n = -1;
58
    int nonzeros = column_indices.dim_size(0);
59
    int dim_offset = values.dims() - 1;
60
    int replication = dim_offset == 1 ? values.dim_size(0) : 1;
61

62
    // Validate the sparse matrix shape.
63
    OP_REQUIRES(context, values.dim_size(dim_offset) == nonzeros,
64
                InvalidArgument("Num values must equal num col indices."));
65

66
    // Allocate the output tensor.
67
    Tensor* output_values = nullptr;
68
    OP_REQUIRES_OK(context,
69
                   context->allocate_output(0, values.shape(), &output_values));
70

71
    // Launch the kernel for each step of computation.
72
    //
73
    // TODO(tgale): This could be accelerated by supported replicated/batched
74
    // execution in the kernel. Running the kernel is a loop like this could
75
    // incur significant overhead from kernel launch latency if the computation
76
    // is cheap.
77
    for (int idx = 0; idx < replication; ++idx) {
78
      LaunchSparseSoftmax(context->eigen_device<Device>(), m, n, nonzeros,
79
                          values.flat<float>().data() + nonzeros * idx,
80
                          AsInt32<1>(row_indices), AsInt32<1>(row_offsets),
81
                          AsInt32<1>(column_indices),
82
                          output_values->flat<float>().data() + nonzeros * idx);
83
    }
84
  }
85
};
86

87
#ifdef GOOGLE_CUDA
88
REGISTER_KERNEL_BUILDER(Name("CsrSoftmax").Device(tensorflow::DEVICE_GPU),
89
                        SparseSoftmaxOp<Eigen::GpuDevice, float>);
90
#endif  // GOOGLE_CUDA
91

92
}  // namespace sgk
93

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.