google-research
393 строки · 13.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Script for training the RCE agent.
17
18Example usage:
19python train_eval.py --root_dir=~/c_learning/sawyer_drawer_open \
20--gin_bindings='train_eval.env_name="sawyer_drawer_open"'
21"""
22from __future__ import absolute_import23from __future__ import division24from __future__ import print_function25
26import functools27import os28import time29
30from absl import app31from absl import flags32from absl import logging33import gin34import numpy as np35import rce_agent36import rce_envs37from six.moves import range38import tensorflow as tf39from tf_agents.agents.ddpg import critic_network40from tf_agents.agents.sac import tanh_normal_projection_network41from tf_agents.drivers import dynamic_step_driver42from tf_agents.eval import metric_utils43from tf_agents.metrics import tf_metrics44from tf_agents.networks import actor_distribution_network45from tf_agents.policies import greedy_policy46from tf_agents.policies import random_tf_policy47from tf_agents.replay_buffers import tf_uniform_replay_buffer48from tf_agents.utils import common49
50flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),51'Root directory for writing logs/summaries/checkpoints.')52flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.')53flags.DEFINE_multi_string('gin_bindings', None, 'Gin binding to pass through.')54
55FLAGS = flags.FLAGS56
57
58@gin.configurable59def bce_loss(y_true, y_pred, label_smoothing=0):60loss_fn = tf.keras.losses.BinaryCrossentropy(61label_smoothing=label_smoothing, reduction=tf.keras.losses.Reduction.NONE)62return loss_fn(y_true[:, None], y_pred[:, None])63
64
65@gin.configurable66class ClassifierCriticNetwork(critic_network.CriticNetwork):67"""Creates a critic network."""68
69def __init__(self,70input_tensor_spec,71observation_fc_layer_params=None,72action_fc_layer_params=None,73joint_fc_layer_params=None,74kernel_initializer=None,75last_kernel_initializer=None,76name='ClassifierCriticNetwork'):77super(ClassifierCriticNetwork, self).__init__(78input_tensor_spec,79observation_fc_layer_params=observation_fc_layer_params,80action_fc_layer_params=action_fc_layer_params,81joint_fc_layer_params=joint_fc_layer_params,82kernel_initializer=kernel_initializer,83last_kernel_initializer=last_kernel_initializer,84name=name,85)86
87last_layers = [88tf.keras.layers.Dense(891,90activation=tf.math.sigmoid,91kernel_initializer=last_kernel_initializer,92name='value')93]94self._joint_layers = self._joint_layers[:-1] + last_layers95
96
97@gin.configurable98def train_eval(99root_dir,100env_name='HalfCheetah-v2',101# The SAC paper reported:102# Hopper and Cartpole results up to 1000000 iters,103# Humanoid results up to 10000000 iters,104# Other mujoco tasks up to 3000000 iters.105num_iterations=3000000,106actor_fc_layers=(256, 256),107critic_obs_fc_layers=None,108critic_action_fc_layers=None,109critic_joint_fc_layers=(256, 256),110# Params for collect111# Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py112# HalfCheetah and Ant take 10000 initial collection steps.113# Other mujoco tasks take 1000.114# Different choices roughly keep the initial episodes about the same.115initial_collect_steps=10000,116collect_steps_per_iteration=1,117replay_buffer_capacity=1000000,118# Params for target update119target_update_tau=0.005,120target_update_period=1,121# Params for train122train_steps_per_iteration=1,123batch_size=256,124actor_learning_rate=3e-4,125critic_learning_rate=3e-4,126gamma=0.99,127gradient_clipping=None,128use_tf_functions=True,129# Params for eval130num_eval_episodes=30,131eval_interval=10000,132# Params for summaries and logging133train_checkpoint_interval=200000,134# policy_checkpoint_interval=50000,135rb_checkpoint_interval=50000,136log_interval=1000,137summary_interval=1000,138summaries_flush_secs=10,139debug_summaries=False,140summarize_grads_and_vars=False,141random_seed=0,142actor_min_std=1e-3, # Added for numerical stability.143n_step=10):144"""A simple train and eval for SAC."""145np.random.seed(random_seed)146root_dir = os.path.expanduser(root_dir)147train_dir = os.path.join(root_dir, 'train')148eval_dir = os.path.join(root_dir, 'eval')149
150train_summary_writer = tf.compat.v2.summary.create_file_writer(151train_dir, flush_millis=summaries_flush_secs * 1000)152train_summary_writer.set_as_default()153
154global_step = tf.compat.v1.train.get_or_create_global_step()155with tf.compat.v2.summary.record_if(156lambda: tf.math.equal(global_step % summary_interval, 0)):157tf_env = rce_envs.load_env(env_name)158eval_tf_env = rce_envs.load_env(env_name)159if env_name == 'sawyer_lift':160eval_tf_env.MODE = 'eval'161
162expert_obs = rce_envs.get_data(tf_env.envs[0], env_name=env_name)163
164time_step_spec = tf_env.time_step_spec()165observation_spec = time_step_spec.observation166action_spec = tf_env.action_spec()167
168proj_net = functools.partial(169tanh_normal_projection_network.TanhNormalProjectionNetwork,170std_transform=lambda t: actor_min_std + tf.nn.softplus(t))171actor_net = actor_distribution_network.ActorDistributionNetwork(172observation_spec,173action_spec,174fc_layer_params=actor_fc_layers,175continuous_projection_net=proj_net)176critic_net = ClassifierCriticNetwork(177(observation_spec, action_spec),178observation_fc_layer_params=critic_obs_fc_layers,179action_fc_layer_params=critic_action_fc_layers,180joint_fc_layer_params=critic_joint_fc_layers,181kernel_initializer='glorot_uniform',182last_kernel_initializer='glorot_uniform')183
184tf_agent = rce_agent.RceAgent(185time_step_spec,186action_spec,187actor_network=actor_net,188critic_network=critic_net,189actor_optimizer=tf.compat.v1.train.AdamOptimizer(190learning_rate=actor_learning_rate),191critic_optimizer=tf.compat.v1.train.AdamOptimizer(192learning_rate=critic_learning_rate),193target_update_tau=target_update_tau,194target_update_period=target_update_period,195td_errors_loss_fn=bce_loss,196gamma=gamma,197gradient_clipping=gradient_clipping,198debug_summaries=debug_summaries,199summarize_grads_and_vars=summarize_grads_and_vars,200train_step_counter=global_step,201n_step=n_step)202tf_agent.initialize()203
204eval_summary_writer = tf.compat.v2.summary.create_file_writer(205eval_dir, flush_millis=summaries_flush_secs * 1000)206eval_metrics = [207tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,208batch_size=tf_env.batch_size),209tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes,210batch_size=tf_env.batch_size)211]212train_metrics = [213tf_metrics.NumberOfEpisodes(),214tf_metrics.EnvironmentSteps(),215tf_metrics.AverageReturnMetric(216buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),217tf_metrics.AverageEpisodeLengthMetric(218buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),219]220
221eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)222initial_collect_policy = random_tf_policy.RandomTFPolicy(223tf_env.time_step_spec(), tf_env.action_spec())224collect_policy = tf_agent.collect_policy225
226replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(227data_spec=tf_agent.collect_data_spec,228batch_size=tf_env.batch_size,229max_length=replay_buffer_capacity)230
231train_checkpointer = common.Checkpointer(232ckpt_dir=train_dir,233agent=tf_agent,234global_step=global_step,235metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),236max_to_keep=None)237rb_checkpointer = common.Checkpointer(238ckpt_dir=os.path.join(train_dir, 'replay_buffer'),239max_to_keep=1,240replay_buffer=replay_buffer)241train_checkpointer.initialize_or_restore()242rb_checkpointer.initialize_or_restore()243
244replay_observer = [replay_buffer.add_batch]245
246initial_collect_driver = dynamic_step_driver.DynamicStepDriver(247tf_env,248initial_collect_policy,249observers=replay_observer + train_metrics,250num_steps=initial_collect_steps)251
252collect_driver = dynamic_step_driver.DynamicStepDriver(253tf_env,254collect_policy,255observers=replay_observer + train_metrics,256num_steps=collect_steps_per_iteration)257
258if use_tf_functions:259initial_collect_driver.run = common.function(initial_collect_driver.run)260collect_driver.run = common.function(collect_driver.run)261tf_agent.train = common.function(tf_agent.train)262
263# Save the hyperparameters264operative_filename = os.path.join(root_dir, 'operative.gin')265with tf.compat.v1.gfile.Open(operative_filename, 'w') as f:266f.write(gin.operative_config_str())267print(gin.operative_config_str())268
269if replay_buffer.num_frames() == 0:270# Collect initial replay data.271logging.info(272'Initializing replay buffer by collecting experience for %d steps '273'with a random policy.', initial_collect_steps)274initial_collect_driver.run()275
276results = metric_utils.eager_compute(277eval_metrics,278eval_tf_env,279eval_policy,280num_episodes=num_eval_episodes,281train_step=global_step,282summary_writer=eval_summary_writer,283summary_prefix='Metrics',284)285del results286metric_utils.log_metrics(eval_metrics)287
288time_step = None289policy_state = collect_policy.get_initial_state(tf_env.batch_size)290
291timed_at_step = global_step.numpy()292time_acc = 0293env_time_acc = 0294
295def _filter_invalid_transition(trajectories, unused_arg1):296return ~trajectories.is_boundary()[0]297
298dataset = replay_buffer.as_dataset(299sample_batch_size=batch_size,300num_steps=2 if n_step is None else n_step)301dataset = dataset.unbatch()302dataset = dataset.filter(_filter_invalid_transition)303
304dataset = dataset.batch(batch_size, drop_remainder=True)305dataset = dataset.prefetch(5)306iterator = iter(dataset)307
308### Expert dataset309expert_dataset = tf.data.Dataset.from_tensors(expert_obs)310expert_dataset = expert_dataset.unbatch()311expert_dataset = expert_dataset.repeat().shuffle(int(1e6))312
313expert_dataset = expert_dataset.batch(batch_size, drop_remainder=True)314expert_iterator = iter(expert_dataset)315
316def train_step():317experience, _ = next(iterator)318expert_experience = next(expert_iterator)319return tf_agent.train(experience=(experience, expert_experience))320
321if use_tf_functions:322train_step = common.function(train_step)323
324global_step_val = global_step.numpy()325while global_step_val < num_iterations:326start_time = time.time()327time_step, policy_state = collect_driver.run(328time_step=time_step,329policy_state=policy_state,330)331env_time_acc += time.time() - start_time332for _ in range(train_steps_per_iteration):333train_loss = train_step()334time_acc += time.time() - start_time335
336global_step_val = global_step.numpy()337
338if global_step_val % log_interval == 0:339logging.info('step = %d, loss = %f', global_step_val,340train_loss.loss)341steps_per_sec = (global_step_val - timed_at_step) / time_acc342logging.info('%.3f steps/sec', steps_per_sec)343tf.compat.v2.summary.scalar(344name='global_steps_per_sec', data=steps_per_sec, step=global_step)345
346env_steps_per_sec = (global_step_val - timed_at_step) / env_time_acc347logging.info('Env: %.3f steps/sec', env_steps_per_sec)348tf.compat.v2.summary.scalar(349name='env_steps_per_sec', data=env_steps_per_sec, step=global_step)350
351timed_at_step = global_step_val352time_acc = 0353env_time_acc = 0354
355for train_metric in train_metrics:356train_metric.tf_summaries(357train_step=global_step, step_metrics=train_metrics[:2])358
359if global_step_val % eval_interval == 0:360results = metric_utils.eager_compute(361eval_metrics,362eval_tf_env,363eval_policy,364num_episodes=num_eval_episodes,365train_step=global_step,366summary_writer=eval_summary_writer,367summary_prefix='Metrics',368)369metric_utils.log_metrics(eval_metrics)370
371if global_step_val % train_checkpoint_interval == 0:372train_checkpointer.save(global_step=global_step_val)373
374# if global_step_val % policy_checkpoint_interval == 0:375# policy_checkpointer.save(global_step=global_step_val)376#
377if global_step_val % rb_checkpoint_interval == 0:378rb_checkpointer.save(global_step=global_step_val)379return train_loss380
381
382def main(_):383tf.compat.v1.enable_v2_behavior()384logging.set_verbosity(logging.INFO)385gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)386
387root_dir = FLAGS.root_dir388train_eval(root_dir)389
390
391if __name__ == '__main__':392flags.mark_flag_as_required('root_dir')393app.run(main)394