google-research

Форк
0
/
ksme_implicit_quantile_agent.py 
405 строк · 17.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implicit Quantile agent with KSMe loss."""
17

18
import collections
19
import functools
20

21
from absl import logging
22
from dopamine.jax import losses
23
from dopamine.jax.agents.implicit_quantile import implicit_quantile_agent
24
from dopamine.metrics import statistics_instance
25
from flax import linen as nn
26
import gin
27
import jax
28
import jax.numpy as jnp
29
import numpy as np
30
import optax
31
from ksme.atari import metric_utils
32

33
NetworkType = collections.namedtuple(
34
    'network', ['quantile_values', 'quantiles', 'representation'])
35

36

37
def stable_scaled_log_softmax(x, tau, axis=-1):
38
  max_x = jnp.amax(x, axis=axis, keepdims=True)
39
  y = x - max_x
40
  tau_lse = max_x + tau * jnp.log(
41
      jnp.sum(jnp.exp(y / tau), axis=axis, keepdims=True))
42
  return x - tau_lse
43

44

45
def stable_softmax(x, tau, axis=-1):
46
  max_x = jnp.amax(x, axis=axis, keepdims=True)
47
  y = x - max_x
48
  return jax.nn.softmax(y / tau, axis=axis)
49

50

51
class AtariImplicitQuantileNetwork(nn.Module):
52
  """The Implicit Quantile Network (Dabney et al., 2018).."""
53
  num_actions: int
54
  quantile_embedding_dim: int
55

56
  @nn.compact
57
  def __call__(self, x, num_quantiles, rng):
58
    initializer = jax.nn.initializers.variance_scaling(
59
        scale=1.0 / jnp.sqrt(3.0),
60
        mode='fan_in',
61
        distribution='uniform')
62
    x = x.astype(jnp.float32) / 255.
63
    x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4),
64
                kernel_init=initializer)(x)
65
    x = nn.relu(x)
66
    x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2),
67
                kernel_init=initializer)(x)
68
    x = nn.relu(x)
69
    x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1),
70
                kernel_init=initializer)(x)
71
    x = nn.relu(x)
72
    representation = x.reshape((-1))  # flatten
73
    state_vector_length = representation.shape[-1]
74
    state_net_tiled = jnp.tile(representation, [num_quantiles, 1])
75
    quantiles_shape = [num_quantiles, 1]
76
    quantiles = jax.random.uniform(rng, shape=quantiles_shape)
77
    quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim])
78
    quantile_net = (
79
        jnp.arange(1, self.quantile_embedding_dim + 1, 1).astype(jnp.float32)
80
        * np.pi
81
        * quantile_net)
82
    quantile_net = jnp.cos(quantile_net)
83
    quantile_net = nn.Dense(features=state_vector_length,
84
                            kernel_init=initializer)(quantile_net)
85
    quantile_net = nn.relu(quantile_net)
86
    x = state_net_tiled * quantile_net
87
    x = nn.Dense(features=512, kernel_init=initializer)(x)
88
    x = nn.relu(x)
89
    quantile_values = nn.Dense(features=self.num_actions,
90
                               kernel_init=initializer)(x)
91
    return NetworkType(quantile_values, quantiles, representation)
92

93

94
@functools.partial(
95
    jax.vmap,
96
    in_axes=(None, None, 0, 0, 0, 0, 0, None, None, None, None, None,
97
             None, None),
98
    out_axes=(None, 0, 0, 0))
99
def munchausen_target_quantile_values(network, target_params, states,
100
                                      actions, next_states, rewards, terminals,
101
                                      num_tau_prime_samples,
102
                                      num_quantile_samples, cumulative_gamma,
103
                                      rng, tau, alpha, clip_value_min):
104
  """Build the munchausen target for return values at given quantiles."""
105
  rng, rng1, rng2, rng3 = jax.random.split(rng, num=4)
106
  target_action = network.apply(
107
      target_params, states, num_quantiles=num_quantile_samples, rng=rng1)
108
  curr_state_representation = target_action.representation
109
  curr_state_representation = jnp.squeeze(curr_state_representation)
110
  is_terminal_multiplier = 1. - terminals.astype(jnp.float32)
111
  # Incorporate terminal state to discount factor.
112
  gamma_with_terminal = cumulative_gamma * is_terminal_multiplier
113
  gamma_with_terminal = jnp.tile(gamma_with_terminal, [num_tau_prime_samples])
114

115
  replay_net_target_outputs = network.apply(
116
      target_params, next_states, num_quantiles=num_tau_prime_samples,
117
      rng=rng2)
118
  replay_quantile_values = replay_net_target_outputs.quantile_values
119

120
  target_next_action = network.apply(target_params,
121
                                     next_states,
122
                                     num_quantiles=num_quantile_samples,
123
                                     rng=rng3)
124
  target_next_quantile_values_action = target_next_action.quantile_values
125
  replay_next_target_q_values = jnp.squeeze(
126
      jnp.mean(target_next_quantile_values_action, axis=0))
127

128
  q_state_values = target_action.quantile_values
129
  replay_target_q_values = jnp.squeeze(jnp.mean(q_state_values, axis=0))
130

131
  num_actions = q_state_values.shape[-1]
132
  replay_action_one_hot = jax.nn.one_hot(actions, num_actions)
133
  replay_next_log_policy = stable_scaled_log_softmax(
134
      replay_next_target_q_values, tau, axis=0)
135
  replay_next_policy = stable_softmax(
136
      replay_next_target_q_values, tau, axis=0)
137
  replay_log_policy = stable_scaled_log_softmax(replay_target_q_values,
138
                                                tau, axis=0)
139

140
  tau_log_pi_a = jnp.sum(replay_log_policy * replay_action_one_hot, axis=0)
141
  tau_log_pi_a = jnp.clip(tau_log_pi_a, a_min=clip_value_min, a_max=1)
142
  munchausen_term = alpha * tau_log_pi_a
143
  weighted_logits = (
144
      replay_next_policy * (replay_quantile_values -
145
                            replay_next_log_policy))
146

147
  target_quantile_vals = jnp.sum(weighted_logits, axis=1)
148
  rewards += munchausen_term
149
  rewards = jnp.tile(rewards, [num_tau_prime_samples])
150
  target_quantile_vals = (
151
      rewards + gamma_with_terminal * target_quantile_vals)
152
  next_state_representation = target_next_action.representation
153
  next_state_representation = jnp.squeeze(next_state_representation)
154

155
  return (
156
      rng,
157
      jax.lax.stop_gradient(target_quantile_vals[:, None]),
158
      jax.lax.stop_gradient(curr_state_representation),
159
      jax.lax.stop_gradient(next_state_representation))
160

161

162
@functools.partial(
163
    jax.vmap,
164
    in_axes=(None, None, None, 0, 0, 0, 0, None, None, None, None, None),
165
    out_axes=(None, 0, 0, 0))
166
def target_quantile_values(network, online_params, target_params, states,
167
                           next_states, rewards, terminals,
168
                           num_tau_prime_samples, num_quantile_samples,
169
                           cumulative_gamma, double_dqn, rng):
170
  """Build the target for return values at given quantiles."""
171
  rng, rng1, rng2, rng3 = jax.random.split(rng, num=4)
172
  curr_state_representation = network.apply(
173
      target_params, states, num_quantiles=num_quantile_samples,
174
      rng=rng3).representation
175
  curr_state_representation = jnp.squeeze(curr_state_representation)
176
  rewards = jnp.tile(rewards, [num_tau_prime_samples])
177
  is_terminal_multiplier = 1. - terminals.astype(jnp.float32)
178
  # Incorporate terminal state to discount factor.
179
  gamma_with_terminal = cumulative_gamma * is_terminal_multiplier
180
  gamma_with_terminal = jnp.tile(gamma_with_terminal, [num_tau_prime_samples])
181
  # Compute Q-values which are used for action selection for the next states
182
  # in the replay buffer. Compute the argmax over the Q-values.
183
  if double_dqn:
184
    outputs_action = network.apply(online_params,
185
                                   next_states,
186
                                   num_quantiles=num_quantile_samples,
187
                                   rng=rng1)
188
  else:
189
    outputs_action = network.apply(target_params,
190
                                   next_states,
191
                                   num_quantiles=num_quantile_samples,
192
                                   rng=rng1)
193
  target_quantile_values_action = outputs_action.quantile_values
194
  target_q_values = jnp.squeeze(
195
      jnp.mean(target_quantile_values_action, axis=0))
196
  # Shape: batch_size.
197
  next_qt_argmax = jnp.argmax(target_q_values)
198
  # Get the indices of the maximium Q-value across the action dimension.
199
  # Shape of next_qt_argmax: (num_tau_prime_samples x batch_size).
200
  next_state_target_outputs = network.apply(
201
      target_params,
202
      next_states,
203
      num_quantiles=num_tau_prime_samples,
204
      rng=rng2)
205
  next_qt_argmax = jnp.tile(next_qt_argmax, [num_tau_prime_samples])
206
  target_quantile_vals = (
207
      jax.vmap(lambda x, y: x[y])(next_state_target_outputs.quantile_values,
208
                                  next_qt_argmax))
209
  target_quantile_vals = rewards + gamma_with_terminal * target_quantile_vals
210
  # We return with an extra dimension, which is expected by train.
211
  next_state_representation = next_state_target_outputs.representation
212
  next_state_representation = jnp.squeeze(next_state_representation)
213
  return (
214
      rng,
215
      jax.lax.stop_gradient(target_quantile_vals[:, None]),
216
      jax.lax.stop_gradient(curr_state_representation),
217
      jax.lax.stop_gradient(next_state_representation))
218

219

220
@functools.partial(jax.jit, static_argnums=(0, 3, 10, 11, 12, 13, 14, 15, 17,
221
                                            18, 19, 20, 21, 22))
222
def train(network, online_params, target_params, optimizer, optimizer_state,
223
          states, actions, next_states, rewards, terminals, num_tau_samples,
224
          num_tau_prime_samples, num_quantile_samples, cumulative_gamma,
225
          double_dqn, kappa, rng, mico_weight, distance_fn, similarity_fn,
226
          tau, alpha, clip_value_min):
227
  """Run a training step."""
228
  # The parameters tau, alpha, and clip_value_min are only used for
229
  # Munchausen-IQN (https://arxiv.org/abs/2007.14430), and are only used when
230
  # tau is not None.
231
  def loss_fn(params, rng_input, target_quantile_vals, target_r, target_next_r):
232
    def online(state):
233
      return network.apply(params, state, num_quantiles=num_tau_samples,
234
                           rng=rng_input)
235

236
    model_output = jax.vmap(online)(states)
237
    quantile_values = model_output.quantile_values
238
    quantiles = model_output.quantiles
239
    representations = model_output.representation
240
    representations = jnp.squeeze(representations)
241
    chosen_action_quantile_values = jax.vmap(lambda x, y: x[:, y][:, None])(
242
        quantile_values, actions)
243
    # Shape of bellman_erors and huber_loss:
244
    # batch_size x num_tau_prime_samples x num_tau_samples x 1.
245
    bellman_errors = (target_quantile_vals[:, :, None, :] -
246
                      chosen_action_quantile_values[:, None, :, :])
247
    # The huber loss (see Section 2.3 of the paper) is defined via two cases:
248
    # case_one: |bellman_errors| <= kappa
249
    # case_two: |bellman_errors| > kappa
250
    huber_loss_case_one = (
251
        (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) *
252
        0.5 * bellman_errors ** 2)
253
    huber_loss_case_two = (
254
        (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) *
255
        kappa * (jnp.abs(bellman_errors) - 0.5 * kappa))
256
    huber_loss = huber_loss_case_one + huber_loss_case_two
257
    # Tile by num_tau_prime_samples along a new dimension. Shape is now
258
    # batch_size x num_tau_prime_samples x num_tau_samples x 1.
259
    # These quantiles will be used for computation of the quantile huber loss
260
    # below (see section 2.3 of the paper).
261
    quantiles = jnp.tile(quantiles[:, None, :, :],
262
                         [1, num_tau_prime_samples, 1, 1]).astype(jnp.float32)
263
    # Shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.
264
    quantile_huber_loss = (jnp.abs(quantiles - jax.lax.stop_gradient(
265
        (bellman_errors < 0).astype(jnp.float32))) * huber_loss) / kappa
266
    # Sum over current quantile value (num_tau_samples) dimension,
267
    # average over target quantile value (num_tau_prime_samples) dimension.
268
    # Shape: batch_size x num_tau_prime_samples x 1.
269
    quantile_huber_loss = jnp.sum(quantile_huber_loss, axis=2)
270
    quantile_huber_loss = jnp.mean(quantile_huber_loss, axis=1)
271
    online_similarities = metric_utils.representation_similarities(
272
        representations, target_r, distance_fn, similarity_fn)
273
    target_similarities = metric_utils.target_similarities(
274
        target_next_r, rewards, distance_fn, similarity_fn, cumulative_gamma)
275
    kernel_loss = jnp.mean(jax.vmap(losses.huber_loss)(online_similarities,
276
                                                       target_similarities))
277
    loss = ((1. - mico_weight) * quantile_huber_loss +
278
            mico_weight * kernel_loss)
279
    return jnp.mean(loss), (jnp.mean(quantile_huber_loss), kernel_loss)
280

281
  if tau is None:
282
    rng, target_quantile_vals, target_r, target_next_r = target_quantile_values(
283
        network,
284
        online_params,
285
        target_params,
286
        states,
287
        next_states,
288
        rewards,
289
        terminals,
290
        num_tau_prime_samples,
291
        num_quantile_samples,
292
        cumulative_gamma,
293
        double_dqn,
294
        rng)
295
  else:
296
    rng, target_quantile_vals, target_r, target_next_r = (
297
        munchausen_target_quantile_values(
298
            network,
299
            target_params,
300
            states,
301
            actions,
302
            next_states,
303
            rewards,
304
            terminals,
305
            num_tau_prime_samples,
306
            num_quantile_samples,
307
            cumulative_gamma,
308
            rng,
309
            tau,
310
            alpha,
311
            clip_value_min))
312
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
313
  rng, rng_input = jax.random.split(rng)
314
  all_losses, grad = grad_fn(online_params, rng_input, target_quantile_vals,
315
                             target_r, target_next_r)
316
  loss, component_losses = all_losses
317
  quantile_loss, kernel_loss = component_losses
318
  updates, optimizer_state = optimizer.update(grad, optimizer_state)
319
  online_params = optax.apply_updates(online_params, updates)
320
  return rng, optimizer_state, online_params, loss, quantile_loss, kernel_loss
321

322

323
@gin.configurable
324
class KSMeImplicitQuantileAgent(
325
    implicit_quantile_agent.JaxImplicitQuantileAgent):
326
  """Implicit Quantile Agent with the KSMe loss."""
327

328
  def __init__(self, num_actions, summary_writer=None,
329
               mico_weight=0.5, distance_fn='dot',
330
               similarity_fn='dot',
331
               tau=None, alpha=0.9, clip_value_min=-1):
332
    self._mico_weight = mico_weight
333
    if distance_fn == 'cosine':
334
      self._distance_fn = metric_utils.cosine_distance
335
    elif distance_fn == 'dot':
336
      self._distance_fn = metric_utils.l2
337
    else:
338
      raise ValueError(f'Unknown distance function: {distance_fn}')
339

340
    if similarity_fn == 'cosine':
341
      self._similarity_fn = metric_utils.cosine_similarity
342
    elif similarity_fn == 'dot':
343
      self._similarity_fn = metric_utils.dot
344
    else:
345
      raise ValueError(f'Unknown similarity function: {similarity_fn}')
346

347
    self._tau = tau
348
    self._alpha = alpha
349
    self._clip_value_min = clip_value_min
350
    super().__init__(num_actions, network=AtariImplicitQuantileNetwork,
351
                     summary_writer=summary_writer)
352
    logging.info('\t mico_weight: %f', mico_weight)
353
    logging.info('\t distance_fn: %s', distance_fn)
354
    logging.info('\t similarity_fn: %s', similarity_fn)
355

356
  def _train_step(self):
357
    """Runs a single training step."""
358
    if self._replay.add_count > self.min_replay_history:
359
      if self.training_steps % self.update_period == 0:
360
        self._sample_from_replay_buffer()
361
        (self._rng, self.optimizer_state, self.online_params,
362
         loss, quantile_loss, kernel_loss) = train(
363
             self.network_def,
364
             self.online_params,
365
             self.target_network_params,
366
             self.optimizer,
367
             self.optimizer_state,
368
             self.replay_elements['state'],
369
             self.replay_elements['action'],
370
             self.replay_elements['next_state'],
371
             self.replay_elements['reward'],
372
             self.replay_elements['terminal'],
373
             self.num_tau_samples,
374
             self.num_tau_prime_samples,
375
             self.num_quantile_samples,
376
             self.cumulative_gamma,
377
             self.double_dqn,
378
             self.kappa,
379
             self._rng,
380
             self._mico_weight,
381
             self._distance_fn,
382
             self._similarity_fn,
383
             self._tau,
384
             self._alpha,
385
             self._clip_value_min)
386
        if (self.summary_writer is not None and
387
            self.training_steps > 0 and
388
            self.training_steps % self.summary_writing_frequency == 0):
389
          if hasattr(self, 'collector_dispatcher'):
390
            self.collector_dispatcher.write(
391
                [statistics_instance.StatisticsInstance(
392
                    'Losses/Aggregate', np.asarray(loss),
393
                    step=self.training_steps),
394
                 statistics_instance.StatisticsInstance(
395
                     'Losses/Quantile', np.asarray(quantile_loss),
396
                     step=self.training_steps),
397
                 statistics_instance.StatisticsInstance(
398
                     'Losses/Metric', np.asarray(kernel_loss),
399
                     step=self.training_steps),
400
                 ],
401
                collector_allowlist=self._collector_allowlist)
402
      if self.training_steps % self.target_update_period == 0:
403
        self._sync_weights()
404

405
    self.training_steps += 1
406

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.