google-research
56 строк · 1.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for ES-ENAS objects."""
17
18from absl.testing import absltest19from absl.testing import parameterized20import pyglove as pg21from es_enas import config as config_util22from es_enas import objects23
24
25class ObjectsTest(parameterized.TestCase):26
27def setUp(self):28self.base_config = config_util.get_config()29super().setUp()30
31@parameterized.named_parameters(32('Random', 'random'), ('PolicyGradient', 'policy_gradient'),33('RegularizedEvolution', 'regularized_evolution'))34def test_GeneralTopologyBlackboxObject(self, controller_str):35self.base_config.controller_type_str = controller_str36self.base_config.horizon = 237self.base_config.environment_name = 'Pendulum'38
39self.config = config_util.generate_config(40self.base_config, current_time_string='TEST')41
42self.config.setup_controller_fn()43self.object = objects.GeneralTopologyBlackboxObject(self.config)44
45optimizer = self.config.es_blackbox_optimizer_fn(46self.object.get_metaparams())47
48params = self.object.get_initial()49topology_str = pg.to_json(self.config.controller.propose_dna())50core_hyperparams = optimizer.get_hyperparameters()51hyperparams = [0] + list(core_hyperparams)52self.object.execute_with_topology(params, topology_str, hyperparams)53
54
55if __name__ == '__main__':56absltest.main()57