google-research
117 строк · 3.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Input pipeline for DP-GNN training."""
17
18from typing import Dict, Tuple19
20import chex21import jraph22import ml_collections23import numpy as np24
25from differentially_private_gnns import dataset_readers26from differentially_private_gnns import normalizations27from differentially_private_gnns import sampler28
29
30def add_reverse_edges(31graph):32"""Add reverse edges to the graph."""33senders = np.concatenate(34(graph.senders, graph.receivers))35receivers = np.concatenate(36(graph.receivers, graph.senders))37
38graph.senders = senders39graph.receivers = receivers40return graph41
42
43def subsample_graph(graph, max_degree,44rng):45"""Subsamples the undirected input graph."""46edges = sampler.get_adjacency_lists(graph)47edges = sampler.sample_adjacency_lists(edges, graph.train_nodes, max_degree,48rng)49senders = []50receivers = []51for u in edges:52for v in edges[u]:53senders.append(u)54receivers.append(v)55
56graph.senders = senders57graph.receivers = receivers58return graph59
60
61def compute_masks_for_splits(62graph):63"""Compute boolean masks for the train, validation and test splits."""64masks = {}65num_nodes = graph.num_nodes()66for split, split_nodes in zip(67['train', 'validation', 'test'],68[graph.train_nodes, graph.validation_nodes, graph.test_nodes]):69split_mask = np.zeros(num_nodes, dtype=bool)70split_mask[split_nodes] = True71masks[split] = split_mask72return masks73
74
75def convert_to_graphstuple(76graph):77"""Converts a dataset to one entire jraph.GraphsTuple, extracting labels."""78return jraph.GraphsTuple( # pytype: disable=wrong-arg-types # jax-ndarray79nodes=np.asarray(graph.node_features),80edges=np.ones_like(graph.senders),81senders=np.asarray(graph.senders),82receivers=np.asarray(graph.receivers),83globals=np.zeros(1),84n_node=np.asarray([graph.num_nodes()]),85n_edge=np.asarray([graph.num_edges()]),86), np.asarray(graph.node_labels)87
88
89def add_self_loops(graph):90"""Adds self-loops to the graph."""91num_nodes = normalizations.compute_num_nodes(graph)92senders = np.concatenate(93(np.arange(num_nodes), np.asarray(graph.senders, dtype=np.int32)))94receivers = np.concatenate(95(np.arange(num_nodes), np.asarray(graph.receivers, dtype=np.int32)))96
97return graph._replace(98senders=senders,99receivers=receivers,100edges=np.ones_like(senders),101n_edge=np.asarray([senders.shape[0]]))102
103
104def get_dataset(105config,106rng,107):108"""Load graph dataset."""109graph = dataset_readers.get_dataset(config.dataset, config.dataset_path)110graph = add_reverse_edges(graph)111graph = subsample_graph(graph, config.max_degree, rng)112masks = compute_masks_for_splits(graph)113graph, labels = convert_to_graphstuple(graph)114graph = add_self_loops(graph)115graph = normalizations.normalize_edges_with_mask(116graph, mask=None, adjacency_normalization=config.adjacency_normalization)117return graph, labels, masks118