google-research
222 строки · 8.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Hindsight Instruction Relabeling with Random Network Distillation."""
17# pylint: disable=unused-variable
18# pylint: disable=g-explicit-length-test
19# pylint: disable=unused-import
20# pylint: disable=line-too-long
21from __future__ import absolute_import22from __future__ import division23
24import random25import time26
27import numpy as np28
29from hal.learner.hir import HIR30from hal.learner.language_utils import get_vocab_path31from hal.learner.language_utils import instruction_type32from hal.learner.language_utils import negate_unary_sentence33from hal.learner.language_utils import pad_to_max_length34from hal.learner.language_utils import paraphrase_sentence35from hal.utils.video_utils import add_text36from hal.utils.video_utils import pad_image37from hal.utils.video_utils import save_json38# from hal.utils.video_utils import save_video
39import hal.utils.word_vectorization as wv40
41
42class RndHIR(HIR):43"""Learner that executes Hindsight Instruction Relabeling."""44
45def reset(self, env, agent, rnd_agent, sample_new_scene=False):46"""Reset at the episode boundary.47
48Args:
49env: the RL environment
50agent: the RL agent
51rnd_agent: the second agent responsible for doing RND
52sample_new_scene: sample a brand new set of objects for the scene
53
54Returns:
55the reset state of the environment
56"""
57if self.cfg.reset_mode == 'random_action':58for _ in range(20):59s, _, _, _ = env.step(env.sample_random_action())60elif self.cfg.reset_mode == 'none':61s = env.get_obs()62elif self.cfg.reset_mode == 'rnd':63s = env.get_obs()64for _ in range(20):65s, _, _, _ = env.step(rnd_agent.step(s, env, epsilon=0.1))66else:67s = env.reset(sample_new_scene)68return s69
70def learn(self, env, agent, replay_buffer, rnd_model, rnd_agent):71"""Run learning for 1 cycle with consists of num_episode of episodes.72
73Args:
74env: the RL environment
75agent: the RL agent
76replay_buffer: the experience replay buffer
77rnd_model: the rnd model for computing the pseudocount of a state
78rnd_agent: the second agent responsible for doing RND
79
80Returns:
81statistics of the training episode
82"""
83average_per_ep_reward = []84average_per_ep_achieved_n = []85average_per_ep_relabel_n = []86average_batch_loss = []87average_rnd_loss = []88average_rnd_agent_loss = []89
90curr_step = agent.get_global_step()91self.update_epsilon(curr_step)92tic = time.time()93for _ in range(self.cfg.num_episode):94curr_step = agent.increase_global_step()95
96sample_new_scene = random.uniform(0, 1) < self.cfg.sample_new_scene_prob97s = self.reset(env, agent, sample_new_scene)98episode_experience = []99episode_reward = 0100episode_achieved_n = 0101episode_relabel_n = 0102
103# rollout104g_text, p = env.sample_goal()105if env.all_goals_satisfied:106s = self.reset(env, agent, True)107g_text, p = env.sample_goal()108g = np.squeeze(self.encode_fn(g_text))109
110for t in range(self.cfg.max_episode_length):111rnd_model.update_stats(s) # Update the moving statistics of rnd model112a = agent.step(s, g, env, self.epsilon)113s_tp1, r, _, _ = env.step(114a,115record_achieved_goal=True,116goal=p,117atomic_goal=self.cfg.record_atomic_instruction)118ag = env.get_achieved_goals()119ag_text = env.get_achieved_goal_programs()120ag_total = ag # TODO(ydjiang): more can be stored in ag121episode_experience.append((s, a, r, s_tp1, g, ag_total))122episode_reward += r123s = s_tp1124if r > env.shape_val:125episode_achieved_n += 1126g_text, p = env.sample_goal()127if env.all_goals_satisfied:128break129g = np.squeeze(self.encode_fn(g_text))130
131average_per_ep_reward.append(episode_reward)132average_per_ep_achieved_n.append(episode_achieved_n)133
134# processing trajectory135episode_length = len(episode_experience)136for t in range(episode_length):137s, a, r, s_tp1, g, ag = episode_experience[t]138episode_relabel_n += float(len(ag) > 0)139g_text = self.decode_fn(g)140if self.cfg.paraphrase:141g_text = paraphrase_sentence(142g_text, delete_color=self.cfg.diverse_scene_content)143g = self.encode_fn(g_text)144replay_buffer.add((s, a, r, s_tp1, g))145if self.cfg.relabeling:146self.hir_relabel(episode_experience, t, replay_buffer, env)147
148average_per_ep_relabel_n.append(episode_relabel_n / float(episode_length))149
150if not self.is_warming_up(curr_step):151state_trajectory = []152for t in range(episode_length):153state_trajectory.append(episode_experience[t][0])154state_trajectory = np.stack(state_trajectory)155curiosity_loss = 0156for _ in range(self.cfg.optimization_steps):157curiosity_loss += rnd_model.train(state_trajectory)['prediction_loss']158
159average_rnd_loss.append(curiosity_loss / self.cfg.optimization_steps)160
161# training162if not self.is_warming_up(curr_step):163batch_loss, rnd_batch_loss = 0, 0164for _ in range(self.cfg.optimization_steps):165experience = replay_buffer.sample(self.cfg.batchsize)166s, a, r, s_tp1, g = [167np.squeeze(elem, axis=1) for elem in np.split(experience, 5, 1)168]169s = np.stack(s)170s_tp1 = np.stack(s_tp1)171g = np.array(list(g))172if self.cfg.instruction_repr == 'language':173g = np.array(pad_to_max_length(g, self.cfg.max_sequence_length))174batch = {175'obs': np.asarray(s),176'action': np.asarray(a),177'reward': np.asarray(r),178'obs_next': np.asarray(s_tp1),179'g': np.asarray(g)180}181loss_dict = agent.train(batch)182batch_loss += loss_dict['loss']183# update rnd agent184batch['reward'] = rnd_model.compute_intrinsic_reward(185batch['obs_next'])186batch['done'] = np.zeros(self.cfg.batchsize)187rnd_agent_loss = rnd_agent.train(batch)188rnd_batch_loss += rnd_agent_loss['loss']189average_batch_loss.append(batch_loss / self.cfg.optimization_steps)190average_rnd_agent_loss.append(rnd_batch_loss /191self.cfg.optimization_steps)192
193time_per_episode = (time.time() - tic) / self.cfg.num_episode194
195# Update the target network196agent.update_target_network()197rnd_agent.update_target_network()198
199################## Debug ##################200sample = replay_buffer.sample(min(10000, len(replay_buffer.buffer)))201sample_s, _, sample_r, _, _ = [202np.squeeze(elem, axis=1) for elem in np.split(sample, 5, 1)203]204sample_intrinsic_r = rnd_model.compute_intrinsic_reward(sample_s)205print('n one:', np.sum(np.float32(sample_r == 1.0)), 'n zero',206np.sum(np.float32(sample_r == 0.0)), 'n buff',207len(replay_buffer.buffer))208################## Debug ##################209stats = {210'loss': np.mean(average_batch_loss) if average_batch_loss else 0,211'reward': np.mean(average_per_ep_reward),212'achieved_goal': np.mean(average_per_ep_achieved_n),213'average_relabel_goal': np.mean(average_per_ep_relabel_n),214'epsilon': self.epsilon,215'global_step': curr_step,216'time_per_episode': time_per_episode,217'replay_buffer_reward_avg': np.mean(sample_r),218'replay_buffer_reward_var': np.var(sample_r),219'intrinsic_reward_avg': np.mean(sample_intrinsic_r),220'intrinsic_reward_var': np.var(sample_intrinsic_r)221}222return stats223