google-research
122 строки · 4.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""PWIL training script."""
17
18from absl import app19from absl import flags20from acme import specs21from acme.agents.tf import d4pg22from acme.agents.tf.actors import FeedForwardActor23from acme.utils.loggers import csv as csv_logger24import sonnet as snt25
26from pwil import imitation_loop27from pwil import rewarder28from pwil import utils29
30
31flags.DEFINE_string('workdir', None, 'Logging directory')32flags.DEFINE_string('env_name', None, 'Environment name.')33flags.DEFINE_string('demo_dir', None, 'Directory of expert demonstrations.')34flags.DEFINE_boolean('state_only', False,35'Use only state for reward computation')36flags.DEFINE_float('sigma', 0.2, 'Exploration noise.')37flags.DEFINE_integer('num_transitions_rb', 50000,38'Number of transitions to fill the rb with.')39flags.DEFINE_integer('num_demonstrations', 1, 'Number of expert episodes.')40flags.DEFINE_integer('subsampling', 20, 'Subsampling factor of demonstrations.')41flags.DEFINE_integer('random_seed', 1, 'Experiment random seed.')42flags.DEFINE_integer('num_steps_per_iteration', 10000,43'Number of training steps per iteration.')44flags.DEFINE_integer('num_iterations', 100, 'Number of iterations.')45flags.DEFINE_integer('num_eval_episodes', 10, 'Number of evaluation episodes.')46flags.DEFINE_integer('samples_per_insert', 256, 'Controls update frequency.')47flags.DEFINE_float('policy_learning_rate', 1e-4,48'Larning rate for policy updates')49flags.DEFINE_float('critic_learning_rate', 1e-4,50'Larning rate for critic updates')51
52FLAGS = flags.FLAGS53
54
55def main(_):56# Load environment.57environment = utils.load_environment(FLAGS.env_name)58environment_spec = specs.make_environment_spec(environment)59
60# Create Rewarder.61demonstrations = utils.load_demonstrations(62demo_dir=FLAGS.demo_dir, env_name=FLAGS.env_name)63pwil_rewarder = rewarder.PWILRewarder(64demonstrations,65subsampling=FLAGS.subsampling,66env_specs=environment_spec,67num_demonstrations=FLAGS.num_demonstrations,68observation_only=FLAGS.state_only)69
70# Define optimizers71policy_optimizer = snt.optimizers.Adam(72learning_rate=FLAGS.policy_learning_rate)73critic_optimizer = snt.optimizers.Adam(74learning_rate=FLAGS.critic_learning_rate)75
76# Define D4PG agent.77agent_networks = utils.make_d4pg_networks(environment_spec.actions)78agent = d4pg.D4PG(79environment_spec=environment_spec,80policy_network=agent_networks['policy'],81critic_network=agent_networks['critic'],82observation_network=agent_networks['observation'],83policy_optimizer=policy_optimizer,84critic_optimizer=critic_optimizer,85samples_per_insert=FLAGS.samples_per_insert,86sigma=FLAGS.sigma,87)88
89# Prefill the agent's Replay Buffer.90utils.prefill_rb_with_demonstrations(91agent=agent,92demonstrations=pwil_rewarder.demonstrations,93num_transitions_rb=FLAGS.num_transitions_rb,94reward=pwil_rewarder.reward_scale)95
96# Create the eval policy (without exploration noise).97eval_policy = snt.Sequential([98agent_networks['observation'],99agent_networks['policy'],100])101eval_agent = FeedForwardActor(policy_network=eval_policy)102
103# Define train/eval loops.104
105train_logger = csv_logger.CSVLogger(106directory=FLAGS.workdir, label='train_logs')107eval_logger = csv_logger.CSVLogger(108directory=FLAGS.workdir, label='eval_logs')109
110
111train_loop = imitation_loop.TrainEnvironmentLoop(112environment, agent, pwil_rewarder, logger=train_logger)113
114eval_loop = imitation_loop.EvalEnvironmentLoop(115environment, eval_agent, pwil_rewarder, logger=eval_logger)116
117for _ in range(FLAGS.num_iterations):118train_loop.run(num_steps=FLAGS.num_steps_per_iteration)119eval_loop.run(num_episodes=FLAGS.num_eval_episodes)120
121if __name__ == '__main__':122app.run(main)123