google-research

Форк
0
141 строка · 5.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implementation of a Rainbow agent with bootstrap."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
from bonus_based_exploration.noisy_networks import noisy_dqn_agent
23

24
from dopamine.agents.rainbow import rainbow_agent as base_rainbow_agent
25
import gin
26
import numpy as np
27
import tensorflow.compat.v1 as tf
28
from tensorflow.contrib import layers as contrib_layers
29
from tensorflow.contrib import slim as contrib_slim
30

31
slim = contrib_slim
32

33

34
@gin.configurable
35
class NoisyRainbowAgent(base_rainbow_agent.RainbowAgent):
36
  """A Rainbow agent with noisy networks."""
37

38
  def __init__(self,
39
               sess,
40
               num_actions,
41
               num_atoms=51,
42
               vmax=10.,
43
               gamma=0.99,
44
               update_horizon=1,
45
               min_replay_history=20000,
46
               update_period=4,
47
               target_update_period=8000,
48
               epsilon_fn=lambda w, x, y, z: 0,
49
               epsilon_decay_period=250000,
50
               replay_scheme='prioritized',
51
               tf_device='/cpu:*',
52
               use_staging=True,
53
               optimizer=tf.train.AdamOptimizer(
54
                   learning_rate=0.00025, epsilon=0.0003125),
55
               summary_writer=None,
56
               summary_writing_frequency=500,
57
               noise_distribution='independent'):
58
    """Initializes the agent and constructs the components of its graph.
59

60
    Args:
61
      sess: `tf.Session`, for executing ops.
62
      num_actions: int, number of actions the agent can take at any state.
63
      num_atoms: int, the number of buckets of the value function distribution.
64
      vmax: float, the value distribution support is [-vmax, vmax].
65
      gamma: float, discount factor with the usual RL meaning.
66
      update_horizon: int, horizon at which updates are performed, the 'n' in
67
        n-step update.
68
      min_replay_history: int, number of transitions that should be experienced
69
        before the agent begins training its value function.
70
      update_period: int, period between DQN updates.
71
      target_update_period: int, update period for the target network.
72
      epsilon_fn: function expecting 4 parameters:
73
        (decay_period, step, warmup_steps, epsilon). This function should return
74
        the epsilon value used for exploration during training.
75
      epsilon_decay_period: int, length of the epsilon decay schedule.
76
      replay_scheme: str, 'prioritized' or 'uniform', the sampling scheme of the
77
        replay memory.
78
      tf_device: str, Tensorflow device on which the agent's graph is executed.
79
      use_staging: bool, when True use a staging area to prefetch the next
80
        training batch, speeding training up by about 30%.
81
      optimizer: tf.train.Optimizer, for training the value function.
82
      summary_writer: SummaryWriter object for outputting training statistics.
83
        Summary writing disabled if set to None.
84
      summary_writing_frequency: int, frequency with which summaries will be
85
        written. Lower values will result in slower training.
86
      noise_distribution: string, distribution used to sample noise, must be
87
        `factorised` or `independent`.
88
    """
89
    self.noise_distribution = noise_distribution
90
    super(NoisyRainbowAgent, self).__init__(
91
        sess=sess,
92
        num_actions=num_actions,
93
        num_atoms=num_atoms,
94
        vmax=vmax,
95
        gamma=gamma,
96
        update_horizon=update_horizon,
97
        min_replay_history=min_replay_history,
98
        update_period=update_period,
99
        target_update_period=target_update_period,
100
        epsilon_fn=epsilon_fn,
101
        epsilon_decay_period=epsilon_decay_period,
102
        replay_scheme=replay_scheme,
103
        tf_device=tf_device,
104
        use_staging=use_staging,
105
        optimizer=optimizer,
106
        summary_writer=summary_writer,
107
        summary_writing_frequency=summary_writing_frequency)
108

109
  def _network_template(self, state):
110
    """Builds the convolutional network used to compute the agent's Q-values.
111

112
    Args:
113
      state: `tf.Tensor`, contains the agent's current state.
114

115
    Returns:
116
      net: _network_type object containing the tensors output by the network.
117
    """
118
    weights_initializer = slim.variance_scaling_initializer(
119
        factor=1.0 / np.sqrt(3.0), mode='FAN_IN', uniform=True)
120

121
    net = tf.cast(state, tf.float32)
122
    net = tf.div(net, 255.)
123
    net = slim.conv2d(
124
        net, 32, [8, 8], stride=4, weights_initializer=weights_initializer)
125
    net = slim.conv2d(
126
        net, 64, [4, 4], stride=2, weights_initializer=weights_initializer)
127
    net = slim.conv2d(
128
        net, 64, [3, 3], stride=1, weights_initializer=weights_initializer)
129
    net = slim.flatten(net)
130
    net = noisy_dqn_agent.fully_connected(
131
        net, 512, scope='fully_connected')
132
    net = noisy_dqn_agent.fully_connected(
133
        net,
134
        self.num_actions * self._num_atoms,
135
        activation_fn=None,
136
        scope='fully_connected_1')
137

138
    logits = tf.reshape(net, [-1, self.num_actions, self._num_atoms])
139
    probabilities = contrib_layers.softmax(logits)
140
    q_values = tf.reduce_sum(self._support * probabilities, axis=2)
141
    return self._get_network_type()(q_values, logits, probabilities)
142

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.