google-research
289 строк · 12.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Train low-level policy."""
17# pylint: disable=unused-variable
18# pylint: disable=g-import-not-at-top
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import json
24import os
25import sys
26
27from absl import app
28from absl import flags
29import tensorflow as tf
30
31from hal.experiment_config import get_exp_config
32from hal.experiment_setup import experiment_setup
33from hal.utils.config import Config
34from hal.utils.logger import Logger
35from hal.utils.logger import Logger2
36
37if 'gfile' not in sys.modules:
38import tf.io.gfile as gfile
39
40
41FLAGS = flags.FLAGS
42flags.DEFINE_bool('use_tf2', True, 'use eager execution')
43flags.DEFINE_bool('use_nn_relabeling', False,
44'use function approximators for relabeling')
45flags.DEFINE_bool(
46'use_labeler_as_reward', False,
47'use the labeling function reward instead of true environment reward')
48flags.DEFINE_bool(
49'use_oracle_instruction', True,
50'use the oracle/environment to generate the relabeling instructions')
51flags.DEFINE_string('save_dir', None, 'experiment home directory')
52flags.DEFINE_string('agent_type', 'pm', 'Which agent to use')
53flags.DEFINE_string('scenario_type', 'fixed_primitive', 'Which env to use')
54flags.DEFINE_bool('save_model', False, 'save model and log')
55flags.DEFINE_bool('save_video', False, 'save video for evaluation')
56flags.DEFINE_integer('save_interval', 50, 'intervals between saving models')
57flags.DEFINE_integer('video_interval', 400, 'intervals between videos')
58flags.DEFINE_bool('direct_obs', True, 'direct observation')
59flags.DEFINE_string('action_type', 'perfect', 'what type of action to use')
60flags.DEFINE_string('obs_type', 'order_invariant', 'type of observation')
61flags.DEFINE_integer('img_resolution', 64, 'resolution of image observations')
62flags.DEFINE_integer('render_resolution', 300, 'resolution of rendered image')
63flags.DEFINE_integer('max_episode_length', 50, 'maximum episode duration')
64flags.DEFINE_integer('num_epoch', 200, 'number of epoch')
65flags.DEFINE_integer('num_cycle', 50, 'number of cycle per epoch')
66flags.DEFINE_integer('num_episode', 50, 'number of episode per cycle')
67flags.DEFINE_integer('optimization_steps', 100, 'optimization per episode')
68flags.DEFINE_integer('collect_cycle', 10, 'cycles for populating buffer')
69flags.DEFINE_integer('future_k', 3, 'number of future to put into buffer')
70flags.DEFINE_integer('buffer_size', int(1e6), 'size of replay buffer')
71flags.DEFINE_integer('batchsize', 128, 'batchsize')
72flags.DEFINE_integer('rollout_episode', 10, 'number of episode for testing')
73flags.DEFINE_bool('record_trajectory', False,
74'test using the same epsilon as training and record traj')
75flags.DEFINE_float('polyak_rate', 0.95, 'moving average factor for target')
76flags.DEFINE_float('discount', 0.5, 'discount factor')
77flags.DEFINE_float('initial_epsilon', 1.0, 'initial epsilon')
78flags.DEFINE_float('min_epsilon', 0.1, 'minimum epsilon')
79flags.DEFINE_float('learning_rate', 1e-4, 'minimum epsilon')
80flags.DEFINE_float('epsilon_decay', 0.95, 'decay factor for epsilon')
81flags.DEFINE_float('sample_new_scene_prob', 0.1,
82'Probability of sampling a new scene')
83flags.DEFINE_bool('record_atomic_instruction', False, 'record atomic goals')
84flags.DEFINE_integer('k_immediate', 2,
85'number of immediate correct statements added')
86flags.DEFINE_bool('masking_q', False, 'mask the end of the episode')
87flags.DEFINE_bool('paraphrase', False, 'paraphrase sentences')
88flags.DEFINE_bool('relabeling', True, 'use hindsight experience replay')
89flags.DEFINE_bool('use_subset_instruction', False,
90'use a subset of 600 sentences, for generalization')
91flags.DEFINE_string('image_action_parameterization', 'regular',
92'type of action parameterization used by the image model')
93flags.DEFINE_integer('frame_skip', 20, 'simulation step for the physics')
94flags.DEFINE_bool('use_polar', False,
95'use polar coordinate for neighbor assignment')
96flags.DEFINE_bool('suppress', False,
97'suppress movements of unnecessary objects')
98flags.DEFINE_bool('diverse_scene_content', False,
99'whether to use variable scene content')
100flags.DEFINE_bool('use_synonym_for_rollout', False,
101'use unseen synonyms for rolling out')
102flags.DEFINE_float('reward_shape_val', 0.25, 'Value for reward shaping')
103flags.DEFINE_string('instruction_repr', 'language',
104'representation of the instruction')
105flags.DEFINE_string('encoder_type', 'vanilla_rnn', 'type of language encoder')
106flags.DEFINE_string('embedding_type', 'random',
107'type of word embedding to use for training agent')
108flags.DEFINE_bool('negate_unary', True, 'Negate unary sentences')
109flags.DEFINE_string('experiment_confg', None, 'specific experiment config')
110flags.DEFINE_bool('trainable_encoder', True, 'Is encoder trainable for agent.')
111flags.DEFINE_bool('use_movement_bonus', False, 'Encourage moving objects.')
112flags.DEFINE_string('varying', None, 'What parameters are changing.')
113flags.DEFINE_integer('generated_label_num', 50, 'Number of generated label')
114flags.DEFINE_float('sampling_temperature', 1.0, 'Sampling temperature')
115flags.DEFINE_string('reset_mode', 'regular', 'How the environment is reset')
116flags.DEFINE_float('reward_scale', 1.0, 'Reward scale of the environment')
117# Maxent IRL
118flags.DEFINE_bool('maxent_irl', False, 'Use maximum entropy IRL')
119flags.DEFINE_integer('irl_parallel_n', 1, 'Number of parallel inference in irl')
120flags.DEFINE_integer('irl_sample_goal_n', 32, 'Number of goals sampled for irl')
121flags.DEFINE_float('relabel_proportion', 0.5, 'portion of minibatch to relabel')
122flags.DEFINE_float('entropy_alpha', 0.001, 'alpha for max ent')
123
124
125def main(_):
126if FLAGS.use_tf2:
127tf.enable_v2_behavior()
128config_content = {
129'action_type': FLAGS.action_type,
130'obs_type': FLAGS.obs_type,
131'reward_shape_val': FLAGS.reward_shape_val,
132'use_subset_instruction': FLAGS.use_subset_instruction,
133'frame_skip': FLAGS.frame_skip,
134'use_polar': FLAGS.use_polar,
135'suppress': FLAGS.suppress,
136'diverse_scene_content': FLAGS.diverse_scene_content,
137'buffer_size': FLAGS.buffer_size,
138'use_movement_bonus': FLAGS.use_movement_bonus,
139'reward_scale': FLAGS.reward_scale,
140'scenario_type': FLAGS.scenario_type,
141'img_resolution': FLAGS.img_resolution,
142'render_resolution': FLAGS.render_resolution,
143
144# agent
145'agent_type': FLAGS.agent_type,
146'masking_q': FLAGS.masking_q,
147'discount': FLAGS.discount,
148'instruction_repr': FLAGS.instruction_repr,
149'encoder_type': FLAGS.encoder_type,
150'learning_rate': FLAGS.learning_rate,
151'polyak_rate': FLAGS.polyak_rate,
152'trainable_encoder': FLAGS.trainable_encoder,
153'embedding_type': FLAGS.embedding_type,
154
155# learner
156'num_episode': FLAGS.num_episode,
157'optimization_steps': FLAGS.optimization_steps,
158'batchsize': FLAGS.batchsize,
159'sample_new_scene_prob': FLAGS.sample_new_scene_prob,
160'max_episode_length': FLAGS.max_episode_length,
161'record_atomic_instruction': FLAGS.record_atomic_instruction,
162'paraphrase': FLAGS.paraphrase,
163'relabeling': FLAGS.relabeling,
164'k_immediate': FLAGS.k_immediate,
165'future_k': FLAGS.future_k,
166'negate_unary': FLAGS.negate_unary,
167'min_epsilon': FLAGS.min_epsilon,
168'epsilon_decay': FLAGS.epsilon_decay,
169'collect_cycle': FLAGS.collect_cycle,
170'use_synonym_for_rollout': FLAGS.use_synonym_for_rollout,
171'reset_mode': FLAGS.reset_mode,
172'maxent_irl': FLAGS.maxent_irl,
173
174# relabeler
175'sampling_temperature': FLAGS.sampling_temperature,
176'generated_label_num': FLAGS.generated_label_num,
177'use_labeler_as_reward': FLAGS.use_labeler_as_reward,
178'use_oracle_instruction': FLAGS.use_oracle_instruction
179}
180
181if FLAGS.maxent_irl:
182assert FLAGS.batchsize % FLAGS.irl_parallel_n == 0
183config_content['irl_parallel_n'] = FLAGS.irl_parallel_n
184config_content['irl_sample_goal_n'] = FLAGS.irl_sample_goal_n
185config_content['relabel_proportion'] = FLAGS.relabel_proportion
186config_content['entropy_alpha'] = FLAGS.entropy_alpha
187
188cfg = Config(config_content)
189
190if FLAGS.experiment_confg:
191cfg.update(get_exp_config(FLAGS.experiment_confg))
192
193save_home = FLAGS.save_dir if FLAGS.save_dir else tf.test.get_temp_dir()
194if FLAGS.varying:
195exp_name = 'exp-'
196for varied_var in FLAGS.varying.split(','):
197exp_name += str(varied_var) + '=' + str(FLAGS[varied_var].value) + '-'
198else:
199exp_name = 'SingleExperiment'
200save_dir = os.path.join(save_home, exp_name)
201try:
202gfile.MkDir(save_home)
203except gfile.Error as e:
204print(e)
205try:
206gfile.MkDir(save_dir)
207except gfile.Error as e:
208print(e)
209
210cfg.update(Config({'model_dir': save_dir}))
211
212print('############################################################')
213print(cfg)
214print('############################################################')
215
216env, learner, replay_buffer, agent, extra_components = experiment_setup(
217cfg, FLAGS.use_tf2, FLAGS.use_nn_relabeling)
218agent.init_networks()
219
220if FLAGS.use_tf2:
221logger = Logger2(save_dir)
222else:
223logger = Logger(save_dir)
224
225with gfile.GFile(os.path.join(save_dir, 'config.json'), mode='w+') as f:
226json.dump(cfg.as_dict(), f, sort_keys=True, indent=4)
227
228if FLAGS.save_model and tf.train.latest_checkpoint(save_dir):
229print('Loading saved weights from {}'.format(save_dir))
230agent.load_model(save_dir)
231
232if FLAGS.save_model:
233video_dir = os.path.join(save_dir, 'rollout_cycle_{}.mp4'.format('init'))
234print('Saving video to {}'.format(video_dir))
235learner.rollout(
236env,
237agent,
238video_dir,
239num_episode=FLAGS.rollout_episode,
240record_trajectory=FLAGS.record_trajectory)
241
242success_rate_ema = -1.0
243
244# Training loop
245for epoch in range(FLAGS.num_epoch):
246for cycle in range(FLAGS.num_cycle):
247stats = learner.learn(env, agent, replay_buffer)
248
249if success_rate_ema < 0:
250success_rate_ema = stats['achieved_goal']
251
252loss_dropped = stats['achieved_goal'] < 0.1 * success_rate_ema
253far_along_training = stats['global_step'] > 100000
254if FLAGS.save_model and loss_dropped and far_along_training:
255print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
256print('Step {}: Loading models due to sudden loss drop D:'.format(
257stats['global_step']))
258print('Dropped from {} to {}'.format(success_rate_ema,
259stats['achieved_goal']))
260agent.load_model(save_dir)
261print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
262continue
263success_rate_ema = 0.95 * success_rate_ema + 0.05 * stats['achieved_goal']
264
265at_save_interval = stats['global_step'] % FLAGS.save_interval == 0
266better_reward = stats['achieved_goal'] > success_rate_ema
267if FLAGS.save_model and at_save_interval and better_reward:
268print('Saving model to {}'.format(save_dir))
269agent.save_model(save_dir)
270
271if FLAGS.save_model and stats['global_step'] % FLAGS.video_interval == 0:
272video_dir = os.path.join(save_dir, 'rollout_cycle_{}.mp4'.format(cycle))
273print('Saving video to {}'.format(video_dir))
274test_success_rate = learner.rollout(
275env,
276agent,
277video_dir,
278record_video=FLAGS.save_video,
279num_episode=FLAGS.rollout_episode,
280record_trajectory=FLAGS.record_trajectory)
281stats['Test Success Rate'] = test_success_rate
282print('Test Success Rate: {}'.format(test_success_rate))
283
284stats['ema success rate'] = success_rate_ema
285logger.log(epoch, cycle, stats)
286
287
288if __name__ == '__main__':
289app.run(main)
290