google-research
405 строк · 17.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implicit Quantile agent with KSMe loss."""
17
18import collections19import functools20
21from absl import logging22from dopamine.jax import losses23from dopamine.jax.agents.implicit_quantile import implicit_quantile_agent24from dopamine.metrics import statistics_instance25from flax import linen as nn26import gin27import jax28import jax.numpy as jnp29import numpy as np30import optax31from ksme.atari import metric_utils32
33NetworkType = collections.namedtuple(34'network', ['quantile_values', 'quantiles', 'representation'])35
36
37def stable_scaled_log_softmax(x, tau, axis=-1):38max_x = jnp.amax(x, axis=axis, keepdims=True)39y = x - max_x40tau_lse = max_x + tau * jnp.log(41jnp.sum(jnp.exp(y / tau), axis=axis, keepdims=True))42return x - tau_lse43
44
45def stable_softmax(x, tau, axis=-1):46max_x = jnp.amax(x, axis=axis, keepdims=True)47y = x - max_x48return jax.nn.softmax(y / tau, axis=axis)49
50
51class AtariImplicitQuantileNetwork(nn.Module):52"""The Implicit Quantile Network (Dabney et al., 2018).."""53num_actions: int54quantile_embedding_dim: int55
56@nn.compact57def __call__(self, x, num_quantiles, rng):58initializer = jax.nn.initializers.variance_scaling(59scale=1.0 / jnp.sqrt(3.0),60mode='fan_in',61distribution='uniform')62x = x.astype(jnp.float32) / 255.63x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4),64kernel_init=initializer)(x)65x = nn.relu(x)66x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2),67kernel_init=initializer)(x)68x = nn.relu(x)69x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1),70kernel_init=initializer)(x)71x = nn.relu(x)72representation = x.reshape((-1)) # flatten73state_vector_length = representation.shape[-1]74state_net_tiled = jnp.tile(representation, [num_quantiles, 1])75quantiles_shape = [num_quantiles, 1]76quantiles = jax.random.uniform(rng, shape=quantiles_shape)77quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim])78quantile_net = (79jnp.arange(1, self.quantile_embedding_dim + 1, 1).astype(jnp.float32)80* np.pi81* quantile_net)82quantile_net = jnp.cos(quantile_net)83quantile_net = nn.Dense(features=state_vector_length,84kernel_init=initializer)(quantile_net)85quantile_net = nn.relu(quantile_net)86x = state_net_tiled * quantile_net87x = nn.Dense(features=512, kernel_init=initializer)(x)88x = nn.relu(x)89quantile_values = nn.Dense(features=self.num_actions,90kernel_init=initializer)(x)91return NetworkType(quantile_values, quantiles, representation)92
93
94@functools.partial(95jax.vmap,96in_axes=(None, None, 0, 0, 0, 0, 0, None, None, None, None, None,97None, None),98out_axes=(None, 0, 0, 0))99def munchausen_target_quantile_values(network, target_params, states,100actions, next_states, rewards, terminals,101num_tau_prime_samples,102num_quantile_samples, cumulative_gamma,103rng, tau, alpha, clip_value_min):104"""Build the munchausen target for return values at given quantiles."""105rng, rng1, rng2, rng3 = jax.random.split(rng, num=4)106target_action = network.apply(107target_params, states, num_quantiles=num_quantile_samples, rng=rng1)108curr_state_representation = target_action.representation109curr_state_representation = jnp.squeeze(curr_state_representation)110is_terminal_multiplier = 1. - terminals.astype(jnp.float32)111# Incorporate terminal state to discount factor.112gamma_with_terminal = cumulative_gamma * is_terminal_multiplier113gamma_with_terminal = jnp.tile(gamma_with_terminal, [num_tau_prime_samples])114
115replay_net_target_outputs = network.apply(116target_params, next_states, num_quantiles=num_tau_prime_samples,117rng=rng2)118replay_quantile_values = replay_net_target_outputs.quantile_values119
120target_next_action = network.apply(target_params,121next_states,122num_quantiles=num_quantile_samples,123rng=rng3)124target_next_quantile_values_action = target_next_action.quantile_values125replay_next_target_q_values = jnp.squeeze(126jnp.mean(target_next_quantile_values_action, axis=0))127
128q_state_values = target_action.quantile_values129replay_target_q_values = jnp.squeeze(jnp.mean(q_state_values, axis=0))130
131num_actions = q_state_values.shape[-1]132replay_action_one_hot = jax.nn.one_hot(actions, num_actions)133replay_next_log_policy = stable_scaled_log_softmax(134replay_next_target_q_values, tau, axis=0)135replay_next_policy = stable_softmax(136replay_next_target_q_values, tau, axis=0)137replay_log_policy = stable_scaled_log_softmax(replay_target_q_values,138tau, axis=0)139
140tau_log_pi_a = jnp.sum(replay_log_policy * replay_action_one_hot, axis=0)141tau_log_pi_a = jnp.clip(tau_log_pi_a, a_min=clip_value_min, a_max=1)142munchausen_term = alpha * tau_log_pi_a143weighted_logits = (144replay_next_policy * (replay_quantile_values -145replay_next_log_policy))146
147target_quantile_vals = jnp.sum(weighted_logits, axis=1)148rewards += munchausen_term149rewards = jnp.tile(rewards, [num_tau_prime_samples])150target_quantile_vals = (151rewards + gamma_with_terminal * target_quantile_vals)152next_state_representation = target_next_action.representation153next_state_representation = jnp.squeeze(next_state_representation)154
155return (156rng,157jax.lax.stop_gradient(target_quantile_vals[:, None]),158jax.lax.stop_gradient(curr_state_representation),159jax.lax.stop_gradient(next_state_representation))160
161
162@functools.partial(163jax.vmap,164in_axes=(None, None, None, 0, 0, 0, 0, None, None, None, None, None),165out_axes=(None, 0, 0, 0))166def target_quantile_values(network, online_params, target_params, states,167next_states, rewards, terminals,168num_tau_prime_samples, num_quantile_samples,169cumulative_gamma, double_dqn, rng):170"""Build the target for return values at given quantiles."""171rng, rng1, rng2, rng3 = jax.random.split(rng, num=4)172curr_state_representation = network.apply(173target_params, states, num_quantiles=num_quantile_samples,174rng=rng3).representation175curr_state_representation = jnp.squeeze(curr_state_representation)176rewards = jnp.tile(rewards, [num_tau_prime_samples])177is_terminal_multiplier = 1. - terminals.astype(jnp.float32)178# Incorporate terminal state to discount factor.179gamma_with_terminal = cumulative_gamma * is_terminal_multiplier180gamma_with_terminal = jnp.tile(gamma_with_terminal, [num_tau_prime_samples])181# Compute Q-values which are used for action selection for the next states182# in the replay buffer. Compute the argmax over the Q-values.183if double_dqn:184outputs_action = network.apply(online_params,185next_states,186num_quantiles=num_quantile_samples,187rng=rng1)188else:189outputs_action = network.apply(target_params,190next_states,191num_quantiles=num_quantile_samples,192rng=rng1)193target_quantile_values_action = outputs_action.quantile_values194target_q_values = jnp.squeeze(195jnp.mean(target_quantile_values_action, axis=0))196# Shape: batch_size.197next_qt_argmax = jnp.argmax(target_q_values)198# Get the indices of the maximium Q-value across the action dimension.199# Shape of next_qt_argmax: (num_tau_prime_samples x batch_size).200next_state_target_outputs = network.apply(201target_params,202next_states,203num_quantiles=num_tau_prime_samples,204rng=rng2)205next_qt_argmax = jnp.tile(next_qt_argmax, [num_tau_prime_samples])206target_quantile_vals = (207jax.vmap(lambda x, y: x[y])(next_state_target_outputs.quantile_values,208next_qt_argmax))209target_quantile_vals = rewards + gamma_with_terminal * target_quantile_vals210# We return with an extra dimension, which is expected by train.211next_state_representation = next_state_target_outputs.representation212next_state_representation = jnp.squeeze(next_state_representation)213return (214rng,215jax.lax.stop_gradient(target_quantile_vals[:, None]),216jax.lax.stop_gradient(curr_state_representation),217jax.lax.stop_gradient(next_state_representation))218
219
220@functools.partial(jax.jit, static_argnums=(0, 3, 10, 11, 12, 13, 14, 15, 17,22118, 19, 20, 21, 22))222def train(network, online_params, target_params, optimizer, optimizer_state,223states, actions, next_states, rewards, terminals, num_tau_samples,224num_tau_prime_samples, num_quantile_samples, cumulative_gamma,225double_dqn, kappa, rng, mico_weight, distance_fn, similarity_fn,226tau, alpha, clip_value_min):227"""Run a training step."""228# The parameters tau, alpha, and clip_value_min are only used for229# Munchausen-IQN (https://arxiv.org/abs/2007.14430), and are only used when230# tau is not None.231def loss_fn(params, rng_input, target_quantile_vals, target_r, target_next_r):232def online(state):233return network.apply(params, state, num_quantiles=num_tau_samples,234rng=rng_input)235
236model_output = jax.vmap(online)(states)237quantile_values = model_output.quantile_values238quantiles = model_output.quantiles239representations = model_output.representation240representations = jnp.squeeze(representations)241chosen_action_quantile_values = jax.vmap(lambda x, y: x[:, y][:, None])(242quantile_values, actions)243# Shape of bellman_erors and huber_loss:244# batch_size x num_tau_prime_samples x num_tau_samples x 1.245bellman_errors = (target_quantile_vals[:, :, None, :] -246chosen_action_quantile_values[:, None, :, :])247# The huber loss (see Section 2.3 of the paper) is defined via two cases:248# case_one: |bellman_errors| <= kappa249# case_two: |bellman_errors| > kappa250huber_loss_case_one = (251(jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) *2520.5 * bellman_errors ** 2)253huber_loss_case_two = (254(jnp.abs(bellman_errors) > kappa).astype(jnp.float32) *255kappa * (jnp.abs(bellman_errors) - 0.5 * kappa))256huber_loss = huber_loss_case_one + huber_loss_case_two257# Tile by num_tau_prime_samples along a new dimension. Shape is now258# batch_size x num_tau_prime_samples x num_tau_samples x 1.259# These quantiles will be used for computation of the quantile huber loss260# below (see section 2.3 of the paper).261quantiles = jnp.tile(quantiles[:, None, :, :],262[1, num_tau_prime_samples, 1, 1]).astype(jnp.float32)263# Shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.264quantile_huber_loss = (jnp.abs(quantiles - jax.lax.stop_gradient(265(bellman_errors < 0).astype(jnp.float32))) * huber_loss) / kappa266# Sum over current quantile value (num_tau_samples) dimension,267# average over target quantile value (num_tau_prime_samples) dimension.268# Shape: batch_size x num_tau_prime_samples x 1.269quantile_huber_loss = jnp.sum(quantile_huber_loss, axis=2)270quantile_huber_loss = jnp.mean(quantile_huber_loss, axis=1)271online_similarities = metric_utils.representation_similarities(272representations, target_r, distance_fn, similarity_fn)273target_similarities = metric_utils.target_similarities(274target_next_r, rewards, distance_fn, similarity_fn, cumulative_gamma)275kernel_loss = jnp.mean(jax.vmap(losses.huber_loss)(online_similarities,276target_similarities))277loss = ((1. - mico_weight) * quantile_huber_loss +278mico_weight * kernel_loss)279return jnp.mean(loss), (jnp.mean(quantile_huber_loss), kernel_loss)280
281if tau is None:282rng, target_quantile_vals, target_r, target_next_r = target_quantile_values(283network,284online_params,285target_params,286states,287next_states,288rewards,289terminals,290num_tau_prime_samples,291num_quantile_samples,292cumulative_gamma,293double_dqn,294rng)295else:296rng, target_quantile_vals, target_r, target_next_r = (297munchausen_target_quantile_values(298network,299target_params,300states,301actions,302next_states,303rewards,304terminals,305num_tau_prime_samples,306num_quantile_samples,307cumulative_gamma,308rng,309tau,310alpha,311clip_value_min))312grad_fn = jax.value_and_grad(loss_fn, has_aux=True)313rng, rng_input = jax.random.split(rng)314all_losses, grad = grad_fn(online_params, rng_input, target_quantile_vals,315target_r, target_next_r)316loss, component_losses = all_losses317quantile_loss, kernel_loss = component_losses318updates, optimizer_state = optimizer.update(grad, optimizer_state)319online_params = optax.apply_updates(online_params, updates)320return rng, optimizer_state, online_params, loss, quantile_loss, kernel_loss321
322
323@gin.configurable324class KSMeImplicitQuantileAgent(325implicit_quantile_agent.JaxImplicitQuantileAgent):326"""Implicit Quantile Agent with the KSMe loss."""327
328def __init__(self, num_actions, summary_writer=None,329mico_weight=0.5, distance_fn='dot',330similarity_fn='dot',331tau=None, alpha=0.9, clip_value_min=-1):332self._mico_weight = mico_weight333if distance_fn == 'cosine':334self._distance_fn = metric_utils.cosine_distance335elif distance_fn == 'dot':336self._distance_fn = metric_utils.l2337else:338raise ValueError(f'Unknown distance function: {distance_fn}')339
340if similarity_fn == 'cosine':341self._similarity_fn = metric_utils.cosine_similarity342elif similarity_fn == 'dot':343self._similarity_fn = metric_utils.dot344else:345raise ValueError(f'Unknown similarity function: {similarity_fn}')346
347self._tau = tau348self._alpha = alpha349self._clip_value_min = clip_value_min350super().__init__(num_actions, network=AtariImplicitQuantileNetwork,351summary_writer=summary_writer)352logging.info('\t mico_weight: %f', mico_weight)353logging.info('\t distance_fn: %s', distance_fn)354logging.info('\t similarity_fn: %s', similarity_fn)355
356def _train_step(self):357"""Runs a single training step."""358if self._replay.add_count > self.min_replay_history:359if self.training_steps % self.update_period == 0:360self._sample_from_replay_buffer()361(self._rng, self.optimizer_state, self.online_params,362loss, quantile_loss, kernel_loss) = train(363self.network_def,364self.online_params,365self.target_network_params,366self.optimizer,367self.optimizer_state,368self.replay_elements['state'],369self.replay_elements['action'],370self.replay_elements['next_state'],371self.replay_elements['reward'],372self.replay_elements['terminal'],373self.num_tau_samples,374self.num_tau_prime_samples,375self.num_quantile_samples,376self.cumulative_gamma,377self.double_dqn,378self.kappa,379self._rng,380self._mico_weight,381self._distance_fn,382self._similarity_fn,383self._tau,384self._alpha,385self._clip_value_min)386if (self.summary_writer is not None and387self.training_steps > 0 and388self.training_steps % self.summary_writing_frequency == 0):389if hasattr(self, 'collector_dispatcher'):390self.collector_dispatcher.write(391[statistics_instance.StatisticsInstance(392'Losses/Aggregate', np.asarray(loss),393step=self.training_steps),394statistics_instance.StatisticsInstance(395'Losses/Quantile', np.asarray(quantile_loss),396step=self.training_steps),397statistics_instance.StatisticsInstance(398'Losses/Metric', np.asarray(kernel_loss),399step=self.training_steps),400],401collector_allowlist=self._collector_allowlist)402if self.training_steps % self.target_update_period == 0:403self._sync_weights()404
405self.training_steps += 1406