google-research
231 строка · 9.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Quantile regression agent with KSMe loss."""
17
18import collections
19import functools
20
21from absl import logging
22from dopamine.jax import losses
23from dopamine.jax.agents.quantile import quantile_agent
24from dopamine.metrics import statistics_instance
25from flax import linen as nn
26import gin
27import jax
28import jax.numpy as jnp
29import numpy as np
30import optax
31from ksme.atari import ksme_rainbow_agent
32from ksme.atari import metric_utils
33
34
35NetworkType = collections.namedtuple(
36'network', ['q_values', 'logits', 'probabilities', 'representation'])
37
38
39@gin.configurable
40class AtariQuantileNetwork(nn.Module):
41"""Convolutional network used to compute the agent's return quantiles."""
42num_actions: int
43num_atoms: int
44
45@nn.compact
46def __call__(self, x):
47initializer = nn.initializers.variance_scaling(
48scale=1.0 / jnp.sqrt(3.0),
49mode='fan_in',
50distribution='uniform')
51x = x.astype(jnp.float32) / 255.
52x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4),
53kernel_init=initializer)(x)
54x = nn.relu(x)
55x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2),
56kernel_init=initializer)(x)
57x = nn.relu(x)
58x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1),
59kernel_init=initializer)(x)
60x = nn.relu(x)
61representation = x.reshape(-1) # flatten
62x = nn.Dense(features=512, kernel_init=initializer)(representation)
63x = nn.relu(x)
64x = nn.Dense(features=self.num_actions * self.num_atoms,
65kernel_init=initializer)(x)
66logits = x.reshape((self.num_actions, self.num_atoms))
67probabilities = nn.softmax(logits)
68q_values = jnp.mean(logits, axis=1)
69return ksme_rainbow_agent.NetworkType(q_values, logits, probabilities,
70representation)
71
72
73@functools.partial(jax.vmap, in_axes=(None, 0, 0, 0, 0, None))
74def target_distribution(target_network, states, next_states, rewards, terminals,
75cumulative_gamma):
76"""Builds the Quantile target distribution as per Dabney et al. (2017)."""
77curr_state_representation = target_network(states).representation
78curr_state_representation = jnp.squeeze(curr_state_representation)
79is_terminal_multiplier = 1. - terminals.astype(jnp.float32)
80# Incorporate terminal state to discount factor.
81gamma_with_terminal = cumulative_gamma * is_terminal_multiplier
82next_state_target_outputs = target_network(next_states)
83q_values = jnp.squeeze(next_state_target_outputs.q_values)
84next_qt_argmax = jnp.argmax(q_values)
85logits = jnp.squeeze(next_state_target_outputs.logits)
86next_logits = logits[next_qt_argmax]
87next_state_representation = next_state_target_outputs.representation
88next_state_representation = jnp.squeeze(next_state_representation)
89return (
90jax.lax.stop_gradient(rewards + gamma_with_terminal * next_logits),
91jax.lax.stop_gradient(curr_state_representation),
92jax.lax.stop_gradient(next_state_representation))
93
94
95@functools.partial(jax.jit, static_argnums=(0, 3, 10, 11, 12, 13, 14, 15))
96def train(network_def, online_params, target_params, optimizer, optimizer_state,
97states, actions, next_states, rewards, terminals, kappa, num_atoms,
98cumulative_gamma, mico_weight, distance_fn, similarity_fn):
99"""Run a training step."""
100def loss_fn(params, bellman_target, target_r, target_next_r):
101def q_online(state):
102return network_def.apply(params, state)
103
104model_output = jax.vmap(q_online)(states)
105logits = model_output.logits
106logits = jnp.squeeze(logits)
107representations = model_output.representation
108representations = jnp.squeeze(representations)
109# Fetch the logits for its selected action. We use vmap to perform this
110# indexing across the batch.
111chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)
112bellman_errors = (bellman_target[:, None, :] -
113chosen_action_logits[:, :, None]) # Input `u' of Eq. 9.
114# Eq. 9 of paper.
115huber_loss = (
116(jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) *
1170.5 * bellman_errors ** 2 +
118(jnp.abs(bellman_errors) > kappa).astype(jnp.float32) *
119kappa * (jnp.abs(bellman_errors) - 0.5 * kappa))
120
121tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) /
122num_atoms) # Quantile midpoints. See Lemma 2 of paper.
123# Eq. 10 of paper.
124tau_bellman_diff = jnp.abs(
125tau_hat[None, :, None] - (bellman_errors < 0).astype(jnp.float32))
126quantile_huber_loss = tau_bellman_diff * huber_loss
127# Sum over tau dimension, average over target value dimension.
128quantile_loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1)
129online_similarities = metric_utils.representation_similarities(
130representations, target_r, distance_fn, similarity_fn)
131target_similarities = metric_utils.target_similarities(
132target_next_r, rewards, distance_fn, similarity_fn, cumulative_gamma)
133kernel_loss = jnp.mean(jax.vmap(losses.huber_loss)(online_similarities,
134target_similarities))
135loss = ((1. - mico_weight) * quantile_loss +
136mico_weight * kernel_loss)
137return jnp.mean(loss), (loss, jnp.mean(quantile_loss), kernel_loss)
138
139def q_target(state):
140return network_def.apply(target_params, state)
141
142grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
143bellman_target, target_r, target_next_r = target_distribution(
144q_target, states, next_states, rewards, terminals, cumulative_gamma)
145all_losses, grad = grad_fn(online_params, bellman_target, target_r,
146target_next_r)
147mean_loss, component_losses = all_losses
148loss, quantile_loss, kernel_loss = component_losses
149updates, optimizer_state = optimizer.update(grad, optimizer_state)
150online_params = optax.apply_updates(online_params, updates)
151return (optimizer_state, online_params, loss, mean_loss, quantile_loss,
152kernel_loss)
153
154
155@gin.configurable
156class KSMeQuantileAgent(quantile_agent.JaxQuantileAgent):
157"""Quantile Agent with the KSMe loss."""
158
159def __init__(self, num_actions, summary_writer=None,
160mico_weight=0.5, distance_fn='dot',
161similarity_fn='dot'):
162self._mico_weight = mico_weight
163if distance_fn == 'cosine':
164self._distance_fn = metric_utils.cosine_distance
165elif distance_fn == 'dot':
166self._distance_fn = metric_utils.l2
167else:
168raise ValueError(f'Unknown distance function: {distance_fn}')
169
170if similarity_fn == 'cosine':
171self._similarity_fn = metric_utils.cosine_similarity
172elif similarity_fn == 'dot':
173self._similarity_fn = metric_utils.dot
174else:
175raise ValueError(f'Unknown similarity function: {similarity_fn}')
176
177network = AtariQuantileNetwork
178super().__init__(num_actions, network=network,
179summary_writer=summary_writer)
180logging.info('\t mico_weight: %f', mico_weight)
181logging.info('\t distance_fn: %s', distance_fn)
182logging.info('\t similarity_fn: %s', similarity_fn)
183
184def _train_step(self):
185"""Runs a single training step."""
186if self._replay.add_count > self.min_replay_history:
187if self.training_steps % self.update_period == 0:
188self._sample_from_replay_buffer()
189(self.optimizer_state, self.online_params,
190loss, mean_loss, quantile_loss, kernel_loss) = train(
191self.network_def,
192self.online_params,
193self.target_network_params,
194self.optimizer,
195self.optimizer_state,
196self.replay_elements['state'],
197self.replay_elements['action'],
198self.replay_elements['next_state'],
199self.replay_elements['reward'],
200self.replay_elements['terminal'],
201self._kappa,
202self._num_atoms,
203self.cumulative_gamma,
204self._mico_weight,
205self._distance_fn,
206self._similarity_fn)
207if self._replay_scheme == 'prioritized':
208probs = self.replay_elements['sampling_probabilities']
209loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
210loss_weights /= jnp.max(loss_weights)
211self._replay.set_priority(self.replay_elements['indices'],
212jnp.sqrt(loss + 1e-10))
213loss = loss_weights * loss
214mean_loss = jnp.mean(loss)
215if hasattr(self, 'collector_dispatcher'):
216self.collector_dispatcher.write(
217[statistics_instance.StatisticsInstance(
218'Losses/Aggregate', np.asarray(mean_loss),
219step=self.training_steps),
220statistics_instance.StatisticsInstance(
221'Losses/Quantile', np.asarray(quantile_loss),
222step=self.training_steps),
223statistics_instance.StatisticsInstance(
224'Losses/Metric', np.asarray(kernel_loss),
225step=self.training_steps),
226],
227collector_allowlist=self._collector_allowlist)
228if self.training_steps % self.target_update_period == 0:
229self._sync_weights()
230
231self.training_steps += 1
232