google-research
167 строк · 5.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for models."""
17
18from typing import Optional19
20from absl.testing import absltest21from absl.testing import parameterized22import jax23import jraph24import numpy as np25
26from differentially_private_gnns import models27from differentially_private_gnns import normalizations28
29
30def get_dummy_graph(31add_self_loops,32symmetrize_edges,33adjacency_normalization):34"""Returns a small dummy GraphsTuple."""35senders = np.array([0, 2])36receivers = np.array([1, 1])37num_edges = len(senders)38num_nodes = 339node_features = np.array([[2.], [1.], [1.]], dtype=np.float32)40
41if symmetrize_edges:42new_senders = np.concatenate([senders, receivers], axis=0)43new_receivers = np.concatenate([receivers, senders], axis=0)44senders, receivers = new_senders, new_receivers45num_edges *= 246
47if add_self_loops:48senders = np.concatenate([senders, np.arange(num_nodes)], axis=0)49receivers = np.concatenate([receivers, np.arange(num_nodes)], axis=0)50num_edges += num_nodes51
52dummy_graph = jraph.GraphsTuple(53n_node=np.asarray([num_nodes]),54n_edge=np.asarray([num_edges]),55senders=senders,56receivers=receivers,57nodes=node_features,58edges=np.ones((num_edges, 1)),59globals=np.zeros((1, 1)),60)61
62return normalizations.normalize_edges_with_mask(63dummy_graph, mask=None, adjacency_normalization=adjacency_normalization)64
65
66def get_adjacency_matrix(graph):67"""Returns a dense adjacency matrix for the given graph."""68# Initialize the adjacency matrix as all zeros.69num_nodes = graph.n_node[0]70adj = np.zeros((num_nodes, num_nodes))71
72# Add edges, indicated by a 1 in the corresponding row and column.73for u, v in zip(graph.senders, graph.receivers):74adj[u][v] = 175
76return adj77
78
79def normalize_adjacency(80adj,81adjacency_normalization):82"""Performs appropriate normalization of the given adjacency matrix."""83if adjacency_normalization is None:84return adj85if adjacency_normalization == 'inverse-sqrt-degree':86sender_degrees = np.sum(adj, axis=1)87sender_degrees = np.maximum(sender_degrees, 1.)88inv_sqrt_sender_degrees = np.diag(891 / np.sqrt(sender_degrees))90receiver_degrees = np.sum(adj, axis=0)91receiver_degrees = np.maximum(receiver_degrees, 1.)92inv_sqrt_receiver_degrees = np.diag(931 / np.sqrt(receiver_degrees))94return inv_sqrt_sender_degrees @ adj @ inv_sqrt_receiver_degrees95if adjacency_normalization == 'inverse-degree':96sender_degrees = np.sum(adj, axis=1)97inv_sender_degrees = np.diag(1 / np.maximum(sender_degrees, 1.))98return inv_sender_degrees @ adj99raise ValueError(f'Unsupported normalization {adjacency_normalization}.')100
101
102class ModelsTest(parameterized.TestCase):103
104@parameterized.named_parameters(105dict(106testcase_name='inverse-degree-without-self-loops',107add_self_loops=False,108adjacency_normalization='inverse-degree'),109dict(110testcase_name='inverse-sqrt-degree-without-self-loops',111add_self_loops=False,112adjacency_normalization='inverse-sqrt-degree'),113dict(114testcase_name='no-normalization-symmetrize',115adjacency_normalization=None,116symmetrize_edges=True),117dict(118testcase_name='no-normalization-no-symmetrize',119adjacency_normalization=None,120symmetrize_edges=False),121dict(122testcase_name='inv-sqrt-degree-normalization-symmetrize',123adjacency_normalization='inverse-sqrt-degree',124symmetrize_edges=True),125dict(126testcase_name='inv-sqrt-degree-normalization-no-symmetrize',127adjacency_normalization='inverse-sqrt-degree',128symmetrize_edges=False),129dict(130testcase_name='inv-degree-normalization-symmetrize',131adjacency_normalization='inverse-degree',132symmetrize_edges=True),133dict(134testcase_name='inv-degree-normalization-no-symmetrize',135adjacency_normalization='inverse-degree',136symmetrize_edges=False),137)138def test_graph_convolution_one_hop(139self,140add_self_loops = True,141symmetrize_edges = False,142adjacency_normalization = None):143
144# Create a dummy graph.145dummy_graph = get_dummy_graph(146add_self_loops=add_self_loops,147symmetrize_edges=symmetrize_edges,148adjacency_normalization=adjacency_normalization)149
150# Build 1-hop GCN.151model = models.OneHopGraphConvolution(update_fn=lambda nodes: nodes)152rng = jax.random.PRNGKey(0)153params = model.init(rng, dummy_graph)154processed_nodes = model.apply(params, dummy_graph).nodes155
156# Compute expected node features.157adj = get_adjacency_matrix(dummy_graph)158normalized_adj = normalize_adjacency(159adj, adjacency_normalization=adjacency_normalization)160expected_nodes = normalized_adj @ dummy_graph.nodes161
162# Check whether outputs match.163self.assertTrue(np.allclose(processed_nodes, expected_nodes))164
165
166if __name__ == '__main__':167absltest.main()168