google-research

Форк
0
202 строки · 7.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implementation of noisy networks DQN https://arxiv.org/abs/1706.10295.
17
"""
18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
import enum
24
from dopamine.agents.dqn import dqn_agent as base_dqn_agent
25
import gin
26
import numpy as np
27
import tensorflow.compat.v1 as tf
28
from tensorflow.contrib import slim as contrib_slim
29

30
slim = contrib_slim
31
linearly_decaying_epsilon = base_dqn_agent.linearly_decaying_epsilon
32

33

34
@gin.constants_from_enum
35
class NoiseDistribution(enum.Enum):
36
  INDEPENDENT = 0
37
  FACTORISED = 1
38

39

40
def signed_sqrt(tensor):
41
  return tf.sign(tensor) * tf.sqrt(tf.abs(tensor))
42

43

44
def fully_connected(inputs, num_outputs,
45
                    activation_fn=tf.nn.relu,
46
                    scope=None,
47
                    collection=None,
48
                    distribution=NoiseDistribution.INDEPENDENT,
49
                    summary_writer=None):
50
  """Creates a fully connected layer with noise."""
51
  num_inputs = int(inputs.get_shape()[-1])
52
  weight_shape = (num_inputs, num_outputs)
53
  biases_shape = [num_outputs]
54

55
  # Parameters for each noise distribution, see Section 3.2 in original paper.
56
  if distribution == NoiseDistribution.INDEPENDENT:
57
    stddev = np.sqrt(3./num_inputs)
58
    constant = 0.017
59
    epsilon_w = tf.truncated_normal(weight_shape)
60
    epsilon_b = tf.truncated_normal(biases_shape)
61
  elif distribution == NoiseDistribution.FACTORISED:
62
    stddev = np.sqrt(1./num_inputs)
63
    constant = 0.5*np.sqrt(1/num_inputs)
64
    noise_input = tf.truncated_normal(weight_shape)
65
    noise_output = tf.truncated_normal(biases_shape)
66
    epsilon_w = tf.matmul(
67
        signed_sqrt(noise_output)[:, None], signed_sqrt(noise_input)[None, :])
68
    epsilon_b = signed_sqrt(noise_output)
69
  else:
70
    raise ValueError('Unknown noise distribution')
71

72
  mu_initializer = tf.initializers.random_uniform(
73
      minval=-stddev,
74
      maxval=stddev)
75
  sigma_initializer = tf.constant_initializer(value=constant)
76

77
  with tf.variable_scope(scope):
78
    mu_w = tf.get_variable('mu_w', weight_shape, trainable=True,
79
                           initializer=mu_initializer)
80
    sigma_w = tf.get_variable('sigma_w', weight_shape, trainable=True,
81
                              initializer=sigma_initializer)
82
    mu_b = tf.get_variable('mu_b', biases_shape, trainable=True,
83
                           initializer=mu_initializer)
84
    sigma_b = tf.get_variable('sigma_b', biases_shape, trainable=True,
85
                              initializer=sigma_initializer)
86
    if collection is not None:
87
      tf.add_to_collection(collection, mu_w)
88
      tf.add_to_collection(collection, mu_b)
89
      tf.add_to_collection(collection, sigma_w)
90
      tf.add_to_collection(collection, sigma_b)
91

92
    w = mu_w + sigma_w * epsilon_w
93
    b = mu_b + sigma_b * epsilon_b
94
    layer = tf.matmul(inputs, w)
95
    layer_bias = tf.nn.bias_add(layer, b)
96

97
    if summary_writer is not None:
98
      with tf.variable_scope('Noisy'):
99
        tf.summary.scalar('Sigma', tf.reduce_mean(sigma_w))
100

101
    if activation_fn is not None:
102
      layer_bias = activation_fn(layer_bias)
103
  return layer_bias
104

105

106
@gin.configurable
107
class NoisyDQNAgent(base_dqn_agent.DQNAgent):
108
  """Base class for a DQN agent with noisy layers."""
109

110
  def __init__(self,
111
               sess,
112
               num_actions,
113
               observation_shape=base_dqn_agent.NATURE_DQN_OBSERVATION_SHAPE,
114
               gamma=0.99,
115
               update_horizon=1,
116
               min_replay_history=20000,
117
               update_period=4,
118
               target_update_period=8000,
119
               epsilon_fn=lambda w, x, y, z: 0,
120
               epsilon_decay_period=250000,
121
               tf_device='/cpu:*',
122
               use_staging=True,
123
               max_tf_checkpoints_to_keep=3,
124
               optimizer=tf.train.RMSPropOptimizer(
125
                   learning_rate=0.00025,
126
                   decay=0.95,
127
                   momentum=0.0,
128
                   epsilon=0.00001,
129
                   centered=True),
130
               summary_writer=None,
131
               summary_writing_frequency=500,
132
               noise_distribution=NoiseDistribution.INDEPENDENT):
133
    """Initializes the agent and constructs the components of its graph.
134

135
    Args:
136
      sess: `tf.Session`, for executing ops.
137
      num_actions: int, number of actions the agent can take at any state.
138
      observation_shape: tuple of ints describing the observation shape.
139
      gamma: float, discount factor with the usual RL meaning.
140
      update_horizon: int, horizon at which updates are performed, the 'n' in
141
        n-step update.
142
      min_replay_history: int, number of transitions that should be experienced
143
        before the agent begins training its value function.
144
      update_period: int, period between DQN updates.
145
      target_update_period: int, update period for the target network.
146
      epsilon_fn: function expecting 4 parameters:
147
        (decay_period, step, warmup_steps, epsilon). This function should return
148
        the epsilon value used for exploration during training.
149
      epsilon_decay_period: int, length of the epsilon decay schedule.
150
      tf_device: str, Tensorflow device on which the agent's graph is executed.
151
      use_staging: bool, when True use a staging area to prefetch the next
152
        training batch, speeding training up by about 30%.
153
      max_tf_checkpoints_to_keep: int, the number of TensorFlow checkpoints to
154
        keep.
155
      optimizer: `tf.train.Optimizer`, for training the value function.
156
      summary_writer: SummaryWriter object for outputting training statistics.
157
        Summary writing disabled if set to None.
158
      summary_writing_frequency: int, frequency with which summaries will be
159
        written. Lower values will result in slower training.
160
      noise_distribution: string, distribution used to sample noise, must be
161
        `factorised` or `independent`.
162
    """
163
    self.noise_distribution = noise_distribution
164
    super(NoisyDQNAgent, self).__init__(
165
        sess=sess,
166
        num_actions=num_actions,
167
        observation_shape=observation_shape,
168
        gamma=gamma,
169
        update_horizon=update_horizon,
170
        min_replay_history=min_replay_history,
171
        update_period=update_period,
172
        target_update_period=target_update_period,
173
        epsilon_fn=epsilon_fn,
174
        epsilon_decay_period=epsilon_decay_period,
175
        tf_device=tf_device,
176
        use_staging=use_staging,
177
        optimizer=optimizer,
178
        summary_writer=summary_writer,
179
        summary_writing_frequency=summary_writing_frequency)
180

181
  def _network_template(self, state):
182
    """Builds the convolutional network used to compute the agent's Q-values.
183

184
    Args:
185
      state: `tf.Tensor`, contains the agent's current state.
186

187
    Returns:
188
      net: _network_type object containing the tensors output by the network.
189
    """
190
    net = tf.cast(state, tf.float32)
191
    net = tf.div(net, 255.)
192
    net = slim.conv2d(net, 32, [8, 8], stride=4)
193
    net = slim.conv2d(net, 64, [4, 4], stride=2)
194
    net = slim.conv2d(net, 64, [3, 3], stride=1)
195
    net = slim.flatten(net)
196
    net = fully_connected(net, 512, distribution=self.noise_distribution,
197
                          scope='fully_connected')
198
    q_values = fully_connected(net, self.num_actions,
199
                               activation_fn=None,
200
                               distribution=self.noise_distribution,
201
                               scope='fully_connected_1')
202
    return self._get_network_type()(q_values)
203

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.