google-research

Форк
0
/
subgraph_extractors_test.py 
178 строк · 6.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for subgraph_extractors."""
17

18
import functools
19

20
from absl.testing import absltest
21

22
import jax
23
from jax.experimental import sparse as jsparse
24
import jax.numpy as jnp
25

26
import tree_math as tm
27

28
from jaxsel import agents
29
from jaxsel import subgraph_extractors
30
from jaxsel.tests import base_graph_test
31

32

33
def check_sparse_against_dense(sparse, dense):
34
  assert jnp.allclose(sparse.todense(), dense)
35

36

37
def diff_norm(tree, other_tree):
38
  tree_diff = jax.tree_map(lambda x, y: x - y, tree, other_tree)
39
  return jax.tree_map(jnp.linalg.norm, tree_diff).sum()
40

41

42
# TODO(gnegiar): add test comparing sparse implementation to
43
# dense implementation on a small problem with known solution.
44
# TODO(gnegiar): write test case where max_subgraph_size is too small
45
class ISTASubgraphExtractorsTest(base_graph_test.BaseGraphTest):
46

47
  def test_abs_top_k(self):
48
    u = jsparse.BCOO.fromdense(jnp.array([0., 0., 1., 10., -5., 2.]), nse=4)
49

50
    k = 3
51
    nse = 5
52

53
    topk_u = subgraph_extractors._abs_top_k(u, k, nse)
54
    expected = jnp.array([0., 0., 0., 10., -5., 2.])
55
    check_sparse_against_dense(topk_u, expected)
56

57
    assert topk_u.nse == nse
58

59
  def test_dense_submatrix_extraction(self):
60

61
    # Dense matrix, used to build the sparse matrix
62
    mat = jnp.array([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
63
                     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
64
                     [0.0, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0],
65
                     [0.0, 0.0, 0.0, 0.0, 0.4, 0.0, 0.9, 0.0, 0.0, 0.0],
66
                     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
67
                     [0.0, 0.9, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0],
68
                     [0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4, 0.0],
69
                     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
70
                     [0.9, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
71
                     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
72
    n = 10
73
    assert n == mat.shape[0]
74
    k = 6  # Leave some wiggle room: nse > |true non zeros elements|.
75
    # Sparse matrix to extract from.
76
    mat_sp = jsparse.BCOO.fromdense(mat, nse=k**2)
77

78
    # Sparse vector, used to index the submatrix to extract.
79
    v = jnp.array([0.0, 0.0, 0.9, 0.7, 0.0, 0.9, 0.4, 0.0, 0.7, 0.0])
80
    v_sp = jsparse.BCOO.fromdense(v, nse=k)
81
    indices = v_sp.indices.flatten()
82
    assert jnp.allclose(indices, jnp.array([2, 3, 5, 6, 8, 10]))
83

84
    in_bounds_indices = indices[indices < n]
85
    assert jnp.allclose(in_bounds_indices, jnp.array([2, 3, 5, 6, 8]))
86

87
    # Extract submatrix.
88
    submat = self.extractor._extract_dense_submatrix(mat_sp, indices)
89
    # Check the extracted matrix's shape.
90
    assert submat.shape == (len(indices), len(indices))
91

92
    # Check that the extracted values are correct.
93
    expected = jnp.array([
94
        [0.0, 0.0, 0.0, 0.0, 0.0],
95
        [0.0, 0.0, 0.0, 0.9, 0.0],
96
        [0.0, 0.0, 0.4, 0.0, 0.0],
97
        [0.0, 0.0, 0.0, 0.0, 0.4],
98
        [0.4, 0.0, 0.0, 0.0, 0.0],
99
    ])
100

101
    assert jnp.allclose(
102
        submat[:len(in_bounds_indices)][:, :len(in_bounds_indices)], expected)
103

104
  def test_qstar(self):
105
    """Tests ability to extract a subgraph, with jit."""
106
    rng_extractor, self.rng = jax.random.split(self.rng)
107
    params = self.extractor.init(rng_extractor, self.start_node_id, self.graph)
108
    qstar, _, _, dense_submat, _, _, error = self.extractor.apply(
109
        params, self.start_node_id, self.graph)
110

111
    # Tests output shape
112
    assert qstar.shape == (self.extractor.config.max_subgraph_size,)
113

114
    s = self.extractor._s(self.start_node_id)
115
    # Tests convergence of the sparse PageRank
116
    assert error < 1e-4, subgraph_extractors._dense_fixed_point(
117
        qstar, dense_submat, s, self.extractor.config.alpha,
118
        self.extractor.config.rho)
119

120
  def test_backprop(self):
121
    """Tests ability to backpropagate through the Subgraph Selection layer.
122

123
    Verifies numerical value of the jax backprop vs finite differences.
124
    """
125

126
    rng_extractor, self.rng = jax.random.split(self.rng)
127
    params = self.extractor.init(rng_extractor, self.start_node_id, self.graph)
128

129
    def get_qstar_sum(params):
130
      return self.extractor.apply(params, self.start_node_id,
131
                                  self.graph)[0].sum()
132

133
    eps = 1e-5  # magnitude of the finite difference
134

135
    grad = jax.grad(get_qstar_sum)(params)
136
    delta_rng, self.rng = jax.random.split(self.rng)
137
    # Take a random direction
138
    delta_params = tm.Vector(
139
        self.extractor.init(delta_rng, self.start_node_id, self.graph))
140
    # Normalize
141
    delta_params = jax.tree_map(lambda x: x / max(1e-9, jnp.linalg.norm(x)),
142
                                delta_params).tree
143

144
    # Directional derivative given by jax
145
    deriv_jax = tm.Vector(jax.tree_map(jnp.vdot, delta_params, grad)).sum()
146

147
    # Directional derivative given by finite differences
148
    agent_plus_eps = (tm.Vector(params) + eps * tm.Vector(delta_params)).tree
149
    agent_minus_eps = (tm.Vector(params) - eps * tm.Vector(delta_params)).tree
150
    deriv_diff = ((tm.Vector(get_qstar_sum(agent_plus_eps)) -
151
                   tm.Vector(get_qstar_sum(agent_minus_eps))) / 2 * eps).tree
152

153
    err = diff_norm(deriv_jax, deriv_diff)
154
    assert jax.tree_util.tree_all(
155
        jax.tree_map(
156
            functools.partial(jnp.allclose, atol=1e-3), deriv_jax,
157
            deriv_diff)), f"Difference between FDM and autograd is {err}"
158

159
  def test_convert_to_bcoo_indices(self):
160
    node_id = 0
161
    n_neighbors = 5
162
    neighbor_node_ids = jnp.arange(n_neighbors)
163
    indices = agents._make_adjacency_mat_row_indices(node_id, neighbor_node_ids)
164
    assert (indices[0] == jnp.array([0, 0])).all()
165

166
  def test_make_dense_vector(self):
167
    q_dense = jnp.zeros(10).at[6].set(1.)
168
    q_sparse = jsparse.BCOO.fromdense(q_dense, nse=4)
169
    extracted_q = self.extractor._make_dense_vector(q_sparse, q_sparse.indices)
170
    assert (extracted_q == jnp.array([1., 0., 0., 0.])).all()
171

172
    full_q = self.extractor._make_dense_vector(q_sparse,
173
                                               jnp.arange(q_dense.shape[0]))
174
    assert (full_q == q_dense).all()
175

176

177
if __name__ == "__main__":
178
  absltest.main()
179

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.