google-research
141 строка · 5.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implementation of a Rainbow agent with bootstrap."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22from bonus_based_exploration.noisy_networks import noisy_dqn_agent23
24from dopamine.agents.rainbow import rainbow_agent as base_rainbow_agent25import gin26import numpy as np27import tensorflow.compat.v1 as tf28from tensorflow.contrib import layers as contrib_layers29from tensorflow.contrib import slim as contrib_slim30
31slim = contrib_slim32
33
34@gin.configurable35class NoisyRainbowAgent(base_rainbow_agent.RainbowAgent):36"""A Rainbow agent with noisy networks."""37
38def __init__(self,39sess,40num_actions,41num_atoms=51,42vmax=10.,43gamma=0.99,44update_horizon=1,45min_replay_history=20000,46update_period=4,47target_update_period=8000,48epsilon_fn=lambda w, x, y, z: 0,49epsilon_decay_period=250000,50replay_scheme='prioritized',51tf_device='/cpu:*',52use_staging=True,53optimizer=tf.train.AdamOptimizer(54learning_rate=0.00025, epsilon=0.0003125),55summary_writer=None,56summary_writing_frequency=500,57noise_distribution='independent'):58"""Initializes the agent and constructs the components of its graph.59
60Args:
61sess: `tf.Session`, for executing ops.
62num_actions: int, number of actions the agent can take at any state.
63num_atoms: int, the number of buckets of the value function distribution.
64vmax: float, the value distribution support is [-vmax, vmax].
65gamma: float, discount factor with the usual RL meaning.
66update_horizon: int, horizon at which updates are performed, the 'n' in
67n-step update.
68min_replay_history: int, number of transitions that should be experienced
69before the agent begins training its value function.
70update_period: int, period between DQN updates.
71target_update_period: int, update period for the target network.
72epsilon_fn: function expecting 4 parameters:
73(decay_period, step, warmup_steps, epsilon). This function should return
74the epsilon value used for exploration during training.
75epsilon_decay_period: int, length of the epsilon decay schedule.
76replay_scheme: str, 'prioritized' or 'uniform', the sampling scheme of the
77replay memory.
78tf_device: str, Tensorflow device on which the agent's graph is executed.
79use_staging: bool, when True use a staging area to prefetch the next
80training batch, speeding training up by about 30%.
81optimizer: tf.train.Optimizer, for training the value function.
82summary_writer: SummaryWriter object for outputting training statistics.
83Summary writing disabled if set to None.
84summary_writing_frequency: int, frequency with which summaries will be
85written. Lower values will result in slower training.
86noise_distribution: string, distribution used to sample noise, must be
87`factorised` or `independent`.
88"""
89self.noise_distribution = noise_distribution90super(NoisyRainbowAgent, self).__init__(91sess=sess,92num_actions=num_actions,93num_atoms=num_atoms,94vmax=vmax,95gamma=gamma,96update_horizon=update_horizon,97min_replay_history=min_replay_history,98update_period=update_period,99target_update_period=target_update_period,100epsilon_fn=epsilon_fn,101epsilon_decay_period=epsilon_decay_period,102replay_scheme=replay_scheme,103tf_device=tf_device,104use_staging=use_staging,105optimizer=optimizer,106summary_writer=summary_writer,107summary_writing_frequency=summary_writing_frequency)108
109def _network_template(self, state):110"""Builds the convolutional network used to compute the agent's Q-values.111
112Args:
113state: `tf.Tensor`, contains the agent's current state.
114
115Returns:
116net: _network_type object containing the tensors output by the network.
117"""
118weights_initializer = slim.variance_scaling_initializer(119factor=1.0 / np.sqrt(3.0), mode='FAN_IN', uniform=True)120
121net = tf.cast(state, tf.float32)122net = tf.div(net, 255.)123net = slim.conv2d(124net, 32, [8, 8], stride=4, weights_initializer=weights_initializer)125net = slim.conv2d(126net, 64, [4, 4], stride=2, weights_initializer=weights_initializer)127net = slim.conv2d(128net, 64, [3, 3], stride=1, weights_initializer=weights_initializer)129net = slim.flatten(net)130net = noisy_dqn_agent.fully_connected(131net, 512, scope='fully_connected')132net = noisy_dqn_agent.fully_connected(133net,134self.num_actions * self._num_atoms,135activation_fn=None,136scope='fully_connected_1')137
138logits = tf.reshape(net, [-1, self.num_actions, self._num_atoms])139probabilities = contrib_layers.softmax(logits)140q_values = tf.reduce_sum(self._support * probabilities, axis=2)141return self._get_network_type()(q_values, logits, probabilities)142