google-research
40 строк · 1.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for task_set.registry."""
17
18from task_set import registry19# pylint: disable=unused-import
20from task_set.optimizers import adam21from task_set.tasks import mlp22# pylint: enable=unused-import
23import tensorflow.compat.v1 as tf24
25
26class RegistryTest(tf.test.TestCase):27
28def test_optimizer_registry(self):29optimizer_instance = registry.optimizers_registry.get_instance(30"adam_lr_-5.00")31loss = tf.get_variable(shape=[], dtype=tf.float32, name="var")32_ = optimizer_instance.minimize(loss)33
34def test_task_registry(self):35task_instance = registry.task_registry.get_instance("mlp_family_seed10")36self.assertEqual(task_instance.name, "mlp_family_seed10")37
38
39if __name__ == "__main__":40tf.test.main()41