google-research
79 строк · 2.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Test for sparse matrix class and utilities."""
17from absl.testing import parameterized18import numpy as np19import scipy20import tensorflow.compat.v1 as tf21
22from sgk.sparse import connectors23from sgk.sparse import initializers24from sgk.sparse import sparse_matrix25
26
27@parameterized.parameters((4, 4, 0.0), (64, 128, 0.8), (512, 512, 0.64),28(273, 519, 0.71))29class SparseMatrixTest(tf.test.TestCase, parameterized.TestCase):30
31def testCreateMatrix(self, m, n, sparsity):32matrix = sparse_matrix.SparseMatrix(33"matrix", [m, n], connector=connectors.Uniform(sparsity))34
35with self.test_session() as sess:36sess.run(tf.global_variables_initializer())37values, row_indices, row_offsets, column_indices = sess.run([38matrix.values, matrix.row_indices, matrix.row_offsets,39matrix.column_indices40])41
42# Check the shape of the matrix.43self.assertLen(values.shape, 1)44self.assertLen(row_indices.shape, 1)45self.assertLen(row_offsets.shape, 1)46self.assertLen(column_indices.shape, 1)47
48# Check the sparsity matches the target.49target_nonzeros = m * n - int(round(sparsity * m * n))50self.assertEqual(values.shape[0], target_nonzeros)51
52def testDenseToSparse(self, m, n, sparsity):53# Helpers to set up the matrices.54connector = connectors.Uniform(sparsity)55initializer = initializers.Uniform()56
57# Create a dense matrix in numpy with the specified sparsity.58matrix = connector(initializer([m, n]))59
60# Convert to a sparse numpy matrix.61values, row_indices, row_offsets, column_indices = sparse_matrix._dense_to_sparse(62matrix)63
64# Create a scipy version of the matrix.65expected_output = scipy.sparse.csr_matrix(66(values, column_indices, row_offsets), [m, n])67
68# Create the expected row indices.69expected_row_indices = np.argsort(-1 * np.diff(expected_output.indptr))70
71# Compare the matrices.72self.assertAllEqual(expected_output.data, values)73self.assertAllEqual(expected_output.indptr, row_offsets)74self.assertAllEqual(expected_output.indices, column_indices)75self.assertAllEqual(expected_row_indices, row_indices)76
77
78if __name__ == "__main__":79tf.test.main()80