google-research

Форк
0
542 строки · 22.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Elephant Rainbow agent with adjustable replay ratios."""
17

18
import collections
19

20

21
from dopamine.agents.dqn import dqn_agent
22
from dopamine.agents.rainbow import rainbow_agent
23
from dopamine.discrete_domains import legacy_networks
24

25
import gin
26
import numpy as np
27
import tensorflow.compat.v1 as tf
28

29
from experience_replay.replay_memory import prioritized_replay_buffer
30
from experience_replay.replay_memory.circular_replay_buffer import ReplayElement
31

32

33
def statistics_summaries(name, var):
34
  """Attach additional statistical summaries to the variable."""
35
  var = tf.to_float(var)
36
  with tf.variable_scope(name):
37
    tf.summary.scalar('mean', tf.reduce_mean(var))
38
    tf.summary.scalar('stddev', tf.math.reduce_std(var))
39
    tf.summary.scalar('max', tf.reduce_max(var))
40
    tf.summary.scalar('min', tf.reduce_min(var))
41
  tf.summary.histogram(name, var)
42

43

44
@gin.configurable
45
class ElephantRainbowAgent(dqn_agent.DQNAgent):
46
  """A compact implementation of an Elephant Rainbow agent."""
47

48
  def __init__(self,
49
               sess,
50
               num_actions,
51
               observation_shape=dqn_agent.NATURE_DQN_OBSERVATION_SHAPE,
52
               observation_dtype=dqn_agent.NATURE_DQN_DTYPE,
53
               stack_size=dqn_agent.NATURE_DQN_STACK_SIZE,
54
               network=legacy_networks.rainbow_network,
55
               num_atoms=51,
56
               vmax=10.,
57
               gamma=0.99,
58
               update_horizon=1,
59
               min_replay_history=20000,
60
               update_period=4,
61
               target_update_period=8000,
62
               epsilon_fn=dqn_agent.linearly_decaying_epsilon,
63
               epsilon_train=0.01,
64
               epsilon_eval=0.001,
65
               epsilon_decay_period=250000,
66
               replay_scheme='prioritized',
67
               alpha_exponent=0.5,
68
               beta_exponent=0.5,
69
               tf_device='/cpu:*',
70
               use_staging=True,
71
               optimizer=tf.train.AdamOptimizer(
72
                   learning_rate=0.00025, epsilon=0.0003125),
73
               summary_writer=None,
74
               summary_writing_frequency=2500,
75
               replay_forgetting='default',
76
               sample_newest_immediately=False,
77
               oldest_policy_in_buffer=250000):
78
    """Initializes the agent and constructs the components of its graph.
79

80
    Args:
81
      sess: `tf.Session`, for executing ops.
82
      num_actions: int, number of actions the agent can take at any state.
83
      observation_shape: tuple of ints or an int. If single int, the observation
84
        is assumed to be a 2D square.
85
      observation_dtype: tf.DType, specifies the type of the observations. Note
86
        that if your inputs are continuous, you should set this to tf.float32.
87
      stack_size: int, number of frames to use in state stack.
88
      network: function expecting three parameters:
89
        (num_actions, network_type, state). This function will return the
90
        network_type object containing the tensors output by the network.
91
        See dopamine.discrete_domains.legacy_networks.rainbow_network as
92
        an example.
93
      num_atoms: int, the number of buckets of the value function distribution.
94
      vmax: float, the value distribution support is [-vmax, vmax].
95
      gamma: float, discount factor with the usual RL meaning.
96
      update_horizon: int, horizon at which updates are performed, the 'n' in
97
        n-step update.
98
      min_replay_history: int, number of transitions that should be experienced
99
        before the agent begins training its value function.
100
      update_period: int, period between DQN updates.
101
      target_update_period: int, update period for the target network.
102
      epsilon_fn: function expecting 4 parameters:
103
        (decay_period, step, warmup_steps, epsilon). This function should return
104
        the epsilon value used for exploration during training.
105
      epsilon_train: float, the value to which the agent's epsilon is eventually
106
        decayed during training.
107
      epsilon_eval: float, epsilon used when evaluating the agent.
108
      epsilon_decay_period: int, length of the epsilon decay schedule.
109
      replay_scheme: str, 'prioritized' or 'uniform', the sampling scheme of the
110
        replay memory.
111
      alpha_exponent: float, alpha hparam in prioritized experience replay.
112
      beta_exponent: float, beta hparam in prioritized experience replay.
113
      tf_device: str, Tensorflow device on which the agent's graph is executed.
114
      use_staging: bool, when True use a staging area to prefetch the next
115
        training batch, speeding training up by about 30%.
116
      optimizer: `tf.train.Optimizer`, for training the value function.
117
      summary_writer: SummaryWriter object for outputting training statistics.
118
        Summary writing disabled if set to None.
119
      summary_writing_frequency: int, frequency with which summaries will be
120
        written. Lower values will result in slower training.
121
      replay_forgetting:  str, What strategy to employ for forgetting old
122
        trajectories.  One of ['default', 'elephant'].
123
      sample_newest_immediately: bool, when True, immediately trains on the
124
        newest transition instead of using the max_priority hack.
125
      oldest_policy_in_buffer: int, the number of gradient updates of the oldest
126
        policy that has added data to the replay buffer.
127
    """
128
    # We need this because some tools convert round floats into ints.
129
    vmax = float(vmax)
130
    self._num_atoms = num_atoms
131
    self._support = tf.linspace(-vmax, vmax, num_atoms)
132
    self._replay_scheme = replay_scheme
133
    self._alpha_exponent = alpha_exponent
134
    self._beta_exponent = beta_exponent
135
    self._replay_forgetting = replay_forgetting
136
    self._sample_newest_immediately = sample_newest_immediately
137
    self._oldest_policy_in_buffer = oldest_policy_in_buffer
138
    # TODO(b/110897128): Make agent optimizer attribute private.
139
    self.optimizer = optimizer
140

141
    dqn_agent.DQNAgent.__init__(
142
        self,
143
        sess=sess,
144
        num_actions=num_actions,
145
        observation_shape=observation_shape,
146
        observation_dtype=observation_dtype,
147
        stack_size=stack_size,
148
        network=network,
149
        gamma=gamma,
150
        update_horizon=update_horizon,
151
        min_replay_history=min_replay_history,
152
        update_period=update_period,
153
        target_update_period=target_update_period,
154
        epsilon_fn=epsilon_fn,
155
        epsilon_train=epsilon_train,
156
        epsilon_eval=epsilon_eval,
157
        epsilon_decay_period=epsilon_decay_period,
158
        tf_device=tf_device,
159
        use_staging=use_staging,
160
        optimizer=self.optimizer,
161
        summary_writer=summary_writer,
162
        summary_writing_frequency=summary_writing_frequency)
163
    tf.logging.info('\t replay_scheme: %s', replay_scheme)
164
    tf.logging.info('\t alpha_exponent: %f', alpha_exponent)
165
    tf.logging.info('\t beta_exponent: %f', beta_exponent)
166
    tf.logging.info('\t replay_forgetting: %s', replay_forgetting)
167
    tf.logging.info('\t oldest_policy_in_buffer: %s', oldest_policy_in_buffer)
168
    self.episode_return = 0.0
169

170
    # We maintain attributes to record online and target network updates which
171
    self._online_network_updates = 0
172
    self._target_network_updates = 0
173

174
    # pylint: disable=protected-access
175
    buffer_to_oldest_policy_ratio = (
176
        float(self._replay.memory._replay_capacity) /
177
        float(self._oldest_policy_in_buffer))
178
    # pylint: enable=protected-access
179

180
    # This ratio is used to adjust other attributes that are explicitly tied to
181
    # agent steps.  When designed, the Dopamine agents assumed that the replay
182
    # ratio remain fixed and therefore elements such as epsilon_decay_period
183
    # will not be set appropriately without adjustment.
184
    self._gin_param_multiplier = (
185
        buffer_to_oldest_policy_ratio / self.update_period)
186
    tf.logging.info('\t self._gin_param_multiplier: %f',
187
                    self._gin_param_multiplier)
188

189
    # Adjust agent attributes that are tied to the agent steps.
190
    self.update_period = self.update_period * self._gin_param_multiplier
191
    self.target_update_period = (
192
        self.target_update_period * self._gin_param_multiplier)
193
    self.epsilon_decay_period = int(self.epsilon_decay_period *
194
                                    self._gin_param_multiplier)
195

196
    if self._replay_scheme == 'prioritized':
197
      if self._replay_forgetting == 'elephant':
198
        raise NotImplementedError
199

200
  def _get_network_type(self):
201
    """Returns the type of the outputs of a value distribution network.
202

203
    Returns:
204
      net_type: _network_type object defining the outputs of the network.
205
    """
206
    return collections.namedtuple('c51_network',
207
                                  ['q_values', 'logits', 'probabilities'])
208

209
  def _network_template(self, state):
210
    """Builds a convolutional network that outputs Q-value distributions.
211

212
    Args:
213
      state: `tf.Tensor`, contains the agent's current state.
214

215
    Returns:
216
      net: _network_type object containing the tensors output by the network.
217
    """
218
    return self.network(self.num_actions, self._num_atoms, self._support,
219
                        self._get_network_type(), state)
220

221
  def _build_replay_buffer(self, use_staging):
222
    """Creates the replay buffer used by the agent.
223

224
    Args:
225
      use_staging: bool, if True, uses a staging area to prefetch data for
226
        faster training.
227

228
    Returns:
229
      A `WrappedPrioritizedReplayBuffer` object.
230

231
    Raises:
232
      ValueError: if given an invalid replay scheme.
233
    """
234
    if self._replay_scheme not in ['uniform', 'prioritized']:
235
      raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme))
236
    # Both replay schemes use the same data structure, but the 'uniform' scheme
237
    # sets all priorities to the same value (which yields uniform sampling).
238
    extra_elements = [ReplayElement('return', (), np.float32)]
239

240
    return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer(
241
        observation_shape=self.observation_shape,
242
        stack_size=self.stack_size,
243
        use_staging=use_staging,
244
        update_horizon=self.update_horizon,
245
        gamma=self.gamma,
246
        observation_dtype=self.observation_dtype.as_numpy_dtype,
247
        extra_storage_types=extra_elements,
248
        replay_forgetting=self._replay_forgetting,
249
        sample_newest_immediately=self._sample_newest_immediately)
250

251
  def _build_target_distribution(self):
252
    """Builds the C51 target distribution as per Bellemare et al. (2017).
253

254
    First, we compute the support of the Bellman target, r + gamma Z'. Where Z'
255
    is the support of the next state distribution:
256

257
      * Evenly spaced in [-vmax, vmax] if the current state is nonterminal;
258
      * 0 otherwise (duplicated num_atoms times).
259

260
    Second, we compute the next-state probabilities, corresponding to the action
261
    with highest expected value.
262

263
    Finally we project the Bellman target (support + probabilities) onto the
264
    original support.
265

266
    Returns:
267
      target_distribution: tf.tensor, the target distribution from the replay.
268
    """
269
    batch_size = self._replay.batch_size
270

271
    # size of rewards: batch_size x 1
272
    rewards = self._replay.rewards[:, None]
273

274
    # size of tiled_support: batch_size x num_atoms
275
    tiled_support = tf.tile(self._support, [batch_size])
276
    tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms])
277

278
    # size of target_support: batch_size x num_atoms
279

280
    is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32)
281
    # Incorporate terminal state to discount factor.
282
    # size of gamma_with_terminal: batch_size x 1
283
    gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
284
    gamma_with_terminal = gamma_with_terminal[:, None]
285

286
    target_support = rewards + gamma_with_terminal * tiled_support
287

288
    # size of next_qt_argmax: 1 x batch_size
289
    next_qt_argmax = tf.argmax(
290
        self._replay_next_target_net_outputs.q_values, axis=1)[:, None]
291
    batch_indices = tf.range(tf.to_int64(batch_size))[:, None]
292
    # size of next_qt_argmax: batch_size x 2
293
    batch_indexed_next_qt_argmax = tf.concat(
294
        [batch_indices, next_qt_argmax], axis=1)
295

296
    # size of next_probabilities: batch_size x num_atoms
297
    next_probabilities = tf.gather_nd(
298
        self._replay_next_target_net_outputs.probabilities,
299
        batch_indexed_next_qt_argmax)
300

301
    return rainbow_agent.project_distribution(target_support,
302
                                              next_probabilities, self._support)
303

304
  def _build_train_op(self):
305
    """Builds a training op.
306

307
    Returns:
308
      train_op: An op performing one step of training from replay data.
309
    """
310
    target_distribution = tf.stop_gradient(
311
        self._build_target_distribution())
312

313
    # size of indices: batch_size x 1.
314
    indices = tf.range(tf.shape(self._replay_net_outputs.logits)[0])[:, None]
315
    # size of reshaped_actions: batch_size x 2.
316
    reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1)
317
    # For each element of the batch, fetch the logits for its selected action.
318
    chosen_action_logits = tf.gather_nd(self._replay_net_outputs.logits,
319
                                        reshaped_actions)
320

321
    loss = tf.nn.softmax_cross_entropy_with_logits(
322
        labels=target_distribution,
323
        logits=chosen_action_logits)
324

325
    # Record returns encountered in the sampled training batches.
326
    returns = self._replay.transition['return']
327
    statistics_summaries('returns', returns)
328
    train_counts = self._replay.transition['train_counts']
329
    statistics_summaries('train_counts', train_counts)
330
    steps_until_first_train = self._replay.transition['steps_until_first_train']
331
    statistics_summaries('steps_until_first_train', steps_until_first_train)
332
    age = self._replay.transition['age']
333
    statistics_summaries('age', age)
334

335
    if self._replay_scheme == 'prioritized':
336
      # The original prioritized experience replay uses a linear exponent
337
      # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 0.5
338
      # on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) suggested
339
      # a fixed exponent actually performs better, except on Pong.
340
      probs = self._replay.transition['sampling_probabilities']
341
      beta = self._beta_exponent
342
      tf.summary.histogram('probs', probs)
343
      loss_weights = 1.0 / tf.math.pow(probs + 1e-10, beta)
344
      tf.summary.histogram('loss_weights', loss_weights)
345
      loss_weights /= tf.reduce_max(loss_weights)
346

347
      # Rainbow and prioritized replay are parametrized by an exponent alpha,
348
      # but in both cases it is set to 0.5 - for simplicity's sake we leave it
349
      # as is here, using the more direct tf.sqrt(). Taking the square root
350
      # "makes sense", as we are dealing with a squared loss.
351
      # Add a small nonzero value to the loss to avoid 0 priority items. While
352
      # technically this may be okay, setting all items to 0 priority will cause
353
      # troubles, and also result in 1.0 / 0.0 = NaN correction terms.
354
      alpha = self._alpha_exponent
355
      update_priorities_op = self._replay.tf_set_priority(
356
          self._replay.indices, tf.math.pow(loss + 1e-10, alpha))
357

358
      # Weight the loss by the inverse priorities.
359
      loss = loss_weights * loss
360
    else:
361
      update_priorities_op = tf.no_op()
362

363
    update_train_counts_op = self._replay.tf_update_train_counts(
364
        self._replay.indices)
365

366
    with tf.control_dependencies([update_priorities_op,
367
                                  update_train_counts_op]):
368
      if self.summary_writer is not None:
369
        with tf.variable_scope('Losses'):
370
          tf.summary.scalar('CrossEntropyLoss', tf.reduce_mean(loss))
371
      # Schaul et al. reports a slightly different rule, where 1/N is also
372
      # exponentiated by beta. Not doing so seems more reasonable, and did not
373
      # impact performance in our experiments.
374
      return self.optimizer.minimize(tf.reduce_mean(loss)), loss
375

376
  def begin_episode(self, observation):
377
    """Returns the agent's first action for this episode.
378

379
    Args:
380
      observation: numpy array, the environment's initial observation.
381

382
    Returns:
383
      int, the selected action.
384
    """
385
    self._reset_state()
386
    self._reset_return()
387

388
    if self._replay_forgetting == 'elephant':
389
      self._replay.memory.sort_replay_buffer_trajectories()
390

391
    self._record_observation(observation)
392

393
    if not self.eval_mode:
394
      self._train_step()
395

396
    self.action = self._select_action()
397
    return self.action
398

399
  def step(self, reward, observation):
400
    """Records the most recent transition and returns the agent's next action.
401

402
    We store the observation of the last time step since we want to store it
403
    with the reward.
404

405
    Args:
406
      reward: float, the reward received from the agent's most recent action.
407
      observation: numpy array, the most recent observation.
408

409
    Returns:
410
      int, the selected action.
411
    """
412
    self._last_observation = self._observation
413
    self._record_observation(observation)
414

415
    if not self.eval_mode:
416
      self._update_return(reward)
417
      self._store_transition(self._last_observation,
418
                             self.action,
419
                             reward,
420
                             False,
421
                             self.episode_return)
422
      self._train_step()
423

424
    self.action = self._select_action()
425
    return self.action
426

427
  def end_episode(self, reward):
428
    """Signals the end of the episode to the agent.
429

430
    We store the observation of the current time step, which is the last
431
    observation of the episode.
432

433
    Args:
434
      reward: float, the last reward from the environment.
435
    """
436
    if not self.eval_mode:
437
      self._update_return(reward)
438
      self._store_transition(self._observation,
439
                             self.action,
440
                             reward,
441
                             True,
442
                             self.episode_return)
443

444
  def _train_step(self):
445
    """Runs a single training step.
446

447
    Runs a training op if both:
448
      (1) A minimum number of frames have been added to the replay buffer.
449
      (2) `training_steps` is a multiple of `update_period`.
450

451
    Also, syncs weights from online to target network if training steps is a
452
    multiple of target update period.
453
    """
454
    # Run a train op at the rate of self.update_period if enough training steps
455
    # have been run. This matches the Nature DQN behaviour.
456
    if self._replay.memory.add_count > self.min_replay_history:
457
      while self._online_network_updates * self.update_period < self.training_steps:
458
        self._sess.run(self._train_op)
459
        if (self.summary_writer is not None and
460
            self.training_steps > 0 and
461
            self.training_steps % self.summary_writing_frequency == 0):
462
          summary = self._sess.run(self._merged_summaries)
463
          self.summary_writer.add_summary(summary, self.training_steps)
464
        self._online_network_updates += 1
465

466
      while self._target_network_updates * self.target_update_period < self.training_steps:
467
        self._sess.run(self._sync_qt_ops)
468
        self._target_network_updates += 1
469

470
    self.training_steps += 1
471

472
  def _store_transition(self,
473
                        last_observation,
474
                        action,
475
                        reward,
476
                        is_terminal,
477
                        episode_return,
478
                        priority=None):
479
    """Stores a transition when in training mode.
480

481
    Executes a tf session and executes replay buffer ops in order to store the
482
    following tuple in the replay buffer (last_observation, action, reward,
483
    is_terminal, priority).
484

485
    Args:
486
      last_observation: Last observation, type determined via observation_type
487
        parameter in the replay_memory constructor.
488
      action: An integer, the action taken.
489
      reward: A float, the reward.
490
      is_terminal: Boolean indicating if the current state is a terminal state.
491
      episode_return: A float, the episode undiscounted return so far.
492
      priority: Float. Priority of sampling the transition. If None, the default
493
        priority will be used. If replay scheme is uniform, the default priority
494
        is 1. If the replay scheme is prioritized, the default priority is the
495
        maximum ever seen [Schaul et al., 2015].
496
    """
497
    if priority is None:
498
      if self._replay_scheme == 'uniform':
499
        priority = 1.0
500
      else:
501
        priority = self._replay.memory.sum_tree.max_recorded_priority
502

503
    # TODO(liamfedus): This storage mechanism is brittle depending on order.
504
    # The internal replay buffers should be added via **kwargs not *args.
505
    if not self.eval_mode:
506
      self._replay.add(last_observation,
507
                       action,
508
                       reward,
509
                       is_terminal,
510
                       episode_return,
511
                       priority)
512

513
  def _update_return(self, reward):
514
    """Updates the current context based on the reward."""
515
    if self.episode_return != self.episode_return + reward:
516
      self.episode_return += reward
517

518
  def _reset_return(self):
519
    """Reset the episode return."""
520
    self.episode_return = 0.0
521

522
  def bundle_and_checkpoint(self, checkpoint_dir, iteration_number):
523
    """Returns a self-contained bundle of the agent's state.
524

525
    This is used for checkpointing. It will return a dictionary containing all
526
    non-TensorFlow objects (to be saved into a file by the caller), and it saves
527
    all TensorFlow objects into a checkpoint file.
528

529
    Args:
530
      checkpoint_dir: str, directory where TensorFlow objects will be saved.
531
      iteration_number: int, iteration number to use for naming the checkpoint
532
        file.
533

534
    Returns:
535
      A dict containing additional Python objects to be checkpointed by the
536
        experiment. If the checkpoint directory does not exist, returns None.
537
    """
538
    bundle_dictionary = super(ElephantRainbowAgent, self).bundle_and_checkpoint(
539
        checkpoint_dir, iteration_number)
540
    bundle_dictionary['_online_network_updates'] = self._online_network_updates
541
    bundle_dictionary['_target_network_updates'] = self._target_network_updates
542
    return bundle_dictionary
543

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.