google-research

Форк
0
/
train_inner_test.py 
127 строк · 3.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for task_set.train_inner."""
17
import json
18
import os
19
import tempfile
20
import numpy as np
21

22
from task_set import datasets
23
from task_set import train_inner
24
from task_set.tasks import base
25
import tensorflow.compat.v1 as tf
26

27

28
class DummyTask(base.BaseTask):
29
  """Dummy task used for tests."""
30

31
  def call_split(self, params, split, with_metrics=False):
32
    r = tf.random_normal(shape=[], dtype=tf.float32)
33
    offset = {
34
        datasets.Split.TRAIN: 1.0,
35
        datasets.Split.VALID_INNER: 2.0,
36
        datasets.Split.VALID_OUTER: 3.0,
37
        datasets.Split.TEST: 4.0,
38
    }
39
    loss = offset[split] + r
40

41
    if with_metrics:
42
      return loss, {"metric": -1 * loss}
43
    else:
44
      return loss
45

46
  def get_batch(self, split):
47
    return None
48

49
  def current_params(self):
50
    return {}
51

52
  def gradients(self, loss):
53
    return {}
54

55
  def initial_params(self):
56
    return {}
57

58
  def get_variables(self):
59
    return []
60

61

62
class TrainInnerTest(tf.test.TestCase):
63

64
  def test_compute_averaged_loss(self):
65
    task = DummyTask()
66
    params = task.initial_params()
67
    losses, _ = train_inner.compute_averaged_loss(
68
        task, params, num_batches=100, with_metrics=False)
69

70
    with self.test_session() as sess:
71
      all_np_losses = []
72
      for _ in range(10):
73
        all_np_losses.append(sess.run(losses))
74

75
    tr, vai, vao, te = zip(*all_np_losses)
76
    # We are averaging over 100 with 10 replications evaluatons.
77
    # This means the std. error of the mean should be 1/sqrt(1000) or 0.03.
78
    # We use a threshold of 0.15, corresponding to a 5-sigma test.
79
    self.assertNear(np.mean(tr), 1.0, 0.15)
80
    self.assertNear(np.mean(vai), 2.0, 0.15)
81
    self.assertNear(np.mean(vao), 3.0, 0.15)
82
    self.assertNear(np.mean(te), 4.0, 0.15)
83

84
    # ensure that each sample is also different.
85
    self.assertLess(1e-5, np.var(tr), 0.5)
86
    self.assertLess(1e-5, np.var(vai), 0.5)
87
    self.assertLess(1e-5, np.var(vao), 0.5)
88
    self.assertLess(1e-5, np.var(te), 0.5)
89

90
    losses, metrics = train_inner.compute_averaged_loss(
91
        task, params, num_batches=100, with_metrics=True)
92
    tr_metrics, vai_metrics, vao_metrics, te_metrics = metrics
93
    with self.test_session() as sess:
94
      # this std. error is 1/sqrt(100), or 0.1. 5 std out is 0.5
95
      self.assertNear(sess.run(tr_metrics["metric"]), -1.0, 0.5)
96
      self.assertNear(sess.run(vai_metrics["metric"]), -2.0, 0.5)
97
      self.assertNear(sess.run(vao_metrics["metric"]), -3.0, 0.5)
98
      self.assertNear(sess.run(te_metrics["metric"]), -4.0, 0.5)
99

100
  def test_train(self):
101
    tmp_dir = tempfile.mkdtemp()
102

103
    # TODO(lmetz) when toy tasks are done, switch this away from an mlp.
104
    train_inner.train(
105
        tmp_dir,
106
        task_name="mlp_family_seed12",
107
        optimizer_name="adam8p_wide_grid_seed21",
108
        training_steps=10,
109
        eval_every_n=5)
110

111
    with tf.gfile.Open(os.path.join(tmp_dir, "result")) as f:
112
      result_data = json.loads(f.read())
113

114
    self.assertEqual(len(result_data), 3)
115
    # 4 losses logged out per timestep
116
    self.assertEqual(len(result_data["5"]), 4)
117

118
    with tf.gfile.Open(os.path.join(tmp_dir, "time_per_step")) as f:
119
      time_per_step_data = json.loads(f.read())
120

121
    self.assertIn("mean_last_half", time_per_step_data)
122
    self.assertIn("mean_time", time_per_step_data)
123
    self.assertIn("median_time", time_per_step_data)
124

125

126
if __name__ == "__main__":
127
  tf.test.main()
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.