google-research

Форк
0
/
mime_acme.py 
185 строк · 5.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Train an Acme agent on mime."""
17

18
import datetime
19
import os
20
from typing import Dict, Sequence
21

22
from absl import app
23
from absl import flags
24
from acme import specs
25
from acme import types
26
from acme.agents.tf import dmpo
27
from acme.agents.tf import mpo
28
from acme.tf import networks
29
from acme.tf import utils as tf_utils
30
from acme.utils import counting
31
from acme.utils.loggers.google import cns
32
import numpy as np
33
import sonnet as snt
34
import tensorflow as tf
35

36
from rrlfd import environment_loop
37
from rrlfd.env_wrapper import DmMimeWrapper
38
from rrlfd.env_wrapper import KwargWrapper
39

40
flags.DEFINE_string('task', 'Pick', 'Mime task.')
41
flags.DEFINE_enum('input_type', 'position',
42
                  ['depth', 'rgb', 'rgbd', 'position'],
43
                  'Input modality.')
44
flags.DEFINE_boolean('dense_reward', True, 'If True, use dense reward signal.')
45
flags.DEFINE_float('dense_reward_multiplier', 1.0,
46
                   'Multiplier for dense rewards.')
47

48
flags.DEFINE_string('agent', 'DMPO', 'Acme agent to train.')
49
flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to run for.')
50
flags.DEFINE_integer('max_episode_steps', None,
51
                     'If set, override environment default for max episode '
52
                     'length.')
53
flags.DEFINE_integer('seed', 0, 'Experiment seed.')
54

55
flags.DEFINE_string('logdir', None, 'Location to log results to.')
56
flags.DEFINE_boolean('log_learner', False, 'If True, save learner logs.')
57
flags.DEFINE_boolean('render', False, 'If True, render environment.')
58
flags.DEFINE_boolean('verbose', False, 'If True, log actions at each step.')
59

60
FLAGS = flags.FLAGS
61

62

63
def make_mpo_networks(
64
    action_spec,
65
    policy_layer_sizes = (300, 200),
66
    critic_layer_sizes = (400, 300),
67
):
68
  """Creates networks used by the agent."""
69

70
  num_dimensions = np.prod(action_spec.shape, dtype=int)
71
  critic_layer_sizes = list(critic_layer_sizes) + [1]
72

73
  policy_network = snt.Sequential([
74
      networks.LayerNormMLP(policy_layer_sizes),
75
      networks.MultivariateNormalDiagHead(num_dimensions)
76
  ])
77
  # The multiplexer concatenates the (maybe transformed) observations/actions.
78
  critic_network = networks.CriticMultiplexer(
79
      critic_network=networks.LayerNormMLP(critic_layer_sizes),
80
      action_network=networks.ClipToSpec(action_spec))
81

82
  return {
83
      'policy': policy_network,
84
      'critic': critic_network,
85
      'observation': tf_utils.batch_concat,
86
  }
87

88

89
def make_dmpo_networks(
90
    action_spec,
91
    policy_layer_sizes = (300, 200),
92
    critic_layer_sizes = (400, 300),
93
    vmin = -150.,
94
    vmax = 150.,
95
    num_atoms = 51,
96
):
97
  """Creates networks used by the agent."""
98

99
  num_dimensions = np.prod(action_spec.shape, dtype=int)
100

101
  policy_network = snt.Sequential([
102
      networks.LayerNormMLP(policy_layer_sizes),
103
      networks.MultivariateNormalDiagHead(num_dimensions)
104
  ])
105
  # The multiplexer concatenates the (maybe transformed) observations/actions.
106
  critic_network = networks.CriticMultiplexer(
107
      critic_network=networks.LayerNormMLP(critic_layer_sizes),
108
      action_network=networks.ClipToSpec(action_spec))
109
  critic_network = snt.Sequential(
110
      [critic_network,
111
       networks.DiscreteValuedHead(vmin, vmax, num_atoms)])
112

113
  return {
114
      'policy': policy_network,
115
      'critic': critic_network,
116
      'observation': tf_utils.batch_concat,
117
  }
118

119

120
def main(_):
121
  tf.random.set_seed(FLAGS.seed)
122

123
  if FLAGS.logdir is not None:
124
    logdir = FLAGS.logdir
125
  else:
126
    logdir = os.path.join(
127
        FLAGS.logdir, datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
128

129
  # Create an environment and grab the spec.
130
  env = DmMimeWrapper(
131
      task=FLAGS.task,
132
      seed=FLAGS.seed,
133
      input_type=FLAGS.input_type,
134
      dense_reward=FLAGS.dense_reward,
135
      dense_reward_multiplier=FLAGS.dense_reward_multiplier,
136
      max_episode_steps=FLAGS.max_episode_steps,
137
      logdir=logdir,
138
      render=FLAGS.render,
139
      verbose=FLAGS.verbose)
140
  environment = KwargWrapper(env)
141
  environment_spec = specs.make_environment_spec(environment)
142
  print(environment_spec)
143

144
  counter = counting.Counter()
145
  agent_logger = (
146
      cns.CNSLogger(logdir, 'learner')
147
      if logdir is not None and FLAGS.log_learner else None)
148

149
  if FLAGS.agent == 'MPO':
150
    agent_networks = make_mpo_networks(environment_spec.actions)
151

152
    agent = mpo.MPO(
153
        environment_spec=environment_spec,
154
        policy_network=agent_networks['policy'],
155
        critic_network=agent_networks['critic'],
156
        observation_network=agent_networks['observation'],
157
        checkpoint=True,
158
        logger=agent_logger,
159
        counter=counter,
160
    )
161
  elif FLAGS.agent == 'DMPO':
162
    agent_networks = make_dmpo_networks(environment_spec.actions)
163

164
    agent = dmpo.DistributionalMPO(
165
        environment_spec=environment_spec,
166
        policy_network=agent_networks['policy'],
167
        critic_network=agent_networks['critic'],
168
        observation_network=agent_networks['observation'],
169
        checkpoint=True,
170
        logger=agent_logger,
171
        counter=counter,
172
    )
173
  else:
174
    raise NotImplementedError('Supported agents: MPO, DMPO.')
175
  env_logger = (
176
      cns.CNSLogger(logdir, 'env_loop') if logdir is not None else None)
177

178
  # Run the environment loop.
179
  loop = environment_loop.EnvironmentLoop(
180
      environment, agent, logger=env_logger, counter=counter)
181
  loop.run(num_episodes=FLAGS.num_episodes)
182

183

184
if __name__ == '__main__':
185
  app.run(main)
186

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.