google-research

Форк
0
/
core_test.py 
99 строк · 3.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for muzero.core."""
17

18
import tensorflow as tf
19
from muzero import core
20

21

22
class CoreTest(tf.test.TestCase):
23

24
  def test_make_target(self):
25
    num_unroll_steps = 3
26
    td_steps = -1
27
    rewards = [1., 2., 3., 4.]
28
    # Assume 4 different actions.
29
    policy_distributions = [
30
        [0.7, 0.1, 0.1, 0.1],
31
        [0.1, 0.7, 0.1, 0.1],
32
        [0.1, 0.1, 0.7, 0.1],
33
        [0.1, 0.1, 0.1, 0.7],
34
    ]
35
    discount = 0.9
36

37
    target = core.Episode.make_target(
38
        state_index=0,
39
        num_unroll_steps=num_unroll_steps,
40
        td_steps=td_steps,
41
        rewards=rewards,
42
        policy_distributions=policy_distributions,
43
        discount=discount)
44
    self.assertEqual(core.Target(
45
        value_mask=(1., 1., 1., 1.),
46
        reward_mask=(0., 1., 1., 1.),
47
        policy_mask=(1., 1., 1., 1.),
48
        value=(rewards[0] + rewards[1] * discount \
49
                + rewards[2] * discount**2 + rewards[3] * discount**3,
50
               rewards[1] + rewards[2] * discount + rewards[3] * discount**2,
51
               rewards[2] + rewards[3] * discount,
52
               rewards[3]),
53
        reward=(rewards[3], rewards[0], rewards[1], rewards[2]),
54
        visits=tuple(policy_distributions)), target)
55

56
    target = core.Episode.make_target(
57
        state_index=2,
58
        num_unroll_steps=num_unroll_steps,
59
        td_steps=td_steps,
60
        rewards=rewards,
61
        policy_distributions=policy_distributions,
62
        discount=discount)
63
    self.assertEqual(
64
        core.Target(
65
            value_mask=(1., 1., 1., 1.),
66
            reward_mask=(0., 1., 1., 0.),
67
            policy_mask=(1., 1., 0., 0.),
68
            value=(rewards[2] + rewards[3] * discount, rewards[3], 0., 0.),
69
            reward=(rewards[1], rewards[2], rewards[3], 0.),
70
            visits=tuple(policy_distributions[2:] +
71
                         [policy_distributions[0]] * 2)), target)
72

73
  def test_encode_decode(self):
74
    encoder = core.ValueEncoder(
75
        min_value=-2,
76
        max_value=2,
77
        num_steps=5,
78
        use_contractive_mapping=False)
79
    encoded = encoder.encode(tf.constant([-0.5, 0.9, 5.0]))
80
    self.assertAllClose([[0, 0.5, 0.5, 0, 0],
81
                         [0, 0, 0.1, 0.9, 0],
82
                         [0, 0, 0, 0, 1]], encoded)
83
    self.assertAllClose([-0.5, 0.9, 2.0], encoder.decode(encoded))
84

85
    encoder = core.ValueEncoder(
86
        min_value=-2,
87
        max_value=2,
88
        num_steps=5,
89
        use_contractive_mapping=True)
90
    encoded = encoder.encode(tf.constant([-0.5, 0.9, 5.0]))
91
    # Scaling transformation with contractive mapping
92
    self.assertAllClose([[0, 0.61, 0.39, 0, 0],
93
                         [0, 0, 0, 0.97, 0.03],
94
                         [0, 0, 0, 0, 1]], encoded, atol=0.01)
95
    self.assertAllClose([-0.5, 0.9, 2.0], encoder.decode(encoded), atol=0.001)
96

97

98
if __name__ == '__main__':
99
  tf.test.main()
100

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.