google-research
134 строки · 4.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for sampler."""
17
18from typing import Any, Callable, Dict, List, Union
19
20from absl.testing import absltest
21from absl.testing import parameterized
22import jax
23import networkx as nx
24import numpy as np
25
26from differentially_private_gnns import sampler
27
28Subgraphs = Dict[int, Union[List[int], 'Subgraphs']]
29
30
31def sample_subgraphs(edges, num_hops):
32"""Samples subgraphs given an edgelist."""
33
34if num_hops not in [1, 2]:
35raise NotImplementedError
36
37if num_hops == 1:
38return edges
39
40subgraphs = {}
41for root_node, neighbors in edges.items():
42subgraphs[root_node] = {}
43for neighbor in neighbors:
44subgraphs[root_node][neighbor] = edges[neighbor]
45return subgraphs
46
47
48def flatten_subgraphs(subgraphs):
49"""Flattens sampled subgraphs."""
50
51def flatten_subgraph(
52node, node_subgraph):
53# Base case.
54if isinstance(node_subgraph, list):
55return [node, *node_subgraph]
56
57# Recurse on neighbours.
58flattened = []
59for neighbor, neighbor_subgraph in node_subgraph.items():
60flattened.extend(flatten_subgraph(neighbor, neighbor_subgraph))
61return flattened
62
63return {
64node: flatten_subgraph(node, node_subgraph)
65for node, node_subgraph in subgraphs.items()
66}
67
68
69class SampleEdgelistTest(parameterized.TestCase):
70
71@parameterized.product(
72rng_key=[0, 1],
73sample_fn=[sampler.sample_adjacency_lists],
74num_nodes=[10, 20, 50],
75edge_probability=[0.1, 0.2, 0.5, 0.8, 1.],
76max_degree=[1, 2, 5, 10, 20])
77def test_occurrence_constraints_one_hop(
78self, rng_key, sample_fn,
79num_nodes, edge_probability, max_degree):
80
81graph = nx.erdos_renyi_graph(num_nodes, p=edge_probability)
82edges = {node: list(graph.neighbors(node)) for node in graph.nodes}
83train_nodes = set(np.arange(num_nodes, step=2).flat)
84rng = jax.random.PRNGKey(rng_key)
85sampled_edges = sample_fn(edges, train_nodes, max_degree, rng)
86sampled_subgraphs = sample_subgraphs(sampled_edges, num_hops=1)
87sampled_subgraphs = flatten_subgraphs(sampled_subgraphs)
88
89occurrence_counts = {node: 0 for node in sampled_edges}
90for root_node, subgraph in sampled_subgraphs.items():
91if root_node in train_nodes:
92for node in subgraph:
93occurrence_counts[node] += 1
94
95self.assertLen(sampled_edges, num_nodes)
96self.assertLen(sampled_subgraphs, num_nodes)
97for count in occurrence_counts.values():
98self.assertLessEqual(count, max_degree + 1)
99
100@parameterized.product(
101rng_key=[0, 1],
102sample_fn=[sampler.sample_adjacency_lists],
103num_nodes=[10, 20, 50],
104edge_probability=[0.1, 0.2, 0.5, 0.8, 1.],
105max_degree=[1, 2, 5, 10, 20])
106def test_occurrence_constraints_two_hop_disjoint(
107self, rng_key, sample_fn,
108num_nodes, edge_probability, max_degree):
109
110num_train_nodes = num_nodes // 2
111graph = nx.disjoint_union(
112nx.erdos_renyi_graph(num_train_nodes, p=edge_probability),
113nx.erdos_renyi_graph(num_nodes - num_train_nodes, p=edge_probability))
114edges = {node: list(graph.neighbors(node)) for node in graph.nodes}
115train_nodes = set(np.arange(num_train_nodes).flat)
116rng = jax.random.PRNGKey(rng_key)
117sampled_edges = sample_fn(edges, train_nodes, max_degree, rng)
118sampled_subgraphs = sample_subgraphs(sampled_edges, num_hops=2)
119sampled_subgraphs = flatten_subgraphs(sampled_subgraphs)
120
121occurrence_counts = {node: 0 for node in sampled_edges}
122for root_node, subgraph in sampled_subgraphs.items():
123if root_node in train_nodes:
124for node in subgraph:
125occurrence_counts[node] += 1
126
127self.assertLen(sampled_edges, num_nodes)
128self.assertLen(sampled_subgraphs, num_nodes)
129for count in occurrence_counts.values():
130self.assertLessEqual(count, max_degree * max_degree + max_degree + 1)
131
132
133if __name__ == '__main__':
134absltest.main()
135