google-research

Форк
0
243 строки · 10.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Elephant DQN agent with adjustable replay ratios."""
17

18

19
from dopamine.agents.dqn import dqn_agent
20

21
import gin
22
import tensorflow.compat.v1 as tf
23
from experience_replay.replay_memory import prioritized_replay_buffer
24

25

26
def statistics_summaries(name, var):
27
  """Attach additional statistical summaries to the variable."""
28
  var = tf.to_float(var)
29
  with tf.variable_scope(name):
30
    tf.summary.scalar('mean', tf.reduce_mean(var))
31
    tf.summary.scalar('stddev', tf.math.reduce_std(var))
32
    tf.summary.scalar('max', tf.reduce_max(var))
33
    tf.summary.scalar('min', tf.reduce_min(var))
34
  tf.summary.histogram(name, var)
35

36

37
@gin.configurable
38
class ElephantDQNAgent(dqn_agent.DQNAgent):
39
  """A compact implementation of an Elephant DQN agent."""
40

41
  def __init__(self,
42
               replay_scheme='uniform',
43
               oldest_policy_in_buffer=250000,
44
               **kwargs):
45
    """Initializes the agent and constructs the components of its graph."""
46
    self._replay_scheme = replay_scheme
47
    self._oldest_policy_in_buffer = oldest_policy_in_buffer
48

49
    dqn_agent.DQNAgent.__init__(self, **kwargs)
50
    tf.logging.info('\t replay_scheme: %s', replay_scheme)
51
    tf.logging.info('\t oldest_policy_in_buffer: %s', oldest_policy_in_buffer)
52

53
    # We maintain attributes to record online and target network updates which
54
    # is later used for non-integer logic.
55
    self._online_network_updates = 0
56
    self._target_network_updates = 0
57

58
    # pylint: disable=protected-access
59
    buffer_to_oldest_policy_ratio = (
60
        float(self._replay.memory._replay_capacity) /
61
        float(self._oldest_policy_in_buffer))
62
    # pylint: enable=protected-access
63

64
    # This ratio is used to adjust other attributes that are explicitly tied to
65
    # agent steps.  When designed, the Dopamine agents assumed that the replay
66
    # ratio remain fixed and therefore elements such as epsilon_decay_period
67
    # will not be set appropriately without adjustment.
68
    self._gin_param_multiplier = (
69
        buffer_to_oldest_policy_ratio / self.update_period)
70
    tf.logging.info('\t self._gin_param_multiplier: %f',
71
                    self._gin_param_multiplier)
72

73
    # Adjust agent attributes that are tied to the agent steps.
74
    self.update_period *= self._gin_param_multiplier
75
    self.target_update_period *= self._gin_param_multiplier
76
    self.epsilon_decay_period *= self._gin_param_multiplier
77

78
  def _build_replay_buffer(self, use_staging):
79
    """Creates the replay buffer used by the agent.
80

81
    Args:
82
      use_staging: bool, if True, uses a staging area to prefetch data for
83
        faster training.
84

85
    Returns:
86
      A `WrappedPrioritizedReplayBuffer` object.
87

88
    Raises:
89
      ValueError: if given an invalid replay scheme.
90
    """
91
    if self._replay_scheme not in ['uniform', 'prioritized']:
92
      raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme))
93
    # Both replay schemes use the same data structure, but the 'uniform' scheme
94
    # sets all priorities to the same value (which yields uniform sampling).
95

96
    return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer(
97
        observation_shape=self.observation_shape,
98
        stack_size=self.stack_size,
99
        use_staging=use_staging,
100
        update_horizon=self.update_horizon,
101
        gamma=self.gamma,
102
        observation_dtype=self.observation_dtype.as_numpy_dtype,
103
        replay_forgetting='default',
104
        sample_newest_immediately=False)
105

106
  def _build_train_op(self):
107
    """Builds a training op.
108

109
    Returns:
110
      train_op: An op performing one step of training from replay data.
111
    """
112
    replay_action_one_hot = tf.one_hot(
113
        self._replay.actions, self.num_actions, 1., 0., name='action_one_hot')
114
    replay_chosen_q = tf.reduce_sum(
115
        self._replay_net_outputs.q_values * replay_action_one_hot,
116
        reduction_indices=1,
117
        name='replay_chosen_q')
118

119
    target = tf.stop_gradient(self._build_target_q_op())
120
    loss = tf.losses.huber_loss(
121
        target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
122

123
    if self._replay_scheme == 'prioritized':
124
      # The original prioritized experience replay uses a linear exponent
125
      # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 0.5
126
      # on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) suggested
127
      # a fixed exponent actually performs better, except on Pong.
128
      probs = self._replay.transition['sampling_probabilities']
129
      loss_weights = 1.0 / tf.math.pow(probs + 1e-10, 0.5)
130
      loss_weights /= tf.reduce_max(loss_weights)
131

132
      # Rainbow and prioritized replay are parametrized by an exponent alpha,
133
      # but in both cases it is set to 0.5 - for simplicity's sake we leave it
134
      # as is here, using the more direct tf.sqrt(). Taking the square root
135
      # "makes sense", as we are dealing with a squared loss.
136
      # Add a small nonzero value to the loss to avoid 0 priority items. While
137
      # technically this may be okay, setting all items to 0 priority will cause
138
      # troubles, and also result in 1.0 / 0.0 = NaN correction terms.
139
      update_priorities_op = self._replay.tf_set_priority(
140
          self._replay.indices, tf.math.pow(loss + 1e-10, 0.5))
141

142
      # Weight the loss by the inverse priorities.
143
      loss = loss_weights * loss
144
    else:
145
      update_priorities_op = tf.no_op()
146

147
    if self.summary_writer is not None:
148
      with tf.variable_scope('Losses'):
149
        tf.summary.scalar('HuberLoss', tf.reduce_mean(loss))
150
    with tf.control_dependencies([update_priorities_op]):
151
      # Schaul et al. reports a slightly different rule, where 1/N is also
152
      # exponentiated by beta. Not doing so seems more reasonable, and did not
153
      # impact performance in our experiments.
154
      return self.optimizer.minimize(tf.reduce_mean(loss))
155

156
  def _train_step(self):
157
    """Runs a single training step.
158

159
    Runs a training op if both:
160
      (1) A minimum number of frames have been added to the replay buffer.
161
      (2) `training_steps` is a multiple of `update_period`.
162

163
    Also, syncs weights from online to target network if training steps is a
164
    multiple of target update period.
165
    """
166
    # Run a train_op at the rate of self.update_period if enough training steps
167
    # have been run. This matches the Nature DQN behaviour.
168
    # We maintain training_steps as a measure of genuine training steps, not
169
    # tied to environment interactions. This is used to control the online and
170
    # target network updates.
171
    if self._replay.memory.add_count > self.min_replay_history:
172
      while self._online_network_updates * self.update_period < self.training_steps:
173
        self._sess.run(self._train_op)
174
        if (self.summary_writer is not None and
175
            self.training_steps > 0 and
176
            self.training_steps % self.summary_writing_frequency == 0):
177
          summary = self._sess.run(self._merged_summaries)
178
          self.summary_writer.add_summary(summary, self.training_steps)
179
        self._online_network_updates += 1
180

181
      while self._target_network_updates * self.target_update_period < self.training_steps:
182
        self._sess.run(self._sync_qt_ops)
183
        self._target_network_updates += 1
184

185
    self.training_steps += 1
186

187
  def _store_transition(self,
188
                        last_observation,
189
                        action,
190
                        reward,
191
                        is_terminal,
192
                        priority=None):
193
    """Stores a transition when in training mode.
194

195
    Executes a tf session and executes replay buffer ops in order to store the
196
    following tuple in the replay buffer (last_observation, action, reward,
197
    is_terminal, priority).
198

199
    Args:
200
      last_observation: Last observation, type determined via observation_type
201
        parameter in the replay_memory constructor.
202
      action: An integer, the action taken.
203
      reward: A float, the reward.
204
      is_terminal: Boolean indicating if the current state is a terminal state.
205
      priority: Float. Priority of sampling the transition. If None, the default
206
        priority will be used. If replay scheme is uniform, the default priority
207
        is 1. If the replay scheme is prioritized, the default priority is the
208
        maximum ever seen [Schaul et al., 2015].
209
    """
210
    if priority is None:
211
      if self._replay_scheme == 'uniform':
212
        priority = 1.0
213
      else:
214
        priority = self._replay.memory.sum_tree.max_recorded_priority
215

216
    if not self.eval_mode:
217
      self._replay.add(last_observation,
218
                       action,
219
                       reward,
220
                       is_terminal,
221
                       priority)
222

223
  def bundle_and_checkpoint(self, checkpoint_dir, iteration_number):
224
    """Returns a self-contained bundle of the agent's state.
225

226
    This is used for checkpointing. It will return a dictionary containing all
227
    non-TensorFlow objects (to be saved into a file by the caller), and it saves
228
    all TensorFlow objects into a checkpoint file.
229

230
    Args:
231
      checkpoint_dir: str, directory where TensorFlow objects will be saved.
232
      iteration_number: int, iteration number to use for naming the checkpoint
233
        file.
234

235
    Returns:
236
      A dict containing additional Python objects to be checkpointed by the
237
        experiment. If the checkpoint directory does not exist, returns None.
238
    """
239
    bundle_dictionary = super(ElephantDQNAgent, self).bundle_and_checkpoint(
240
        checkpoint_dir, iteration_number)
241
    bundle_dictionary['_online_network_updates'] = self._online_network_updates
242
    bundle_dictionary['_target_network_updates'] = self._target_network_updates
243
    return bundle_dictionary
244

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.