google-research

Форк
0
/
trainer.py 
122 строки · 4.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""PWIL training script."""
17

18
from absl import app
19
from absl import flags
20
from acme import specs
21
from acme.agents.tf import d4pg
22
from acme.agents.tf.actors import FeedForwardActor
23
from acme.utils.loggers import csv as csv_logger
24
import sonnet as snt
25

26
from pwil import imitation_loop
27
from pwil import rewarder
28
from pwil import utils
29

30

31
flags.DEFINE_string('workdir', None, 'Logging directory')
32
flags.DEFINE_string('env_name', None, 'Environment name.')
33
flags.DEFINE_string('demo_dir', None, 'Directory of expert demonstrations.')
34
flags.DEFINE_boolean('state_only', False,
35
                     'Use only state for reward computation')
36
flags.DEFINE_float('sigma', 0.2, 'Exploration noise.')
37
flags.DEFINE_integer('num_transitions_rb', 50000,
38
                     'Number of transitions to fill the rb with.')
39
flags.DEFINE_integer('num_demonstrations', 1, 'Number of expert episodes.')
40
flags.DEFINE_integer('subsampling', 20, 'Subsampling factor of demonstrations.')
41
flags.DEFINE_integer('random_seed', 1, 'Experiment random seed.')
42
flags.DEFINE_integer('num_steps_per_iteration', 10000,
43
                     'Number of training steps per iteration.')
44
flags.DEFINE_integer('num_iterations', 100, 'Number of iterations.')
45
flags.DEFINE_integer('num_eval_episodes', 10, 'Number of evaluation episodes.')
46
flags.DEFINE_integer('samples_per_insert', 256, 'Controls update frequency.')
47
flags.DEFINE_float('policy_learning_rate', 1e-4,
48
                   'Larning rate for policy updates')
49
flags.DEFINE_float('critic_learning_rate', 1e-4,
50
                   'Larning rate for critic updates')
51

52
FLAGS = flags.FLAGS
53

54

55
def main(_):
56
  # Load environment.
57
  environment = utils.load_environment(FLAGS.env_name)
58
  environment_spec = specs.make_environment_spec(environment)
59

60
  # Create Rewarder.
61
  demonstrations = utils.load_demonstrations(
62
      demo_dir=FLAGS.demo_dir, env_name=FLAGS.env_name)
63
  pwil_rewarder = rewarder.PWILRewarder(
64
      demonstrations,
65
      subsampling=FLAGS.subsampling,
66
      env_specs=environment_spec,
67
      num_demonstrations=FLAGS.num_demonstrations,
68
      observation_only=FLAGS.state_only)
69

70
  # Define optimizers
71
  policy_optimizer = snt.optimizers.Adam(
72
      learning_rate=FLAGS.policy_learning_rate)
73
  critic_optimizer = snt.optimizers.Adam(
74
      learning_rate=FLAGS.critic_learning_rate)
75

76
  # Define D4PG agent.
77
  agent_networks = utils.make_d4pg_networks(environment_spec.actions)
78
  agent = d4pg.D4PG(
79
      environment_spec=environment_spec,
80
      policy_network=agent_networks['policy'],
81
      critic_network=agent_networks['critic'],
82
      observation_network=agent_networks['observation'],
83
      policy_optimizer=policy_optimizer,
84
      critic_optimizer=critic_optimizer,
85
      samples_per_insert=FLAGS.samples_per_insert,
86
      sigma=FLAGS.sigma,
87
  )
88

89
  # Prefill the agent's Replay Buffer.
90
  utils.prefill_rb_with_demonstrations(
91
      agent=agent,
92
      demonstrations=pwil_rewarder.demonstrations,
93
      num_transitions_rb=FLAGS.num_transitions_rb,
94
      reward=pwil_rewarder.reward_scale)
95

96
  # Create the eval policy (without exploration noise).
97
  eval_policy = snt.Sequential([
98
      agent_networks['observation'],
99
      agent_networks['policy'],
100
  ])
101
  eval_agent = FeedForwardActor(policy_network=eval_policy)
102

103
  # Define train/eval loops.
104

105
  train_logger = csv_logger.CSVLogger(
106
      directory=FLAGS.workdir, label='train_logs')
107
  eval_logger = csv_logger.CSVLogger(
108
      directory=FLAGS.workdir, label='eval_logs')
109

110

111
  train_loop = imitation_loop.TrainEnvironmentLoop(
112
      environment, agent, pwil_rewarder, logger=train_logger)
113

114
  eval_loop = imitation_loop.EvalEnvironmentLoop(
115
      environment, eval_agent, pwil_rewarder, logger=eval_logger)
116

117
  for _ in range(FLAGS.num_iterations):
118
    train_loop.run(num_steps=FLAGS.num_steps_per_iteration)
119
    eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
120

121
if __name__ == '__main__':
122
  app.run(main)
123

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.