google-research
178 строк · 6.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for subgraph_extractors."""
17
18import functools
19
20from absl.testing import absltest
21
22import jax
23from jax.experimental import sparse as jsparse
24import jax.numpy as jnp
25
26import tree_math as tm
27
28from jaxsel import agents
29from jaxsel import subgraph_extractors
30from jaxsel.tests import base_graph_test
31
32
33def check_sparse_against_dense(sparse, dense):
34assert jnp.allclose(sparse.todense(), dense)
35
36
37def diff_norm(tree, other_tree):
38tree_diff = jax.tree_map(lambda x, y: x - y, tree, other_tree)
39return jax.tree_map(jnp.linalg.norm, tree_diff).sum()
40
41
42# TODO(gnegiar): add test comparing sparse implementation to
43# dense implementation on a small problem with known solution.
44# TODO(gnegiar): write test case where max_subgraph_size is too small
45class ISTASubgraphExtractorsTest(base_graph_test.BaseGraphTest):
46
47def test_abs_top_k(self):
48u = jsparse.BCOO.fromdense(jnp.array([0., 0., 1., 10., -5., 2.]), nse=4)
49
50k = 3
51nse = 5
52
53topk_u = subgraph_extractors._abs_top_k(u, k, nse)
54expected = jnp.array([0., 0., 0., 10., -5., 2.])
55check_sparse_against_dense(topk_u, expected)
56
57assert topk_u.nse == nse
58
59def test_dense_submatrix_extraction(self):
60
61# Dense matrix, used to build the sparse matrix
62mat = jnp.array([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
63[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
64[0.0, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0],
65[0.0, 0.0, 0.0, 0.0, 0.4, 0.0, 0.9, 0.0, 0.0, 0.0],
66[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
67[0.0, 0.9, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0],
68[0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4, 0.0],
69[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
70[0.9, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
71[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
72n = 10
73assert n == mat.shape[0]
74k = 6 # Leave some wiggle room: nse > |true non zeros elements|.
75# Sparse matrix to extract from.
76mat_sp = jsparse.BCOO.fromdense(mat, nse=k**2)
77
78# Sparse vector, used to index the submatrix to extract.
79v = jnp.array([0.0, 0.0, 0.9, 0.7, 0.0, 0.9, 0.4, 0.0, 0.7, 0.0])
80v_sp = jsparse.BCOO.fromdense(v, nse=k)
81indices = v_sp.indices.flatten()
82assert jnp.allclose(indices, jnp.array([2, 3, 5, 6, 8, 10]))
83
84in_bounds_indices = indices[indices < n]
85assert jnp.allclose(in_bounds_indices, jnp.array([2, 3, 5, 6, 8]))
86
87# Extract submatrix.
88submat = self.extractor._extract_dense_submatrix(mat_sp, indices)
89# Check the extracted matrix's shape.
90assert submat.shape == (len(indices), len(indices))
91
92# Check that the extracted values are correct.
93expected = jnp.array([
94[0.0, 0.0, 0.0, 0.0, 0.0],
95[0.0, 0.0, 0.0, 0.9, 0.0],
96[0.0, 0.0, 0.4, 0.0, 0.0],
97[0.0, 0.0, 0.0, 0.0, 0.4],
98[0.4, 0.0, 0.0, 0.0, 0.0],
99])
100
101assert jnp.allclose(
102submat[:len(in_bounds_indices)][:, :len(in_bounds_indices)], expected)
103
104def test_qstar(self):
105"""Tests ability to extract a subgraph, with jit."""
106rng_extractor, self.rng = jax.random.split(self.rng)
107params = self.extractor.init(rng_extractor, self.start_node_id, self.graph)
108qstar, _, _, dense_submat, _, _, error = self.extractor.apply(
109params, self.start_node_id, self.graph)
110
111# Tests output shape
112assert qstar.shape == (self.extractor.config.max_subgraph_size,)
113
114s = self.extractor._s(self.start_node_id)
115# Tests convergence of the sparse PageRank
116assert error < 1e-4, subgraph_extractors._dense_fixed_point(
117qstar, dense_submat, s, self.extractor.config.alpha,
118self.extractor.config.rho)
119
120def test_backprop(self):
121"""Tests ability to backpropagate through the Subgraph Selection layer.
122
123Verifies numerical value of the jax backprop vs finite differences.
124"""
125
126rng_extractor, self.rng = jax.random.split(self.rng)
127params = self.extractor.init(rng_extractor, self.start_node_id, self.graph)
128
129def get_qstar_sum(params):
130return self.extractor.apply(params, self.start_node_id,
131self.graph)[0].sum()
132
133eps = 1e-5 # magnitude of the finite difference
134
135grad = jax.grad(get_qstar_sum)(params)
136delta_rng, self.rng = jax.random.split(self.rng)
137# Take a random direction
138delta_params = tm.Vector(
139self.extractor.init(delta_rng, self.start_node_id, self.graph))
140# Normalize
141delta_params = jax.tree_map(lambda x: x / max(1e-9, jnp.linalg.norm(x)),
142delta_params).tree
143
144# Directional derivative given by jax
145deriv_jax = tm.Vector(jax.tree_map(jnp.vdot, delta_params, grad)).sum()
146
147# Directional derivative given by finite differences
148agent_plus_eps = (tm.Vector(params) + eps * tm.Vector(delta_params)).tree
149agent_minus_eps = (tm.Vector(params) - eps * tm.Vector(delta_params)).tree
150deriv_diff = ((tm.Vector(get_qstar_sum(agent_plus_eps)) -
151tm.Vector(get_qstar_sum(agent_minus_eps))) / 2 * eps).tree
152
153err = diff_norm(deriv_jax, deriv_diff)
154assert jax.tree_util.tree_all(
155jax.tree_map(
156functools.partial(jnp.allclose, atol=1e-3), deriv_jax,
157deriv_diff)), f"Difference between FDM and autograd is {err}"
158
159def test_convert_to_bcoo_indices(self):
160node_id = 0
161n_neighbors = 5
162neighbor_node_ids = jnp.arange(n_neighbors)
163indices = agents._make_adjacency_mat_row_indices(node_id, neighbor_node_ids)
164assert (indices[0] == jnp.array([0, 0])).all()
165
166def test_make_dense_vector(self):
167q_dense = jnp.zeros(10).at[6].set(1.)
168q_sparse = jsparse.BCOO.fromdense(q_dense, nse=4)
169extracted_q = self.extractor._make_dense_vector(q_sparse, q_sparse.indices)
170assert (extracted_q == jnp.array([1., 0., 0., 0.])).all()
171
172full_q = self.extractor._make_dense_vector(q_sparse,
173jnp.arange(q_dense.shape[0]))
174assert (full_q == q_dense).all()
175
176
177if __name__ == "__main__":
178absltest.main()
179