google-research
542 строки · 22.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Elephant Rainbow agent with adjustable replay ratios."""
17
18import collections
19
20
21from dopamine.agents.dqn import dqn_agent
22from dopamine.agents.rainbow import rainbow_agent
23from dopamine.discrete_domains import legacy_networks
24
25import gin
26import numpy as np
27import tensorflow.compat.v1 as tf
28
29from experience_replay.replay_memory import prioritized_replay_buffer
30from experience_replay.replay_memory.circular_replay_buffer import ReplayElement
31
32
33def statistics_summaries(name, var):
34"""Attach additional statistical summaries to the variable."""
35var = tf.to_float(var)
36with tf.variable_scope(name):
37tf.summary.scalar('mean', tf.reduce_mean(var))
38tf.summary.scalar('stddev', tf.math.reduce_std(var))
39tf.summary.scalar('max', tf.reduce_max(var))
40tf.summary.scalar('min', tf.reduce_min(var))
41tf.summary.histogram(name, var)
42
43
44@gin.configurable
45class ElephantRainbowAgent(dqn_agent.DQNAgent):
46"""A compact implementation of an Elephant Rainbow agent."""
47
48def __init__(self,
49sess,
50num_actions,
51observation_shape=dqn_agent.NATURE_DQN_OBSERVATION_SHAPE,
52observation_dtype=dqn_agent.NATURE_DQN_DTYPE,
53stack_size=dqn_agent.NATURE_DQN_STACK_SIZE,
54network=legacy_networks.rainbow_network,
55num_atoms=51,
56vmax=10.,
57gamma=0.99,
58update_horizon=1,
59min_replay_history=20000,
60update_period=4,
61target_update_period=8000,
62epsilon_fn=dqn_agent.linearly_decaying_epsilon,
63epsilon_train=0.01,
64epsilon_eval=0.001,
65epsilon_decay_period=250000,
66replay_scheme='prioritized',
67alpha_exponent=0.5,
68beta_exponent=0.5,
69tf_device='/cpu:*',
70use_staging=True,
71optimizer=tf.train.AdamOptimizer(
72learning_rate=0.00025, epsilon=0.0003125),
73summary_writer=None,
74summary_writing_frequency=2500,
75replay_forgetting='default',
76sample_newest_immediately=False,
77oldest_policy_in_buffer=250000):
78"""Initializes the agent and constructs the components of its graph.
79
80Args:
81sess: `tf.Session`, for executing ops.
82num_actions: int, number of actions the agent can take at any state.
83observation_shape: tuple of ints or an int. If single int, the observation
84is assumed to be a 2D square.
85observation_dtype: tf.DType, specifies the type of the observations. Note
86that if your inputs are continuous, you should set this to tf.float32.
87stack_size: int, number of frames to use in state stack.
88network: function expecting three parameters:
89(num_actions, network_type, state). This function will return the
90network_type object containing the tensors output by the network.
91See dopamine.discrete_domains.legacy_networks.rainbow_network as
92an example.
93num_atoms: int, the number of buckets of the value function distribution.
94vmax: float, the value distribution support is [-vmax, vmax].
95gamma: float, discount factor with the usual RL meaning.
96update_horizon: int, horizon at which updates are performed, the 'n' in
97n-step update.
98min_replay_history: int, number of transitions that should be experienced
99before the agent begins training its value function.
100update_period: int, period between DQN updates.
101target_update_period: int, update period for the target network.
102epsilon_fn: function expecting 4 parameters:
103(decay_period, step, warmup_steps, epsilon). This function should return
104the epsilon value used for exploration during training.
105epsilon_train: float, the value to which the agent's epsilon is eventually
106decayed during training.
107epsilon_eval: float, epsilon used when evaluating the agent.
108epsilon_decay_period: int, length of the epsilon decay schedule.
109replay_scheme: str, 'prioritized' or 'uniform', the sampling scheme of the
110replay memory.
111alpha_exponent: float, alpha hparam in prioritized experience replay.
112beta_exponent: float, beta hparam in prioritized experience replay.
113tf_device: str, Tensorflow device on which the agent's graph is executed.
114use_staging: bool, when True use a staging area to prefetch the next
115training batch, speeding training up by about 30%.
116optimizer: `tf.train.Optimizer`, for training the value function.
117summary_writer: SummaryWriter object for outputting training statistics.
118Summary writing disabled if set to None.
119summary_writing_frequency: int, frequency with which summaries will be
120written. Lower values will result in slower training.
121replay_forgetting: str, What strategy to employ for forgetting old
122trajectories. One of ['default', 'elephant'].
123sample_newest_immediately: bool, when True, immediately trains on the
124newest transition instead of using the max_priority hack.
125oldest_policy_in_buffer: int, the number of gradient updates of the oldest
126policy that has added data to the replay buffer.
127"""
128# We need this because some tools convert round floats into ints.
129vmax = float(vmax)
130self._num_atoms = num_atoms
131self._support = tf.linspace(-vmax, vmax, num_atoms)
132self._replay_scheme = replay_scheme
133self._alpha_exponent = alpha_exponent
134self._beta_exponent = beta_exponent
135self._replay_forgetting = replay_forgetting
136self._sample_newest_immediately = sample_newest_immediately
137self._oldest_policy_in_buffer = oldest_policy_in_buffer
138# TODO(b/110897128): Make agent optimizer attribute private.
139self.optimizer = optimizer
140
141dqn_agent.DQNAgent.__init__(
142self,
143sess=sess,
144num_actions=num_actions,
145observation_shape=observation_shape,
146observation_dtype=observation_dtype,
147stack_size=stack_size,
148network=network,
149gamma=gamma,
150update_horizon=update_horizon,
151min_replay_history=min_replay_history,
152update_period=update_period,
153target_update_period=target_update_period,
154epsilon_fn=epsilon_fn,
155epsilon_train=epsilon_train,
156epsilon_eval=epsilon_eval,
157epsilon_decay_period=epsilon_decay_period,
158tf_device=tf_device,
159use_staging=use_staging,
160optimizer=self.optimizer,
161summary_writer=summary_writer,
162summary_writing_frequency=summary_writing_frequency)
163tf.logging.info('\t replay_scheme: %s', replay_scheme)
164tf.logging.info('\t alpha_exponent: %f', alpha_exponent)
165tf.logging.info('\t beta_exponent: %f', beta_exponent)
166tf.logging.info('\t replay_forgetting: %s', replay_forgetting)
167tf.logging.info('\t oldest_policy_in_buffer: %s', oldest_policy_in_buffer)
168self.episode_return = 0.0
169
170# We maintain attributes to record online and target network updates which
171self._online_network_updates = 0
172self._target_network_updates = 0
173
174# pylint: disable=protected-access
175buffer_to_oldest_policy_ratio = (
176float(self._replay.memory._replay_capacity) /
177float(self._oldest_policy_in_buffer))
178# pylint: enable=protected-access
179
180# This ratio is used to adjust other attributes that are explicitly tied to
181# agent steps. When designed, the Dopamine agents assumed that the replay
182# ratio remain fixed and therefore elements such as epsilon_decay_period
183# will not be set appropriately without adjustment.
184self._gin_param_multiplier = (
185buffer_to_oldest_policy_ratio / self.update_period)
186tf.logging.info('\t self._gin_param_multiplier: %f',
187self._gin_param_multiplier)
188
189# Adjust agent attributes that are tied to the agent steps.
190self.update_period = self.update_period * self._gin_param_multiplier
191self.target_update_period = (
192self.target_update_period * self._gin_param_multiplier)
193self.epsilon_decay_period = int(self.epsilon_decay_period *
194self._gin_param_multiplier)
195
196if self._replay_scheme == 'prioritized':
197if self._replay_forgetting == 'elephant':
198raise NotImplementedError
199
200def _get_network_type(self):
201"""Returns the type of the outputs of a value distribution network.
202
203Returns:
204net_type: _network_type object defining the outputs of the network.
205"""
206return collections.namedtuple('c51_network',
207['q_values', 'logits', 'probabilities'])
208
209def _network_template(self, state):
210"""Builds a convolutional network that outputs Q-value distributions.
211
212Args:
213state: `tf.Tensor`, contains the agent's current state.
214
215Returns:
216net: _network_type object containing the tensors output by the network.
217"""
218return self.network(self.num_actions, self._num_atoms, self._support,
219self._get_network_type(), state)
220
221def _build_replay_buffer(self, use_staging):
222"""Creates the replay buffer used by the agent.
223
224Args:
225use_staging: bool, if True, uses a staging area to prefetch data for
226faster training.
227
228Returns:
229A `WrappedPrioritizedReplayBuffer` object.
230
231Raises:
232ValueError: if given an invalid replay scheme.
233"""
234if self._replay_scheme not in ['uniform', 'prioritized']:
235raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme))
236# Both replay schemes use the same data structure, but the 'uniform' scheme
237# sets all priorities to the same value (which yields uniform sampling).
238extra_elements = [ReplayElement('return', (), np.float32)]
239
240return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer(
241observation_shape=self.observation_shape,
242stack_size=self.stack_size,
243use_staging=use_staging,
244update_horizon=self.update_horizon,
245gamma=self.gamma,
246observation_dtype=self.observation_dtype.as_numpy_dtype,
247extra_storage_types=extra_elements,
248replay_forgetting=self._replay_forgetting,
249sample_newest_immediately=self._sample_newest_immediately)
250
251def _build_target_distribution(self):
252"""Builds the C51 target distribution as per Bellemare et al. (2017).
253
254First, we compute the support of the Bellman target, r + gamma Z'. Where Z'
255is the support of the next state distribution:
256
257* Evenly spaced in [-vmax, vmax] if the current state is nonterminal;
258* 0 otherwise (duplicated num_atoms times).
259
260Second, we compute the next-state probabilities, corresponding to the action
261with highest expected value.
262
263Finally we project the Bellman target (support + probabilities) onto the
264original support.
265
266Returns:
267target_distribution: tf.tensor, the target distribution from the replay.
268"""
269batch_size = self._replay.batch_size
270
271# size of rewards: batch_size x 1
272rewards = self._replay.rewards[:, None]
273
274# size of tiled_support: batch_size x num_atoms
275tiled_support = tf.tile(self._support, [batch_size])
276tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms])
277
278# size of target_support: batch_size x num_atoms
279
280is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32)
281# Incorporate terminal state to discount factor.
282# size of gamma_with_terminal: batch_size x 1
283gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
284gamma_with_terminal = gamma_with_terminal[:, None]
285
286target_support = rewards + gamma_with_terminal * tiled_support
287
288# size of next_qt_argmax: 1 x batch_size
289next_qt_argmax = tf.argmax(
290self._replay_next_target_net_outputs.q_values, axis=1)[:, None]
291batch_indices = tf.range(tf.to_int64(batch_size))[:, None]
292# size of next_qt_argmax: batch_size x 2
293batch_indexed_next_qt_argmax = tf.concat(
294[batch_indices, next_qt_argmax], axis=1)
295
296# size of next_probabilities: batch_size x num_atoms
297next_probabilities = tf.gather_nd(
298self._replay_next_target_net_outputs.probabilities,
299batch_indexed_next_qt_argmax)
300
301return rainbow_agent.project_distribution(target_support,
302next_probabilities, self._support)
303
304def _build_train_op(self):
305"""Builds a training op.
306
307Returns:
308train_op: An op performing one step of training from replay data.
309"""
310target_distribution = tf.stop_gradient(
311self._build_target_distribution())
312
313# size of indices: batch_size x 1.
314indices = tf.range(tf.shape(self._replay_net_outputs.logits)[0])[:, None]
315# size of reshaped_actions: batch_size x 2.
316reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1)
317# For each element of the batch, fetch the logits for its selected action.
318chosen_action_logits = tf.gather_nd(self._replay_net_outputs.logits,
319reshaped_actions)
320
321loss = tf.nn.softmax_cross_entropy_with_logits(
322labels=target_distribution,
323logits=chosen_action_logits)
324
325# Record returns encountered in the sampled training batches.
326returns = self._replay.transition['return']
327statistics_summaries('returns', returns)
328train_counts = self._replay.transition['train_counts']
329statistics_summaries('train_counts', train_counts)
330steps_until_first_train = self._replay.transition['steps_until_first_train']
331statistics_summaries('steps_until_first_train', steps_until_first_train)
332age = self._replay.transition['age']
333statistics_summaries('age', age)
334
335if self._replay_scheme == 'prioritized':
336# The original prioritized experience replay uses a linear exponent
337# schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 0.5
338# on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) suggested
339# a fixed exponent actually performs better, except on Pong.
340probs = self._replay.transition['sampling_probabilities']
341beta = self._beta_exponent
342tf.summary.histogram('probs', probs)
343loss_weights = 1.0 / tf.math.pow(probs + 1e-10, beta)
344tf.summary.histogram('loss_weights', loss_weights)
345loss_weights /= tf.reduce_max(loss_weights)
346
347# Rainbow and prioritized replay are parametrized by an exponent alpha,
348# but in both cases it is set to 0.5 - for simplicity's sake we leave it
349# as is here, using the more direct tf.sqrt(). Taking the square root
350# "makes sense", as we are dealing with a squared loss.
351# Add a small nonzero value to the loss to avoid 0 priority items. While
352# technically this may be okay, setting all items to 0 priority will cause
353# troubles, and also result in 1.0 / 0.0 = NaN correction terms.
354alpha = self._alpha_exponent
355update_priorities_op = self._replay.tf_set_priority(
356self._replay.indices, tf.math.pow(loss + 1e-10, alpha))
357
358# Weight the loss by the inverse priorities.
359loss = loss_weights * loss
360else:
361update_priorities_op = tf.no_op()
362
363update_train_counts_op = self._replay.tf_update_train_counts(
364self._replay.indices)
365
366with tf.control_dependencies([update_priorities_op,
367update_train_counts_op]):
368if self.summary_writer is not None:
369with tf.variable_scope('Losses'):
370tf.summary.scalar('CrossEntropyLoss', tf.reduce_mean(loss))
371# Schaul et al. reports a slightly different rule, where 1/N is also
372# exponentiated by beta. Not doing so seems more reasonable, and did not
373# impact performance in our experiments.
374return self.optimizer.minimize(tf.reduce_mean(loss)), loss
375
376def begin_episode(self, observation):
377"""Returns the agent's first action for this episode.
378
379Args:
380observation: numpy array, the environment's initial observation.
381
382Returns:
383int, the selected action.
384"""
385self._reset_state()
386self._reset_return()
387
388if self._replay_forgetting == 'elephant':
389self._replay.memory.sort_replay_buffer_trajectories()
390
391self._record_observation(observation)
392
393if not self.eval_mode:
394self._train_step()
395
396self.action = self._select_action()
397return self.action
398
399def step(self, reward, observation):
400"""Records the most recent transition and returns the agent's next action.
401
402We store the observation of the last time step since we want to store it
403with the reward.
404
405Args:
406reward: float, the reward received from the agent's most recent action.
407observation: numpy array, the most recent observation.
408
409Returns:
410int, the selected action.
411"""
412self._last_observation = self._observation
413self._record_observation(observation)
414
415if not self.eval_mode:
416self._update_return(reward)
417self._store_transition(self._last_observation,
418self.action,
419reward,
420False,
421self.episode_return)
422self._train_step()
423
424self.action = self._select_action()
425return self.action
426
427def end_episode(self, reward):
428"""Signals the end of the episode to the agent.
429
430We store the observation of the current time step, which is the last
431observation of the episode.
432
433Args:
434reward: float, the last reward from the environment.
435"""
436if not self.eval_mode:
437self._update_return(reward)
438self._store_transition(self._observation,
439self.action,
440reward,
441True,
442self.episode_return)
443
444def _train_step(self):
445"""Runs a single training step.
446
447Runs a training op if both:
448(1) A minimum number of frames have been added to the replay buffer.
449(2) `training_steps` is a multiple of `update_period`.
450
451Also, syncs weights from online to target network if training steps is a
452multiple of target update period.
453"""
454# Run a train op at the rate of self.update_period if enough training steps
455# have been run. This matches the Nature DQN behaviour.
456if self._replay.memory.add_count > self.min_replay_history:
457while self._online_network_updates * self.update_period < self.training_steps:
458self._sess.run(self._train_op)
459if (self.summary_writer is not None and
460self.training_steps > 0 and
461self.training_steps % self.summary_writing_frequency == 0):
462summary = self._sess.run(self._merged_summaries)
463self.summary_writer.add_summary(summary, self.training_steps)
464self._online_network_updates += 1
465
466while self._target_network_updates * self.target_update_period < self.training_steps:
467self._sess.run(self._sync_qt_ops)
468self._target_network_updates += 1
469
470self.training_steps += 1
471
472def _store_transition(self,
473last_observation,
474action,
475reward,
476is_terminal,
477episode_return,
478priority=None):
479"""Stores a transition when in training mode.
480
481Executes a tf session and executes replay buffer ops in order to store the
482following tuple in the replay buffer (last_observation, action, reward,
483is_terminal, priority).
484
485Args:
486last_observation: Last observation, type determined via observation_type
487parameter in the replay_memory constructor.
488action: An integer, the action taken.
489reward: A float, the reward.
490is_terminal: Boolean indicating if the current state is a terminal state.
491episode_return: A float, the episode undiscounted return so far.
492priority: Float. Priority of sampling the transition. If None, the default
493priority will be used. If replay scheme is uniform, the default priority
494is 1. If the replay scheme is prioritized, the default priority is the
495maximum ever seen [Schaul et al., 2015].
496"""
497if priority is None:
498if self._replay_scheme == 'uniform':
499priority = 1.0
500else:
501priority = self._replay.memory.sum_tree.max_recorded_priority
502
503# TODO(liamfedus): This storage mechanism is brittle depending on order.
504# The internal replay buffers should be added via **kwargs not *args.
505if not self.eval_mode:
506self._replay.add(last_observation,
507action,
508reward,
509is_terminal,
510episode_return,
511priority)
512
513def _update_return(self, reward):
514"""Updates the current context based on the reward."""
515if self.episode_return != self.episode_return + reward:
516self.episode_return += reward
517
518def _reset_return(self):
519"""Reset the episode return."""
520self.episode_return = 0.0
521
522def bundle_and_checkpoint(self, checkpoint_dir, iteration_number):
523"""Returns a self-contained bundle of the agent's state.
524
525This is used for checkpointing. It will return a dictionary containing all
526non-TensorFlow objects (to be saved into a file by the caller), and it saves
527all TensorFlow objects into a checkpoint file.
528
529Args:
530checkpoint_dir: str, directory where TensorFlow objects will be saved.
531iteration_number: int, iteration number to use for naming the checkpoint
532file.
533
534Returns:
535A dict containing additional Python objects to be checkpointed by the
536experiment. If the checkpoint directory does not exist, returns None.
537"""
538bundle_dictionary = super(ElephantRainbowAgent, self).bundle_and_checkpoint(
539checkpoint_dir, iteration_number)
540bundle_dictionary['_online_network_updates'] = self._online_network_updates
541bundle_dictionary['_target_network_updates'] = self._target_network_updates
542return bundle_dictionary
543