google-research

Форк
0
/
subgraph_extractors.py 
418 строк · 13.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Classes for extracting a subgraph of a larger graph_api.GraphAPI.
17

18
This module implements a differentiable L1-regularized PageRank.
19
It follows the method in Fountoulakis et al., 2016.
20
`https://arxiv.org/abs/1602.01886`.
21

22
The main problem can be written as a L1 penalized quadratic problem (when the
23
graph is undirected), which we solve using proximal algorithms.
24
"""
25

26
from typing import Optional, Tuple
27

28
from flax import struct
29
import flax.linen as nn
30

31
import jax
32
from jax.experimental import sparse as jsparse
33
import jax.numpy as jnp
34

35
from jaxsel._src import agents
36
from jaxsel._src import graph_api
37

38

39
@struct.dataclass
40
class ExtractorConfig:
41
  """Config for subgraph extractors.
42

43
  Attributes:
44
    max_graph_size: Maximum size of the handled graphs.
45
    max_subgraph_size: Maximum size of the extracted subgraph.
46
    rho: L1 penalty strength
47
    alpha: probability to teleport back to the start node during the walk.
48
    num_steps: max number of steps in the ISTA algorithm/in the random walk.
49
    ridge: scale of ridge penalty for the backwards linear problem.
50
    agent_config: Configuration for the underlying agent model.
51
  """
52
  # TODO(gnegiar): Take agent class/name as argument here.
53
  max_graph_size: int
54
  max_subgraph_size: int
55
  rho: float
56
  alpha: float
57
  num_steps: int
58
  ridge: float
59
  agent_config: agents.AgentConfig
60

61

62
class SparseISTAExtractor(nn.Module):
63
  """Performs sparse PageRank.
64

65
  PageRank can be seen as a random walk on the graph, where the transition
66
  probabilities are given by the adjacency matrix on the graph.
67
  Here, the adjacency matrix is parametrized by an `agent` model. PageRank
68
  yields a vector of weights on the nodes of the graph.
69

70
  Reference: https://arxiv.org/abs/1602.01886
71

72
  When rho is 0., we still perform topk thresholding on the weight vector.
73

74
  Attributes:
75
    config: Configuration for the extractor layer and the underlying
76
      agent model.
77
  """
78
  config: ExtractorConfig
79

80
  def setup(self):
81
    self.agent = agents.SimpleFiLMedAgentModel(self.config.agent_config)
82

83
  def _s(self, start_node_id):
84
    """Encode the start node."""
85
    nse = self.config.max_subgraph_size
86
    return jsparse.BCOO(
87
        (jnp.zeros(nse).at[0].set(1.), jnp.zeros(
88
            (nse, 1), dtype=int).at[0, 0].set(start_node_id)),
89
        shape=(self.config.max_graph_size,))
90

91
  def _q_minus_grad(self, q, adjacency_matrix,
92
                    s):
93
    return _sum_with_nse(
94
        (1 - self.config.alpha) * adjacency_matrix.T @ q,
95
        -self.config.alpha * s,
96
        nse=self.config.max_subgraph_size)
97

98
  def _sparse_softthresh(self, x):
99
    return jsparse.BCOO(
100
        (_softthresh(x.data, self.config.alpha, self.config.rho), x.indices),
101
        shape=x.shape)
102

103
  def _ista_step(self, q, adjacency_matrix,
104
                 s):
105
    q_minus_grad = self._q_minus_grad(q, adjacency_matrix, s)
106
    return self._sparse_softthresh(q_minus_grad)
107

108
  def _error(
109
      self, q, dense_adjacency_matrix, s
110
  ):
111
    return (_dense_fixed_point(q, dense_adjacency_matrix, s, self.config.alpha,
112
                               self.config.rho)**2).sum()
113

114
  def _extract_dense_submatrix(self, sp_mat,
115
                               indices):
116
    """Extracts a dense submatrix of a sparse square matrix at given indices.
117

118
    Args:
119
      sp_mat: A sparse matrix.
120
      indices: A 1D array of indices into M_sp, or `M_sp.shape[0]` to indicate
121
        an empty row/column. Assumed to be deduplicated using
122
        `BCOO.sum_duplicates` or similar beforehand.
123

124
    Returns:
125
      dense_submat:
126
        A dense submatrix of M_sp, with a subset of rows and columns from M_sp
127
        or with zero rows and columns in place of -1 indices.
128
    """
129
    if sp_mat.ndim != 2:
130
      raise ValueError(
131
          f"The first argument must be a 2d matrix. Got {sp_mat.ndim}.")
132
    if sp_mat.shape[0] != sp_mat.shape[1]:
133
      raise ValueError(f"sp_mat should be square. Got shape {sp_mat.shape}.")
134
    if indices.ndim != 1:
135
      raise ValueError(
136
          f"indices should be a 1-d array. Got shape {indices.shape}.")
137

138
    n_indices = indices.shape[0]
139
    if n_indices > self.config.max_subgraph_size:
140
      raise ValueError(f"indices should be smaller than the max_subgraph_size."
141
                       f"Got shape {indices.shape}.")
142

143
    submat_indices = jnp.arange(n_indices)
144
    i_j = _dstack_product(submat_indices, submat_indices)
145
    values_to_extract = jax.vmap(_subscript, (None, 0))(sp_mat, indices[i_j])
146
    dense_submat = values_to_extract.reshape((n_indices, n_indices))
147
    return dense_submat
148

149
  def _make_dense_vector(self, q, indices):
150
    """Extracts a dense subvector from a sparse vector at given indices."""
151
    return jax.vmap(_subscript, (None, 0))(q, indices)
152

153
  def _ista_solve(self, s,
154
                  graph):
155
    """Runs the ISTA solver."""
156

157
    def body_fun(mdl, carry):
158
      step, q = carry
159
      adjacency_matrix = mdl.agent.fill_sparse_adjacency_matrix(q, graph)
160
      q = self._ista_step(q, adjacency_matrix, s)
161
      return step + 1, q
162

163
    def cond_fun(mdl, c):
164
      del mdl
165
      step, q = c
166
      del q
167
      return step < self.config.num_steps
168

169
    # Make sure the agent is initialized
170
    if self.is_mutable_collection("params"):
171
      _, q = body_fun(self, (0, s))
172
    else:
173
      # Things are initialized
174
      _, q = nn.while_loop(cond_fun, body_fun, self, (0, s))
175
    return q
176

177
  def __call__(
178
      self, start_node_id, graph
179
  ):
180
    """Performs differentiable subgraph extraction.
181

182
    Args:
183
      start_node_id: initial start node id
184
      graph: underlying graph
185

186
    Returns:
187
      q_star: dense weights over nodes in the sparse subgraph
188
      node_features: features associated with the extracted subgraph
189
      dense_submat: the adjacency matrix of the extracted subgraph
190
      q: sparse weights over nodes. Used for debug purposes.
191
      adjacency_matrix: the sparse adjacency matrix. Used for debug purposes.
192
      error: L2 norm of q_t+1 - q_t. Should be 0 at convergence.
193
    """
194
    # TODO(gnegiar): remove unnecessary return values
195
    s = self._s(start_node_id)
196
    # TODO(gnegiar): Do we need to add a stop_gradient here
197
    q = self._ista_solve(s, graph)
198
    q = jax.lax.stop_gradient(q)
199
    # TODO(gnegiar): Find a way to avoid re-filling adjacency_matrix
200
    # For now, this allows to propagate gradients back to the `agent` model
201
    adjacency_matrix = self.agent.fill_sparse_adjacency_matrix(q, graph)
202
    # Extract dense submatrix
203
    dense_q = self._make_dense_vector(q, q.indices.flatten())
204
    dense_s = self._make_dense_vector(s, q.indices.flatten())
205
    dense_submat = self._extract_dense_submatrix(adjacency_matrix,
206
                                                 q.indices.flatten())
207

208
    def _fixed_point(q):
209
      return _dense_fixed_point(q, dense_submat, dense_s, self.config.alpha,
210
                                self.config.rho)
211

212
    def _tangent_solve(g, y):
213
      """Solve implicit function theorem linear system.
214

215
      Optionally, use ridge regularization on the normal equation.
216

217
      Args:
218
        g: the linearized zero function in the implicit function theorem. This
219
          is required by `custom_root`.
220
        y: the target, here the jvp for what comes after this layer. This is
221
          required by `custom_root`.
222

223
      Returns:
224
        jvp: the jvp for the subgraph extraction layer.
225
      """
226
      linearization = jax.jacobian(g)(y)
227
      if self.config.ridge != 0.:
228
        normal_mat_regularized, normal_target = _make_normal_system(
229
            linearization, y, self.config.ridge)
230
        jvp = jnp.linalg.solve(normal_mat_regularized, normal_target)
231
      else:
232
        jvp = jnp.linalg.solve(linearization, y)
233
      return jvp
234

235
    q_star = jax.lax.custom_root(
236
        f=_fixed_point,
237
        initial_guess=dense_q,
238
        solve=lambda _, q: dense_q,
239
        tangent_solve=_tangent_solve)
240

241
    node_features = jax.vmap(graph.node_features)(q.indices.flatten())
242

243
    node_ids = q.indices.flatten()
244

245
    error = self._error(q_star, dense_submat, dense_s)
246
    return q_star, node_features, node_ids, dense_submat, q, adjacency_matrix, error  # pytype: disable=bad-return-type  # jnp-array
247

248

249
# Utility functions
250

251

252
def _abs_top_k(u,
253
               k,
254
               nse = None):
255
  """Returns a sparse vector zeroing all but the top k values of `u` in magnitude.
256

257
  Args:
258
    u: BCOO 1d vector to threshold.
259
    k: number of elements to keep.
260
    nse: [Optional] Maximal allowed nse. If None passed, use `u.nse`.
261
  FYI: `nse` means the number of nonzero elements in the matrix. This number
262
    must be fixed, due to XLA requiring fixed shaped arrays.
263

264
  Returns:
265
    thresholded_u: BCOO 1d vector, where the top k values of `u` in magnitude
266
      were kept. Has the specified `nse`.
267
  """
268
  if nse is None:
269
    nse = u.nse
270
  if nse < k:
271
    raise ValueError(
272
        f"nse should be larger than the number of elements to keep. "
273
        f"Got nse={nse} and k={k}.")
274
  k = min(k, u.nse)
275
  # TODO(gnegiar): Benchmark speedups using jax.lax.approx_max_k
276
  _, idx = jax.lax.top_k(abs(u.data), k)
277
  # Pad to wanted nse
278
  pad_length = nse - len(idx)
279
  # Pad data with zeros
280
  data = jnp.concatenate((u.data[idx][:nse], jnp.zeros(pad_length)))
281
  # Pad indices with u.shape[0]
282
  indices = jnp.concatenate(
283
      (u.indices[idx][:nse], jnp.full((pad_length, 1), u.shape[0])))
284

285
  return jsparse.BCOO((data, indices), shape=u.shape)
286

287

288
def _sum_with_nse(mat, other_mat,
289
                  nse):
290
  """Returns the sum of two sparse arrays, with fixed nse.
291

292
  If `mat` has nse `a`, and `other_mat` has nse `b`, the nse of the sum is
293
  `a+b` at most. To satisfy `jax`'s fixed shape desiderata,
294
  we impose a fixed `nse` on the result.
295

296
  This may cause unexpected behavior when the true `nse` of `a+b` is more than
297
  `nse`.
298

299
  FYI: `nse` means the number of nonzero elements in the matrix. This number
300
    must be fixed, due to XLA requiring fixed shaped arrays.
301

302
  Args:
303
    mat: first array to add
304
    other_mat: second array to add
305
    nse: max nse of the result
306

307
  Returns:
308
    sum: array with nse=`nse`
309
  """
310
  result = mat + other_mat
311
  # Remove duplicate indices in result.
312
  result = result.sum_duplicates(nse=result.nse)
313
  # Return the topk `nse` items in magnitude.
314
  # TODO(gnegiar): print a warning when the clipping removes elements?
315
  return _abs_top_k(result, k=nse, nse=nse)
316

317

318
def _dense_q_minus_grad(
319
    q,
320
    dense_adjacency_matrix,
321
    s,
322
    alpha,
323
):
324
  """Computes q-grad on dense arguments."""
325
  return (1. - alpha) * dense_adjacency_matrix.T @ q - alpha * s
326

327

328
def _softthresh(x, alpha, rho):
329
  """Performs soft-thresholding with alpha*rho threshold."""
330
  return jnp.sign(x) * jnp.maximum(jnp.abs(x) - alpha * rho, 0.)
331

332

333
def _dense_ista_step(
334
    q,
335
    dense_adjacency_matrix,
336
    s,
337
    alpha,
338
    rho,
339
):
340
  """Performs a single step of ISTA for dense arguments."""
341
  q_minus_grad = _dense_q_minus_grad(q, dense_adjacency_matrix, s, alpha)
342
  return _softthresh(q_minus_grad, alpha, rho)
343

344

345
def _dense_fixed_point(
346
    q,
347
    dense_adjacency_matrix,
348
    s,
349
    alpha,
350
    rho,
351
):
352
  """Returns the equation to be used in implicit differentiation."""
353
  q_ = _dense_ista_step(q, dense_adjacency_matrix, s, alpha, rho)
354
  return q - q_
355

356

357
# TODO(gnegiar): use binary search. Look at jnp.searchsorted.
358
# https://github.com/google/jax/pull/9108/files
359
# This would greatly lower memory requirements, at the possible cost of speed.
360
def _subscript(bcoo_mat, idx):
361
  """Returns a single element from a sparse matrix at a given index.
362

363
  Args:
364
    bcoo_mat: the sparse matrix to extract from.
365
    idx: indices to extract the element. The length of idx should match
366
      bcoo_mat.ndim.
367

368
  Returns:
369
    bcoo_mat[idx]
370
  """
371
  # Handle negative indices
372
  idx = jnp.where(idx >= 0, idx, jnp.array(bcoo_mat.shape) - (-idx))
373
  data, indices = bcoo_mat.data, bcoo_mat.indices
374
  # If indices are sorted, the mask should be findable via binary search.
375
  mask = jnp.all(indices == idx, axis=-1)
376
  return jnp.vdot(mask, data)  # Sum duplicate indices
377

378

379
def _dstack_product(x, y):
380
  """Returns the cartesian product of the elements of x and y vectors.
381

382
  Args:
383
    x: 1d array
384
    y: 1d array of the same dtype as x.
385

386
  Returns:
387
    a 2D array containing the elements of [x]x[y].
388
  Example:
389
    x = jnp.array([1, 2, 3])
390
    y = jnp.array([4, 5]
391

392
    _dstack_product(x,y)
393
    >>> [[1, 4], [2, 4], [3, 4], [1, 5], [2, 5], [3, 5]]
394
  """
395
  return jnp.dstack(jnp.meshgrid(x, y, indexing="ij")).reshape(-1, 2)
396

397

398
def _make_normal_system(mat, b, ridge):
399
  """Makes regularized normal linear system.
400

401
  The normal system to `A x = b` is `(A.T A + ridge * Id) x = A.T b`.
402

403
  Args:
404
    mat: the A in above
405
    b: the target in the linear system
406
    ridge: amount of regularization to add.
407

408
  Returns:
409
    normal_mat_regularized: corresponds to `(A.T A + ridge * Id)`
410
    normal_target: corresponds to `A.T b`
411
  """
412
  normal_mat = mat.T @ mat
413
  normal_target = mat.T @ b
414
  # Add regularization to diagonal
415
  normal_mat_regularized = normal_mat.at[jnp.arange(normal_mat.shape[0]),
416
                                         jnp.arange(normal_mat.shape[1])].add(
417
                                             ridge)
418
  return normal_mat_regularized, normal_target
419

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.