google-research

Форк
0
495 строк · 19.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Wrapper for generative models used to derive intrinsic rewards.
17
"""
18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
import collections
24
import math
25

26
import cv2
27
from dopamine.discrete_domains import atari_lib
28
import gin
29
import numpy as np
30
import tensorflow.compat.v1 as tf
31
from tensorflow.contrib import slim
32

33

34
PSEUDO_COUNT_QUANTIZATION_FACTOR = 8
35
PSEUDO_COUNT_OBSERVATION_SHAPE = (42, 42)
36
NATURE_DQN_OBSERVATION_SHAPE = atari_lib.NATURE_DQN_OBSERVATION_SHAPE
37

38

39
@slim.add_arg_scope
40
def masked_conv2d(inputs, num_outputs, kernel_size,
41
                  activation_fn=tf.nn.relu,
42
                  weights_initializer=tf.initializers.glorot_normal(),
43
                  biases_initializer=tf.initializers.zeros(),
44
                  stride=(1, 1),
45
                  scope=None,
46
                  mask_type='A',
47
                  collection=None,
48
                  output_multiplier=1):
49
  """Creates masked convolutions used in PixelCNN.
50

51
  There are two types of masked convolutions, type A and B, see Figure 1 in
52
  https://arxiv.org/abs/1606.05328 for more details.
53

54
  Args:
55
    inputs: input image.
56
    num_outputs: int, number of filters used in the convolution.
57
    kernel_size: int, size of convolution kernel.
58
    activation_fn: activation function used after the convolution.
59
    weights_initializer: distribution used to initialize the kernel.
60
    biases_initializer: distribution used to initialize biases.
61
    stride: convolution stride.
62
    scope: name of the tensorflow scope.
63
    mask_type: type of masked convolution, must be A or B.
64
    collection: tf variables collection.
65
    output_multiplier: number of convolutional network stacks.
66

67
  Returns:
68
    frame post convolution.
69
  """
70
  assert mask_type in ('A', 'B') and num_outputs % output_multiplier == 0
71
  num_inputs = int(inputs.get_shape()[-1])
72
  kernel_shape = tuple(kernel_size) + (num_inputs, num_outputs)
73
  strides = (1,) + tuple(stride) + (1,)
74
  biases_shape = [num_outputs]
75

76
  mask_list = [np.zeros(
77
      tuple(kernel_size) + (num_inputs, num_outputs // output_multiplier),
78
      dtype=np.float32) for _ in range(output_multiplier)]
79
  for i in range(output_multiplier):
80
    # Mask type A
81
    if kernel_shape[0] > 1:
82
      mask_list[i][:kernel_shape[0]//2] = 1.0
83
    if kernel_shape[1] > 1:
84
      mask_list[i][kernel_shape[0]//2, :kernel_shape[1]//2] = 1.0
85
    # Mask type B
86
    if mask_type == 'B':
87
      mask_list[i][kernel_shape[0]//2, kernel_shape[1]//2] = 1.0
88
  mask_values = np.concatenate(mask_list, axis=3)
89

90
  with tf.variable_scope(scope):
91
    w = tf.get_variable('W', kernel_shape, trainable=True,
92
                        initializer=weights_initializer)
93
    b = tf.get_variable('biases', biases_shape, trainable=True,
94
                        initializer=biases_initializer)
95
    if collection is not None:
96
      tf.add_to_collection(collection, w)
97
      tf.add_to_collection(collection, b)
98

99
    mask = tf.constant(mask_values, dtype=tf.float32)
100
    mask.set_shape(kernel_shape)
101

102
    convolution = tf.nn.conv2d(inputs, mask * w, strides, padding='SAME')
103
    convolution_bias = tf.nn.bias_add(convolution, b)
104

105
    if activation_fn is not None:
106
      convolution_bias = activation_fn(convolution_bias)
107
  return convolution_bias
108

109

110
def gating_layer(x, embedding, hidden_units, scope_name=''):
111
  """Create the gating layer used in the PixelCNN architecture."""
112
  with tf.variable_scope(scope_name):
113
    out = masked_conv2d(x, 2*hidden_units, [3, 3],
114
                        mask_type='B',
115
                        activation_fn=None,
116
                        output_multiplier=2,
117
                        scope='masked_conv')
118
    out += slim.conv2d(embedding, 2*hidden_units, [1, 1],
119
                       activation_fn=None)
120
    out = tf.reshape(out, [-1, 2])
121
    out = tf.tanh(out[:, 0]) + tf.sigmoid(out[:, 1])
122
  return tf.reshape(out, x.get_shape())
123

124

125
@gin.configurable
126
class CTSIntrinsicReward(object):
127
  """Class used to instantiate a CTS density model used for exploration."""
128

129
  def __init__(self,
130
               reward_scale,
131
               convolutional=False,
132
               observation_shape=PSEUDO_COUNT_OBSERVATION_SHAPE,
133
               quantization_factor=PSEUDO_COUNT_QUANTIZATION_FACTOR):
134
    """Constructor.
135

136
    Args:
137
      reward_scale: float, scale factor applied to the raw rewards.
138
      convolutional: bool, whether to use convolutional CTS.
139
      observation_shape: tuple, 2D dimensions of the observation predicted
140
        by the model. Needs to be square.
141
      quantization_factor: int, number of bits for the predicted image
142
    Raises:
143
      ValueError: when the `observation_shape` is not square.
144
    """
145
    self._reward_scale = reward_scale
146
    if  (len(observation_shape) != 2
147
         or observation_shape[0] != observation_shape[1]):
148
      raise ValueError('Observation shape needs to be square')
149
    self._observation_shape = observation_shape
150
    self.density_model = shannon.CTSTensorModel(
151
        observation_shape, convolutional)
152
    self._quantization_factor = quantization_factor
153

154
  def update(self, observation):
155
    """Updates the density model with the given observation.
156

157
    Args:
158
      observation: Input frame.
159

160
    Returns:
161
      Update log-probability.
162
    """
163
    input_tensor = self._preprocess(observation)
164
    return self.density_model.Update(input_tensor)
165

166
  def compute_intrinsic_reward(self, observation, training_steps, eval_mode):
167
    """Updates the model, returns the intrinsic reward.
168

169
    Args:
170
      observation: Input frame. For compatibility with other models, this
171
        may have a batch-size of 1 as its first dimension.
172
      training_steps: int, number of training steps.
173
      eval_mode: bool, whether or not running eval mode.
174

175
    Returns:
176
      The corresponding intrinsic reward.
177
    """
178
    del training_steps
179
    input_tensor = self._preprocess(observation)
180
    if not eval_mode:
181
      log_rho_t = self.density_model.Update(input_tensor)
182
      log_rho_tp1 = self.density_model.LogProb(input_tensor)
183
      ipd = log_rho_tp1 - log_rho_t
184
    else:
185
      # Do not update the density model in evaluation mode
186
      ipd = self.density_model.IPD(input_tensor)
187

188
    # Compute the pseudo count
189
    ipd_clipped = min(ipd, 25)
190
    inv_pseudo_count = max(0, math.expm1(ipd_clipped))
191
    reward = float(self._reward_scale) * math.sqrt(inv_pseudo_count)
192
    return reward
193

194
  def _preprocess(self, observation):
195
    """Converts the given observation into something the model can use.
196

197
    Args:
198
      observation: Input frame.
199

200
    Returns:
201
      Processed frame.
202

203
    Raises:
204
      ValueError: If observation provided is not 2D.
205
    """
206
    if observation.ndim != 2:
207
      raise ValueError('Observation needs to be 2D.')
208
    input_tensor = cv2.resize(observation,
209
                              self._observation_shape,
210
                              interpolation=cv2.INTER_AREA)
211
    input_tensor //= (256 // self._quantization_factor)
212
    # Convert to signed int (this may be unpleasantly inefficient).
213
    input_tensor = input_tensor.astype('i', copy=False)
214
    return input_tensor
215

216

217
@gin.configurable
218
class PixelCNNIntrinsicReward(object):
219
  """PixelCNN class to instantiate a bonus using a PixelCNN density model."""
220

221
  def __init__(self,
222
               sess,
223
               reward_scale,
224
               ipd_scale,
225
               observation_shape=NATURE_DQN_OBSERVATION_SHAPE,
226
               resize_shape=PSEUDO_COUNT_OBSERVATION_SHAPE,
227
               quantization_factor=PSEUDO_COUNT_QUANTIZATION_FACTOR,
228
               tf_device='/cpu:*',
229
               optimizer=tf.train.RMSPropOptimizer(
230
                   learning_rate=0.0001,
231
                   momentum=0.9,
232
                   epsilon=0.0001)):
233
    self._sess = sess
234
    self.reward_scale = reward_scale
235
    self.ipd_scale = ipd_scale
236
    self.observation_shape = observation_shape
237
    self.resize_shape = resize_shape
238
    self.quantization_factor = quantization_factor
239
    self.optimizer = optimizer
240

241
    with tf.device(tf_device), tf.name_scope('intrinsic_pixelcnn'):
242
      observation_shape = (1,) + observation_shape + (1,)
243
      self.obs_ph = tf.placeholder(tf.uint8, shape=observation_shape,
244
                                   name='obs_ph')
245
      self.preproccessed_obs = self._preprocess(self.obs_ph, resize_shape)
246
      self.iter_ph = tf.placeholder(tf.uint32, shape=[], name='iter_num')
247
      self.eval_ph = tf.placeholder(tf.bool, shape=[], name='eval_mode')
248
      self.network = tf.make_template('PixelCNN', self._network_template)
249
      self.ipd = tf.cond(tf.logical_not(self.eval_ph),
250
                         self.update,
251
                         self.virtual_update)
252
      self.reward = self.ipd_to_reward(self.ipd, self.iter_ph)
253

254
  def compute_intrinsic_reward(self, observation, training_steps, eval_mode):
255
    """Updates the model (during training), returns the intrinsic reward.
256

257
    Args:
258
      observation: Input frame. For compatibility with other models, this
259
        may have a batch-size of 1 as its first dimension.
260
      training_steps: Number of training steps, int.
261
      eval_mode: bool, whether or not running eval mode.
262

263
    Returns:
264
      The corresponding intrinsic reward.
265
    """
266
    observation = observation[np.newaxis, :, :, np.newaxis]
267
    return self._sess.run(self.reward, {self.obs_ph: observation,
268
                                        self.iter_ph: training_steps,
269
                                        self.eval_ph: eval_mode})
270

271
  def _preprocess(self, obs, obs_shape):
272
    """Preprocess the input."""
273
    obs = tf.cast(obs, tf.float32)
274
    obs = tf.image.resize_bilinear(obs, obs_shape)
275
    denom = tf.constant(256 // self.quantization_factor, dtype=tf.float32)
276
    return tf.floordiv(obs, denom)
277

278
  @gin.configurable
279
  def _network_template(self, obs, num_layers, hidden_units):
280
    """PixelCNN network architecture."""
281
    with slim.arg_scope([slim.conv2d, masked_conv2d],
282
                        weights_initializer=tf.variance_scaling_initializer(
283
                            distribution='uniform'),
284
                        biases_initializer=tf.constant_initializer(0.0)):
285
      net = masked_conv2d(obs, hidden_units, [7, 7], mask_type='A',
286
                          activation_fn=None, scope='masked_conv_1')
287

288
      embedding = slim.model_variable(
289
          'embedding',
290
          shape=(1,) + self.resize_shape + (4,),
291
          initializer=tf.variance_scaling_initializer(
292
              distribution='uniform'))
293
      for i in range(1, num_layers + 1):
294
        net2 = gating_layer(net, embedding, hidden_units,
295
                            'gating_{}'.format(i))
296
        net += masked_conv2d(net2, hidden_units, [1, 1],
297
                             mask_type='B',
298
                             activation_fn=None,
299
                             scope='masked_conv_{}'.format(i+1))
300

301
      net += slim.conv2d(embedding, hidden_units, [1, 1],
302
                         activation_fn=None)
303
      net = tf.nn.relu(net)
304
      net = masked_conv2d(net, 64, [1, 1], scope='1x1_conv_out',
305
                          mask_type='B',
306
                          activation_fn=tf.nn.relu)
307
      logits = masked_conv2d(net, self.quantization_factor, [1, 1],
308
                             scope='logits', mask_type='B',
309
                             activation_fn=None)
310
    loss = tf.losses.sparse_softmax_cross_entropy(
311
        labels=tf.cast(obs, tf.int32),
312
        logits=logits,
313
        reduction=tf.losses.Reduction.MEAN)
314
    return collections.namedtuple('PixelCNN_network', ['logits', 'loss'])(
315
        logits, loss)
316

317
  def update(self):
318
    """Computes the log likehood difference and update the density model."""
319
    with tf.name_scope('update'):
320
      with tf.name_scope('pre_update'):
321
        loss = self.network(self.preproccessed_obs).loss
322

323
      train_op = self.optimizer.minimize(loss)
324

325
      with tf.name_scope('post_update'), tf.control_dependencies([train_op]):
326
        loss_post_training = self.network(self.preproccessed_obs).loss
327
        ipd = (loss - loss_post_training) * (
328
            self.resize_shape[0] * self.resize_shape[1])
329
    return ipd
330

331
  def virtual_update(self):
332
    """Computes the log likelihood difference without updating the network."""
333
    with tf.name_scope('virtual_update'):
334
      with tf.name_scope('pre_update'):
335
        loss = self.network(self.preproccessed_obs).loss
336

337
      grads_and_vars = self.optimizer.compute_gradients(loss)
338
      model_vars = [gv[1] for gv in grads_and_vars]
339
      saved_vars = [tf.Variable(v.initialized_value()) for v in model_vars]
340
      backup_op = tf.group(*[t.assign(s)
341
                             for t, s in zip(saved_vars, model_vars)])
342
      with tf.control_dependencies([backup_op]):
343
        train_op = self.optimizer.apply_gradients(grads_and_vars)
344
      with tf.control_dependencies([train_op]), tf.name_scope('post_update'):
345
        loss_post_training = self.network(self.preproccessed_obs).loss
346
      with tf.control_dependencies([loss_post_training]):
347
        restore_op = tf.group(*[d.assign(s)
348
                                for d, s in zip(model_vars, saved_vars)])
349
      with tf.control_dependencies([restore_op]):
350
        ipd = (loss - loss_post_training) * \
351
              self.resize_shape[0] * self.resize_shape[1]
352
      return ipd
353

354
  def ipd_to_reward(self, ipd, steps):
355
    """Computes the intrinsic reward from IPD."""
356
    # Prediction gain decay
357
    ipd = self.ipd_scale * ipd / tf.sqrt(tf.to_float(steps))
358
    inv_pseudo_count = tf.maximum(tf.expm1(ipd), 0.0)
359
    return self.reward_scale * tf.sqrt(inv_pseudo_count)
360

361

362
@gin.configurable
363
class RNDIntrinsicReward(object):
364
  """Class used to instantiate a bonus using random network distillation."""
365

366
  def __init__(self,
367
               sess,
368
               embedding_size=512,
369
               observation_shape=NATURE_DQN_OBSERVATION_SHAPE,
370
               tf_device='/gpu:0',
371
               reward_scale=1.0,
372
               optimizer=tf.train.AdamOptimizer(
373
                   learning_rate=0.0001,
374
                   epsilon=0.00001),
375
               summary_writer=None):
376
    self.embedding_size = embedding_size
377
    self.reward_scale = reward_scale
378
    self.optimizer = optimizer
379
    self._sess = sess
380
    self.summary_writer = summary_writer
381

382
    with tf.device(tf_device), tf.name_scope('intrinsic_rnd'):
383
      obs_shape = (1,) + observation_shape + (1,)
384
      self.iter_ph = tf.placeholder(tf.uint64, shape=[], name='iter_num')
385
      self.iter = tf.cast(self.iter_ph, tf.float32)
386
      self.obs_ph = tf.placeholder(tf.uint8, shape=obs_shape,
387
                                   name='obs_ph')
388
      self.eval_ph = tf.placeholder(tf.bool, shape=[], name='eval_mode')
389
      self.obs = tf.cast(self.obs_ph, tf.float32)
390
      # Placeholder for running mean and std of observations and rewards
391
      self.obs_mean = tf.Variable(tf.zeros(shape=obs_shape),
392
                                  trainable=False,
393
                                  name='obs_mean',
394
                                  dtype=tf.float32)
395
      self.obs_std = tf.Variable(tf.ones(shape=obs_shape),
396
                                 trainable=False,
397
                                 name='obs_std',
398
                                 dtype=tf.float32)
399
      self.reward_mean = tf.Variable(tf.zeros(shape=[]),
400
                                     trainable=False,
401
                                     name='reward_mean',
402
                                     dtype=tf.float32)
403
      self.reward_std = tf.Variable(tf.ones(shape=[]),
404
                                    trainable=False,
405
                                    name='reward_std',
406
                                    dtype=tf.float32)
407
      self.obs = self._preprocess(self.obs)
408
      self.target_embedding = self._target_network(self.obs)
409
      self.prediction_embedding = self._prediction_network(self.obs)
410
      self._train_op = self._build_train_op()
411

412
  def _preprocess(self, obs):
413
    return tf.clip_by_value((obs - self.obs_mean) / self.obs_std, -5.0, 5.0)
414

415
  def compute_intrinsic_reward(self, obs, training_step, eval_mode=False):
416
    """Computes the RND intrinsic reward."""
417
    obs = obs[np.newaxis, :, :, np.newaxis]
418
    to_evaluate = [self.intrinsic_reward]
419
    if not eval_mode:
420
      # Also update the prediction network
421
      to_evaluate.append(self._train_op)
422
    reward = self._sess.run(to_evaluate,
423
                            {self.obs_ph: obs,
424
                             self.iter_ph: training_step,
425
                             self.eval_ph: eval_mode})[0]
426
    return self.reward_scale * float(reward)
427

428
  def _target_network(self, obs):
429
    """Implements the random target network used by RND."""
430
    with slim.arg_scope([slim.conv2d, slim.fully_connected], trainable=False,
431
                        weights_initializer=tf.orthogonal_initializer(
432
                            gain=np.sqrt(2)),
433
                        biases_initializer=tf.zeros_initializer()):
434
      net = slim.conv2d(obs, 32, [8, 8], stride=4,
435
                        activation_fn=tf.nn.leaky_relu)
436
      net = slim.conv2d(net, 64, [4, 4], stride=2,
437
                        activation_fn=tf.nn.leaky_relu)
438
      net = slim.conv2d(net, 64, [3, 3], stride=1,
439
                        activation_fn=tf.nn.leaky_relu)
440
      net = slim.flatten(net)
441
      embedding = slim.fully_connected(net, self.embedding_size,
442
                                       activation_fn=None)
443
    return embedding
444

445
  def _prediction_network(self, obs):
446
    """Prediction network used by RND to predict to target network output."""
447
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
448
                        weights_initializer=tf.orthogonal_initializer(
449
                            gain=np.sqrt(2)),
450
                        biases_initializer=tf.zeros_initializer()):
451
      net = slim.conv2d(obs, 32, [8, 8], stride=4,
452
                        activation_fn=tf.nn.leaky_relu)
453
      net = slim.conv2d(net, 64, [4, 4], stride=2,
454
                        activation_fn=tf.nn.leaky_relu)
455
      net = slim.conv2d(net, 64, [3, 3], stride=1,
456
                        activation_fn=tf.nn.leaky_relu)
457
      net = slim.flatten(net)
458
      net = slim.fully_connected(net, 512, activation_fn=tf.nn.relu)
459
      net = slim.fully_connected(net, 512, activation_fn=tf.nn.relu)
460
      embedding = slim.fully_connected(net, self.embedding_size,
461
                                       activation_fn=None)
462
    return embedding
463

464
  def _update_moments(self):
465
    """Update the moments estimates, assumes a batch size of 1."""
466
    def update():
467
      """Update moment function passed later to a tf.cond."""
468
      moments = [
469
          (self.obs, self.obs_mean, self.obs_std),
470
          (self.loss, self.reward_mean, self.reward_std)
471
      ]
472
      ops = []
473
      for value, mean, std in moments:
474
        delta = value - mean
475
        assign_mean = mean.assign_add(delta / self.iter)
476
        std_ = std * self.iter + (delta ** 2) * self.iter / (self.iter + 1)
477
        assign_std = std.assign(std_ / (self.iter + 1))
478
        ops.extend([assign_mean, assign_std])
479
      return ops
480

481
    return tf.cond(
482
        tf.logical_not(self.eval_ph),
483
        update,
484
        # false_fn must have the same number and type of outputs.
485
        lambda: 4 * [tf.constant(0., tf.float32)])
486

487
  def _build_train_op(self):
488
    """Returns train op to update the prediction network."""
489
    prediction = self.prediction_embedding
490
    target = tf.stop_gradient(self.target_embedding)
491
    self.loss = tf.losses.mean_squared_error(
492
        target, prediction, reduction=tf.losses.Reduction.MEAN)
493
    with tf.control_dependencies(self._update_moments()):
494
      self.intrinsic_reward = (self.loss - self.reward_mean) / self.reward_std
495
    return self.optimizer.minimize(self.loss)
496

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.