google-research
127 строк · 3.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for task_set.train_inner."""
17import json18import os19import tempfile20import numpy as np21
22from task_set import datasets23from task_set import train_inner24from task_set.tasks import base25import tensorflow.compat.v1 as tf26
27
28class DummyTask(base.BaseTask):29"""Dummy task used for tests."""30
31def call_split(self, params, split, with_metrics=False):32r = tf.random_normal(shape=[], dtype=tf.float32)33offset = {34datasets.Split.TRAIN: 1.0,35datasets.Split.VALID_INNER: 2.0,36datasets.Split.VALID_OUTER: 3.0,37datasets.Split.TEST: 4.0,38}39loss = offset[split] + r40
41if with_metrics:42return loss, {"metric": -1 * loss}43else:44return loss45
46def get_batch(self, split):47return None48
49def current_params(self):50return {}51
52def gradients(self, loss):53return {}54
55def initial_params(self):56return {}57
58def get_variables(self):59return []60
61
62class TrainInnerTest(tf.test.TestCase):63
64def test_compute_averaged_loss(self):65task = DummyTask()66params = task.initial_params()67losses, _ = train_inner.compute_averaged_loss(68task, params, num_batches=100, with_metrics=False)69
70with self.test_session() as sess:71all_np_losses = []72for _ in range(10):73all_np_losses.append(sess.run(losses))74
75tr, vai, vao, te = zip(*all_np_losses)76# We are averaging over 100 with 10 replications evaluatons.77# This means the std. error of the mean should be 1/sqrt(1000) or 0.03.78# We use a threshold of 0.15, corresponding to a 5-sigma test.79self.assertNear(np.mean(tr), 1.0, 0.15)80self.assertNear(np.mean(vai), 2.0, 0.15)81self.assertNear(np.mean(vao), 3.0, 0.15)82self.assertNear(np.mean(te), 4.0, 0.15)83
84# ensure that each sample is also different.85self.assertLess(1e-5, np.var(tr), 0.5)86self.assertLess(1e-5, np.var(vai), 0.5)87self.assertLess(1e-5, np.var(vao), 0.5)88self.assertLess(1e-5, np.var(te), 0.5)89
90losses, metrics = train_inner.compute_averaged_loss(91task, params, num_batches=100, with_metrics=True)92tr_metrics, vai_metrics, vao_metrics, te_metrics = metrics93with self.test_session() as sess:94# this std. error is 1/sqrt(100), or 0.1. 5 std out is 0.595self.assertNear(sess.run(tr_metrics["metric"]), -1.0, 0.5)96self.assertNear(sess.run(vai_metrics["metric"]), -2.0, 0.5)97self.assertNear(sess.run(vao_metrics["metric"]), -3.0, 0.5)98self.assertNear(sess.run(te_metrics["metric"]), -4.0, 0.5)99
100def test_train(self):101tmp_dir = tempfile.mkdtemp()102
103# TODO(lmetz) when toy tasks are done, switch this away from an mlp.104train_inner.train(105tmp_dir,106task_name="mlp_family_seed12",107optimizer_name="adam8p_wide_grid_seed21",108training_steps=10,109eval_every_n=5)110
111with tf.gfile.Open(os.path.join(tmp_dir, "result")) as f:112result_data = json.loads(f.read())113
114self.assertEqual(len(result_data), 3)115# 4 losses logged out per timestep116self.assertEqual(len(result_data["5"]), 4)117
118with tf.gfile.Open(os.path.join(tmp_dir, "time_per_step")) as f:119time_per_step_data = json.loads(f.read())120
121self.assertIn("mean_last_half", time_per_step_data)122self.assertIn("mean_time", time_per_step_data)123self.assertIn("median_time", time_per_step_data)124
125
126if __name__ == "__main__":127tf.test.main()128