google-research
378 строк · 13.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Hindsight Instruction Relabeling."""
17# pylint: disable=unused-variable
18# pylint: disable=unused-argument
19# pylint: disable=g-explicit-length-test
20from __future__ import absolute_import21from __future__ import division22
23import random24import time25
26import numpy as np27
28from hal.learner.language_utils import get_vocab_path29from hal.learner.language_utils import instruction_type30from hal.learner.language_utils import negate_unary_sentence31from hal.learner.language_utils import pad_to_max_length32from hal.learner.language_utils import paraphrase_sentence33from hal.utils.video_utils import add_text34from hal.utils.video_utils import pad_image35from hal.utils.video_utils import save_json36from hal.utils.video_utils import save_video37import hal.utils.word_vectorization as wv38
39
40class HIR:41"""Learner that executes Hindsight Instruction Relabeling.42
43Attributes:
44cfg: configuration of this learner
45step: current training step
46epsilon: value of the epsilon for sampling random action
47vocab_list: vocabulary list used for the instruction labeler
48encode_fn: function that encodes a instruction
49decode_fn: function that converts encoded instruction back to text
50labeler: object that generates labels for transitions
51"""
52
53def __init__(self, cfg):54# making session55self.cfg = cfg56self.step = 057self.epsilon = 1.058
59# Vocab loading60vocab_path = get_vocab_path(cfg)61self.vocab_list = wv.load_vocab_list(vocab_path)62v2i, i2v = wv.create_look_up_table(self.vocab_list)63self.encode_fn = wv.encode_text_with_lookup_table(64v2i, max_sequence_length=self.cfg.max_sequence_length)65self.decode_fn = wv.decode_with_lookup_table(i2v)66
67def reset(self, env, agent, sample_new_scene=False, **kwargs):68"""Reset at the episode boundary.69
70Args:
71env: the RL environment
72agent: the RL agent
73sample_new_scene: sample a brand new set of objects for the scene
74**kwargs: other potential arguments
75
76Returns:
77the reset state of the environment
78"""
79if self.cfg.reset_mode == 'random_action':80for _ in range(20):81s, _, _, _ = env.step(env.sample_random_action())82elif self.cfg.reset_mode == 'none':83s = env.get_obs()84else:85s = env.reset(sample_new_scene)86return s87
88def learn(self, env, agent, replay_buffer, **kwargs):89"""Run learning for 1 cycle with consists of num_episode of episodes.90
91Args:
92env: the RL environment
93agent: the RL agent
94replay_buffer: the experience replay buffer
95**kwargs: other potential arguments
96
97Returns:
98statistics of the training episode
99"""
100average_per_ep_reward = []101average_per_ep_achieved_n = []102average_per_ep_relabel_n = []103average_batch_loss = []104curiosity_loss = 0105
106curr_step = agent.get_global_step()107self.update_epsilon(curr_step)108tic = time.time()109time_rolling_out, time_training = 0.0, 0.0110for _ in range(self.cfg.num_episode):111curr_step = agent.increase_global_step()112
113sample_new_scene = random.uniform(0, 1) < self.cfg.sample_new_scene_prob114s = self.reset(env, agent, sample_new_scene)115episode_experience = []116episode_reward = 0117episode_achieved_n = 0118episode_relabel_n = 0119
120# rollout121rollout_tic = time.time()122g_text, p = env.sample_goal()123if env.all_goals_satisfied:124s = self.reset(env, agent, True)125g_text, p = env.sample_goal()126g = np.squeeze(self.encode_fn(g_text))127
128for t in range(self.cfg.max_episode_length):129a = agent.step(s, g, env, self.epsilon)130s_tp1, r, _, _ = env.step(131a,132record_achieved_goal=True,133goal=p,134atomic_goal=self.cfg.record_atomic_instruction)135ag = env.get_achieved_goals()136ag_text = env.get_achieved_goal_programs()137ag_total = ag # TODO(ydjiang): more can be stored in ag138episode_experience.append((s, a, r, s_tp1, g, ag_total))139episode_reward += r140s = s_tp1141if r > env.shape_val:142episode_achieved_n += 1143g_text, p = env.sample_goal()144if env.all_goals_satisfied:145break146g = np.squeeze(self.encode_fn(g_text))147time_rolling_out += time.time() - rollout_tic148
149average_per_ep_reward.append(episode_reward)150average_per_ep_achieved_n.append(episode_achieved_n)151
152# processing trajectory153train_tic = time.time()154episode_length = len(episode_experience)155for t in range(episode_length):156s, a, r, s_tp1, g, ag = episode_experience[t]157episode_relabel_n += float(len(ag) > 0)158g_text = self.decode_fn(g)159if self.cfg.paraphrase:160g_text = paraphrase_sentence(161g_text, delete_color=self.cfg.diverse_scene_content)162g = self.encode_fn(g_text)163replay_buffer.add((s, a, r, s_tp1, g))164if self.cfg.relabeling:165self.hir_relabel(episode_experience, t, replay_buffer, env)166
167average_per_ep_relabel_n.append(episode_relabel_n / float(episode_length))168
169# training170if not self.is_warming_up(curr_step):171batch_loss = 0172for _ in range(self.cfg.optimization_steps):173experience = replay_buffer.sample(self.cfg.batchsize)174s, a, r, s_tp1, g = [175np.squeeze(elem, axis=1) for elem in np.split(experience, 5, 1)176]177s = np.stack(s)178s_tp1 = np.stack(s_tp1)179g = np.array(list(g))180if self.cfg.instruction_repr == 'language':181g = np.array(pad_to_max_length(g, self.cfg.max_sequence_length))182batch = {183'obs': np.asarray(s),184'action': np.asarray(a),185'reward': np.asarray(r),186'obs_next': np.asarray(s_tp1),187'g': np.asarray(g)188}189loss_dict = agent.train(batch)190batch_loss += loss_dict['loss']191if 'prediction_loss' in loss_dict:192curiosity_loss += loss_dict['prediction_loss']193average_batch_loss.append(batch_loss / self.cfg.optimization_steps)194time_training += time.time()-train_tic195
196time_per_episode = (time.time() - tic) / self.cfg.num_episode197time_training_per_episode = time_training / self.cfg.num_episode198time_rolling_out_per_episode = time_rolling_out / self.cfg.num_episode199
200# Update the target network201agent.update_target_network()202################## Debug ##################203sample = replay_buffer.sample(min(10000, len(replay_buffer.buffer)))204_, _, sample_r, _, _ = [205np.squeeze(elem, axis=1) for elem in np.split(sample, 5, 1)206]207print('n one:', np.sum(np.float32(sample_r == 1.0)), 'n zero',208np.sum(np.float32(sample_r == 0.0)), 'n buff',209len(replay_buffer.buffer))210################## Debug ##################211stats = {212'loss': np.mean(average_batch_loss) if average_batch_loss else 0,213'reward': np.mean(average_per_ep_reward),214'achieved_goal': np.mean(average_per_ep_achieved_n),215'average_relabel_goal': np.mean(average_per_ep_relabel_n),216'epsilon': self.epsilon,217'global_step': curr_step,218'time_per_episode': time_per_episode,219'time_training_per_episode': time_training_per_episode,220'time_rolling_out_per_episode': time_rolling_out_per_episode,221'replay_buffer_reward_avg': np.mean(sample_r),222'replay_buffer_reward_var': np.var(sample_r)223}224return stats225
226def hir_relabel(self, episode_experience, current_t, replay_buffer, env):227"""Relabeling trajectories.228
229Args:
230episode_experience: the RL environment
231current_t: time time step at which the experience is relabeled
232replay_buffer: the experience replay buffer
233env: the RL environment
234
235Returns:
236the reset state of the environment
237"""
238ep_len = len(episode_experience)239s, a, _, s_tp1, _, ag = episode_experience[current_t]240if ag:241for _ in range(min(self.cfg.k_immediate, len(ag) + 1)):242ag_text_single = random.choice(ag)243g_type = instruction_type(ag_text_single)244if self.cfg.paraphrase and g_type != 'unary':245ag_text_single = paraphrase_sentence(246ag_text_single, delete_color=self.cfg.diverse_scene_content)247replay_buffer.add(248(s, a, env.reward_scale, s_tp1, self.encode_fn(ag_text_single)))249if g_type == 'unary' and self.cfg.negate_unary:250negative_ag = negate_unary_sentence(ag_text_single)251if negative_ag:252replay_buffer.add((s, a, 0., s_tp1, self.encode_fn(negative_ag)))253goal_count, repeat = 0, 0254while goal_count < self.cfg.future_k and repeat < (ep_len - current_t) * 2:255repeat += 1256future = np.random.randint(current_t, ep_len)257_, _, _, _, _, ag_future = episode_experience[future]258if not ag_future:259continue260random.shuffle(ag_future)261for single_g in ag_future:262if instruction_type(single_g) != 'unary':263discount = self.cfg.discount**(future - current_t)264if self.cfg.paraphrase:265single_g = paraphrase_sentence(266single_g, delete_color=self.cfg.diverse_scene_content)267replay_buffer.add((s, a, discount * env.reward_scale, s_tp1,268self.encode_fn(single_g)))269goal_count += 1270break271
272def update_epsilon(self, step):273new_epsilon = self.cfg.epsilon_decay**(step // self.cfg.num_episode)274self.epsilon = max(new_epsilon, self.cfg.min_epsilon)275
276def is_warming_up(self, step):277return step <= self.cfg.collect_cycle * self.cfg.num_episode278
279def rollout(self,280env,281agent,282directory,283record_video=False,284timeout=8,285num_episode=10,286record_trajectory=False):287"""Rollout and save.288
289Args:
290env: the RL environment
291agent: the RL agent
292directory: directory where the output of the rollout is saved
293record_video: record the video
294timeout: timeout step if the agent is stuck
295num_episode: number of rollout episode
296record_trajectory: record the ground truth trajectory
297
298Returns:
299percentage of success during this rollout
300"""
301print('#######################################')302print('Rolling out...')303print('#######################################')304all_frames = []305ep_observation, ep_action, ep_agn = [], [], []306black_frame = pad_image(env.render(mode='rgb_array')) * 0.0307goal_sampled = 0308timeout_count, success = 0, 0309for ep in range(num_episode):310s = self.reset(env, agent, self.cfg.diverse_scene_content)311all_frames += [black_frame] * 10312g_text, p = env.sample_goal()313if env.all_goals_satisfied:314s = self.reset(env, agent, True)315g, p = env.sample_goal()316goal_sampled += 1317g = np.squeeze(self.encode_fn(g_text))318current_goal_repetition = 0319for t in range(self.cfg.max_episode_length):320prob = self.epsilon if record_trajectory else 0.0321action = agent.step(s, g, env, explore_prob=prob)322s_tp1, r, _, _ = env.step(323action,324record_achieved_goal=True,325goal=p,326atomic_goal=self.cfg.record_atomic_instruction)327ag = env.get_achieved_goals()328s = s_tp1329all_frames.append(330add_text(pad_image(env.render(mode='rgb_array')), g_text))331current_goal_repetition += 1332
333if record_trajectory:334ep_observation.append(env.get_direct_obs().tolist())335ep_action.append(action)336ep_agn.append(len(ag))337
338sample_new_goal = False339if r > env.shape_val:340for _ in range(5):341all_frames.append(342add_text(343pad_image(env.render(mode='rgb_array')),344g_text,345color='green'))346success += 1347sample_new_goal = True348
349if current_goal_repetition >= timeout:350all_frames.append(351add_text(pad_image(env.render(mode='rgb_array')), 'time out :('))352timeout_count += 1353sample_new_goal = True354
355if sample_new_goal:356g, p = env.sample_goal()357if env.all_goals_satisfied:358break359g_text = g360g = np.squeeze(self.encode_fn(g))361current_goal_repetition = 0362goal_sampled += 1363
364print('Rollout finished')365print('{} instrutctions tried given'.format(goal_sampled))366print('{} instructions timed out'.format(timeout_count))367if record_video:368save_video(np.uint8(all_frames), directory, fps=5)369print('Video saved...')370if record_trajectory:371print('Recording trajectory...')372datum = {373'obs': ep_observation,374'action': ep_action,375'achieved goal': ep_agn,376}377save_json(datum, directory[:-4] + '_trajectory.json')378return success / float(num_episode)379