google-research
495 строк · 19.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Wrapper for generative models used to derive intrinsic rewards.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import collections
24import math
25
26import cv2
27from dopamine.discrete_domains import atari_lib
28import gin
29import numpy as np
30import tensorflow.compat.v1 as tf
31from tensorflow.contrib import slim
32
33
34PSEUDO_COUNT_QUANTIZATION_FACTOR = 8
35PSEUDO_COUNT_OBSERVATION_SHAPE = (42, 42)
36NATURE_DQN_OBSERVATION_SHAPE = atari_lib.NATURE_DQN_OBSERVATION_SHAPE
37
38
39@slim.add_arg_scope
40def masked_conv2d(inputs, num_outputs, kernel_size,
41activation_fn=tf.nn.relu,
42weights_initializer=tf.initializers.glorot_normal(),
43biases_initializer=tf.initializers.zeros(),
44stride=(1, 1),
45scope=None,
46mask_type='A',
47collection=None,
48output_multiplier=1):
49"""Creates masked convolutions used in PixelCNN.
50
51There are two types of masked convolutions, type A and B, see Figure 1 in
52https://arxiv.org/abs/1606.05328 for more details.
53
54Args:
55inputs: input image.
56num_outputs: int, number of filters used in the convolution.
57kernel_size: int, size of convolution kernel.
58activation_fn: activation function used after the convolution.
59weights_initializer: distribution used to initialize the kernel.
60biases_initializer: distribution used to initialize biases.
61stride: convolution stride.
62scope: name of the tensorflow scope.
63mask_type: type of masked convolution, must be A or B.
64collection: tf variables collection.
65output_multiplier: number of convolutional network stacks.
66
67Returns:
68frame post convolution.
69"""
70assert mask_type in ('A', 'B') and num_outputs % output_multiplier == 0
71num_inputs = int(inputs.get_shape()[-1])
72kernel_shape = tuple(kernel_size) + (num_inputs, num_outputs)
73strides = (1,) + tuple(stride) + (1,)
74biases_shape = [num_outputs]
75
76mask_list = [np.zeros(
77tuple(kernel_size) + (num_inputs, num_outputs // output_multiplier),
78dtype=np.float32) for _ in range(output_multiplier)]
79for i in range(output_multiplier):
80# Mask type A
81if kernel_shape[0] > 1:
82mask_list[i][:kernel_shape[0]//2] = 1.0
83if kernel_shape[1] > 1:
84mask_list[i][kernel_shape[0]//2, :kernel_shape[1]//2] = 1.0
85# Mask type B
86if mask_type == 'B':
87mask_list[i][kernel_shape[0]//2, kernel_shape[1]//2] = 1.0
88mask_values = np.concatenate(mask_list, axis=3)
89
90with tf.variable_scope(scope):
91w = tf.get_variable('W', kernel_shape, trainable=True,
92initializer=weights_initializer)
93b = tf.get_variable('biases', biases_shape, trainable=True,
94initializer=biases_initializer)
95if collection is not None:
96tf.add_to_collection(collection, w)
97tf.add_to_collection(collection, b)
98
99mask = tf.constant(mask_values, dtype=tf.float32)
100mask.set_shape(kernel_shape)
101
102convolution = tf.nn.conv2d(inputs, mask * w, strides, padding='SAME')
103convolution_bias = tf.nn.bias_add(convolution, b)
104
105if activation_fn is not None:
106convolution_bias = activation_fn(convolution_bias)
107return convolution_bias
108
109
110def gating_layer(x, embedding, hidden_units, scope_name=''):
111"""Create the gating layer used in the PixelCNN architecture."""
112with tf.variable_scope(scope_name):
113out = masked_conv2d(x, 2*hidden_units, [3, 3],
114mask_type='B',
115activation_fn=None,
116output_multiplier=2,
117scope='masked_conv')
118out += slim.conv2d(embedding, 2*hidden_units, [1, 1],
119activation_fn=None)
120out = tf.reshape(out, [-1, 2])
121out = tf.tanh(out[:, 0]) + tf.sigmoid(out[:, 1])
122return tf.reshape(out, x.get_shape())
123
124
125@gin.configurable
126class CTSIntrinsicReward(object):
127"""Class used to instantiate a CTS density model used for exploration."""
128
129def __init__(self,
130reward_scale,
131convolutional=False,
132observation_shape=PSEUDO_COUNT_OBSERVATION_SHAPE,
133quantization_factor=PSEUDO_COUNT_QUANTIZATION_FACTOR):
134"""Constructor.
135
136Args:
137reward_scale: float, scale factor applied to the raw rewards.
138convolutional: bool, whether to use convolutional CTS.
139observation_shape: tuple, 2D dimensions of the observation predicted
140by the model. Needs to be square.
141quantization_factor: int, number of bits for the predicted image
142Raises:
143ValueError: when the `observation_shape` is not square.
144"""
145self._reward_scale = reward_scale
146if (len(observation_shape) != 2
147or observation_shape[0] != observation_shape[1]):
148raise ValueError('Observation shape needs to be square')
149self._observation_shape = observation_shape
150self.density_model = shannon.CTSTensorModel(
151observation_shape, convolutional)
152self._quantization_factor = quantization_factor
153
154def update(self, observation):
155"""Updates the density model with the given observation.
156
157Args:
158observation: Input frame.
159
160Returns:
161Update log-probability.
162"""
163input_tensor = self._preprocess(observation)
164return self.density_model.Update(input_tensor)
165
166def compute_intrinsic_reward(self, observation, training_steps, eval_mode):
167"""Updates the model, returns the intrinsic reward.
168
169Args:
170observation: Input frame. For compatibility with other models, this
171may have a batch-size of 1 as its first dimension.
172training_steps: int, number of training steps.
173eval_mode: bool, whether or not running eval mode.
174
175Returns:
176The corresponding intrinsic reward.
177"""
178del training_steps
179input_tensor = self._preprocess(observation)
180if not eval_mode:
181log_rho_t = self.density_model.Update(input_tensor)
182log_rho_tp1 = self.density_model.LogProb(input_tensor)
183ipd = log_rho_tp1 - log_rho_t
184else:
185# Do not update the density model in evaluation mode
186ipd = self.density_model.IPD(input_tensor)
187
188# Compute the pseudo count
189ipd_clipped = min(ipd, 25)
190inv_pseudo_count = max(0, math.expm1(ipd_clipped))
191reward = float(self._reward_scale) * math.sqrt(inv_pseudo_count)
192return reward
193
194def _preprocess(self, observation):
195"""Converts the given observation into something the model can use.
196
197Args:
198observation: Input frame.
199
200Returns:
201Processed frame.
202
203Raises:
204ValueError: If observation provided is not 2D.
205"""
206if observation.ndim != 2:
207raise ValueError('Observation needs to be 2D.')
208input_tensor = cv2.resize(observation,
209self._observation_shape,
210interpolation=cv2.INTER_AREA)
211input_tensor //= (256 // self._quantization_factor)
212# Convert to signed int (this may be unpleasantly inefficient).
213input_tensor = input_tensor.astype('i', copy=False)
214return input_tensor
215
216
217@gin.configurable
218class PixelCNNIntrinsicReward(object):
219"""PixelCNN class to instantiate a bonus using a PixelCNN density model."""
220
221def __init__(self,
222sess,
223reward_scale,
224ipd_scale,
225observation_shape=NATURE_DQN_OBSERVATION_SHAPE,
226resize_shape=PSEUDO_COUNT_OBSERVATION_SHAPE,
227quantization_factor=PSEUDO_COUNT_QUANTIZATION_FACTOR,
228tf_device='/cpu:*',
229optimizer=tf.train.RMSPropOptimizer(
230learning_rate=0.0001,
231momentum=0.9,
232epsilon=0.0001)):
233self._sess = sess
234self.reward_scale = reward_scale
235self.ipd_scale = ipd_scale
236self.observation_shape = observation_shape
237self.resize_shape = resize_shape
238self.quantization_factor = quantization_factor
239self.optimizer = optimizer
240
241with tf.device(tf_device), tf.name_scope('intrinsic_pixelcnn'):
242observation_shape = (1,) + observation_shape + (1,)
243self.obs_ph = tf.placeholder(tf.uint8, shape=observation_shape,
244name='obs_ph')
245self.preproccessed_obs = self._preprocess(self.obs_ph, resize_shape)
246self.iter_ph = tf.placeholder(tf.uint32, shape=[], name='iter_num')
247self.eval_ph = tf.placeholder(tf.bool, shape=[], name='eval_mode')
248self.network = tf.make_template('PixelCNN', self._network_template)
249self.ipd = tf.cond(tf.logical_not(self.eval_ph),
250self.update,
251self.virtual_update)
252self.reward = self.ipd_to_reward(self.ipd, self.iter_ph)
253
254def compute_intrinsic_reward(self, observation, training_steps, eval_mode):
255"""Updates the model (during training), returns the intrinsic reward.
256
257Args:
258observation: Input frame. For compatibility with other models, this
259may have a batch-size of 1 as its first dimension.
260training_steps: Number of training steps, int.
261eval_mode: bool, whether or not running eval mode.
262
263Returns:
264The corresponding intrinsic reward.
265"""
266observation = observation[np.newaxis, :, :, np.newaxis]
267return self._sess.run(self.reward, {self.obs_ph: observation,
268self.iter_ph: training_steps,
269self.eval_ph: eval_mode})
270
271def _preprocess(self, obs, obs_shape):
272"""Preprocess the input."""
273obs = tf.cast(obs, tf.float32)
274obs = tf.image.resize_bilinear(obs, obs_shape)
275denom = tf.constant(256 // self.quantization_factor, dtype=tf.float32)
276return tf.floordiv(obs, denom)
277
278@gin.configurable
279def _network_template(self, obs, num_layers, hidden_units):
280"""PixelCNN network architecture."""
281with slim.arg_scope([slim.conv2d, masked_conv2d],
282weights_initializer=tf.variance_scaling_initializer(
283distribution='uniform'),
284biases_initializer=tf.constant_initializer(0.0)):
285net = masked_conv2d(obs, hidden_units, [7, 7], mask_type='A',
286activation_fn=None, scope='masked_conv_1')
287
288embedding = slim.model_variable(
289'embedding',
290shape=(1,) + self.resize_shape + (4,),
291initializer=tf.variance_scaling_initializer(
292distribution='uniform'))
293for i in range(1, num_layers + 1):
294net2 = gating_layer(net, embedding, hidden_units,
295'gating_{}'.format(i))
296net += masked_conv2d(net2, hidden_units, [1, 1],
297mask_type='B',
298activation_fn=None,
299scope='masked_conv_{}'.format(i+1))
300
301net += slim.conv2d(embedding, hidden_units, [1, 1],
302activation_fn=None)
303net = tf.nn.relu(net)
304net = masked_conv2d(net, 64, [1, 1], scope='1x1_conv_out',
305mask_type='B',
306activation_fn=tf.nn.relu)
307logits = masked_conv2d(net, self.quantization_factor, [1, 1],
308scope='logits', mask_type='B',
309activation_fn=None)
310loss = tf.losses.sparse_softmax_cross_entropy(
311labels=tf.cast(obs, tf.int32),
312logits=logits,
313reduction=tf.losses.Reduction.MEAN)
314return collections.namedtuple('PixelCNN_network', ['logits', 'loss'])(
315logits, loss)
316
317def update(self):
318"""Computes the log likehood difference and update the density model."""
319with tf.name_scope('update'):
320with tf.name_scope('pre_update'):
321loss = self.network(self.preproccessed_obs).loss
322
323train_op = self.optimizer.minimize(loss)
324
325with tf.name_scope('post_update'), tf.control_dependencies([train_op]):
326loss_post_training = self.network(self.preproccessed_obs).loss
327ipd = (loss - loss_post_training) * (
328self.resize_shape[0] * self.resize_shape[1])
329return ipd
330
331def virtual_update(self):
332"""Computes the log likelihood difference without updating the network."""
333with tf.name_scope('virtual_update'):
334with tf.name_scope('pre_update'):
335loss = self.network(self.preproccessed_obs).loss
336
337grads_and_vars = self.optimizer.compute_gradients(loss)
338model_vars = [gv[1] for gv in grads_and_vars]
339saved_vars = [tf.Variable(v.initialized_value()) for v in model_vars]
340backup_op = tf.group(*[t.assign(s)
341for t, s in zip(saved_vars, model_vars)])
342with tf.control_dependencies([backup_op]):
343train_op = self.optimizer.apply_gradients(grads_and_vars)
344with tf.control_dependencies([train_op]), tf.name_scope('post_update'):
345loss_post_training = self.network(self.preproccessed_obs).loss
346with tf.control_dependencies([loss_post_training]):
347restore_op = tf.group(*[d.assign(s)
348for d, s in zip(model_vars, saved_vars)])
349with tf.control_dependencies([restore_op]):
350ipd = (loss - loss_post_training) * \
351self.resize_shape[0] * self.resize_shape[1]
352return ipd
353
354def ipd_to_reward(self, ipd, steps):
355"""Computes the intrinsic reward from IPD."""
356# Prediction gain decay
357ipd = self.ipd_scale * ipd / tf.sqrt(tf.to_float(steps))
358inv_pseudo_count = tf.maximum(tf.expm1(ipd), 0.0)
359return self.reward_scale * tf.sqrt(inv_pseudo_count)
360
361
362@gin.configurable
363class RNDIntrinsicReward(object):
364"""Class used to instantiate a bonus using random network distillation."""
365
366def __init__(self,
367sess,
368embedding_size=512,
369observation_shape=NATURE_DQN_OBSERVATION_SHAPE,
370tf_device='/gpu:0',
371reward_scale=1.0,
372optimizer=tf.train.AdamOptimizer(
373learning_rate=0.0001,
374epsilon=0.00001),
375summary_writer=None):
376self.embedding_size = embedding_size
377self.reward_scale = reward_scale
378self.optimizer = optimizer
379self._sess = sess
380self.summary_writer = summary_writer
381
382with tf.device(tf_device), tf.name_scope('intrinsic_rnd'):
383obs_shape = (1,) + observation_shape + (1,)
384self.iter_ph = tf.placeholder(tf.uint64, shape=[], name='iter_num')
385self.iter = tf.cast(self.iter_ph, tf.float32)
386self.obs_ph = tf.placeholder(tf.uint8, shape=obs_shape,
387name='obs_ph')
388self.eval_ph = tf.placeholder(tf.bool, shape=[], name='eval_mode')
389self.obs = tf.cast(self.obs_ph, tf.float32)
390# Placeholder for running mean and std of observations and rewards
391self.obs_mean = tf.Variable(tf.zeros(shape=obs_shape),
392trainable=False,
393name='obs_mean',
394dtype=tf.float32)
395self.obs_std = tf.Variable(tf.ones(shape=obs_shape),
396trainable=False,
397name='obs_std',
398dtype=tf.float32)
399self.reward_mean = tf.Variable(tf.zeros(shape=[]),
400trainable=False,
401name='reward_mean',
402dtype=tf.float32)
403self.reward_std = tf.Variable(tf.ones(shape=[]),
404trainable=False,
405name='reward_std',
406dtype=tf.float32)
407self.obs = self._preprocess(self.obs)
408self.target_embedding = self._target_network(self.obs)
409self.prediction_embedding = self._prediction_network(self.obs)
410self._train_op = self._build_train_op()
411
412def _preprocess(self, obs):
413return tf.clip_by_value((obs - self.obs_mean) / self.obs_std, -5.0, 5.0)
414
415def compute_intrinsic_reward(self, obs, training_step, eval_mode=False):
416"""Computes the RND intrinsic reward."""
417obs = obs[np.newaxis, :, :, np.newaxis]
418to_evaluate = [self.intrinsic_reward]
419if not eval_mode:
420# Also update the prediction network
421to_evaluate.append(self._train_op)
422reward = self._sess.run(to_evaluate,
423{self.obs_ph: obs,
424self.iter_ph: training_step,
425self.eval_ph: eval_mode})[0]
426return self.reward_scale * float(reward)
427
428def _target_network(self, obs):
429"""Implements the random target network used by RND."""
430with slim.arg_scope([slim.conv2d, slim.fully_connected], trainable=False,
431weights_initializer=tf.orthogonal_initializer(
432gain=np.sqrt(2)),
433biases_initializer=tf.zeros_initializer()):
434net = slim.conv2d(obs, 32, [8, 8], stride=4,
435activation_fn=tf.nn.leaky_relu)
436net = slim.conv2d(net, 64, [4, 4], stride=2,
437activation_fn=tf.nn.leaky_relu)
438net = slim.conv2d(net, 64, [3, 3], stride=1,
439activation_fn=tf.nn.leaky_relu)
440net = slim.flatten(net)
441embedding = slim.fully_connected(net, self.embedding_size,
442activation_fn=None)
443return embedding
444
445def _prediction_network(self, obs):
446"""Prediction network used by RND to predict to target network output."""
447with slim.arg_scope([slim.conv2d, slim.fully_connected],
448weights_initializer=tf.orthogonal_initializer(
449gain=np.sqrt(2)),
450biases_initializer=tf.zeros_initializer()):
451net = slim.conv2d(obs, 32, [8, 8], stride=4,
452activation_fn=tf.nn.leaky_relu)
453net = slim.conv2d(net, 64, [4, 4], stride=2,
454activation_fn=tf.nn.leaky_relu)
455net = slim.conv2d(net, 64, [3, 3], stride=1,
456activation_fn=tf.nn.leaky_relu)
457net = slim.flatten(net)
458net = slim.fully_connected(net, 512, activation_fn=tf.nn.relu)
459net = slim.fully_connected(net, 512, activation_fn=tf.nn.relu)
460embedding = slim.fully_connected(net, self.embedding_size,
461activation_fn=None)
462return embedding
463
464def _update_moments(self):
465"""Update the moments estimates, assumes a batch size of 1."""
466def update():
467"""Update moment function passed later to a tf.cond."""
468moments = [
469(self.obs, self.obs_mean, self.obs_std),
470(self.loss, self.reward_mean, self.reward_std)
471]
472ops = []
473for value, mean, std in moments:
474delta = value - mean
475assign_mean = mean.assign_add(delta / self.iter)
476std_ = std * self.iter + (delta ** 2) * self.iter / (self.iter + 1)
477assign_std = std.assign(std_ / (self.iter + 1))
478ops.extend([assign_mean, assign_std])
479return ops
480
481return tf.cond(
482tf.logical_not(self.eval_ph),
483update,
484# false_fn must have the same number and type of outputs.
485lambda: 4 * [tf.constant(0., tf.float32)])
486
487def _build_train_op(self):
488"""Returns train op to update the prediction network."""
489prediction = self.prediction_embedding
490target = tf.stop_gradient(self.target_embedding)
491self.loss = tf.losses.mean_squared_error(
492target, prediction, reduction=tf.losses.Reduction.MEAN)
493with tf.control_dependencies(self._update_moments()):
494self.intrinsic_reward = (self.loss - self.reward_mean) / self.reward_std
495return self.optimizer.minimize(self.loss)
496