google-research

Форк
0
/
train.py 
289 строк · 12.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Train low-level policy."""
17
# pylint: disable=unused-variable
18
# pylint: disable=g-import-not-at-top
19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
import json
24
import os
25
import sys
26

27
from absl import app
28
from absl import flags
29
import tensorflow as tf
30

31
from hal.experiment_config import get_exp_config
32
from hal.experiment_setup import experiment_setup
33
from hal.utils.config import Config
34
from hal.utils.logger import Logger
35
from hal.utils.logger import Logger2
36

37
if 'gfile' not in sys.modules:
38
  import tf.io.gfile as gfile
39

40

41
FLAGS = flags.FLAGS
42
flags.DEFINE_bool('use_tf2', True, 'use eager execution')
43
flags.DEFINE_bool('use_nn_relabeling', False,
44
                  'use function approximators for relabeling')
45
flags.DEFINE_bool(
46
    'use_labeler_as_reward', False,
47
    'use the labeling function reward instead of true environment reward')
48
flags.DEFINE_bool(
49
    'use_oracle_instruction', True,
50
    'use the oracle/environment to generate the relabeling instructions')
51
flags.DEFINE_string('save_dir', None, 'experiment home directory')
52
flags.DEFINE_string('agent_type', 'pm', 'Which agent to use')
53
flags.DEFINE_string('scenario_type', 'fixed_primitive', 'Which env to use')
54
flags.DEFINE_bool('save_model', False, 'save model and log')
55
flags.DEFINE_bool('save_video', False, 'save video for evaluation')
56
flags.DEFINE_integer('save_interval', 50, 'intervals between saving models')
57
flags.DEFINE_integer('video_interval', 400, 'intervals between videos')
58
flags.DEFINE_bool('direct_obs', True, 'direct observation')
59
flags.DEFINE_string('action_type', 'perfect', 'what type of action to use')
60
flags.DEFINE_string('obs_type', 'order_invariant', 'type of observation')
61
flags.DEFINE_integer('img_resolution', 64, 'resolution of image observations')
62
flags.DEFINE_integer('render_resolution', 300, 'resolution of rendered image')
63
flags.DEFINE_integer('max_episode_length', 50, 'maximum episode duration')
64
flags.DEFINE_integer('num_epoch', 200, 'number of epoch')
65
flags.DEFINE_integer('num_cycle', 50, 'number of cycle per epoch')
66
flags.DEFINE_integer('num_episode', 50, 'number of episode per cycle')
67
flags.DEFINE_integer('optimization_steps', 100, 'optimization per episode')
68
flags.DEFINE_integer('collect_cycle', 10, 'cycles for populating buffer')
69
flags.DEFINE_integer('future_k', 3, 'number of future to put into buffer')
70
flags.DEFINE_integer('buffer_size', int(1e6), 'size of replay buffer')
71
flags.DEFINE_integer('batchsize', 128, 'batchsize')
72
flags.DEFINE_integer('rollout_episode', 10, 'number of episode for testing')
73
flags.DEFINE_bool('record_trajectory', False,
74
                  'test using the same epsilon as training and record traj')
75
flags.DEFINE_float('polyak_rate', 0.95, 'moving average factor for target')
76
flags.DEFINE_float('discount', 0.5, 'discount factor')
77
flags.DEFINE_float('initial_epsilon', 1.0, 'initial epsilon')
78
flags.DEFINE_float('min_epsilon', 0.1, 'minimum epsilon')
79
flags.DEFINE_float('learning_rate', 1e-4, 'minimum epsilon')
80
flags.DEFINE_float('epsilon_decay', 0.95, 'decay factor for epsilon')
81
flags.DEFINE_float('sample_new_scene_prob', 0.1,
82
                   'Probability of sampling a new scene')
83
flags.DEFINE_bool('record_atomic_instruction', False, 'record atomic goals')
84
flags.DEFINE_integer('k_immediate', 2,
85
                     'number of immediate correct statements added')
86
flags.DEFINE_bool('masking_q', False, 'mask the end of the episode')
87
flags.DEFINE_bool('paraphrase', False, 'paraphrase sentences')
88
flags.DEFINE_bool('relabeling', True, 'use hindsight experience replay')
89
flags.DEFINE_bool('use_subset_instruction', False,
90
                  'use a subset of 600 sentences, for generalization')
91
flags.DEFINE_string('image_action_parameterization', 'regular',
92
                    'type of action parameterization used by the image model')
93
flags.DEFINE_integer('frame_skip', 20, 'simulation step for the physics')
94
flags.DEFINE_bool('use_polar', False,
95
                  'use polar coordinate for neighbor assignment')
96
flags.DEFINE_bool('suppress', False,
97
                  'suppress movements of unnecessary objects')
98
flags.DEFINE_bool('diverse_scene_content', False,
99
                  'whether to use variable scene content')
100
flags.DEFINE_bool('use_synonym_for_rollout', False,
101
                  'use unseen synonyms for rolling out')
102
flags.DEFINE_float('reward_shape_val', 0.25, 'Value for reward shaping')
103
flags.DEFINE_string('instruction_repr', 'language',
104
                    'representation of the instruction')
105
flags.DEFINE_string('encoder_type', 'vanilla_rnn', 'type of language encoder')
106
flags.DEFINE_string('embedding_type', 'random',
107
                    'type of word embedding to use for training agent')
108
flags.DEFINE_bool('negate_unary', True, 'Negate unary sentences')
109
flags.DEFINE_string('experiment_confg', None, 'specific experiment config')
110
flags.DEFINE_bool('trainable_encoder', True, 'Is encoder trainable for agent.')
111
flags.DEFINE_bool('use_movement_bonus', False, 'Encourage moving objects.')
112
flags.DEFINE_string('varying', None, 'What parameters are changing.')
113
flags.DEFINE_integer('generated_label_num', 50, 'Number of generated label')
114
flags.DEFINE_float('sampling_temperature', 1.0, 'Sampling temperature')
115
flags.DEFINE_string('reset_mode', 'regular', 'How the environment is reset')
116
flags.DEFINE_float('reward_scale', 1.0, 'Reward scale of the environment')
117
# Maxent IRL
118
flags.DEFINE_bool('maxent_irl', False, 'Use maximum entropy IRL')
119
flags.DEFINE_integer('irl_parallel_n', 1, 'Number of parallel inference in irl')
120
flags.DEFINE_integer('irl_sample_goal_n', 32, 'Number of goals sampled for irl')
121
flags.DEFINE_float('relabel_proportion', 0.5, 'portion of minibatch to relabel')
122
flags.DEFINE_float('entropy_alpha', 0.001, 'alpha for max ent')
123

124

125
def main(_):
126
  if FLAGS.use_tf2:
127
    tf.enable_v2_behavior()
128
  config_content = {
129
      'action_type': FLAGS.action_type,
130
      'obs_type': FLAGS.obs_type,
131
      'reward_shape_val': FLAGS.reward_shape_val,
132
      'use_subset_instruction': FLAGS.use_subset_instruction,
133
      'frame_skip': FLAGS.frame_skip,
134
      'use_polar': FLAGS.use_polar,
135
      'suppress': FLAGS.suppress,
136
      'diverse_scene_content': FLAGS.diverse_scene_content,
137
      'buffer_size': FLAGS.buffer_size,
138
      'use_movement_bonus': FLAGS.use_movement_bonus,
139
      'reward_scale': FLAGS.reward_scale,
140
      'scenario_type': FLAGS.scenario_type,
141
      'img_resolution': FLAGS.img_resolution,
142
      'render_resolution': FLAGS.render_resolution,
143

144
      # agent
145
      'agent_type': FLAGS.agent_type,
146
      'masking_q': FLAGS.masking_q,
147
      'discount': FLAGS.discount,
148
      'instruction_repr': FLAGS.instruction_repr,
149
      'encoder_type': FLAGS.encoder_type,
150
      'learning_rate': FLAGS.learning_rate,
151
      'polyak_rate': FLAGS.polyak_rate,
152
      'trainable_encoder': FLAGS.trainable_encoder,
153
      'embedding_type': FLAGS.embedding_type,
154

155
      # learner
156
      'num_episode': FLAGS.num_episode,
157
      'optimization_steps': FLAGS.optimization_steps,
158
      'batchsize': FLAGS.batchsize,
159
      'sample_new_scene_prob': FLAGS.sample_new_scene_prob,
160
      'max_episode_length': FLAGS.max_episode_length,
161
      'record_atomic_instruction': FLAGS.record_atomic_instruction,
162
      'paraphrase': FLAGS.paraphrase,
163
      'relabeling': FLAGS.relabeling,
164
      'k_immediate': FLAGS.k_immediate,
165
      'future_k': FLAGS.future_k,
166
      'negate_unary': FLAGS.negate_unary,
167
      'min_epsilon': FLAGS.min_epsilon,
168
      'epsilon_decay': FLAGS.epsilon_decay,
169
      'collect_cycle': FLAGS.collect_cycle,
170
      'use_synonym_for_rollout': FLAGS.use_synonym_for_rollout,
171
      'reset_mode': FLAGS.reset_mode,
172
      'maxent_irl': FLAGS.maxent_irl,
173

174
      # relabeler
175
      'sampling_temperature': FLAGS.sampling_temperature,
176
      'generated_label_num': FLAGS.generated_label_num,
177
      'use_labeler_as_reward': FLAGS.use_labeler_as_reward,
178
      'use_oracle_instruction': FLAGS.use_oracle_instruction
179
  }
180

181
  if FLAGS.maxent_irl:
182
    assert FLAGS.batchsize % FLAGS.irl_parallel_n == 0
183
    config_content['irl_parallel_n'] = FLAGS.irl_parallel_n
184
    config_content['irl_sample_goal_n'] = FLAGS.irl_sample_goal_n
185
    config_content['relabel_proportion'] = FLAGS.relabel_proportion
186
    config_content['entropy_alpha'] = FLAGS.entropy_alpha
187

188
  cfg = Config(config_content)
189

190
  if FLAGS.experiment_confg:
191
    cfg.update(get_exp_config(FLAGS.experiment_confg))
192

193
  save_home = FLAGS.save_dir if FLAGS.save_dir else tf.test.get_temp_dir()
194
  if FLAGS.varying:
195
    exp_name = 'exp-'
196
    for varied_var in FLAGS.varying.split(','):
197
      exp_name += str(varied_var) + '=' + str(FLAGS[varied_var].value) + '-'
198
  else:
199
    exp_name = 'SingleExperiment'
200
  save_dir = os.path.join(save_home, exp_name)
201
  try:
202
    gfile.MkDir(save_home)
203
  except gfile.Error as e:
204
    print(e)
205
  try:
206
    gfile.MkDir(save_dir)
207
  except gfile.Error as e:
208
    print(e)
209

210
  cfg.update(Config({'model_dir': save_dir}))
211

212
  print('############################################################')
213
  print(cfg)
214
  print('############################################################')
215

216
  env, learner, replay_buffer, agent, extra_components = experiment_setup(
217
      cfg, FLAGS.use_tf2, FLAGS.use_nn_relabeling)
218
  agent.init_networks()
219

220
  if FLAGS.use_tf2:
221
    logger = Logger2(save_dir)
222
  else:
223
    logger = Logger(save_dir)
224

225
  with gfile.GFile(os.path.join(save_dir, 'config.json'), mode='w+') as f:
226
    json.dump(cfg.as_dict(), f, sort_keys=True, indent=4)
227

228
  if FLAGS.save_model and tf.train.latest_checkpoint(save_dir):
229
    print('Loading saved weights from {}'.format(save_dir))
230
    agent.load_model(save_dir)
231

232
  if FLAGS.save_model:
233
    video_dir = os.path.join(save_dir, 'rollout_cycle_{}.mp4'.format('init'))
234
    print('Saving video to {}'.format(video_dir))
235
    learner.rollout(
236
        env,
237
        agent,
238
        video_dir,
239
        num_episode=FLAGS.rollout_episode,
240
        record_trajectory=FLAGS.record_trajectory)
241

242
  success_rate_ema = -1.0
243

244
  # Training loop
245
  for epoch in range(FLAGS.num_epoch):
246
    for cycle in range(FLAGS.num_cycle):
247
      stats = learner.learn(env, agent, replay_buffer)
248

249
      if success_rate_ema < 0:
250
        success_rate_ema = stats['achieved_goal']
251

252
      loss_dropped = stats['achieved_goal'] < 0.1 * success_rate_ema
253
      far_along_training = stats['global_step'] > 100000
254
      if FLAGS.save_model and loss_dropped and far_along_training:
255
        print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
256
        print('Step {}: Loading models due to sudden loss drop D:'.format(
257
            stats['global_step']))
258
        print('Dropped from {} to {}'.format(success_rate_ema,
259
                                             stats['achieved_goal']))
260
        agent.load_model(save_dir)
261
        print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
262
        continue
263
      success_rate_ema = 0.95 * success_rate_ema + 0.05 * stats['achieved_goal']
264

265
      at_save_interval = stats['global_step'] % FLAGS.save_interval == 0
266
      better_reward = stats['achieved_goal'] > success_rate_ema
267
      if FLAGS.save_model and at_save_interval and better_reward:
268
        print('Saving model to {}'.format(save_dir))
269
        agent.save_model(save_dir)
270

271
      if FLAGS.save_model and stats['global_step'] % FLAGS.video_interval == 0:
272
        video_dir = os.path.join(save_dir, 'rollout_cycle_{}.mp4'.format(cycle))
273
        print('Saving video to {}'.format(video_dir))
274
        test_success_rate = learner.rollout(
275
            env,
276
            agent,
277
            video_dir,
278
            record_video=FLAGS.save_video,
279
            num_episode=FLAGS.rollout_episode,
280
            record_trajectory=FLAGS.record_trajectory)
281
        stats['Test Success Rate'] = test_success_rate
282
        print('Test Success Rate: {}'.format(test_success_rate))
283

284
      stats['ema success rate'] = success_rate_ema
285
      logger.log(epoch, cycle, stats)
286

287

288
if __name__ == '__main__':
289
  app.run(main)
290

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.