google-research

Форк
0
/
ksme_quantile_agent.py 
231 строка · 9.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Quantile regression agent with KSMe loss."""
17

18
import collections
19
import functools
20

21
from absl import logging
22
from dopamine.jax import losses
23
from dopamine.jax.agents.quantile import quantile_agent
24
from dopamine.metrics import statistics_instance
25
from flax import linen as nn
26
import gin
27
import jax
28
import jax.numpy as jnp
29
import numpy as np
30
import optax
31
from ksme.atari import ksme_rainbow_agent
32
from ksme.atari import metric_utils
33

34

35
NetworkType = collections.namedtuple(
36
    'network', ['q_values', 'logits', 'probabilities', 'representation'])
37

38

39
@gin.configurable
40
class AtariQuantileNetwork(nn.Module):
41
  """Convolutional network used to compute the agent's return quantiles."""
42
  num_actions: int
43
  num_atoms: int
44

45
  @nn.compact
46
  def __call__(self, x):
47
    initializer = nn.initializers.variance_scaling(
48
        scale=1.0 / jnp.sqrt(3.0),
49
        mode='fan_in',
50
        distribution='uniform')
51
    x = x.astype(jnp.float32) / 255.
52
    x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4),
53
                kernel_init=initializer)(x)
54
    x = nn.relu(x)
55
    x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2),
56
                kernel_init=initializer)(x)
57
    x = nn.relu(x)
58
    x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1),
59
                kernel_init=initializer)(x)
60
    x = nn.relu(x)
61
    representation = x.reshape(-1)  # flatten
62
    x = nn.Dense(features=512, kernel_init=initializer)(representation)
63
    x = nn.relu(x)
64
    x = nn.Dense(features=self.num_actions * self.num_atoms,
65
                 kernel_init=initializer)(x)
66
    logits = x.reshape((self.num_actions, self.num_atoms))
67
    probabilities = nn.softmax(logits)
68
    q_values = jnp.mean(logits, axis=1)
69
    return ksme_rainbow_agent.NetworkType(q_values, logits, probabilities,
70
                                          representation)
71

72

73
@functools.partial(jax.vmap, in_axes=(None, 0, 0, 0, 0, None))
74
def target_distribution(target_network, states, next_states, rewards, terminals,
75
                        cumulative_gamma):
76
  """Builds the Quantile target distribution as per Dabney et al. (2017)."""
77
  curr_state_representation = target_network(states).representation
78
  curr_state_representation = jnp.squeeze(curr_state_representation)
79
  is_terminal_multiplier = 1. - terminals.astype(jnp.float32)
80
  # Incorporate terminal state to discount factor.
81
  gamma_with_terminal = cumulative_gamma * is_terminal_multiplier
82
  next_state_target_outputs = target_network(next_states)
83
  q_values = jnp.squeeze(next_state_target_outputs.q_values)
84
  next_qt_argmax = jnp.argmax(q_values)
85
  logits = jnp.squeeze(next_state_target_outputs.logits)
86
  next_logits = logits[next_qt_argmax]
87
  next_state_representation = next_state_target_outputs.representation
88
  next_state_representation = jnp.squeeze(next_state_representation)
89
  return (
90
      jax.lax.stop_gradient(rewards + gamma_with_terminal * next_logits),
91
      jax.lax.stop_gradient(curr_state_representation),
92
      jax.lax.stop_gradient(next_state_representation))
93

94

95
@functools.partial(jax.jit, static_argnums=(0, 3, 10, 11, 12, 13, 14, 15))
96
def train(network_def, online_params, target_params, optimizer, optimizer_state,
97
          states, actions, next_states, rewards, terminals, kappa, num_atoms,
98
          cumulative_gamma, mico_weight, distance_fn, similarity_fn):
99
  """Run a training step."""
100
  def loss_fn(params, bellman_target, target_r, target_next_r):
101
    def q_online(state):
102
      return network_def.apply(params, state)
103

104
    model_output = jax.vmap(q_online)(states)
105
    logits = model_output.logits
106
    logits = jnp.squeeze(logits)
107
    representations = model_output.representation
108
    representations = jnp.squeeze(representations)
109
    # Fetch the logits for its selected action. We use vmap to perform this
110
    # indexing across the batch.
111
    chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)
112
    bellman_errors = (bellman_target[:, None, :] -
113
                      chosen_action_logits[:, :, None])  # Input `u' of Eq. 9.
114
    # Eq. 9 of paper.
115
    huber_loss = (
116
        (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) *
117
        0.5 * bellman_errors ** 2 +
118
        (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) *
119
        kappa * (jnp.abs(bellman_errors) - 0.5 * kappa))
120

121
    tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) /
122
               num_atoms)  # Quantile midpoints.  See Lemma 2 of paper.
123
    # Eq. 10 of paper.
124
    tau_bellman_diff = jnp.abs(
125
        tau_hat[None, :, None] - (bellman_errors < 0).astype(jnp.float32))
126
    quantile_huber_loss = tau_bellman_diff * huber_loss
127
    # Sum over tau dimension, average over target value dimension.
128
    quantile_loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1)
129
    online_similarities = metric_utils.representation_similarities(
130
        representations, target_r, distance_fn, similarity_fn)
131
    target_similarities = metric_utils.target_similarities(
132
        target_next_r, rewards, distance_fn, similarity_fn, cumulative_gamma)
133
    kernel_loss = jnp.mean(jax.vmap(losses.huber_loss)(online_similarities,
134
                                                       target_similarities))
135
    loss = ((1. - mico_weight) * quantile_loss +
136
            mico_weight * kernel_loss)
137
    return jnp.mean(loss), (loss, jnp.mean(quantile_loss), kernel_loss)
138

139
  def q_target(state):
140
    return network_def.apply(target_params, state)
141

142
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
143
  bellman_target, target_r, target_next_r = target_distribution(
144
      q_target, states, next_states, rewards, terminals, cumulative_gamma)
145
  all_losses, grad = grad_fn(online_params, bellman_target, target_r,
146
                             target_next_r)
147
  mean_loss, component_losses = all_losses
148
  loss, quantile_loss, kernel_loss = component_losses
149
  updates, optimizer_state = optimizer.update(grad, optimizer_state)
150
  online_params = optax.apply_updates(online_params, updates)
151
  return (optimizer_state, online_params, loss, mean_loss, quantile_loss,
152
          kernel_loss)
153

154

155
@gin.configurable
156
class KSMeQuantileAgent(quantile_agent.JaxQuantileAgent):
157
  """Quantile Agent with the KSMe loss."""
158

159
  def __init__(self, num_actions, summary_writer=None,
160
               mico_weight=0.5, distance_fn='dot',
161
               similarity_fn='dot'):
162
    self._mico_weight = mico_weight
163
    if distance_fn == 'cosine':
164
      self._distance_fn = metric_utils.cosine_distance
165
    elif distance_fn == 'dot':
166
      self._distance_fn = metric_utils.l2
167
    else:
168
      raise ValueError(f'Unknown distance function: {distance_fn}')
169

170
    if similarity_fn == 'cosine':
171
      self._similarity_fn = metric_utils.cosine_similarity
172
    elif similarity_fn == 'dot':
173
      self._similarity_fn = metric_utils.dot
174
    else:
175
      raise ValueError(f'Unknown similarity function: {similarity_fn}')
176

177
    network = AtariQuantileNetwork
178
    super().__init__(num_actions, network=network,
179
                     summary_writer=summary_writer)
180
    logging.info('\t mico_weight: %f', mico_weight)
181
    logging.info('\t distance_fn: %s', distance_fn)
182
    logging.info('\t similarity_fn: %s', similarity_fn)
183

184
  def _train_step(self):
185
    """Runs a single training step."""
186
    if self._replay.add_count > self.min_replay_history:
187
      if self.training_steps % self.update_period == 0:
188
        self._sample_from_replay_buffer()
189
        (self.optimizer_state, self.online_params,
190
         loss, mean_loss, quantile_loss, kernel_loss) = train(
191
             self.network_def,
192
             self.online_params,
193
             self.target_network_params,
194
             self.optimizer,
195
             self.optimizer_state,
196
             self.replay_elements['state'],
197
             self.replay_elements['action'],
198
             self.replay_elements['next_state'],
199
             self.replay_elements['reward'],
200
             self.replay_elements['terminal'],
201
             self._kappa,
202
             self._num_atoms,
203
             self.cumulative_gamma,
204
             self._mico_weight,
205
             self._distance_fn,
206
             self._similarity_fn)
207
        if self._replay_scheme == 'prioritized':
208
          probs = self.replay_elements['sampling_probabilities']
209
          loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
210
          loss_weights /= jnp.max(loss_weights)
211
          self._replay.set_priority(self.replay_elements['indices'],
212
                                    jnp.sqrt(loss + 1e-10))
213
          loss = loss_weights * loss
214
          mean_loss = jnp.mean(loss)
215
        if hasattr(self, 'collector_dispatcher'):
216
          self.collector_dispatcher.write(
217
              [statistics_instance.StatisticsInstance(
218
                  'Losses/Aggregate', np.asarray(mean_loss),
219
                  step=self.training_steps),
220
               statistics_instance.StatisticsInstance(
221
                   'Losses/Quantile', np.asarray(quantile_loss),
222
                   step=self.training_steps),
223
               statistics_instance.StatisticsInstance(
224
                   'Losses/Metric', np.asarray(kernel_loss),
225
                   step=self.training_steps),
226
               ],
227
              collector_allowlist=self._collector_allowlist)
228
      if self.training_steps % self.target_update_period == 0:
229
        self._sync_weights()
230

231
    self.training_steps += 1
232

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.