google-research

Форк
0
167 строк · 5.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for models."""
17

18
from typing import Optional
19

20
from absl.testing import absltest
21
from absl.testing import parameterized
22
import jax
23
import jraph
24
import numpy as np
25

26
from differentially_private_gnns import models
27
from differentially_private_gnns import normalizations
28

29

30
def get_dummy_graph(
31
    add_self_loops,
32
    symmetrize_edges,
33
    adjacency_normalization):
34
  """Returns a small dummy GraphsTuple."""
35
  senders = np.array([0, 2])
36
  receivers = np.array([1, 1])
37
  num_edges = len(senders)
38
  num_nodes = 3
39
  node_features = np.array([[2.], [1.], [1.]], dtype=np.float32)
40

41
  if symmetrize_edges:
42
    new_senders = np.concatenate([senders, receivers], axis=0)
43
    new_receivers = np.concatenate([receivers, senders], axis=0)
44
    senders, receivers = new_senders, new_receivers
45
    num_edges *= 2
46

47
  if add_self_loops:
48
    senders = np.concatenate([senders, np.arange(num_nodes)], axis=0)
49
    receivers = np.concatenate([receivers, np.arange(num_nodes)], axis=0)
50
    num_edges += num_nodes
51

52
  dummy_graph = jraph.GraphsTuple(
53
      n_node=np.asarray([num_nodes]),
54
      n_edge=np.asarray([num_edges]),
55
      senders=senders,
56
      receivers=receivers,
57
      nodes=node_features,
58
      edges=np.ones((num_edges, 1)),
59
      globals=np.zeros((1, 1)),
60
  )
61

62
  return normalizations.normalize_edges_with_mask(
63
      dummy_graph, mask=None, adjacency_normalization=adjacency_normalization)
64

65

66
def get_adjacency_matrix(graph):
67
  """Returns a dense adjacency matrix for the given graph."""
68
  # Initialize the adjacency matrix as all zeros.
69
  num_nodes = graph.n_node[0]
70
  adj = np.zeros((num_nodes, num_nodes))
71

72
  # Add edges, indicated by a 1 in the corresponding row and column.
73
  for u, v in zip(graph.senders, graph.receivers):
74
    adj[u][v] = 1
75

76
  return adj
77

78

79
def normalize_adjacency(
80
    adj,
81
    adjacency_normalization):
82
  """Performs appropriate normalization of the given adjacency matrix."""
83
  if adjacency_normalization is None:
84
    return adj
85
  if adjacency_normalization == 'inverse-sqrt-degree':
86
    sender_degrees = np.sum(adj, axis=1)
87
    sender_degrees = np.maximum(sender_degrees, 1.)
88
    inv_sqrt_sender_degrees = np.diag(
89
        1 / np.sqrt(sender_degrees))
90
    receiver_degrees = np.sum(adj, axis=0)
91
    receiver_degrees = np.maximum(receiver_degrees, 1.)
92
    inv_sqrt_receiver_degrees = np.diag(
93
        1 / np.sqrt(receiver_degrees))
94
    return inv_sqrt_sender_degrees @ adj @ inv_sqrt_receiver_degrees
95
  if adjacency_normalization == 'inverse-degree':
96
    sender_degrees = np.sum(adj, axis=1)
97
    inv_sender_degrees = np.diag(1 / np.maximum(sender_degrees, 1.))
98
    return inv_sender_degrees @ adj
99
  raise ValueError(f'Unsupported normalization {adjacency_normalization}.')
100

101

102
class ModelsTest(parameterized.TestCase):
103

104
  @parameterized.named_parameters(
105
      dict(
106
          testcase_name='inverse-degree-without-self-loops',
107
          add_self_loops=False,
108
          adjacency_normalization='inverse-degree'),
109
      dict(
110
          testcase_name='inverse-sqrt-degree-without-self-loops',
111
          add_self_loops=False,
112
          adjacency_normalization='inverse-sqrt-degree'),
113
      dict(
114
          testcase_name='no-normalization-symmetrize',
115
          adjacency_normalization=None,
116
          symmetrize_edges=True),
117
      dict(
118
          testcase_name='no-normalization-no-symmetrize',
119
          adjacency_normalization=None,
120
          symmetrize_edges=False),
121
      dict(
122
          testcase_name='inv-sqrt-degree-normalization-symmetrize',
123
          adjacency_normalization='inverse-sqrt-degree',
124
          symmetrize_edges=True),
125
      dict(
126
          testcase_name='inv-sqrt-degree-normalization-no-symmetrize',
127
          adjacency_normalization='inverse-sqrt-degree',
128
          symmetrize_edges=False),
129
      dict(
130
          testcase_name='inv-degree-normalization-symmetrize',
131
          adjacency_normalization='inverse-degree',
132
          symmetrize_edges=True),
133
      dict(
134
          testcase_name='inv-degree-normalization-no-symmetrize',
135
          adjacency_normalization='inverse-degree',
136
          symmetrize_edges=False),
137
  )
138
  def test_graph_convolution_one_hop(
139
      self,
140
      add_self_loops = True,
141
      symmetrize_edges = False,
142
      adjacency_normalization = None):
143

144
    # Create a dummy graph.
145
    dummy_graph = get_dummy_graph(
146
        add_self_loops=add_self_loops,
147
        symmetrize_edges=symmetrize_edges,
148
        adjacency_normalization=adjacency_normalization)
149

150
    # Build 1-hop GCN.
151
    model = models.OneHopGraphConvolution(update_fn=lambda nodes: nodes)
152
    rng = jax.random.PRNGKey(0)
153
    params = model.init(rng, dummy_graph)
154
    processed_nodes = model.apply(params, dummy_graph).nodes
155

156
    # Compute expected node features.
157
    adj = get_adjacency_matrix(dummy_graph)
158
    normalized_adj = normalize_adjacency(
159
        adj, adjacency_normalization=adjacency_normalization)
160
    expected_nodes = normalized_adj @ dummy_graph.nodes
161

162
    # Check whether outputs match.
163
    self.assertTrue(np.allclose(processed_nodes, expected_nodes))
164

165

166
if __name__ == '__main__':
167
  absltest.main()
168

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.