google-research
221 строка · 8.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contrastive Value Learning builder."""
17from typing import Callable, Iterator, List, Optional18
19import acme20from acme import adders21from acme import core22from acme import specs23from acme import types24from acme.adders import reverb as adders_reverb25from acme.agents.jax import actor_core as actor_core_lib26from acme.agents.jax import actors27from acme.agents.jax import builders28from acme.jax import networks as networks_lib29from acme.jax import variable_utils30from acme.utils import counting31from acme.utils import loggers32import optax33import reverb34from reverb import rate_limiters35import tensorflow as tf36import tree37
38from cvl_public import config as contrastive_config39from cvl_public import learning40from cvl_public import networks as contrastive_networks41from cvl_public import utils as contrastive_utils42
43
44class ContrastiveBuilder(builders.ActorLearnerBuilder):45"""Contrastive RL builder."""46
47def __init__(48self,49config,50logger_fn = lambda: None,51):52"""Creates a contrastive RL learner, a behavior policy and an eval actor.53
54Args:
55config: a config with contrastive RL hyperparameters
56logger_fn: a logger factory for the learner
57"""
58self._config = config59self._logger_fn = logger_fn60
61def make_learner(62self,63random_key,64networks,65dataset,66replay_client = None,67counter = None,68):69# Create optimizers70policy_optimizer = optax.adam(71learning_rate=self._config.actor_learning_rate, eps=1e-7)72q_optimizer = optax.adam(learning_rate=self._config.learning_rate, eps=1e-7)73return learning.ContrastiveLearner(74networks=networks,75rng=random_key,76policy_optimizer=policy_optimizer,77q_optimizer=q_optimizer,78iterator=dataset,79counter=counter,80logger=self._logger_fn(),81config=self._config)82
83def make_actor(84self,85random_key,86policy_network,87adder = None,88variable_source = None):89assert variable_source is not None90actor_core = actor_core_lib.batched_feed_forward_to_actor_core(91policy_network)92variable_client = variable_utils.VariableClient(variable_source, 'policy',93device='cpu')94if self._config.use_random_actor:95ACTOR = contrastive_utils.InitiallyRandomActor # pylint: disable=invalid-name96else:97ACTOR = actors.GenericActor # pylint: disable=invalid-name98return ACTOR(99actor_core, random_key, variable_client, adder, backend='cpu')100
101def make_replay_tables(102self,103environment_spec,104):105"""Create tables to insert data into."""106samples_per_insert_tolerance = (107self._config.samples_per_insert_tolerance_rate108* self._config.samples_per_insert)109min_replay_traj = self._config.min_replay_size // self._config.max_episode_steps # pylint: disable=line-too-long110max_replay_traj = self._config.max_replay_size // self._config.max_episode_steps # pylint: disable=line-too-long111error_buffer = min_replay_traj * samples_per_insert_tolerance112limiter = rate_limiters.SampleToInsertRatio(113min_size_to_sample=min_replay_traj,114samples_per_insert=self._config.samples_per_insert,115error_buffer=error_buffer)116return [117reverb.Table(118name=self._config.replay_table_name,119sampler=reverb.selectors.Uniform(),120remover=reverb.selectors.Fifo(),121max_size=max_replay_traj,122rate_limiter=limiter,123signature=adders_reverb.EpisodeAdder.signature(environment_spec, {})) # pylint: disable=line-too-long124]125
126def make_dataset_iterator(127self, replay_client):128"""Create a dataset iterator to use for learning/updating the agent."""129@tf.function130def flatten_fn(sample):131seq_len = tf.shape(sample.data.observation)[0]132arange = tf.range(seq_len)133is_future_mask = tf.cast(arange[:, None] < arange[None], tf.float32)134discount = self._config.discount ** tf.cast(arange[None] - arange[:, None], tf.float32) # pylint: disable=line-too-long135probs = is_future_mask * discount136# The indexing changes the shape from [seq_len, 1] to [seq_len]137goal_index = tf.random.categorical(logits=tf.math.log(probs),138num_samples=1)[:, 0]139state = sample.data.observation[:-1]140next_state = sample.data.observation[1:]141
142# Create the goal observations in three steps.143# 1. Take all future states (not future goals).144# 2. Apply obs_to_goal.145# 3. Sample one of the future states. Note that we don't look for a goal146# for the final state, because there are no future states.147goal = tf.gather(sample.data.observation, goal_index[:-1])148goal_reward = tf.gather(sample.data.reward, goal_index[:-1])149# new_obs = tf.concat([state, goal], axis=1)150# new_next_obs = tf.concat([next_state, goal], axis=1)151transition = types.Transition(152observation=state,153action=sample.data.action[:-1],154reward=sample.data.reward[:-1],155discount=sample.data.discount[:-1],156next_observation=next_state,157extras={158'next_action': sample.data.action[1:],159'goal': goal,160'goal_reward': goal_reward161})162# Shift for the transpose_shuffle.163shift = tf.random.uniform((), 0, seq_len, tf.int32)164transition = tree.map_structure(lambda t: tf.roll(t, shift, axis=0),165transition)166return transition167
168if self._config.num_parallel_calls:169num_parallel_calls = self._config.num_parallel_calls170else:171num_parallel_calls = tf.data.AUTOTUNE172
173def _make_dataset(unused_idx):174dataset = reverb.TrajectoryDataset.from_table_signature(175server_address=replay_client.server_address,176table=self._config.replay_table_name,177max_in_flight_samples_per_worker=100)178dataset = dataset.map(flatten_fn)179# transpose_shuffle180def _transpose_fn(t):181dims = tf.range(tf.shape(tf.shape(t))[0])182perm = tf.concat([[1, 0], dims[2:]], axis=0)183return tf.transpose(t, perm)184dataset = dataset.batch(self._config.batch_size, drop_remainder=True)185dataset = dataset.map(186lambda transition: tree.map_structure(_transpose_fn, transition))187dataset = dataset.unbatch()188# end transpose_shuffle189
190dataset = dataset.unbatch()191return dataset192dataset = tf.data.Dataset.from_tensors(0).repeat()193dataset = dataset.interleave(194map_func=_make_dataset,195cycle_length=num_parallel_calls,196num_parallel_calls=num_parallel_calls,197deterministic=False)198
199dataset = dataset.batch(200self._config.batch_size * self._config.num_sgd_steps_per_step,201drop_remainder=True)202@tf.function203def add_info_fn(data):204info = reverb.SampleInfo(key=0,205probability=0.0,206table_size=0,207priority=0.0,208times_sampled=0)209return reverb.ReplaySample(info=info, data=data)210dataset = dataset.map(add_info_fn, num_parallel_calls=tf.data.AUTOTUNE,211deterministic=False)212dataset = dataset.prefetch(tf.data.AUTOTUNE)213return dataset.as_numpy_iterator()214
215def make_adder(self,216replay_client):217"""Create an adder to record data generated by the actor/environment."""218return adders_reverb.EpisodeAdder(219client=replay_client,220priority_fns={self._config.replay_table_name: None},221max_sequence_length=self._config.max_episode_steps + 1)222