google-research

Форк
0
117 строк · 3.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Input pipeline for DP-GNN training."""
17

18
from typing import Dict, Tuple
19

20
import chex
21
import jraph
22
import ml_collections
23
import numpy as np
24

25
from differentially_private_gnns import dataset_readers
26
from differentially_private_gnns import normalizations
27
from differentially_private_gnns import sampler
28

29

30
def add_reverse_edges(
31
    graph):
32
  """Add reverse edges to the graph."""
33
  senders = np.concatenate(
34
      (graph.senders, graph.receivers))
35
  receivers = np.concatenate(
36
      (graph.receivers, graph.senders))
37

38
  graph.senders = senders
39
  graph.receivers = receivers
40
  return graph
41

42

43
def subsample_graph(graph, max_degree,
44
                    rng):
45
  """Subsamples the undirected input graph."""
46
  edges = sampler.get_adjacency_lists(graph)
47
  edges = sampler.sample_adjacency_lists(edges, graph.train_nodes, max_degree,
48
                                         rng)
49
  senders = []
50
  receivers = []
51
  for u in edges:
52
    for v in edges[u]:
53
      senders.append(u)
54
      receivers.append(v)
55

56
  graph.senders = senders
57
  graph.receivers = receivers
58
  return graph
59

60

61
def compute_masks_for_splits(
62
    graph):
63
  """Compute boolean masks for the train, validation and test splits."""
64
  masks = {}
65
  num_nodes = graph.num_nodes()
66
  for split, split_nodes in zip(
67
      ['train', 'validation', 'test'],
68
      [graph.train_nodes, graph.validation_nodes, graph.test_nodes]):
69
    split_mask = np.zeros(num_nodes, dtype=bool)
70
    split_mask[split_nodes] = True
71
    masks[split] = split_mask
72
  return masks
73

74

75
def convert_to_graphstuple(
76
    graph):
77
  """Converts a dataset to one entire jraph.GraphsTuple, extracting labels."""
78
  return jraph.GraphsTuple(  # pytype: disable=wrong-arg-types  # jax-ndarray
79
      nodes=np.asarray(graph.node_features),
80
      edges=np.ones_like(graph.senders),
81
      senders=np.asarray(graph.senders),
82
      receivers=np.asarray(graph.receivers),
83
      globals=np.zeros(1),
84
      n_node=np.asarray([graph.num_nodes()]),
85
      n_edge=np.asarray([graph.num_edges()]),
86
  ), np.asarray(graph.node_labels)
87

88

89
def add_self_loops(graph):
90
  """Adds self-loops to the graph."""
91
  num_nodes = normalizations.compute_num_nodes(graph)
92
  senders = np.concatenate(
93
      (np.arange(num_nodes), np.asarray(graph.senders, dtype=np.int32)))
94
  receivers = np.concatenate(
95
      (np.arange(num_nodes), np.asarray(graph.receivers, dtype=np.int32)))
96

97
  return graph._replace(
98
      senders=senders,
99
      receivers=receivers,
100
      edges=np.ones_like(senders),
101
      n_edge=np.asarray([senders.shape[0]]))
102

103

104
def get_dataset(
105
    config,
106
    rng,
107
):
108
  """Load graph dataset."""
109
  graph = dataset_readers.get_dataset(config.dataset, config.dataset_path)
110
  graph = add_reverse_edges(graph)
111
  graph = subsample_graph(graph, config.max_degree, rng)
112
  masks = compute_masks_for_splits(graph)
113
  graph, labels = convert_to_graphstuple(graph)
114
  graph = add_self_loops(graph)
115
  graph = normalizations.normalize_edges_with_mask(
116
      graph, mask=None, adjacency_normalization=config.adjacency_normalization)
117
  return graph, labels, masks
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.