google-research

Форк
0
/
train_eval.py 
393 строки · 13.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Script for training the RCE agent.
17

18
Example usage:
19
  python train_eval.py --root_dir=~/c_learning/sawyer_drawer_open \
20
    --gin_bindings='train_eval.env_name="sawyer_drawer_open"'
21
"""
22
from __future__ import absolute_import
23
from __future__ import division
24
from __future__ import print_function
25

26
import functools
27
import os
28
import time
29

30
from absl import app
31
from absl import flags
32
from absl import logging
33
import gin
34
import numpy as np
35
import rce_agent
36
import rce_envs
37
from six.moves import range
38
import tensorflow as tf
39
from tf_agents.agents.ddpg import critic_network
40
from tf_agents.agents.sac import tanh_normal_projection_network
41
from tf_agents.drivers import dynamic_step_driver
42
from tf_agents.eval import metric_utils
43
from tf_agents.metrics import tf_metrics
44
from tf_agents.networks import actor_distribution_network
45
from tf_agents.policies import greedy_policy
46
from tf_agents.policies import random_tf_policy
47
from tf_agents.replay_buffers import tf_uniform_replay_buffer
48
from tf_agents.utils import common
49

50
flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
51
                    'Root directory for writing logs/summaries/checkpoints.')
52
flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.')
53
flags.DEFINE_multi_string('gin_bindings', None, 'Gin binding to pass through.')
54

55
FLAGS = flags.FLAGS
56

57

58
@gin.configurable
59
def bce_loss(y_true, y_pred, label_smoothing=0):
60
  loss_fn = tf.keras.losses.BinaryCrossentropy(
61
      label_smoothing=label_smoothing, reduction=tf.keras.losses.Reduction.NONE)
62
  return loss_fn(y_true[:, None], y_pred[:, None])
63

64

65
@gin.configurable
66
class ClassifierCriticNetwork(critic_network.CriticNetwork):
67
  """Creates a critic network."""
68

69
  def __init__(self,
70
               input_tensor_spec,
71
               observation_fc_layer_params=None,
72
               action_fc_layer_params=None,
73
               joint_fc_layer_params=None,
74
               kernel_initializer=None,
75
               last_kernel_initializer=None,
76
               name='ClassifierCriticNetwork'):
77
    super(ClassifierCriticNetwork, self).__init__(
78
        input_tensor_spec,
79
        observation_fc_layer_params=observation_fc_layer_params,
80
        action_fc_layer_params=action_fc_layer_params,
81
        joint_fc_layer_params=joint_fc_layer_params,
82
        kernel_initializer=kernel_initializer,
83
        last_kernel_initializer=last_kernel_initializer,
84
        name=name,
85
    )
86

87
    last_layers = [
88
        tf.keras.layers.Dense(
89
            1,
90
            activation=tf.math.sigmoid,
91
            kernel_initializer=last_kernel_initializer,
92
            name='value')
93
    ]
94
    self._joint_layers = self._joint_layers[:-1] + last_layers
95

96

97
@gin.configurable
98
def train_eval(
99
    root_dir,
100
    env_name='HalfCheetah-v2',
101
    # The SAC paper reported:
102
    # Hopper and Cartpole results up to 1000000 iters,
103
    # Humanoid results up to 10000000 iters,
104
    # Other mujoco tasks up to 3000000 iters.
105
    num_iterations=3000000,
106
    actor_fc_layers=(256, 256),
107
    critic_obs_fc_layers=None,
108
    critic_action_fc_layers=None,
109
    critic_joint_fc_layers=(256, 256),
110
    # Params for collect
111
    # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
112
    # HalfCheetah and Ant take 10000 initial collection steps.
113
    # Other mujoco tasks take 1000.
114
    # Different choices roughly keep the initial episodes about the same.
115
    initial_collect_steps=10000,
116
    collect_steps_per_iteration=1,
117
    replay_buffer_capacity=1000000,
118
    # Params for target update
119
    target_update_tau=0.005,
120
    target_update_period=1,
121
    # Params for train
122
    train_steps_per_iteration=1,
123
    batch_size=256,
124
    actor_learning_rate=3e-4,
125
    critic_learning_rate=3e-4,
126
    gamma=0.99,
127
    gradient_clipping=None,
128
    use_tf_functions=True,
129
    # Params for eval
130
    num_eval_episodes=30,
131
    eval_interval=10000,
132
    # Params for summaries and logging
133
    train_checkpoint_interval=200000,
134
    # policy_checkpoint_interval=50000,
135
    rb_checkpoint_interval=50000,
136
    log_interval=1000,
137
    summary_interval=1000,
138
    summaries_flush_secs=10,
139
    debug_summaries=False,
140
    summarize_grads_and_vars=False,
141
    random_seed=0,
142
    actor_min_std=1e-3,  # Added for numerical stability.
143
    n_step=10):
144
  """A simple train and eval for SAC."""
145
  np.random.seed(random_seed)
146
  root_dir = os.path.expanduser(root_dir)
147
  train_dir = os.path.join(root_dir, 'train')
148
  eval_dir = os.path.join(root_dir, 'eval')
149

150
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
151
      train_dir, flush_millis=summaries_flush_secs * 1000)
152
  train_summary_writer.set_as_default()
153

154
  global_step = tf.compat.v1.train.get_or_create_global_step()
155
  with tf.compat.v2.summary.record_if(
156
      lambda: tf.math.equal(global_step % summary_interval, 0)):
157
    tf_env = rce_envs.load_env(env_name)
158
    eval_tf_env = rce_envs.load_env(env_name)
159
    if env_name == 'sawyer_lift':
160
      eval_tf_env.MODE = 'eval'
161

162
    expert_obs = rce_envs.get_data(tf_env.envs[0], env_name=env_name)
163

164
    time_step_spec = tf_env.time_step_spec()
165
    observation_spec = time_step_spec.observation
166
    action_spec = tf_env.action_spec()
167

168
    proj_net = functools.partial(
169
        tanh_normal_projection_network.TanhNormalProjectionNetwork,
170
        std_transform=lambda t: actor_min_std + tf.nn.softplus(t))
171
    actor_net = actor_distribution_network.ActorDistributionNetwork(
172
        observation_spec,
173
        action_spec,
174
        fc_layer_params=actor_fc_layers,
175
        continuous_projection_net=proj_net)
176
    critic_net = ClassifierCriticNetwork(
177
        (observation_spec, action_spec),
178
        observation_fc_layer_params=critic_obs_fc_layers,
179
        action_fc_layer_params=critic_action_fc_layers,
180
        joint_fc_layer_params=critic_joint_fc_layers,
181
        kernel_initializer='glorot_uniform',
182
        last_kernel_initializer='glorot_uniform')
183

184
    tf_agent = rce_agent.RceAgent(
185
        time_step_spec,
186
        action_spec,
187
        actor_network=actor_net,
188
        critic_network=critic_net,
189
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
190
            learning_rate=actor_learning_rate),
191
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
192
            learning_rate=critic_learning_rate),
193
        target_update_tau=target_update_tau,
194
        target_update_period=target_update_period,
195
        td_errors_loss_fn=bce_loss,
196
        gamma=gamma,
197
        gradient_clipping=gradient_clipping,
198
        debug_summaries=debug_summaries,
199
        summarize_grads_and_vars=summarize_grads_and_vars,
200
        train_step_counter=global_step,
201
        n_step=n_step)
202
    tf_agent.initialize()
203

204
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
205
        eval_dir, flush_millis=summaries_flush_secs * 1000)
206
    eval_metrics = [
207
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
208
                                       batch_size=tf_env.batch_size),
209
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes,
210
                                              batch_size=tf_env.batch_size)
211
    ]
212
    train_metrics = [
213
        tf_metrics.NumberOfEpisodes(),
214
        tf_metrics.EnvironmentSteps(),
215
        tf_metrics.AverageReturnMetric(
216
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
217
        tf_metrics.AverageEpisodeLengthMetric(
218
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
219
    ]
220

221
    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
222
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
223
        tf_env.time_step_spec(), tf_env.action_spec())
224
    collect_policy = tf_agent.collect_policy
225

226
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
227
        data_spec=tf_agent.collect_data_spec,
228
        batch_size=tf_env.batch_size,
229
        max_length=replay_buffer_capacity)
230

231
    train_checkpointer = common.Checkpointer(
232
        ckpt_dir=train_dir,
233
        agent=tf_agent,
234
        global_step=global_step,
235
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
236
        max_to_keep=None)
237
    rb_checkpointer = common.Checkpointer(
238
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
239
        max_to_keep=1,
240
        replay_buffer=replay_buffer)
241
    train_checkpointer.initialize_or_restore()
242
    rb_checkpointer.initialize_or_restore()
243

244
    replay_observer = [replay_buffer.add_batch]
245

246
    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
247
        tf_env,
248
        initial_collect_policy,
249
        observers=replay_observer + train_metrics,
250
        num_steps=initial_collect_steps)
251

252
    collect_driver = dynamic_step_driver.DynamicStepDriver(
253
        tf_env,
254
        collect_policy,
255
        observers=replay_observer + train_metrics,
256
        num_steps=collect_steps_per_iteration)
257

258
    if use_tf_functions:
259
      initial_collect_driver.run = common.function(initial_collect_driver.run)
260
      collect_driver.run = common.function(collect_driver.run)
261
      tf_agent.train = common.function(tf_agent.train)
262

263
    # Save the hyperparameters
264
    operative_filename = os.path.join(root_dir, 'operative.gin')
265
    with tf.compat.v1.gfile.Open(operative_filename, 'w') as f:
266
      f.write(gin.operative_config_str())
267
      print(gin.operative_config_str())
268

269
    if replay_buffer.num_frames() == 0:
270
      # Collect initial replay data.
271
      logging.info(
272
          'Initializing replay buffer by collecting experience for %d steps '
273
          'with a random policy.', initial_collect_steps)
274
      initial_collect_driver.run()
275

276
    results = metric_utils.eager_compute(
277
        eval_metrics,
278
        eval_tf_env,
279
        eval_policy,
280
        num_episodes=num_eval_episodes,
281
        train_step=global_step,
282
        summary_writer=eval_summary_writer,
283
        summary_prefix='Metrics',
284
    )
285
    del results
286
    metric_utils.log_metrics(eval_metrics)
287

288
    time_step = None
289
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)
290

291
    timed_at_step = global_step.numpy()
292
    time_acc = 0
293
    env_time_acc = 0
294

295
    def _filter_invalid_transition(trajectories, unused_arg1):
296
      return ~trajectories.is_boundary()[0]
297

298
    dataset = replay_buffer.as_dataset(
299
        sample_batch_size=batch_size,
300
        num_steps=2 if n_step is None else n_step)
301
    dataset = dataset.unbatch()
302
    dataset = dataset.filter(_filter_invalid_transition)
303

304
    dataset = dataset.batch(batch_size, drop_remainder=True)
305
    dataset = dataset.prefetch(5)
306
    iterator = iter(dataset)
307

308
    ### Expert dataset
309
    expert_dataset = tf.data.Dataset.from_tensors(expert_obs)
310
    expert_dataset = expert_dataset.unbatch()
311
    expert_dataset = expert_dataset.repeat().shuffle(int(1e6))
312

313
    expert_dataset = expert_dataset.batch(batch_size, drop_remainder=True)
314
    expert_iterator = iter(expert_dataset)
315

316
    def train_step():
317
      experience, _ = next(iterator)
318
      expert_experience = next(expert_iterator)
319
      return tf_agent.train(experience=(experience, expert_experience))
320

321
    if use_tf_functions:
322
      train_step = common.function(train_step)
323

324
    global_step_val = global_step.numpy()
325
    while global_step_val < num_iterations:
326
      start_time = time.time()
327
      time_step, policy_state = collect_driver.run(
328
          time_step=time_step,
329
          policy_state=policy_state,
330
      )
331
      env_time_acc += time.time() - start_time
332
      for _ in range(train_steps_per_iteration):
333
        train_loss = train_step()
334
      time_acc += time.time() - start_time
335

336
      global_step_val = global_step.numpy()
337

338
      if global_step_val % log_interval == 0:
339
        logging.info('step = %d, loss = %f', global_step_val,
340
                     train_loss.loss)
341
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
342
        logging.info('%.3f steps/sec', steps_per_sec)
343
        tf.compat.v2.summary.scalar(
344
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
345

346
        env_steps_per_sec = (global_step_val - timed_at_step) / env_time_acc
347
        logging.info('Env: %.3f steps/sec', env_steps_per_sec)
348
        tf.compat.v2.summary.scalar(
349
            name='env_steps_per_sec', data=env_steps_per_sec, step=global_step)
350

351
        timed_at_step = global_step_val
352
        time_acc = 0
353
        env_time_acc = 0
354

355
      for train_metric in train_metrics:
356
        train_metric.tf_summaries(
357
            train_step=global_step, step_metrics=train_metrics[:2])
358

359
      if global_step_val % eval_interval == 0:
360
        results = metric_utils.eager_compute(
361
            eval_metrics,
362
            eval_tf_env,
363
            eval_policy,
364
            num_episodes=num_eval_episodes,
365
            train_step=global_step,
366
            summary_writer=eval_summary_writer,
367
            summary_prefix='Metrics',
368
        )
369
        metric_utils.log_metrics(eval_metrics)
370

371
      if global_step_val % train_checkpoint_interval == 0:
372
        train_checkpointer.save(global_step=global_step_val)
373

374
      # if global_step_val % policy_checkpoint_interval == 0:
375
      #   policy_checkpointer.save(global_step=global_step_val)
376
#
377
      if global_step_val % rb_checkpoint_interval == 0:
378
        rb_checkpointer.save(global_step=global_step_val)
379
    return train_loss
380

381

382
def main(_):
383
  tf.compat.v1.enable_v2_behavior()
384
  logging.set_verbosity(logging.INFO)
385
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)
386

387
  root_dir = FLAGS.root_dir
388
  train_eval(root_dir)
389

390

391
if __name__ == '__main__':
392
  flags.mark_flag_as_required('root_dir')
393
  app.run(main)
394

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.