google-research
185 строк · 5.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Train an Acme agent on mime."""
17
18import datetime19import os20from typing import Dict, Sequence21
22from absl import app23from absl import flags24from acme import specs25from acme import types26from acme.agents.tf import dmpo27from acme.agents.tf import mpo28from acme.tf import networks29from acme.tf import utils as tf_utils30from acme.utils import counting31from acme.utils.loggers.google import cns32import numpy as np33import sonnet as snt34import tensorflow as tf35
36from rrlfd import environment_loop37from rrlfd.env_wrapper import DmMimeWrapper38from rrlfd.env_wrapper import KwargWrapper39
40flags.DEFINE_string('task', 'Pick', 'Mime task.')41flags.DEFINE_enum('input_type', 'position',42['depth', 'rgb', 'rgbd', 'position'],43'Input modality.')44flags.DEFINE_boolean('dense_reward', True, 'If True, use dense reward signal.')45flags.DEFINE_float('dense_reward_multiplier', 1.0,46'Multiplier for dense rewards.')47
48flags.DEFINE_string('agent', 'DMPO', 'Acme agent to train.')49flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to run for.')50flags.DEFINE_integer('max_episode_steps', None,51'If set, override environment default for max episode '52'length.')53flags.DEFINE_integer('seed', 0, 'Experiment seed.')54
55flags.DEFINE_string('logdir', None, 'Location to log results to.')56flags.DEFINE_boolean('log_learner', False, 'If True, save learner logs.')57flags.DEFINE_boolean('render', False, 'If True, render environment.')58flags.DEFINE_boolean('verbose', False, 'If True, log actions at each step.')59
60FLAGS = flags.FLAGS61
62
63def make_mpo_networks(64action_spec,65policy_layer_sizes = (300, 200),66critic_layer_sizes = (400, 300),67):68"""Creates networks used by the agent."""69
70num_dimensions = np.prod(action_spec.shape, dtype=int)71critic_layer_sizes = list(critic_layer_sizes) + [1]72
73policy_network = snt.Sequential([74networks.LayerNormMLP(policy_layer_sizes),75networks.MultivariateNormalDiagHead(num_dimensions)76])77# The multiplexer concatenates the (maybe transformed) observations/actions.78critic_network = networks.CriticMultiplexer(79critic_network=networks.LayerNormMLP(critic_layer_sizes),80action_network=networks.ClipToSpec(action_spec))81
82return {83'policy': policy_network,84'critic': critic_network,85'observation': tf_utils.batch_concat,86}87
88
89def make_dmpo_networks(90action_spec,91policy_layer_sizes = (300, 200),92critic_layer_sizes = (400, 300),93vmin = -150.,94vmax = 150.,95num_atoms = 51,96):97"""Creates networks used by the agent."""98
99num_dimensions = np.prod(action_spec.shape, dtype=int)100
101policy_network = snt.Sequential([102networks.LayerNormMLP(policy_layer_sizes),103networks.MultivariateNormalDiagHead(num_dimensions)104])105# The multiplexer concatenates the (maybe transformed) observations/actions.106critic_network = networks.CriticMultiplexer(107critic_network=networks.LayerNormMLP(critic_layer_sizes),108action_network=networks.ClipToSpec(action_spec))109critic_network = snt.Sequential(110[critic_network,111networks.DiscreteValuedHead(vmin, vmax, num_atoms)])112
113return {114'policy': policy_network,115'critic': critic_network,116'observation': tf_utils.batch_concat,117}118
119
120def main(_):121tf.random.set_seed(FLAGS.seed)122
123if FLAGS.logdir is not None:124logdir = FLAGS.logdir125else:126logdir = os.path.join(127FLAGS.logdir, datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))128
129# Create an environment and grab the spec.130env = DmMimeWrapper(131task=FLAGS.task,132seed=FLAGS.seed,133input_type=FLAGS.input_type,134dense_reward=FLAGS.dense_reward,135dense_reward_multiplier=FLAGS.dense_reward_multiplier,136max_episode_steps=FLAGS.max_episode_steps,137logdir=logdir,138render=FLAGS.render,139verbose=FLAGS.verbose)140environment = KwargWrapper(env)141environment_spec = specs.make_environment_spec(environment)142print(environment_spec)143
144counter = counting.Counter()145agent_logger = (146cns.CNSLogger(logdir, 'learner')147if logdir is not None and FLAGS.log_learner else None)148
149if FLAGS.agent == 'MPO':150agent_networks = make_mpo_networks(environment_spec.actions)151
152agent = mpo.MPO(153environment_spec=environment_spec,154policy_network=agent_networks['policy'],155critic_network=agent_networks['critic'],156observation_network=agent_networks['observation'],157checkpoint=True,158logger=agent_logger,159counter=counter,160)161elif FLAGS.agent == 'DMPO':162agent_networks = make_dmpo_networks(environment_spec.actions)163
164agent = dmpo.DistributionalMPO(165environment_spec=environment_spec,166policy_network=agent_networks['policy'],167critic_network=agent_networks['critic'],168observation_network=agent_networks['observation'],169checkpoint=True,170logger=agent_logger,171counter=counter,172)173else:174raise NotImplementedError('Supported agents: MPO, DMPO.')175env_logger = (176cns.CNSLogger(logdir, 'env_loop') if logdir is not None else None)177
178# Run the environment loop.179loop = environment_loop.EnvironmentLoop(180environment, agent, logger=env_logger, counter=counter)181loop.run(num_episodes=FLAGS.num_episodes)182
183
184if __name__ == '__main__':185app.run(main)186