google-research

Форк
0
222 строки · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Hindsight Instruction Relabeling with Random Network Distillation."""
17
# pylint: disable=unused-variable
18
# pylint: disable=g-explicit-length-test
19
# pylint: disable=unused-import
20
# pylint: disable=line-too-long
21
from __future__ import absolute_import
22
from __future__ import division
23

24
import random
25
import time
26

27
import numpy as np
28

29
from hal.learner.hir import HIR
30
from hal.learner.language_utils import get_vocab_path
31
from hal.learner.language_utils import instruction_type
32
from hal.learner.language_utils import negate_unary_sentence
33
from hal.learner.language_utils import pad_to_max_length
34
from hal.learner.language_utils import paraphrase_sentence
35
from hal.utils.video_utils import add_text
36
from hal.utils.video_utils import pad_image
37
from hal.utils.video_utils import save_json
38
# from hal.utils.video_utils import save_video
39
import hal.utils.word_vectorization as wv
40

41

42
class RndHIR(HIR):
43
  """Learner that executes Hindsight Instruction Relabeling."""
44

45
  def reset(self, env, agent, rnd_agent, sample_new_scene=False):
46
    """Reset at the episode boundary.
47

48
    Args:
49
      env: the RL environment
50
      agent: the RL agent
51
      rnd_agent: the second agent responsible for doing RND
52
      sample_new_scene: sample a brand new set of objects for the scene
53

54
    Returns:
55
      the reset state of the environment
56
    """
57
    if self.cfg.reset_mode == 'random_action':
58
      for _ in range(20):
59
        s, _, _, _ = env.step(env.sample_random_action())
60
    elif self.cfg.reset_mode == 'none':
61
      s = env.get_obs()
62
    elif self.cfg.reset_mode == 'rnd':
63
      s = env.get_obs()
64
      for _ in range(20):
65
        s, _, _, _ = env.step(rnd_agent.step(s, env, epsilon=0.1))
66
    else:
67
      s = env.reset(sample_new_scene)
68
    return s
69

70
  def learn(self, env, agent, replay_buffer, rnd_model, rnd_agent):
71
    """Run learning for 1 cycle with consists of num_episode of episodes.
72

73
    Args:
74
      env: the RL environment
75
      agent: the RL agent
76
      replay_buffer: the experience replay buffer
77
      rnd_model: the rnd model for computing the pseudocount of a state
78
      rnd_agent: the second agent responsible for doing RND
79

80
    Returns:
81
      statistics of the training episode
82
    """
83
    average_per_ep_reward = []
84
    average_per_ep_achieved_n = []
85
    average_per_ep_relabel_n = []
86
    average_batch_loss = []
87
    average_rnd_loss = []
88
    average_rnd_agent_loss = []
89

90
    curr_step = agent.get_global_step()
91
    self.update_epsilon(curr_step)
92
    tic = time.time()
93
    for _ in range(self.cfg.num_episode):
94
      curr_step = agent.increase_global_step()
95

96
      sample_new_scene = random.uniform(0, 1) < self.cfg.sample_new_scene_prob
97
      s = self.reset(env, agent, sample_new_scene)
98
      episode_experience = []
99
      episode_reward = 0
100
      episode_achieved_n = 0
101
      episode_relabel_n = 0
102

103
      # rollout
104
      g_text, p = env.sample_goal()
105
      if env.all_goals_satisfied:
106
        s = self.reset(env, agent, True)
107
        g_text, p = env.sample_goal()
108
      g = np.squeeze(self.encode_fn(g_text))
109

110
      for t in range(self.cfg.max_episode_length):
111
        rnd_model.update_stats(s)  # Update the moving statistics of rnd model
112
        a = agent.step(s, g, env, self.epsilon)
113
        s_tp1, r, _, _ = env.step(
114
            a,
115
            record_achieved_goal=True,
116
            goal=p,
117
            atomic_goal=self.cfg.record_atomic_instruction)
118
        ag = env.get_achieved_goals()
119
        ag_text = env.get_achieved_goal_programs()
120
        ag_total = ag  # TODO(ydjiang): more can be stored in ag
121
        episode_experience.append((s, a, r, s_tp1, g, ag_total))
122
        episode_reward += r
123
        s = s_tp1
124
        if r > env.shape_val:
125
          episode_achieved_n += 1
126
          g_text, p = env.sample_goal()
127
          if env.all_goals_satisfied:
128
            break
129
          g = np.squeeze(self.encode_fn(g_text))
130

131
      average_per_ep_reward.append(episode_reward)
132
      average_per_ep_achieved_n.append(episode_achieved_n)
133

134
      # processing trajectory
135
      episode_length = len(episode_experience)
136
      for t in range(episode_length):
137
        s, a, r, s_tp1, g, ag = episode_experience[t]
138
        episode_relabel_n += float(len(ag) > 0)
139
        g_text = self.decode_fn(g)
140
        if self.cfg.paraphrase:
141
          g_text = paraphrase_sentence(
142
              g_text, delete_color=self.cfg.diverse_scene_content)
143
        g = self.encode_fn(g_text)
144
        replay_buffer.add((s, a, r, s_tp1, g))
145
        if self.cfg.relabeling:
146
          self.hir_relabel(episode_experience, t, replay_buffer, env)
147

148
      average_per_ep_relabel_n.append(episode_relabel_n / float(episode_length))
149

150
      if not self.is_warming_up(curr_step):
151
        state_trajectory = []
152
        for t in range(episode_length):
153
          state_trajectory.append(episode_experience[t][0])
154
        state_trajectory = np.stack(state_trajectory)
155
        curiosity_loss = 0
156
        for _ in range(self.cfg.optimization_steps):
157
          curiosity_loss += rnd_model.train(state_trajectory)['prediction_loss']
158

159
        average_rnd_loss.append(curiosity_loss / self.cfg.optimization_steps)
160

161
      # training
162
      if not self.is_warming_up(curr_step):
163
        batch_loss, rnd_batch_loss = 0, 0
164
        for _ in range(self.cfg.optimization_steps):
165
          experience = replay_buffer.sample(self.cfg.batchsize)
166
          s, a, r, s_tp1, g = [
167
              np.squeeze(elem, axis=1) for elem in np.split(experience, 5, 1)
168
          ]
169
          s = np.stack(s)
170
          s_tp1 = np.stack(s_tp1)
171
          g = np.array(list(g))
172
          if self.cfg.instruction_repr == 'language':
173
            g = np.array(pad_to_max_length(g, self.cfg.max_sequence_length))
174
          batch = {
175
              'obs': np.asarray(s),
176
              'action': np.asarray(a),
177
              'reward': np.asarray(r),
178
              'obs_next': np.asarray(s_tp1),
179
              'g': np.asarray(g)
180
          }
181
          loss_dict = agent.train(batch)
182
          batch_loss += loss_dict['loss']
183
          # update rnd agent
184
          batch['reward'] = rnd_model.compute_intrinsic_reward(
185
              batch['obs_next'])
186
          batch['done'] = np.zeros(self.cfg.batchsize)
187
          rnd_agent_loss = rnd_agent.train(batch)
188
          rnd_batch_loss += rnd_agent_loss['loss']
189
        average_batch_loss.append(batch_loss / self.cfg.optimization_steps)
190
        average_rnd_agent_loss.append(rnd_batch_loss /
191
                                      self.cfg.optimization_steps)
192

193
    time_per_episode = (time.time() - tic) / self.cfg.num_episode
194

195
    # Update the target network
196
    agent.update_target_network()
197
    rnd_agent.update_target_network()
198

199
    ################## Debug ##################
200
    sample = replay_buffer.sample(min(10000, len(replay_buffer.buffer)))
201
    sample_s, _, sample_r, _, _ = [
202
        np.squeeze(elem, axis=1) for elem in np.split(sample, 5, 1)
203
    ]
204
    sample_intrinsic_r = rnd_model.compute_intrinsic_reward(sample_s)
205
    print('n one:', np.sum(np.float32(sample_r == 1.0)), 'n zero',
206
          np.sum(np.float32(sample_r == 0.0)), 'n buff',
207
          len(replay_buffer.buffer))
208
    ################## Debug ##################
209
    stats = {
210
        'loss': np.mean(average_batch_loss) if average_batch_loss else 0,
211
        'reward': np.mean(average_per_ep_reward),
212
        'achieved_goal': np.mean(average_per_ep_achieved_n),
213
        'average_relabel_goal': np.mean(average_per_ep_relabel_n),
214
        'epsilon': self.epsilon,
215
        'global_step': curr_step,
216
        'time_per_episode': time_per_episode,
217
        'replay_buffer_reward_avg': np.mean(sample_r),
218
        'replay_buffer_reward_var': np.var(sample_r),
219
        'intrinsic_reward_avg': np.mean(sample_intrinsic_r),
220
        'intrinsic_reward_var': np.var(sample_intrinsic_r)
221
    }
222
    return stats
223

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.