google-research
117 строк · 3.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for train."""
17
18import tempfile
19
20from absl.testing import absltest
21from absl.testing import parameterized
22import chex
23import jax
24import jax.numpy as jnp
25import ml_collections
26import numpy as np
27
28from differentially_private_gnns import dataset_readers
29from differentially_private_gnns import input_pipeline
30from differentially_private_gnns import train
31from differentially_private_gnns.configs import dpgcn
32from differentially_private_gnns.configs import dpmlp
33from differentially_private_gnns.configs import gcn
34from differentially_private_gnns.configs import mlp
35
36
37_ALL_CONFIGS = {
38'dpgcn': dpgcn.get_config(),
39'dpmlp': dpmlp.get_config(),
40'gcn': gcn.get_config(),
41'mlp': mlp.get_config(),
42}
43
44
45def update_dummy_config(config):
46"""Updates the dummy config."""
47config.dataset = 'dummy'
48config.batch_size = dataset_readers.DummyDataset.NUM_DUMMY_TRAINING_SAMPLES // 2
49config.max_degree = 2
50config.num_training_steps = 10
51config.num_classes = dataset_readers.DummyDataset.NUM_DUMMY_CLASSES
52
53
54class TrainTest(parameterized.TestCase):
55
56@parameterized.product(
57config_name=['dpmlp', 'dpgcn'], rng_key=[0, 1], max_degree=[0, 1, 2])
58def test_per_example_gradients(self, config_name, rng_key,
59max_degree):
60# Load dummy config.
61config = _ALL_CONFIGS[config_name]
62update_dummy_config(config)
63config.max_degree = max_degree
64
65# Load dummy dataset.
66rng = jax.random.PRNGKey(rng_key)
67rng, dataset_rng = jax.random.split(rng)
68dataset = input_pipeline.get_dataset(config, dataset_rng)
69graph, labels, _ = jax.tree_map(jnp.asarray, dataset)
70labels = jax.nn.one_hot(labels, config.num_classes)
71num_nodes = labels.shape[0]
72
73# Create subgraphs.
74graph = jax.tree_map(np.asarray, graph)
75subgraphs = train.get_subgraphs(graph, config.pad_subgraphs_to)
76graph = jax.tree_map(jnp.asarray, graph)
77
78# Initialize model.
79rng, init_rng = jax.random.split(rng)
80estimation_indices = jnp.asarray([0])
81state = train.create_train_state(init_rng, config, graph, labels, subgraphs,
82estimation_indices)
83
84# Choose indices for batch.
85rng, train_rng = jax.random.split(rng)
86indices = jax.random.choice(train_rng, num_nodes, (config.batch_size,))
87
88# Compute per-example gradients.
89per_example_grads = train.compute_updates_for_dp(
90state, graph, labels, subgraphs, indices,
91config.adjacency_normalization)
92per_example_grads_summed = jax.tree_map(lambda grad: jnp.sum(grad, axis=0),
93per_example_grads)
94
95# Compute batched gradients.
96batched_grads = train.compute_updates(state, graph, labels, indices)
97
98# Check that these gradients match.
99chex.assert_trees_all_close(
100batched_grads, per_example_grads_summed, atol=1e-3, rtol=1e-3)
101
102@parameterized.parameters('gcn', 'mlp', 'dpgcn', 'dpmlp')
103def test_train_and_evaluate(self, config_name):
104
105# Load config for dummy dataset.
106config = _ALL_CONFIGS[config_name]
107update_dummy_config(config)
108
109# Create a temporary directory where metrics are written.
110workdir = tempfile.mkdtemp()
111
112# Training should proceed without any errors.
113train.train_and_evaluate(config, workdir)
114
115
116if __name__ == '__main__':
117absltest.main()
118