google-research

Форк
0
/
ksme_rainbow_agent.py 
255 строк · 10.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Rainbow Agent with the KSMe loss."""
17

18
import collections
19
import functools
20

21
from absl import logging
22
from dopamine.jax import losses
23
from dopamine.jax.agents.rainbow import rainbow_agent
24
from dopamine.metrics import statistics_instance
25
from flax import linen as nn
26
import gin
27
import jax
28
import jax.numpy as jnp
29
import numpy as np
30
import optax
31
from ksme.atari import metric_utils
32

33

34
NetworkType = collections.namedtuple(
35
    'network', ['q_values', 'logits', 'probabilities', 'representation'])
36

37

38
@gin.configurable
39
class AtariRainbowNetwork(nn.Module):
40
  """Convolutional network used to compute the agent's return distributions."""
41
  num_actions: int
42
  num_atoms: int
43

44
  @nn.compact
45
  def __call__(self, x, support):
46
    initializer = jax.nn.initializers.variance_scaling(
47
        scale=1.0 / jnp.sqrt(3.0),
48
        mode='fan_in',
49
        distribution='uniform')
50
    x = x.astype(jnp.float32) / 255.
51
    x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4),
52
                kernel_init=initializer)(x)
53
    x = nn.relu(x)
54
    x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2),
55
                kernel_init=initializer)(x)
56
    x = nn.relu(x)
57
    x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1),
58
                kernel_init=initializer)(x)
59
    x = nn.relu(x)
60
    representation = x.reshape(-1)  # flatten
61
    x = nn.Dense(features=512, kernel_init=initializer)(representation)
62
    x = nn.relu(x)
63
    x = nn.Dense(features=self.num_actions * self.num_atoms,
64
                 kernel_init=initializer)(x)
65
    logits = x.reshape((self.num_actions, self.num_atoms))
66
    probabilities = nn.softmax(logits)
67
    q_values = jnp.sum(support * probabilities, axis=1)
68
    return NetworkType(q_values, logits, probabilities, representation)
69

70

71
@functools.partial(jax.jit, static_argnums=(0, 3, 12, 13, 14, 15))
72
def train(network_def, online_params, target_params, optimizer, optimizer_state,
73
          states, actions, next_states, rewards, terminals, loss_weights,
74
          support, cumulative_gamma, mico_weight, distance_fn, similarity_fn):
75
  """Run a training step."""
76
  def loss_fn(params, bellman_target, loss_multipliers, target_r,
77
              target_next_r):
78
    def q_online(state):
79
      return network_def.apply(params, state, support)
80

81
    model_output = jax.vmap(q_online)(states)
82
    logits = model_output.logits
83
    logits = jnp.squeeze(logits)
84
    representations = model_output.representation
85
    representations = jnp.squeeze(representations)
86
    # Fetch the logits for its selected action. We use vmap to perform this
87
    # indexing across the batch.
88
    chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)
89
    c51_loss = jax.vmap(losses.softmax_cross_entropy_loss_with_logits)(
90
        bellman_target,
91
        chosen_action_logits)
92
    c51_loss *= loss_multipliers
93
    online_similarities, norm_sum, repr_distances = (
94
        metric_utils.representation_similarities(
95
            representations, target_r, distance_fn,
96
            similarity_fn, return_distance_components=True))
97
    target_similarities = metric_utils.target_similarities(
98
        target_next_r, rewards, distance_fn, similarity_fn, cumulative_gamma)
99
    kernel_loss = jnp.mean(jax.vmap(losses.huber_loss)(online_similarities,
100
                                                       target_similarities))
101
    loss = ((1. - mico_weight) * c51_loss +
102
            mico_weight * kernel_loss)
103
    aux_losses = {
104
        'loss': loss,
105
        'mean_loss': jnp.mean(loss),
106
        'c51_loss': jnp.mean(c51_loss),
107
        'kernel_loss': kernel_loss,
108
        'norm_sum': jnp.mean(norm_sum),
109
        'repr_distances': jnp.mean(repr_distances),
110
        'online_similarities': jnp.mean(online_similarities),
111
    }
112
    return jnp.mean(loss), aux_losses
113

114
  def q_target(state):
115
    return network_def.apply(target_params, state, support)
116

117
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
118
  bellman_target, target_r, target_next_r = target_distribution(
119
      q_target,
120
      states,
121
      next_states,
122
      rewards,
123
      terminals,
124
      support,
125
      cumulative_gamma)
126
  (_, aux_losses), grad = grad_fn(online_params, bellman_target,
127
                                  loss_weights, target_r, target_next_r)
128
  updates, optimizer_state = optimizer.update(grad, optimizer_state)
129
  online_params = optax.apply_updates(online_params, updates)
130
  return optimizer_state, online_params, aux_losses
131

132

133
@functools.partial(jax.vmap, in_axes=(None, 0, 0, 0, 0, None, None))
134
def target_distribution(target_network, states, next_states, rewards, terminals,
135
                        support, cumulative_gamma):
136
  """Builds the C51 target distribution as per Bellemare et al. (2017)."""
137
  curr_state_representation = target_network(states).representation
138
  curr_state_representation = jnp.squeeze(curr_state_representation)
139
  is_terminal_multiplier = 1. - terminals.astype(jnp.float32)
140
  # Incorporate terminal state to discount factor.
141
  gamma_with_terminal = cumulative_gamma * is_terminal_multiplier
142
  target_support = rewards + gamma_with_terminal * support
143
  next_state_target_outputs = target_network(next_states)
144
  q_values = jnp.squeeze(next_state_target_outputs.q_values)
145
  next_qt_argmax = jnp.argmax(q_values)
146
  probabilities = jnp.squeeze(next_state_target_outputs.probabilities)
147
  next_probabilities = probabilities[next_qt_argmax]
148
  next_state_representation = next_state_target_outputs.representation
149
  next_state_representation = jnp.squeeze(next_state_representation)
150
  return (
151
      jax.lax.stop_gradient(rainbow_agent.project_distribution(
152
          target_support, next_probabilities, support)),
153
      jax.lax.stop_gradient(curr_state_representation),
154
      jax.lax.stop_gradient(next_state_representation))
155

156

157
@gin.configurable
158
class KSMeRainbowAgent(rainbow_agent.JaxRainbowAgent):
159
  """Rainbow Agent with the KSMe loss."""
160

161
  def __init__(self, num_actions, summary_writer=None,
162
               mico_weight=0.01, distance_fn='dot',
163
               similarity_fn='dot'):
164
    self._mico_weight = mico_weight
165
    if distance_fn == 'cosine':
166
      self._distance_fn = metric_utils.cosine_distance
167
    elif distance_fn == 'dot':
168
      self._distance_fn = metric_utils.l2
169
    else:
170
      raise ValueError(f'Unknown distance function: {distance_fn}')
171

172
    if similarity_fn == 'cosine':
173
      self._similarity_fn = metric_utils.cosine_similarity
174
    elif similarity_fn == 'dot':
175
      self._similarity_fn = metric_utils.dot
176
    else:
177
      raise ValueError(f'Unknown similarity function: {similarity_fn}')
178

179
    network = AtariRainbowNetwork
180
    super().__init__(num_actions, network=network,
181
                     summary_writer=summary_writer)
182
    logging.info('\t mico_weight: %f', mico_weight)
183
    logging.info('\t distance_fn: %s', distance_fn)
184
    logging.info('\t similarity_fn: %s', similarity_fn)
185

186
  def _train_step(self):
187
    """Runs a single training step."""
188
    if self._replay.add_count > self.min_replay_history:
189
      if self.training_steps % self.update_period == 0:
190
        self._sample_from_replay_buffer()
191

192
        if self._replay_scheme == 'prioritized':
193
          # The original prioritized experience replay uses a linear exponent
194
          # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of
195
          # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders)
196
          # suggested a fixed exponent actually performs better, except on Pong.
197
          probs = self.replay_elements['sampling_probabilities']
198
          # Weight the loss by the inverse priorities.
199
          loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
200
          loss_weights /= jnp.max(loss_weights)
201
        else:
202
          loss_weights = jnp.ones(self.replay_elements['state'].shape[0])
203

204
        self.optimizer_state, self.online_params, aux_losses = train(
205
            self.network_def,
206
            self.online_params,
207
            self.target_network_params,
208
            self.optimizer,
209
            self.optimizer_state,
210
            self.replay_elements['state'],
211
            self.replay_elements['action'],
212
            self.replay_elements['next_state'],
213
            self.replay_elements['reward'],
214
            self.replay_elements['terminal'],
215
            loss_weights,
216
            self._support,
217
            self.cumulative_gamma,
218
            self._mico_weight,
219
            self._distance_fn,
220
            self._similarity_fn)
221

222
        loss = aux_losses.pop('loss')
223
        if self._replay_scheme == 'prioritized':
224
          # Rainbow and prioritized replay are parametrized by an exponent
225
          # alpha, but in both cases it is set to 0.5 - for simplicity's sake we
226
          # leave it as is here, using the more direct sqrt(). Taking the square
227
          # root "makes sense", as we are dealing with a squared loss.  Add a
228
          # small nonzero value to the loss to avoid 0 priority items. While
229
          # technically this may be okay, setting all items to 0 priority will
230
          # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms.
231
          self._replay.set_priority(self.replay_elements['indices'],
232
                                    jnp.sqrt(loss + 1e-10))
233

234
        if self._replay_scheme == 'prioritized':
235
          probs = self.replay_elements['sampling_probabilities']
236
          loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
237
          loss_weights /= jnp.max(loss_weights)
238
          self._replay.set_priority(self.replay_elements['indices'],
239
                                    jnp.sqrt(loss + 1e-10))
240
          loss = loss_weights * loss
241
        if (self.summary_writer is not None and
242
            self.training_steps > 0 and
243
            self.training_steps % self.summary_writing_frequency == 0):
244
          if hasattr(self, 'collector_dispatcher'):
245
            stats = []
246
            for k in aux_losses:
247
              stats.append(statistics_instance.StatisticsInstance(
248
                  f'Losses/{k}', np.asarray(aux_losses[k]),
249
                  step=self.training_steps))
250
            self.collector_dispatcher.write(
251
                stats, collector_allowlist=self._collector_allowlist)
252
      if self.training_steps % self.target_update_period == 0:
253
        self._sync_weights()
254

255
    self.training_steps += 1
256

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.