google-research
174 строки · 5.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Collect demonstrations for Adroit using an expert policy."""
17
18import os19import pickle20
21from absl import app22from absl import flags23import gym24import numpy as np25
26from rrlfd import adroit_ext # pylint: disable=unused-import27from rrlfd.bc import pickle_dataset28from tensorflow.io import gfile29
30
31flags.DEFINE_enum('task', None, ['door', 'hammer', 'pen', 'relocate'],32'Adroit task for which to collect demonstrations.')33flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to record.')34flags.DEFINE_integer('seed', 0, 'Experiment seed.')35flags.DEFINE_boolean('increment_seed', False,36'If True, increment seed at every episode.')37flags.DEFINE_integer('image_size', None, 'Size of rendered images.')38
39flags.DEFINE_string('expert_policy_dir', None,40'Path to pickle file with expert policy.')41flags.DEFINE_boolean('record_failed', False,42'If True, save failed demonstrations.')43flags.DEFINE_string('logdir', None, 'Location to save demonstrations to.')44flags.DEFINE_string('run_id', None,45'If set, a custom string to append to saved demonstrations '46'file name.')47
48FLAGS = flags.FLAGS49
50
51def env_loop(env, agent, num_episodes, log_path, record_failed, seed,52increment_seed, compress_images=True):53"""Loop for collecting demonstrations with an agent in a Gym environment."""54if log_path is None:55log_f = None56success_f = None57demo_writer = None58else:59log_f = gfile.GFile(log_path + '_log.txt', 'w')60success_f = gfile.GFile(log_path + '_success.txt', 'w')61demo_writer = pickle_dataset.DemoWriter(log_path + '.pkl', compress_images)62print('Writing demos to', log_path + '.pkl')63e = 064# Counter to keep track of seed offset, if not recording failed episodes.65skipped_seeds = 066num_successes = 067num_attempts = 068min_reward, max_reward = np.inf, -np.inf69while e < num_episodes:70if e % 10 == 0 and e > 0:71print(f'Episode {e} / {num_episodes}; '72f'Success rate {num_successes} / {num_attempts}')73if increment_seed:74env.seed(seed + skipped_seeds + e)75obs = env.reset()76
77done = False78_, agent_info = agent.get_action(obs['original_obs'])79action = agent_info['evaluation']80observations = []81actions = []82rewards = []83# For envs with non-Markovian success criteria, track required fields.84goals_achieved = []85
86while not done:87observations.append(obs)88actions.append(action)89obs, reward, done, info = env.step(action)90rewards.append(reward)91min_reward = min(min_reward, reward)92max_reward = max(max_reward, reward)93_, agent_info = agent.get_action(obs['original_obs'])94action = agent_info['evaluation']95if 'goal_achieved' in info:96goals_achieved.append(info['goal_achieved'])97
98# Environment defines success criteria based on full episode.99success_percentage = env.evaluate_success(100[{'env_infos': {'goal_achieved': goals_achieved}}])101success = bool(success_percentage)102
103num_successes += int(success)104num_attempts += 1105if success:106print(f'{e}: success')107if log_f is not None:108log_f.write(f'{e}: success\n')109log_f.flush()110if success_f is not None:111success_f.write('success\n')112success_f.flush()113else:114if 'TimeLimit.truncated' in info and info['TimeLimit.truncated']:115print(f'{e}: failure: time limit')116else:117print(f'{e}: failure')118if log_f is not None:119if 'TimeLimit.truncated' in info and info['TimeLimit.truncated']:120log_f.write(f'{e}: failure: time limit \n')121else:122log_f.write(f'{e}: failure\n')123log_f.flush()124if success_f is not None:125success_f.write('failure\n')126success_f.flush()127
128if success or record_failed:129e += 1130if demo_writer is not None:131demo_writer.write_episode(observations, actions, rewards)132elif not record_failed:133skipped_seeds += 1134
135print(f'Done; Success rate {num_successes} / {num_attempts}')136print('min reward', min_reward)137print('max reward', max_reward)138if log_f is not None:139log_f.write(f'Done; Success rate {num_successes} / {num_attempts}\n')140log_f.write(f'min reward {min_reward}\n')141log_f.write(f'max reward {max_reward}\n')142log_f.close()143
144
145def main(_):146with gfile.GFile(147os.path.join(FLAGS.expert_policy_dir, f'{FLAGS.task}.pickle'), 'rb') as f:148agent = pickle.load(f)149env = gym.make(f'visual-{FLAGS.task}-v0')150env.seed(FLAGS.seed)151im_size = FLAGS.image_size152if im_size is not None:153env.env.im_size = im_size154
155if FLAGS.logdir is None:156log_path = None157else:158logdir = os.path.join(FLAGS.logdir, f'{FLAGS.task}')159run_id = '' if FLAGS.run_id is None else '_' + FLAGS.run_id160if FLAGS.record_failed:161run_id += '_all'162if im_size is not None and im_size != adroit_ext.camera_kwargs['im_size']:163run_id += f'_{im_size}px'164increment_str = 'i' if FLAGS.increment_seed else ''165log_path = os.path.join(166logdir, f's{FLAGS.seed}{increment_str}_e{FLAGS.num_episodes}{run_id}')167gfile.makedirs(os.path.dirname(log_path))168print('Writing to', log_path)169env_loop(env, agent, FLAGS.num_episodes, log_path, FLAGS.record_failed,170FLAGS.seed, FLAGS.increment_seed)171
172
173if __name__ == '__main__':174app.run(main)175