google-research
428 строк · 16.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contrastive RL learner implementation."""
17import time18from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Callable19
20import acme21from acme import types22from acme.jax import networks as networks_lib23from acme.jax import utils24from acme.utils import counting25from acme.utils import loggers26from contrastive import config as contrastive_config27from contrastive import networks as contrastive_networks28import jax29import jax.numpy as jnp30import optax31import reverb32
33
34class TrainingState(NamedTuple):35"""Contains training state for the learner."""36policy_optimizer_state: optax.OptState37q_optimizer_state: optax.OptState38policy_params: networks_lib.Params39q_params: networks_lib.Params40target_q_params: networks_lib.Params41key: networks_lib.PRNGKey42alpha_optimizer_state: Optional[optax.OptState] = None43alpha_params: Optional[networks_lib.Params] = None44
45
46class ContrastiveLearner(acme.Learner):47"""Contrastive RL learner."""48
49_state: TrainingState50
51def __init__(52self,53networks,54rng,55policy_optimizer,56q_optimizer,57iterator,58counter,59logger,60obs_to_goal,61config):62"""Initialize the Contrastive RL learner.63
64Args:
65networks: Contrastive RL networks.
66rng: a key for random number generation.
67policy_optimizer: the policy optimizer.
68q_optimizer: the Q-function optimizer.
69iterator: an iterator over training data.
70counter: counter object used to keep track of steps.
71logger: logger object to be used by learner.
72obs_to_goal: a function for extracting the goal coordinates.
73config: the experiment config file.
74"""
75if config.add_mc_to_td:76assert config.use_td77adaptive_entropy_coefficient = config.entropy_coefficient is None78self._num_sgd_steps_per_step = config.num_sgd_steps_per_step79self._obs_dim = config.obs_dim80self._use_td = config.use_td81if adaptive_entropy_coefficient:82# alpha is the temperature parameter that determines the relative83# importance of the entropy term versus the reward.84log_alpha = jnp.asarray(0., dtype=jnp.float32)85alpha_optimizer = optax.adam(learning_rate=3e-4)86alpha_optimizer_state = alpha_optimizer.init(log_alpha)87else:88if config.target_entropy:89raise ValueError('target_entropy should not be set when '90'entropy_coefficient is provided')91
92def alpha_loss(log_alpha,93policy_params,94transitions,95key):96"""Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""97dist_params = networks.policy_network.apply(98policy_params, transitions.observation)99action = networks.sample(dist_params, key)100log_prob = networks.log_prob(dist_params, action)101alpha = jnp.exp(log_alpha)102alpha_loss = alpha * jax.lax.stop_gradient(103-log_prob - config.target_entropy)104return jnp.mean(alpha_loss)105
106def critic_loss(q_params,107policy_params,108target_q_params,109transitions,110key):111batch_size = transitions.observation.shape[0]112# Note: We might be able to speed up the computation for some of the113# baselines to making a single network that returns all the values. This114# avoids computing some of the underlying representations multiple times.115if config.use_td:116# For TD learning, the diagonal elements are the immediate next state.117s, g = jnp.split(transitions.observation, [config.obs_dim], axis=1)118next_s, _ = jnp.split(transitions.next_observation, [config.obs_dim],119axis=1)120if config.add_mc_to_td:121next_fraction = (1 - config.discount) / ((1 - config.discount) + 1)122num_next = int(batch_size * next_fraction)123new_g = jnp.concatenate([124obs_to_goal(next_s[:num_next]),125g[num_next:],126], axis=0)127else:128new_g = obs_to_goal(next_s)129obs = jnp.concatenate([s, new_g], axis=1)130transitions = transitions._replace(observation=obs)131I = jnp.eye(batch_size) # pylint: disable=invalid-name132logits = networks.q_network.apply(133q_params, transitions.observation, transitions.action)134
135if config.use_td:136# Make sure to use the twin Q trick.137assert len(logits.shape) == 3138
139# We evaluate the next-state Q function using random goals140s, g = jnp.split(transitions.observation, [config.obs_dim], axis=1)141del s142next_s = transitions.next_observation[:, :config.obs_dim]143goal_indices = jnp.roll(jnp.arange(batch_size, dtype=jnp.int32), -1)144g = g[goal_indices]145transitions = transitions._replace(146next_observation=jnp.concatenate([next_s, g], axis=1))147next_dist_params = networks.policy_network.apply(148policy_params, transitions.next_observation)149next_action = networks.sample(next_dist_params, key)150next_q = networks.q_network.apply(target_q_params,151transitions.next_observation,152next_action) # This outputs logits.153next_q = jax.nn.sigmoid(next_q)154next_v = jnp.min(next_q, axis=-1)155next_v = jax.lax.stop_gradient(next_v)156next_v = jnp.diag(next_v)157# diag(logits) are predictions for future states.158# diag(next_q) are predictions for random states, which correspond to159# the predictions logits[range(B), goal_indices].160# So, the only thing that's meaningful for next_q is the diagonal. Off161# diagonal entries are meaningless and shouldn't be used.162w = next_v / (1 - next_v)163w_clipping = 20.0164w = jnp.clip(w, 0, w_clipping)165# (B, B, 2) --> (B, 2), computes diagonal of each twin Q.166pos_logits = jax.vmap(jnp.diag, -1, -1)(logits)167loss_pos = optax.sigmoid_binary_cross_entropy(168logits=pos_logits, labels=1) # [B, 2]169
170neg_logits = logits[jnp.arange(batch_size), goal_indices]171loss_neg1 = w[:, None] * optax.sigmoid_binary_cross_entropy(172logits=neg_logits, labels=1) # [B, 2]173loss_neg2 = optax.sigmoid_binary_cross_entropy(174logits=neg_logits, labels=0) # [B, 2]175
176if config.add_mc_to_td:177loss = ((1 + (1 - config.discount)) * loss_pos178+ config.discount * loss_neg1 + 2 * loss_neg2)179else:180loss = ((1 - config.discount) * loss_pos181+ config.discount * loss_neg1 + loss_neg2)182# Take the mean here so that we can compute the accuracy.183logits = jnp.mean(logits, axis=-1)184
185else: # For the MC losses.186def loss_fn(_logits): # pylint: disable=invalid-name187if config.use_cpc:188return (optax.softmax_cross_entropy(logits=_logits, labels=I)189+ 0.01 * jax.nn.logsumexp(_logits, axis=1)**2)190else:191return optax.sigmoid_binary_cross_entropy(logits=_logits, labels=I)192if len(logits.shape) == 3: # twin q193# loss.shape = [.., num_q]194loss = jax.vmap(loss_fn, in_axes=2, out_axes=-1)(logits)195loss = jnp.mean(loss, axis=-1)196# Take the mean here so that we can compute the accuracy.197logits = jnp.mean(logits, axis=-1)198else:199loss = loss_fn(logits)200
201loss = jnp.mean(loss)202correct = (jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1))203logits_pos = jnp.sum(logits * I) / jnp.sum(I)204logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I)205if len(logits.shape) == 3:206logsumexp = jax.nn.logsumexp(logits[:, :, 0], axis=1)**2207else:208logsumexp = jax.nn.logsumexp(logits, axis=1)**2209metrics = {210'binary_accuracy': jnp.mean((logits > 0) == I),211'categorical_accuracy': jnp.mean(correct),212'logits_pos': logits_pos,213'logits_neg': logits_neg,214'logsumexp': logsumexp.mean(),215}216
217return loss, metrics218
219def actor_loss(policy_params,220q_params,221alpha,222transitions,223key,224):225obs = transitions.observation226if config.use_gcbc:227dist_params = networks.policy_network.apply(228policy_params, obs)229log_prob = networks.log_prob(dist_params, transitions.action)230actor_loss = -1.0 * jnp.mean(log_prob)231else:232state = obs[:, :config.obs_dim]233goal = obs[:, config.obs_dim:]234
235if config.random_goals == 0.0:236new_state = state237new_goal = goal238elif config.random_goals == 0.5:239new_state = jnp.concatenate([state, state], axis=0)240new_goal = jnp.concatenate([goal, jnp.roll(goal, 1, axis=0)], axis=0)241else:242assert config.random_goals == 1.0243new_state = state244new_goal = jnp.roll(goal, 1, axis=0)245
246new_obs = jnp.concatenate([new_state, new_goal], axis=1)247dist_params = networks.policy_network.apply(248policy_params, new_obs)249action = networks.sample(dist_params, key)250log_prob = networks.log_prob(dist_params, action)251q_action = networks.q_network.apply(252q_params, new_obs, action)253if len(q_action.shape) == 3: # twin q trick254assert q_action.shape[2] == 2255q_action = jnp.min(q_action, axis=-1)256actor_loss = alpha * log_prob - jnp.diag(q_action)257
258assert 0.0 <= config.bc_coef <= 1.0259if config.bc_coef > 0:260orig_action = transitions.action261if config.random_goals == 0.5:262orig_action = jnp.concatenate([orig_action, orig_action], axis=0)263
264bc_loss = -1.0 * networks.log_prob(dist_params, orig_action)265actor_loss = (config.bc_coef * bc_loss266+ (1 - config.bc_coef) * actor_loss)267
268return jnp.mean(actor_loss)269
270alpha_grad = jax.value_and_grad(alpha_loss)271critic_grad = jax.value_and_grad(critic_loss, has_aux=True)272actor_grad = jax.value_and_grad(actor_loss)273
274def update_step(275state,276transitions,277):278
279key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4)280if adaptive_entropy_coefficient:281alpha_loss, alpha_grads = alpha_grad(state.alpha_params,282state.policy_params, transitions,283key_alpha)284alpha = jnp.exp(state.alpha_params)285else:286alpha = config.entropy_coefficient287
288if not config.use_gcbc:289(critic_loss, critic_metrics), critic_grads = critic_grad(290state.q_params, state.policy_params, state.target_q_params,291transitions, key_critic)292
293actor_loss, actor_grads = actor_grad(state.policy_params, state.q_params,294alpha, transitions, key_actor)295
296# Apply policy gradients297actor_update, policy_optimizer_state = policy_optimizer.update(298actor_grads, state.policy_optimizer_state)299policy_params = optax.apply_updates(state.policy_params, actor_update)300
301# Apply critic gradients302if config.use_gcbc:303metrics = {}304critic_loss = 0.0305q_params = state.q_params306q_optimizer_state = state.q_optimizer_state307new_target_q_params = state.target_q_params308else:309critic_update, q_optimizer_state = q_optimizer.update(310critic_grads, state.q_optimizer_state)311
312q_params = optax.apply_updates(state.q_params, critic_update)313
314new_target_q_params = jax.tree_map(315lambda x, y: x * (1 - config.tau) + y * config.tau,316state.target_q_params, q_params)317metrics = critic_metrics318
319metrics.update({320'critic_loss': critic_loss,321'actor_loss': actor_loss,322})323
324new_state = TrainingState(325policy_optimizer_state=policy_optimizer_state,326q_optimizer_state=q_optimizer_state,327policy_params=policy_params,328q_params=q_params,329target_q_params=new_target_q_params,330key=key,331)332if adaptive_entropy_coefficient:333# Apply alpha gradients334alpha_update, alpha_optimizer_state = alpha_optimizer.update(335alpha_grads, state.alpha_optimizer_state)336alpha_params = optax.apply_updates(state.alpha_params, alpha_update)337metrics.update({338'alpha_loss': alpha_loss,339'alpha': jnp.exp(alpha_params),340})341new_state = new_state._replace(342alpha_optimizer_state=alpha_optimizer_state,343alpha_params=alpha_params)344
345return new_state, metrics346
347# General learner book-keeping and loggers.348self._counter = counter or counting.Counter()349self._logger = logger or loggers.make_default_logger(350'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray,351time_delta=10.0)352
353# Iterator on demonstration transitions.354self._iterator = iterator355
356update_step = utils.process_multiple_batches(update_step,357config.num_sgd_steps_per_step)358# Use the JIT compiler.359if config.jit:360self._update_step = jax.jit(update_step)361else:362self._update_step = update_step363
364def make_initial_state(key):365"""Initialises the training state (parameters and optimiser state)."""366key_policy, key_q, key = jax.random.split(key, 3)367
368policy_params = networks.policy_network.init(key_policy)369policy_optimizer_state = policy_optimizer.init(policy_params)370
371q_params = networks.q_network.init(key_q)372q_optimizer_state = q_optimizer.init(q_params)373
374state = TrainingState(375policy_optimizer_state=policy_optimizer_state,376q_optimizer_state=q_optimizer_state,377policy_params=policy_params,378q_params=q_params,379target_q_params=q_params,380key=key)381
382if adaptive_entropy_coefficient:383state = state._replace(alpha_optimizer_state=alpha_optimizer_state,384alpha_params=log_alpha)385return state386
387# Create initial state.388self._state = make_initial_state(rng)389
390# Do not record timestamps until after the first learning step is done.391# This is to avoid including the time it takes for actors to come online392# and fill the replay buffer.393self._timestamp = None394
395def step(self):396with jax.profiler.StepTraceAnnotation('step', step_num=self._counter):397sample = next(self._iterator)398transitions = types.Transition(*sample.data)399self._state, metrics = self._update_step(self._state, transitions)400
401# Compute elapsed time.402timestamp = time.time()403elapsed_time = timestamp - self._timestamp if self._timestamp else 0404self._timestamp = timestamp405
406# Increment counts and record the current time407counts = self._counter.increment(steps=1, walltime=elapsed_time)408if elapsed_time > 0:409metrics['steps_per_second'] = (410self._num_sgd_steps_per_step / elapsed_time)411else:412metrics['steps_per_second'] = 0.413
414# Attempts to write the logs.415self._logger.write({**metrics, **counts})416
417def get_variables(self, names):418variables = {419'policy': self._state.policy_params,420'critic': self._state.q_params,421}422return [variables[name] for name in names]423
424def save(self):425return self._state426
427def restore(self, state):428self._state = state429