google-research
255 строк · 10.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Rainbow Agent with the KSMe loss."""
17
18import collections19import functools20
21from absl import logging22from dopamine.jax import losses23from dopamine.jax.agents.rainbow import rainbow_agent24from dopamine.metrics import statistics_instance25from flax import linen as nn26import gin27import jax28import jax.numpy as jnp29import numpy as np30import optax31from ksme.atari import metric_utils32
33
34NetworkType = collections.namedtuple(35'network', ['q_values', 'logits', 'probabilities', 'representation'])36
37
38@gin.configurable39class AtariRainbowNetwork(nn.Module):40"""Convolutional network used to compute the agent's return distributions."""41num_actions: int42num_atoms: int43
44@nn.compact45def __call__(self, x, support):46initializer = jax.nn.initializers.variance_scaling(47scale=1.0 / jnp.sqrt(3.0),48mode='fan_in',49distribution='uniform')50x = x.astype(jnp.float32) / 255.51x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4),52kernel_init=initializer)(x)53x = nn.relu(x)54x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2),55kernel_init=initializer)(x)56x = nn.relu(x)57x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1),58kernel_init=initializer)(x)59x = nn.relu(x)60representation = x.reshape(-1) # flatten61x = nn.Dense(features=512, kernel_init=initializer)(representation)62x = nn.relu(x)63x = nn.Dense(features=self.num_actions * self.num_atoms,64kernel_init=initializer)(x)65logits = x.reshape((self.num_actions, self.num_atoms))66probabilities = nn.softmax(logits)67q_values = jnp.sum(support * probabilities, axis=1)68return NetworkType(q_values, logits, probabilities, representation)69
70
71@functools.partial(jax.jit, static_argnums=(0, 3, 12, 13, 14, 15))72def train(network_def, online_params, target_params, optimizer, optimizer_state,73states, actions, next_states, rewards, terminals, loss_weights,74support, cumulative_gamma, mico_weight, distance_fn, similarity_fn):75"""Run a training step."""76def loss_fn(params, bellman_target, loss_multipliers, target_r,77target_next_r):78def q_online(state):79return network_def.apply(params, state, support)80
81model_output = jax.vmap(q_online)(states)82logits = model_output.logits83logits = jnp.squeeze(logits)84representations = model_output.representation85representations = jnp.squeeze(representations)86# Fetch the logits for its selected action. We use vmap to perform this87# indexing across the batch.88chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)89c51_loss = jax.vmap(losses.softmax_cross_entropy_loss_with_logits)(90bellman_target,91chosen_action_logits)92c51_loss *= loss_multipliers93online_similarities, norm_sum, repr_distances = (94metric_utils.representation_similarities(95representations, target_r, distance_fn,96similarity_fn, return_distance_components=True))97target_similarities = metric_utils.target_similarities(98target_next_r, rewards, distance_fn, similarity_fn, cumulative_gamma)99kernel_loss = jnp.mean(jax.vmap(losses.huber_loss)(online_similarities,100target_similarities))101loss = ((1. - mico_weight) * c51_loss +102mico_weight * kernel_loss)103aux_losses = {104'loss': loss,105'mean_loss': jnp.mean(loss),106'c51_loss': jnp.mean(c51_loss),107'kernel_loss': kernel_loss,108'norm_sum': jnp.mean(norm_sum),109'repr_distances': jnp.mean(repr_distances),110'online_similarities': jnp.mean(online_similarities),111}112return jnp.mean(loss), aux_losses113
114def q_target(state):115return network_def.apply(target_params, state, support)116
117grad_fn = jax.value_and_grad(loss_fn, has_aux=True)118bellman_target, target_r, target_next_r = target_distribution(119q_target,120states,121next_states,122rewards,123terminals,124support,125cumulative_gamma)126(_, aux_losses), grad = grad_fn(online_params, bellman_target,127loss_weights, target_r, target_next_r)128updates, optimizer_state = optimizer.update(grad, optimizer_state)129online_params = optax.apply_updates(online_params, updates)130return optimizer_state, online_params, aux_losses131
132
133@functools.partial(jax.vmap, in_axes=(None, 0, 0, 0, 0, None, None))134def target_distribution(target_network, states, next_states, rewards, terminals,135support, cumulative_gamma):136"""Builds the C51 target distribution as per Bellemare et al. (2017)."""137curr_state_representation = target_network(states).representation138curr_state_representation = jnp.squeeze(curr_state_representation)139is_terminal_multiplier = 1. - terminals.astype(jnp.float32)140# Incorporate terminal state to discount factor.141gamma_with_terminal = cumulative_gamma * is_terminal_multiplier142target_support = rewards + gamma_with_terminal * support143next_state_target_outputs = target_network(next_states)144q_values = jnp.squeeze(next_state_target_outputs.q_values)145next_qt_argmax = jnp.argmax(q_values)146probabilities = jnp.squeeze(next_state_target_outputs.probabilities)147next_probabilities = probabilities[next_qt_argmax]148next_state_representation = next_state_target_outputs.representation149next_state_representation = jnp.squeeze(next_state_representation)150return (151jax.lax.stop_gradient(rainbow_agent.project_distribution(152target_support, next_probabilities, support)),153jax.lax.stop_gradient(curr_state_representation),154jax.lax.stop_gradient(next_state_representation))155
156
157@gin.configurable158class KSMeRainbowAgent(rainbow_agent.JaxRainbowAgent):159"""Rainbow Agent with the KSMe loss."""160
161def __init__(self, num_actions, summary_writer=None,162mico_weight=0.01, distance_fn='dot',163similarity_fn='dot'):164self._mico_weight = mico_weight165if distance_fn == 'cosine':166self._distance_fn = metric_utils.cosine_distance167elif distance_fn == 'dot':168self._distance_fn = metric_utils.l2169else:170raise ValueError(f'Unknown distance function: {distance_fn}')171
172if similarity_fn == 'cosine':173self._similarity_fn = metric_utils.cosine_similarity174elif similarity_fn == 'dot':175self._similarity_fn = metric_utils.dot176else:177raise ValueError(f'Unknown similarity function: {similarity_fn}')178
179network = AtariRainbowNetwork180super().__init__(num_actions, network=network,181summary_writer=summary_writer)182logging.info('\t mico_weight: %f', mico_weight)183logging.info('\t distance_fn: %s', distance_fn)184logging.info('\t similarity_fn: %s', similarity_fn)185
186def _train_step(self):187"""Runs a single training step."""188if self._replay.add_count > self.min_replay_history:189if self.training_steps % self.update_period == 0:190self._sample_from_replay_buffer()191
192if self._replay_scheme == 'prioritized':193# The original prioritized experience replay uses a linear exponent194# schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of195# 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders)196# suggested a fixed exponent actually performs better, except on Pong.197probs = self.replay_elements['sampling_probabilities']198# Weight the loss by the inverse priorities.199loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)200loss_weights /= jnp.max(loss_weights)201else:202loss_weights = jnp.ones(self.replay_elements['state'].shape[0])203
204self.optimizer_state, self.online_params, aux_losses = train(205self.network_def,206self.online_params,207self.target_network_params,208self.optimizer,209self.optimizer_state,210self.replay_elements['state'],211self.replay_elements['action'],212self.replay_elements['next_state'],213self.replay_elements['reward'],214self.replay_elements['terminal'],215loss_weights,216self._support,217self.cumulative_gamma,218self._mico_weight,219self._distance_fn,220self._similarity_fn)221
222loss = aux_losses.pop('loss')223if self._replay_scheme == 'prioritized':224# Rainbow and prioritized replay are parametrized by an exponent225# alpha, but in both cases it is set to 0.5 - for simplicity's sake we226# leave it as is here, using the more direct sqrt(). Taking the square227# root "makes sense", as we are dealing with a squared loss. Add a228# small nonzero value to the loss to avoid 0 priority items. While229# technically this may be okay, setting all items to 0 priority will230# cause troubles, and also result in 1.0 / 0.0 = NaN correction terms.231self._replay.set_priority(self.replay_elements['indices'],232jnp.sqrt(loss + 1e-10))233
234if self._replay_scheme == 'prioritized':235probs = self.replay_elements['sampling_probabilities']236loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)237loss_weights /= jnp.max(loss_weights)238self._replay.set_priority(self.replay_elements['indices'],239jnp.sqrt(loss + 1e-10))240loss = loss_weights * loss241if (self.summary_writer is not None and242self.training_steps > 0 and243self.training_steps % self.summary_writing_frequency == 0):244if hasattr(self, 'collector_dispatcher'):245stats = []246for k in aux_losses:247stats.append(statistics_instance.StatisticsInstance(248f'Losses/{k}', np.asarray(aux_losses[k]),249step=self.training_steps))250self.collector_dispatcher.write(251stats, collector_allowlist=self._collector_allowlist)252if self.training_steps % self.target_update_period == 0:253self._sync_weights()254
255self.training_steps += 1256