google-research

Форк
0
117 строк · 3.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for train."""
17

18
import tempfile
19

20
from absl.testing import absltest
21
from absl.testing import parameterized
22
import chex
23
import jax
24
import jax.numpy as jnp
25
import ml_collections
26
import numpy as np
27

28
from differentially_private_gnns import dataset_readers
29
from differentially_private_gnns import input_pipeline
30
from differentially_private_gnns import train
31
from differentially_private_gnns.configs import dpgcn
32
from differentially_private_gnns.configs import dpmlp
33
from differentially_private_gnns.configs import gcn
34
from differentially_private_gnns.configs import mlp
35

36

37
_ALL_CONFIGS = {
38
    'dpgcn': dpgcn.get_config(),
39
    'dpmlp': dpmlp.get_config(),
40
    'gcn': gcn.get_config(),
41
    'mlp': mlp.get_config(),
42
}
43

44

45
def update_dummy_config(config):
46
  """Updates the dummy config."""
47
  config.dataset = 'dummy'
48
  config.batch_size = dataset_readers.DummyDataset.NUM_DUMMY_TRAINING_SAMPLES // 2
49
  config.max_degree = 2
50
  config.num_training_steps = 10
51
  config.num_classes = dataset_readers.DummyDataset.NUM_DUMMY_CLASSES
52

53

54
class TrainTest(parameterized.TestCase):
55

56
  @parameterized.product(
57
      config_name=['dpmlp', 'dpgcn'], rng_key=[0, 1], max_degree=[0, 1, 2])
58
  def test_per_example_gradients(self, config_name, rng_key,
59
                                 max_degree):
60
    # Load dummy config.
61
    config = _ALL_CONFIGS[config_name]
62
    update_dummy_config(config)
63
    config.max_degree = max_degree
64

65
    # Load dummy dataset.
66
    rng = jax.random.PRNGKey(rng_key)
67
    rng, dataset_rng = jax.random.split(rng)
68
    dataset = input_pipeline.get_dataset(config, dataset_rng)
69
    graph, labels, _ = jax.tree_map(jnp.asarray, dataset)
70
    labels = jax.nn.one_hot(labels, config.num_classes)
71
    num_nodes = labels.shape[0]
72

73
    # Create subgraphs.
74
    graph = jax.tree_map(np.asarray, graph)
75
    subgraphs = train.get_subgraphs(graph, config.pad_subgraphs_to)
76
    graph = jax.tree_map(jnp.asarray, graph)
77

78
    # Initialize model.
79
    rng, init_rng = jax.random.split(rng)
80
    estimation_indices = jnp.asarray([0])
81
    state = train.create_train_state(init_rng, config, graph, labels, subgraphs,
82
                                     estimation_indices)
83

84
    # Choose indices for batch.
85
    rng, train_rng = jax.random.split(rng)
86
    indices = jax.random.choice(train_rng, num_nodes, (config.batch_size,))
87

88
    # Compute per-example gradients.
89
    per_example_grads = train.compute_updates_for_dp(
90
        state, graph, labels, subgraphs, indices,
91
        config.adjacency_normalization)
92
    per_example_grads_summed = jax.tree_map(lambda grad: jnp.sum(grad, axis=0),
93
                                            per_example_grads)
94

95
    # Compute batched gradients.
96
    batched_grads = train.compute_updates(state, graph, labels, indices)
97

98
    # Check that these gradients match.
99
    chex.assert_trees_all_close(
100
        batched_grads, per_example_grads_summed, atol=1e-3, rtol=1e-3)
101

102
  @parameterized.parameters('gcn', 'mlp', 'dpgcn', 'dpmlp')
103
  def test_train_and_evaluate(self, config_name):
104

105
    # Load config for dummy dataset.
106
    config = _ALL_CONFIGS[config_name]
107
    update_dummy_config(config)
108

109
    # Create a temporary directory where metrics are written.
110
    workdir = tempfile.mkdtemp()
111

112
    # Training should proceed without any errors.
113
    train.train_and_evaluate(config, workdir)
114

115

116
if __name__ == '__main__':
117
  absltest.main()
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.