google-research
418 строк · 13.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Classes for extracting a subgraph of a larger graph_api.GraphAPI.
17
18This module implements a differentiable L1-regularized PageRank.
19It follows the method in Fountoulakis et al., 2016.
20`https://arxiv.org/abs/1602.01886`.
21
22The main problem can be written as a L1 penalized quadratic problem (when the
23graph is undirected), which we solve using proximal algorithms.
24"""
25
26from typing import Optional, Tuple
27
28from flax import struct
29import flax.linen as nn
30
31import jax
32from jax.experimental import sparse as jsparse
33import jax.numpy as jnp
34
35from jaxsel._src import agents
36from jaxsel._src import graph_api
37
38
39@struct.dataclass
40class ExtractorConfig:
41"""Config for subgraph extractors.
42
43Attributes:
44max_graph_size: Maximum size of the handled graphs.
45max_subgraph_size: Maximum size of the extracted subgraph.
46rho: L1 penalty strength
47alpha: probability to teleport back to the start node during the walk.
48num_steps: max number of steps in the ISTA algorithm/in the random walk.
49ridge: scale of ridge penalty for the backwards linear problem.
50agent_config: Configuration for the underlying agent model.
51"""
52# TODO(gnegiar): Take agent class/name as argument here.
53max_graph_size: int
54max_subgraph_size: int
55rho: float
56alpha: float
57num_steps: int
58ridge: float
59agent_config: agents.AgentConfig
60
61
62class SparseISTAExtractor(nn.Module):
63"""Performs sparse PageRank.
64
65PageRank can be seen as a random walk on the graph, where the transition
66probabilities are given by the adjacency matrix on the graph.
67Here, the adjacency matrix is parametrized by an `agent` model. PageRank
68yields a vector of weights on the nodes of the graph.
69
70Reference: https://arxiv.org/abs/1602.01886
71
72When rho is 0., we still perform topk thresholding on the weight vector.
73
74Attributes:
75config: Configuration for the extractor layer and the underlying
76agent model.
77"""
78config: ExtractorConfig
79
80def setup(self):
81self.agent = agents.SimpleFiLMedAgentModel(self.config.agent_config)
82
83def _s(self, start_node_id):
84"""Encode the start node."""
85nse = self.config.max_subgraph_size
86return jsparse.BCOO(
87(jnp.zeros(nse).at[0].set(1.), jnp.zeros(
88(nse, 1), dtype=int).at[0, 0].set(start_node_id)),
89shape=(self.config.max_graph_size,))
90
91def _q_minus_grad(self, q, adjacency_matrix,
92s):
93return _sum_with_nse(
94(1 - self.config.alpha) * adjacency_matrix.T @ q,
95-self.config.alpha * s,
96nse=self.config.max_subgraph_size)
97
98def _sparse_softthresh(self, x):
99return jsparse.BCOO(
100(_softthresh(x.data, self.config.alpha, self.config.rho), x.indices),
101shape=x.shape)
102
103def _ista_step(self, q, adjacency_matrix,
104s):
105q_minus_grad = self._q_minus_grad(q, adjacency_matrix, s)
106return self._sparse_softthresh(q_minus_grad)
107
108def _error(
109self, q, dense_adjacency_matrix, s
110):
111return (_dense_fixed_point(q, dense_adjacency_matrix, s, self.config.alpha,
112self.config.rho)**2).sum()
113
114def _extract_dense_submatrix(self, sp_mat,
115indices):
116"""Extracts a dense submatrix of a sparse square matrix at given indices.
117
118Args:
119sp_mat: A sparse matrix.
120indices: A 1D array of indices into M_sp, or `M_sp.shape[0]` to indicate
121an empty row/column. Assumed to be deduplicated using
122`BCOO.sum_duplicates` or similar beforehand.
123
124Returns:
125dense_submat:
126A dense submatrix of M_sp, with a subset of rows and columns from M_sp
127or with zero rows and columns in place of -1 indices.
128"""
129if sp_mat.ndim != 2:
130raise ValueError(
131f"The first argument must be a 2d matrix. Got {sp_mat.ndim}.")
132if sp_mat.shape[0] != sp_mat.shape[1]:
133raise ValueError(f"sp_mat should be square. Got shape {sp_mat.shape}.")
134if indices.ndim != 1:
135raise ValueError(
136f"indices should be a 1-d array. Got shape {indices.shape}.")
137
138n_indices = indices.shape[0]
139if n_indices > self.config.max_subgraph_size:
140raise ValueError(f"indices should be smaller than the max_subgraph_size."
141f"Got shape {indices.shape}.")
142
143submat_indices = jnp.arange(n_indices)
144i_j = _dstack_product(submat_indices, submat_indices)
145values_to_extract = jax.vmap(_subscript, (None, 0))(sp_mat, indices[i_j])
146dense_submat = values_to_extract.reshape((n_indices, n_indices))
147return dense_submat
148
149def _make_dense_vector(self, q, indices):
150"""Extracts a dense subvector from a sparse vector at given indices."""
151return jax.vmap(_subscript, (None, 0))(q, indices)
152
153def _ista_solve(self, s,
154graph):
155"""Runs the ISTA solver."""
156
157def body_fun(mdl, carry):
158step, q = carry
159adjacency_matrix = mdl.agent.fill_sparse_adjacency_matrix(q, graph)
160q = self._ista_step(q, adjacency_matrix, s)
161return step + 1, q
162
163def cond_fun(mdl, c):
164del mdl
165step, q = c
166del q
167return step < self.config.num_steps
168
169# Make sure the agent is initialized
170if self.is_mutable_collection("params"):
171_, q = body_fun(self, (0, s))
172else:
173# Things are initialized
174_, q = nn.while_loop(cond_fun, body_fun, self, (0, s))
175return q
176
177def __call__(
178self, start_node_id, graph
179):
180"""Performs differentiable subgraph extraction.
181
182Args:
183start_node_id: initial start node id
184graph: underlying graph
185
186Returns:
187q_star: dense weights over nodes in the sparse subgraph
188node_features: features associated with the extracted subgraph
189dense_submat: the adjacency matrix of the extracted subgraph
190q: sparse weights over nodes. Used for debug purposes.
191adjacency_matrix: the sparse adjacency matrix. Used for debug purposes.
192error: L2 norm of q_t+1 - q_t. Should be 0 at convergence.
193"""
194# TODO(gnegiar): remove unnecessary return values
195s = self._s(start_node_id)
196# TODO(gnegiar): Do we need to add a stop_gradient here
197q = self._ista_solve(s, graph)
198q = jax.lax.stop_gradient(q)
199# TODO(gnegiar): Find a way to avoid re-filling adjacency_matrix
200# For now, this allows to propagate gradients back to the `agent` model
201adjacency_matrix = self.agent.fill_sparse_adjacency_matrix(q, graph)
202# Extract dense submatrix
203dense_q = self._make_dense_vector(q, q.indices.flatten())
204dense_s = self._make_dense_vector(s, q.indices.flatten())
205dense_submat = self._extract_dense_submatrix(adjacency_matrix,
206q.indices.flatten())
207
208def _fixed_point(q):
209return _dense_fixed_point(q, dense_submat, dense_s, self.config.alpha,
210self.config.rho)
211
212def _tangent_solve(g, y):
213"""Solve implicit function theorem linear system.
214
215Optionally, use ridge regularization on the normal equation.
216
217Args:
218g: the linearized zero function in the implicit function theorem. This
219is required by `custom_root`.
220y: the target, here the jvp for what comes after this layer. This is
221required by `custom_root`.
222
223Returns:
224jvp: the jvp for the subgraph extraction layer.
225"""
226linearization = jax.jacobian(g)(y)
227if self.config.ridge != 0.:
228normal_mat_regularized, normal_target = _make_normal_system(
229linearization, y, self.config.ridge)
230jvp = jnp.linalg.solve(normal_mat_regularized, normal_target)
231else:
232jvp = jnp.linalg.solve(linearization, y)
233return jvp
234
235q_star = jax.lax.custom_root(
236f=_fixed_point,
237initial_guess=dense_q,
238solve=lambda _, q: dense_q,
239tangent_solve=_tangent_solve)
240
241node_features = jax.vmap(graph.node_features)(q.indices.flatten())
242
243node_ids = q.indices.flatten()
244
245error = self._error(q_star, dense_submat, dense_s)
246return q_star, node_features, node_ids, dense_submat, q, adjacency_matrix, error # pytype: disable=bad-return-type # jnp-array
247
248
249# Utility functions
250
251
252def _abs_top_k(u,
253k,
254nse = None):
255"""Returns a sparse vector zeroing all but the top k values of `u` in magnitude.
256
257Args:
258u: BCOO 1d vector to threshold.
259k: number of elements to keep.
260nse: [Optional] Maximal allowed nse. If None passed, use `u.nse`.
261FYI: `nse` means the number of nonzero elements in the matrix. This number
262must be fixed, due to XLA requiring fixed shaped arrays.
263
264Returns:
265thresholded_u: BCOO 1d vector, where the top k values of `u` in magnitude
266were kept. Has the specified `nse`.
267"""
268if nse is None:
269nse = u.nse
270if nse < k:
271raise ValueError(
272f"nse should be larger than the number of elements to keep. "
273f"Got nse={nse} and k={k}.")
274k = min(k, u.nse)
275# TODO(gnegiar): Benchmark speedups using jax.lax.approx_max_k
276_, idx = jax.lax.top_k(abs(u.data), k)
277# Pad to wanted nse
278pad_length = nse - len(idx)
279# Pad data with zeros
280data = jnp.concatenate((u.data[idx][:nse], jnp.zeros(pad_length)))
281# Pad indices with u.shape[0]
282indices = jnp.concatenate(
283(u.indices[idx][:nse], jnp.full((pad_length, 1), u.shape[0])))
284
285return jsparse.BCOO((data, indices), shape=u.shape)
286
287
288def _sum_with_nse(mat, other_mat,
289nse):
290"""Returns the sum of two sparse arrays, with fixed nse.
291
292If `mat` has nse `a`, and `other_mat` has nse `b`, the nse of the sum is
293`a+b` at most. To satisfy `jax`'s fixed shape desiderata,
294we impose a fixed `nse` on the result.
295
296This may cause unexpected behavior when the true `nse` of `a+b` is more than
297`nse`.
298
299FYI: `nse` means the number of nonzero elements in the matrix. This number
300must be fixed, due to XLA requiring fixed shaped arrays.
301
302Args:
303mat: first array to add
304other_mat: second array to add
305nse: max nse of the result
306
307Returns:
308sum: array with nse=`nse`
309"""
310result = mat + other_mat
311# Remove duplicate indices in result.
312result = result.sum_duplicates(nse=result.nse)
313# Return the topk `nse` items in magnitude.
314# TODO(gnegiar): print a warning when the clipping removes elements?
315return _abs_top_k(result, k=nse, nse=nse)
316
317
318def _dense_q_minus_grad(
319q,
320dense_adjacency_matrix,
321s,
322alpha,
323):
324"""Computes q-grad on dense arguments."""
325return (1. - alpha) * dense_adjacency_matrix.T @ q - alpha * s
326
327
328def _softthresh(x, alpha, rho):
329"""Performs soft-thresholding with alpha*rho threshold."""
330return jnp.sign(x) * jnp.maximum(jnp.abs(x) - alpha * rho, 0.)
331
332
333def _dense_ista_step(
334q,
335dense_adjacency_matrix,
336s,
337alpha,
338rho,
339):
340"""Performs a single step of ISTA for dense arguments."""
341q_minus_grad = _dense_q_minus_grad(q, dense_adjacency_matrix, s, alpha)
342return _softthresh(q_minus_grad, alpha, rho)
343
344
345def _dense_fixed_point(
346q,
347dense_adjacency_matrix,
348s,
349alpha,
350rho,
351):
352"""Returns the equation to be used in implicit differentiation."""
353q_ = _dense_ista_step(q, dense_adjacency_matrix, s, alpha, rho)
354return q - q_
355
356
357# TODO(gnegiar): use binary search. Look at jnp.searchsorted.
358# https://github.com/google/jax/pull/9108/files
359# This would greatly lower memory requirements, at the possible cost of speed.
360def _subscript(bcoo_mat, idx):
361"""Returns a single element from a sparse matrix at a given index.
362
363Args:
364bcoo_mat: the sparse matrix to extract from.
365idx: indices to extract the element. The length of idx should match
366bcoo_mat.ndim.
367
368Returns:
369bcoo_mat[idx]
370"""
371# Handle negative indices
372idx = jnp.where(idx >= 0, idx, jnp.array(bcoo_mat.shape) - (-idx))
373data, indices = bcoo_mat.data, bcoo_mat.indices
374# If indices are sorted, the mask should be findable via binary search.
375mask = jnp.all(indices == idx, axis=-1)
376return jnp.vdot(mask, data) # Sum duplicate indices
377
378
379def _dstack_product(x, y):
380"""Returns the cartesian product of the elements of x and y vectors.
381
382Args:
383x: 1d array
384y: 1d array of the same dtype as x.
385
386Returns:
387a 2D array containing the elements of [x]x[y].
388Example:
389x = jnp.array([1, 2, 3])
390y = jnp.array([4, 5]
391
392_dstack_product(x,y)
393>>> [[1, 4], [2, 4], [3, 4], [1, 5], [2, 5], [3, 5]]
394"""
395return jnp.dstack(jnp.meshgrid(x, y, indexing="ij")).reshape(-1, 2)
396
397
398def _make_normal_system(mat, b, ridge):
399"""Makes regularized normal linear system.
400
401The normal system to `A x = b` is `(A.T A + ridge * Id) x = A.T b`.
402
403Args:
404mat: the A in above
405b: the target in the linear system
406ridge: amount of regularization to add.
407
408Returns:
409normal_mat_regularized: corresponds to `(A.T A + ridge * Id)`
410normal_target: corresponds to `A.T b`
411"""
412normal_mat = mat.T @ mat
413normal_target = mat.T @ b
414# Add regularization to diagonal
415normal_mat_regularized = normal_mat.at[jnp.arange(normal_mat.shape[0]),
416jnp.arange(normal_mat.shape[1])].add(
417ridge)
418return normal_mat_regularized, normal_target
419