google-research
202 строки · 7.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implementation of noisy networks DQN https://arxiv.org/abs/1706.10295.
17"""
18
19from __future__ import absolute_import20from __future__ import division21from __future__ import print_function22
23import enum24from dopamine.agents.dqn import dqn_agent as base_dqn_agent25import gin26import numpy as np27import tensorflow.compat.v1 as tf28from tensorflow.contrib import slim as contrib_slim29
30slim = contrib_slim31linearly_decaying_epsilon = base_dqn_agent.linearly_decaying_epsilon32
33
34@gin.constants_from_enum35class NoiseDistribution(enum.Enum):36INDEPENDENT = 037FACTORISED = 138
39
40def signed_sqrt(tensor):41return tf.sign(tensor) * tf.sqrt(tf.abs(tensor))42
43
44def fully_connected(inputs, num_outputs,45activation_fn=tf.nn.relu,46scope=None,47collection=None,48distribution=NoiseDistribution.INDEPENDENT,49summary_writer=None):50"""Creates a fully connected layer with noise."""51num_inputs = int(inputs.get_shape()[-1])52weight_shape = (num_inputs, num_outputs)53biases_shape = [num_outputs]54
55# Parameters for each noise distribution, see Section 3.2 in original paper.56if distribution == NoiseDistribution.INDEPENDENT:57stddev = np.sqrt(3./num_inputs)58constant = 0.01759epsilon_w = tf.truncated_normal(weight_shape)60epsilon_b = tf.truncated_normal(biases_shape)61elif distribution == NoiseDistribution.FACTORISED:62stddev = np.sqrt(1./num_inputs)63constant = 0.5*np.sqrt(1/num_inputs)64noise_input = tf.truncated_normal(weight_shape)65noise_output = tf.truncated_normal(biases_shape)66epsilon_w = tf.matmul(67signed_sqrt(noise_output)[:, None], signed_sqrt(noise_input)[None, :])68epsilon_b = signed_sqrt(noise_output)69else:70raise ValueError('Unknown noise distribution')71
72mu_initializer = tf.initializers.random_uniform(73minval=-stddev,74maxval=stddev)75sigma_initializer = tf.constant_initializer(value=constant)76
77with tf.variable_scope(scope):78mu_w = tf.get_variable('mu_w', weight_shape, trainable=True,79initializer=mu_initializer)80sigma_w = tf.get_variable('sigma_w', weight_shape, trainable=True,81initializer=sigma_initializer)82mu_b = tf.get_variable('mu_b', biases_shape, trainable=True,83initializer=mu_initializer)84sigma_b = tf.get_variable('sigma_b', biases_shape, trainable=True,85initializer=sigma_initializer)86if collection is not None:87tf.add_to_collection(collection, mu_w)88tf.add_to_collection(collection, mu_b)89tf.add_to_collection(collection, sigma_w)90tf.add_to_collection(collection, sigma_b)91
92w = mu_w + sigma_w * epsilon_w93b = mu_b + sigma_b * epsilon_b94layer = tf.matmul(inputs, w)95layer_bias = tf.nn.bias_add(layer, b)96
97if summary_writer is not None:98with tf.variable_scope('Noisy'):99tf.summary.scalar('Sigma', tf.reduce_mean(sigma_w))100
101if activation_fn is not None:102layer_bias = activation_fn(layer_bias)103return layer_bias104
105
106@gin.configurable107class NoisyDQNAgent(base_dqn_agent.DQNAgent):108"""Base class for a DQN agent with noisy layers."""109
110def __init__(self,111sess,112num_actions,113observation_shape=base_dqn_agent.NATURE_DQN_OBSERVATION_SHAPE,114gamma=0.99,115update_horizon=1,116min_replay_history=20000,117update_period=4,118target_update_period=8000,119epsilon_fn=lambda w, x, y, z: 0,120epsilon_decay_period=250000,121tf_device='/cpu:*',122use_staging=True,123max_tf_checkpoints_to_keep=3,124optimizer=tf.train.RMSPropOptimizer(125learning_rate=0.00025,126decay=0.95,127momentum=0.0,128epsilon=0.00001,129centered=True),130summary_writer=None,131summary_writing_frequency=500,132noise_distribution=NoiseDistribution.INDEPENDENT):133"""Initializes the agent and constructs the components of its graph.134
135Args:
136sess: `tf.Session`, for executing ops.
137num_actions: int, number of actions the agent can take at any state.
138observation_shape: tuple of ints describing the observation shape.
139gamma: float, discount factor with the usual RL meaning.
140update_horizon: int, horizon at which updates are performed, the 'n' in
141n-step update.
142min_replay_history: int, number of transitions that should be experienced
143before the agent begins training its value function.
144update_period: int, period between DQN updates.
145target_update_period: int, update period for the target network.
146epsilon_fn: function expecting 4 parameters:
147(decay_period, step, warmup_steps, epsilon). This function should return
148the epsilon value used for exploration during training.
149epsilon_decay_period: int, length of the epsilon decay schedule.
150tf_device: str, Tensorflow device on which the agent's graph is executed.
151use_staging: bool, when True use a staging area to prefetch the next
152training batch, speeding training up by about 30%.
153max_tf_checkpoints_to_keep: int, the number of TensorFlow checkpoints to
154keep.
155optimizer: `tf.train.Optimizer`, for training the value function.
156summary_writer: SummaryWriter object for outputting training statistics.
157Summary writing disabled if set to None.
158summary_writing_frequency: int, frequency with which summaries will be
159written. Lower values will result in slower training.
160noise_distribution: string, distribution used to sample noise, must be
161`factorised` or `independent`.
162"""
163self.noise_distribution = noise_distribution164super(NoisyDQNAgent, self).__init__(165sess=sess,166num_actions=num_actions,167observation_shape=observation_shape,168gamma=gamma,169update_horizon=update_horizon,170min_replay_history=min_replay_history,171update_period=update_period,172target_update_period=target_update_period,173epsilon_fn=epsilon_fn,174epsilon_decay_period=epsilon_decay_period,175tf_device=tf_device,176use_staging=use_staging,177optimizer=optimizer,178summary_writer=summary_writer,179summary_writing_frequency=summary_writing_frequency)180
181def _network_template(self, state):182"""Builds the convolutional network used to compute the agent's Q-values.183
184Args:
185state: `tf.Tensor`, contains the agent's current state.
186
187Returns:
188net: _network_type object containing the tensors output by the network.
189"""
190net = tf.cast(state, tf.float32)191net = tf.div(net, 255.)192net = slim.conv2d(net, 32, [8, 8], stride=4)193net = slim.conv2d(net, 64, [4, 4], stride=2)194net = slim.conv2d(net, 64, [3, 3], stride=1)195net = slim.flatten(net)196net = fully_connected(net, 512, distribution=self.noise_distribution,197scope='fully_connected')198q_values = fully_connected(net, self.num_actions,199activation_fn=None,200distribution=self.noise_distribution,201scope='fully_connected_1')202return self._get_network_type()(q_values)203