google-research
243 строки · 10.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Elephant DQN agent with adjustable replay ratios."""
17
18
19from dopamine.agents.dqn import dqn_agent
20
21import gin
22import tensorflow.compat.v1 as tf
23from experience_replay.replay_memory import prioritized_replay_buffer
24
25
26def statistics_summaries(name, var):
27"""Attach additional statistical summaries to the variable."""
28var = tf.to_float(var)
29with tf.variable_scope(name):
30tf.summary.scalar('mean', tf.reduce_mean(var))
31tf.summary.scalar('stddev', tf.math.reduce_std(var))
32tf.summary.scalar('max', tf.reduce_max(var))
33tf.summary.scalar('min', tf.reduce_min(var))
34tf.summary.histogram(name, var)
35
36
37@gin.configurable
38class ElephantDQNAgent(dqn_agent.DQNAgent):
39"""A compact implementation of an Elephant DQN agent."""
40
41def __init__(self,
42replay_scheme='uniform',
43oldest_policy_in_buffer=250000,
44**kwargs):
45"""Initializes the agent and constructs the components of its graph."""
46self._replay_scheme = replay_scheme
47self._oldest_policy_in_buffer = oldest_policy_in_buffer
48
49dqn_agent.DQNAgent.__init__(self, **kwargs)
50tf.logging.info('\t replay_scheme: %s', replay_scheme)
51tf.logging.info('\t oldest_policy_in_buffer: %s', oldest_policy_in_buffer)
52
53# We maintain attributes to record online and target network updates which
54# is later used for non-integer logic.
55self._online_network_updates = 0
56self._target_network_updates = 0
57
58# pylint: disable=protected-access
59buffer_to_oldest_policy_ratio = (
60float(self._replay.memory._replay_capacity) /
61float(self._oldest_policy_in_buffer))
62# pylint: enable=protected-access
63
64# This ratio is used to adjust other attributes that are explicitly tied to
65# agent steps. When designed, the Dopamine agents assumed that the replay
66# ratio remain fixed and therefore elements such as epsilon_decay_period
67# will not be set appropriately without adjustment.
68self._gin_param_multiplier = (
69buffer_to_oldest_policy_ratio / self.update_period)
70tf.logging.info('\t self._gin_param_multiplier: %f',
71self._gin_param_multiplier)
72
73# Adjust agent attributes that are tied to the agent steps.
74self.update_period *= self._gin_param_multiplier
75self.target_update_period *= self._gin_param_multiplier
76self.epsilon_decay_period *= self._gin_param_multiplier
77
78def _build_replay_buffer(self, use_staging):
79"""Creates the replay buffer used by the agent.
80
81Args:
82use_staging: bool, if True, uses a staging area to prefetch data for
83faster training.
84
85Returns:
86A `WrappedPrioritizedReplayBuffer` object.
87
88Raises:
89ValueError: if given an invalid replay scheme.
90"""
91if self._replay_scheme not in ['uniform', 'prioritized']:
92raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme))
93# Both replay schemes use the same data structure, but the 'uniform' scheme
94# sets all priorities to the same value (which yields uniform sampling).
95
96return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer(
97observation_shape=self.observation_shape,
98stack_size=self.stack_size,
99use_staging=use_staging,
100update_horizon=self.update_horizon,
101gamma=self.gamma,
102observation_dtype=self.observation_dtype.as_numpy_dtype,
103replay_forgetting='default',
104sample_newest_immediately=False)
105
106def _build_train_op(self):
107"""Builds a training op.
108
109Returns:
110train_op: An op performing one step of training from replay data.
111"""
112replay_action_one_hot = tf.one_hot(
113self._replay.actions, self.num_actions, 1., 0., name='action_one_hot')
114replay_chosen_q = tf.reduce_sum(
115self._replay_net_outputs.q_values * replay_action_one_hot,
116reduction_indices=1,
117name='replay_chosen_q')
118
119target = tf.stop_gradient(self._build_target_q_op())
120loss = tf.losses.huber_loss(
121target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
122
123if self._replay_scheme == 'prioritized':
124# The original prioritized experience replay uses a linear exponent
125# schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 0.5
126# on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) suggested
127# a fixed exponent actually performs better, except on Pong.
128probs = self._replay.transition['sampling_probabilities']
129loss_weights = 1.0 / tf.math.pow(probs + 1e-10, 0.5)
130loss_weights /= tf.reduce_max(loss_weights)
131
132# Rainbow and prioritized replay are parametrized by an exponent alpha,
133# but in both cases it is set to 0.5 - for simplicity's sake we leave it
134# as is here, using the more direct tf.sqrt(). Taking the square root
135# "makes sense", as we are dealing with a squared loss.
136# Add a small nonzero value to the loss to avoid 0 priority items. While
137# technically this may be okay, setting all items to 0 priority will cause
138# troubles, and also result in 1.0 / 0.0 = NaN correction terms.
139update_priorities_op = self._replay.tf_set_priority(
140self._replay.indices, tf.math.pow(loss + 1e-10, 0.5))
141
142# Weight the loss by the inverse priorities.
143loss = loss_weights * loss
144else:
145update_priorities_op = tf.no_op()
146
147if self.summary_writer is not None:
148with tf.variable_scope('Losses'):
149tf.summary.scalar('HuberLoss', tf.reduce_mean(loss))
150with tf.control_dependencies([update_priorities_op]):
151# Schaul et al. reports a slightly different rule, where 1/N is also
152# exponentiated by beta. Not doing so seems more reasonable, and did not
153# impact performance in our experiments.
154return self.optimizer.minimize(tf.reduce_mean(loss))
155
156def _train_step(self):
157"""Runs a single training step.
158
159Runs a training op if both:
160(1) A minimum number of frames have been added to the replay buffer.
161(2) `training_steps` is a multiple of `update_period`.
162
163Also, syncs weights from online to target network if training steps is a
164multiple of target update period.
165"""
166# Run a train_op at the rate of self.update_period if enough training steps
167# have been run. This matches the Nature DQN behaviour.
168# We maintain training_steps as a measure of genuine training steps, not
169# tied to environment interactions. This is used to control the online and
170# target network updates.
171if self._replay.memory.add_count > self.min_replay_history:
172while self._online_network_updates * self.update_period < self.training_steps:
173self._sess.run(self._train_op)
174if (self.summary_writer is not None and
175self.training_steps > 0 and
176self.training_steps % self.summary_writing_frequency == 0):
177summary = self._sess.run(self._merged_summaries)
178self.summary_writer.add_summary(summary, self.training_steps)
179self._online_network_updates += 1
180
181while self._target_network_updates * self.target_update_period < self.training_steps:
182self._sess.run(self._sync_qt_ops)
183self._target_network_updates += 1
184
185self.training_steps += 1
186
187def _store_transition(self,
188last_observation,
189action,
190reward,
191is_terminal,
192priority=None):
193"""Stores a transition when in training mode.
194
195Executes a tf session and executes replay buffer ops in order to store the
196following tuple in the replay buffer (last_observation, action, reward,
197is_terminal, priority).
198
199Args:
200last_observation: Last observation, type determined via observation_type
201parameter in the replay_memory constructor.
202action: An integer, the action taken.
203reward: A float, the reward.
204is_terminal: Boolean indicating if the current state is a terminal state.
205priority: Float. Priority of sampling the transition. If None, the default
206priority will be used. If replay scheme is uniform, the default priority
207is 1. If the replay scheme is prioritized, the default priority is the
208maximum ever seen [Schaul et al., 2015].
209"""
210if priority is None:
211if self._replay_scheme == 'uniform':
212priority = 1.0
213else:
214priority = self._replay.memory.sum_tree.max_recorded_priority
215
216if not self.eval_mode:
217self._replay.add(last_observation,
218action,
219reward,
220is_terminal,
221priority)
222
223def bundle_and_checkpoint(self, checkpoint_dir, iteration_number):
224"""Returns a self-contained bundle of the agent's state.
225
226This is used for checkpointing. It will return a dictionary containing all
227non-TensorFlow objects (to be saved into a file by the caller), and it saves
228all TensorFlow objects into a checkpoint file.
229
230Args:
231checkpoint_dir: str, directory where TensorFlow objects will be saved.
232iteration_number: int, iteration number to use for naming the checkpoint
233file.
234
235Returns:
236A dict containing additional Python objects to be checkpointed by the
237experiment. If the checkpoint directory does not exist, returns None.
238"""
239bundle_dictionary = super(ElephantDQNAgent, self).bundle_and_checkpoint(
240checkpoint_dir, iteration_number)
241bundle_dictionary['_online_network_updates'] = self._online_network_updates
242bundle_dictionary['_target_network_updates'] = self._target_network_updates
243return bundle_dictionary
244