google-research
99 строк · 3.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for muzero.core."""
17
18import tensorflow as tf19from muzero import core20
21
22class CoreTest(tf.test.TestCase):23
24def test_make_target(self):25num_unroll_steps = 326td_steps = -127rewards = [1., 2., 3., 4.]28# Assume 4 different actions.29policy_distributions = [30[0.7, 0.1, 0.1, 0.1],31[0.1, 0.7, 0.1, 0.1],32[0.1, 0.1, 0.7, 0.1],33[0.1, 0.1, 0.1, 0.7],34]35discount = 0.936
37target = core.Episode.make_target(38state_index=0,39num_unroll_steps=num_unroll_steps,40td_steps=td_steps,41rewards=rewards,42policy_distributions=policy_distributions,43discount=discount)44self.assertEqual(core.Target(45value_mask=(1., 1., 1., 1.),46reward_mask=(0., 1., 1., 1.),47policy_mask=(1., 1., 1., 1.),48value=(rewards[0] + rewards[1] * discount \49+ rewards[2] * discount**2 + rewards[3] * discount**3,50rewards[1] + rewards[2] * discount + rewards[3] * discount**2,51rewards[2] + rewards[3] * discount,52rewards[3]),53reward=(rewards[3], rewards[0], rewards[1], rewards[2]),54visits=tuple(policy_distributions)), target)55
56target = core.Episode.make_target(57state_index=2,58num_unroll_steps=num_unroll_steps,59td_steps=td_steps,60rewards=rewards,61policy_distributions=policy_distributions,62discount=discount)63self.assertEqual(64core.Target(65value_mask=(1., 1., 1., 1.),66reward_mask=(0., 1., 1., 0.),67policy_mask=(1., 1., 0., 0.),68value=(rewards[2] + rewards[3] * discount, rewards[3], 0., 0.),69reward=(rewards[1], rewards[2], rewards[3], 0.),70visits=tuple(policy_distributions[2:] +71[policy_distributions[0]] * 2)), target)72
73def test_encode_decode(self):74encoder = core.ValueEncoder(75min_value=-2,76max_value=2,77num_steps=5,78use_contractive_mapping=False)79encoded = encoder.encode(tf.constant([-0.5, 0.9, 5.0]))80self.assertAllClose([[0, 0.5, 0.5, 0, 0],81[0, 0, 0.1, 0.9, 0],82[0, 0, 0, 0, 1]], encoded)83self.assertAllClose([-0.5, 0.9, 2.0], encoder.decode(encoded))84
85encoder = core.ValueEncoder(86min_value=-2,87max_value=2,88num_steps=5,89use_contractive_mapping=True)90encoded = encoder.encode(tf.constant([-0.5, 0.9, 5.0]))91# Scaling transformation with contractive mapping92self.assertAllClose([[0, 0.61, 0.39, 0, 0],93[0, 0, 0, 0.97, 0.03],94[0, 0, 0, 0, 1]], encoded, atol=0.01)95self.assertAllClose([-0.5, 0.9, 2.0], encoder.decode(encoded), atol=0.001)96
97
98if __name__ == '__main__':99tf.test.main()100