google-research

Форк
0
134 строки · 4.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for sampler."""
17

18
from typing import Any, Callable, Dict, List, Union
19

20
from absl.testing import absltest
21
from absl.testing import parameterized
22
import jax
23
import networkx as nx
24
import numpy as np
25

26
from differentially_private_gnns import sampler
27

28
Subgraphs = Dict[int, Union[List[int], 'Subgraphs']]
29

30

31
def sample_subgraphs(edges, num_hops):
32
  """Samples subgraphs given an edgelist."""
33

34
  if num_hops not in [1, 2]:
35
    raise NotImplementedError
36

37
  if num_hops == 1:
38
    return edges
39

40
  subgraphs = {}
41
  for root_node, neighbors in edges.items():
42
    subgraphs[root_node] = {}
43
    for neighbor in neighbors:
44
      subgraphs[root_node][neighbor] = edges[neighbor]
45
  return subgraphs
46

47

48
def flatten_subgraphs(subgraphs):
49
  """Flattens sampled subgraphs."""
50

51
  def flatten_subgraph(
52
      node, node_subgraph):
53
    # Base case.
54
    if isinstance(node_subgraph, list):
55
      return [node, *node_subgraph]
56

57
    # Recurse on neighbours.
58
    flattened = []
59
    for neighbor, neighbor_subgraph in node_subgraph.items():
60
      flattened.extend(flatten_subgraph(neighbor, neighbor_subgraph))
61
    return flattened
62

63
  return {
64
      node: flatten_subgraph(node, node_subgraph)
65
      for node, node_subgraph in subgraphs.items()
66
  }
67

68

69
class SampleEdgelistTest(parameterized.TestCase):
70

71
  @parameterized.product(
72
      rng_key=[0, 1],
73
      sample_fn=[sampler.sample_adjacency_lists],
74
      num_nodes=[10, 20, 50],
75
      edge_probability=[0.1, 0.2, 0.5, 0.8, 1.],
76
      max_degree=[1, 2, 5, 10, 20])
77
  def test_occurrence_constraints_one_hop(
78
      self, rng_key, sample_fn,
79
      num_nodes, edge_probability, max_degree):
80

81
    graph = nx.erdos_renyi_graph(num_nodes, p=edge_probability)
82
    edges = {node: list(graph.neighbors(node)) for node in graph.nodes}
83
    train_nodes = set(np.arange(num_nodes, step=2).flat)
84
    rng = jax.random.PRNGKey(rng_key)
85
    sampled_edges = sample_fn(edges, train_nodes, max_degree, rng)
86
    sampled_subgraphs = sample_subgraphs(sampled_edges, num_hops=1)
87
    sampled_subgraphs = flatten_subgraphs(sampled_subgraphs)
88

89
    occurrence_counts = {node: 0 for node in sampled_edges}
90
    for root_node, subgraph in sampled_subgraphs.items():
91
      if root_node in train_nodes:
92
        for node in subgraph:
93
          occurrence_counts[node] += 1
94

95
    self.assertLen(sampled_edges, num_nodes)
96
    self.assertLen(sampled_subgraphs, num_nodes)
97
    for count in occurrence_counts.values():
98
      self.assertLessEqual(count, max_degree + 1)
99

100
  @parameterized.product(
101
      rng_key=[0, 1],
102
      sample_fn=[sampler.sample_adjacency_lists],
103
      num_nodes=[10, 20, 50],
104
      edge_probability=[0.1, 0.2, 0.5, 0.8, 1.],
105
      max_degree=[1, 2, 5, 10, 20])
106
  def test_occurrence_constraints_two_hop_disjoint(
107
      self, rng_key, sample_fn,
108
      num_nodes, edge_probability, max_degree):
109

110
    num_train_nodes = num_nodes // 2
111
    graph = nx.disjoint_union(
112
        nx.erdos_renyi_graph(num_train_nodes, p=edge_probability),
113
        nx.erdos_renyi_graph(num_nodes - num_train_nodes, p=edge_probability))
114
    edges = {node: list(graph.neighbors(node)) for node in graph.nodes}
115
    train_nodes = set(np.arange(num_train_nodes).flat)
116
    rng = jax.random.PRNGKey(rng_key)
117
    sampled_edges = sample_fn(edges, train_nodes, max_degree, rng)
118
    sampled_subgraphs = sample_subgraphs(sampled_edges, num_hops=2)
119
    sampled_subgraphs = flatten_subgraphs(sampled_subgraphs)
120

121
    occurrence_counts = {node: 0 for node in sampled_edges}
122
    for root_node, subgraph in sampled_subgraphs.items():
123
      if root_node in train_nodes:
124
        for node in subgraph:
125
          occurrence_counts[node] += 1
126

127
    self.assertLen(sampled_edges, num_nodes)
128
    self.assertLen(sampled_subgraphs, num_nodes)
129
    for count in occurrence_counts.values():
130
      self.assertLessEqual(count, max_degree * max_degree + max_degree + 1)
131

132

133
if __name__ == '__main__':
134
  absltest.main()
135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.