google-research

Форк
0
378 строк · 13.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Hindsight Instruction Relabeling."""
17
# pylint: disable=unused-variable
18
# pylint: disable=unused-argument
19
# pylint: disable=g-explicit-length-test
20
from __future__ import absolute_import
21
from __future__ import division
22

23
import random
24
import time
25

26
import numpy as np
27

28
from hal.learner.language_utils import get_vocab_path
29
from hal.learner.language_utils import instruction_type
30
from hal.learner.language_utils import negate_unary_sentence
31
from hal.learner.language_utils import pad_to_max_length
32
from hal.learner.language_utils import paraphrase_sentence
33
from hal.utils.video_utils import add_text
34
from hal.utils.video_utils import pad_image
35
from hal.utils.video_utils import save_json
36
from hal.utils.video_utils import save_video
37
import hal.utils.word_vectorization as wv
38

39

40
class HIR:
41
  """Learner that executes Hindsight Instruction Relabeling.
42

43
  Attributes:
44
    cfg: configuration of this learner
45
    step: current training step
46
    epsilon: value of the epsilon for sampling random action
47
    vocab_list: vocabulary list used for the instruction labeler
48
    encode_fn: function that encodes a instruction
49
    decode_fn: function that converts encoded instruction back to text
50
    labeler: object that generates labels for transitions
51
  """
52

53
  def __init__(self, cfg):
54
    # making session
55
    self.cfg = cfg
56
    self.step = 0
57
    self.epsilon = 1.0
58

59
    # Vocab loading
60
    vocab_path = get_vocab_path(cfg)
61
    self.vocab_list = wv.load_vocab_list(vocab_path)
62
    v2i, i2v = wv.create_look_up_table(self.vocab_list)
63
    self.encode_fn = wv.encode_text_with_lookup_table(
64
        v2i, max_sequence_length=self.cfg.max_sequence_length)
65
    self.decode_fn = wv.decode_with_lookup_table(i2v)
66

67
  def reset(self, env, agent, sample_new_scene=False, **kwargs):
68
    """Reset at the episode boundary.
69

70
    Args:
71
      env: the RL environment
72
      agent: the RL agent
73
      sample_new_scene: sample a brand new set of objects for the scene
74
      **kwargs: other potential arguments
75

76
    Returns:
77
      the reset state of the environment
78
    """
79
    if self.cfg.reset_mode == 'random_action':
80
      for _ in range(20):
81
        s, _, _, _ = env.step(env.sample_random_action())
82
    elif self.cfg.reset_mode == 'none':
83
      s = env.get_obs()
84
    else:
85
      s = env.reset(sample_new_scene)
86
    return s
87

88
  def learn(self, env, agent, replay_buffer, **kwargs):
89
    """Run learning for 1 cycle with consists of num_episode of episodes.
90

91
    Args:
92
      env: the RL environment
93
      agent: the RL agent
94
      replay_buffer: the experience replay buffer
95
      **kwargs: other potential arguments
96

97
    Returns:
98
      statistics of the training episode
99
    """
100
    average_per_ep_reward = []
101
    average_per_ep_achieved_n = []
102
    average_per_ep_relabel_n = []
103
    average_batch_loss = []
104
    curiosity_loss = 0
105

106
    curr_step = agent.get_global_step()
107
    self.update_epsilon(curr_step)
108
    tic = time.time()
109
    time_rolling_out, time_training = 0.0, 0.0
110
    for _ in range(self.cfg.num_episode):
111
      curr_step = agent.increase_global_step()
112

113
      sample_new_scene = random.uniform(0, 1) < self.cfg.sample_new_scene_prob
114
      s = self.reset(env, agent, sample_new_scene)
115
      episode_experience = []
116
      episode_reward = 0
117
      episode_achieved_n = 0
118
      episode_relabel_n = 0
119

120
      # rollout
121
      rollout_tic = time.time()
122
      g_text, p = env.sample_goal()
123
      if env.all_goals_satisfied:
124
        s = self.reset(env, agent, True)
125
        g_text, p = env.sample_goal()
126
      g = np.squeeze(self.encode_fn(g_text))
127

128
      for t in range(self.cfg.max_episode_length):
129
        a = agent.step(s, g, env, self.epsilon)
130
        s_tp1, r, _, _ = env.step(
131
            a,
132
            record_achieved_goal=True,
133
            goal=p,
134
            atomic_goal=self.cfg.record_atomic_instruction)
135
        ag = env.get_achieved_goals()
136
        ag_text = env.get_achieved_goal_programs()
137
        ag_total = ag  # TODO(ydjiang): more can be stored in ag
138
        episode_experience.append((s, a, r, s_tp1, g, ag_total))
139
        episode_reward += r
140
        s = s_tp1
141
        if r > env.shape_val:
142
          episode_achieved_n += 1
143
          g_text, p = env.sample_goal()
144
          if env.all_goals_satisfied:
145
            break
146
          g = np.squeeze(self.encode_fn(g_text))
147
      time_rolling_out += time.time() - rollout_tic
148

149
      average_per_ep_reward.append(episode_reward)
150
      average_per_ep_achieved_n.append(episode_achieved_n)
151

152
      # processing trajectory
153
      train_tic = time.time()
154
      episode_length = len(episode_experience)
155
      for t in range(episode_length):
156
        s, a, r, s_tp1, g, ag = episode_experience[t]
157
        episode_relabel_n += float(len(ag) > 0)
158
        g_text = self.decode_fn(g)
159
        if self.cfg.paraphrase:
160
          g_text = paraphrase_sentence(
161
              g_text, delete_color=self.cfg.diverse_scene_content)
162
        g = self.encode_fn(g_text)
163
        replay_buffer.add((s, a, r, s_tp1, g))
164
        if self.cfg.relabeling:
165
          self.hir_relabel(episode_experience, t, replay_buffer, env)
166

167
      average_per_ep_relabel_n.append(episode_relabel_n / float(episode_length))
168

169
      # training
170
      if not self.is_warming_up(curr_step):
171
        batch_loss = 0
172
        for _ in range(self.cfg.optimization_steps):
173
          experience = replay_buffer.sample(self.cfg.batchsize)
174
          s, a, r, s_tp1, g = [
175
              np.squeeze(elem, axis=1) for elem in np.split(experience, 5, 1)
176
          ]
177
          s = np.stack(s)
178
          s_tp1 = np.stack(s_tp1)
179
          g = np.array(list(g))
180
          if self.cfg.instruction_repr == 'language':
181
            g = np.array(pad_to_max_length(g, self.cfg.max_sequence_length))
182
          batch = {
183
              'obs': np.asarray(s),
184
              'action': np.asarray(a),
185
              'reward': np.asarray(r),
186
              'obs_next': np.asarray(s_tp1),
187
              'g': np.asarray(g)
188
          }
189
          loss_dict = agent.train(batch)
190
          batch_loss += loss_dict['loss']
191
          if 'prediction_loss' in loss_dict:
192
            curiosity_loss += loss_dict['prediction_loss']
193
        average_batch_loss.append(batch_loss / self.cfg.optimization_steps)
194
      time_training += time.time()-train_tic
195

196
    time_per_episode = (time.time() - tic) / self.cfg.num_episode
197
    time_training_per_episode = time_training / self.cfg.num_episode
198
    time_rolling_out_per_episode = time_rolling_out / self.cfg.num_episode
199

200
    # Update the target network
201
    agent.update_target_network()
202
    ################## Debug ##################
203
    sample = replay_buffer.sample(min(10000, len(replay_buffer.buffer)))
204
    _, _, sample_r, _, _ = [
205
        np.squeeze(elem, axis=1) for elem in np.split(sample, 5, 1)
206
    ]
207
    print('n one:', np.sum(np.float32(sample_r == 1.0)), 'n zero',
208
          np.sum(np.float32(sample_r == 0.0)), 'n buff',
209
          len(replay_buffer.buffer))
210
    ################## Debug ##################
211
    stats = {
212
        'loss': np.mean(average_batch_loss) if average_batch_loss else 0,
213
        'reward': np.mean(average_per_ep_reward),
214
        'achieved_goal': np.mean(average_per_ep_achieved_n),
215
        'average_relabel_goal': np.mean(average_per_ep_relabel_n),
216
        'epsilon': self.epsilon,
217
        'global_step': curr_step,
218
        'time_per_episode': time_per_episode,
219
        'time_training_per_episode': time_training_per_episode,
220
        'time_rolling_out_per_episode': time_rolling_out_per_episode,
221
        'replay_buffer_reward_avg': np.mean(sample_r),
222
        'replay_buffer_reward_var': np.var(sample_r)
223
    }
224
    return stats
225

226
  def hir_relabel(self, episode_experience, current_t, replay_buffer, env):
227
    """Relabeling trajectories.
228

229
    Args:
230
      episode_experience: the RL environment
231
      current_t: time time step at which the experience is relabeled
232
      replay_buffer:  the experience replay buffer
233
      env: the RL environment
234

235
    Returns:
236
      the reset state of the environment
237
    """
238
    ep_len = len(episode_experience)
239
    s, a, _, s_tp1, _, ag = episode_experience[current_t]
240
    if ag:
241
      for _ in range(min(self.cfg.k_immediate, len(ag) + 1)):
242
        ag_text_single = random.choice(ag)
243
        g_type = instruction_type(ag_text_single)
244
        if self.cfg.paraphrase and g_type != 'unary':
245
          ag_text_single = paraphrase_sentence(
246
              ag_text_single, delete_color=self.cfg.diverse_scene_content)
247
        replay_buffer.add(
248
            (s, a, env.reward_scale, s_tp1, self.encode_fn(ag_text_single)))
249
        if g_type == 'unary' and self.cfg.negate_unary:
250
          negative_ag = negate_unary_sentence(ag_text_single)
251
          if negative_ag:
252
            replay_buffer.add((s, a, 0., s_tp1, self.encode_fn(negative_ag)))
253
    goal_count, repeat = 0, 0
254
    while goal_count < self.cfg.future_k and repeat < (ep_len - current_t) * 2:
255
      repeat += 1
256
      future = np.random.randint(current_t, ep_len)
257
      _, _, _, _, _, ag_future = episode_experience[future]
258
      if not ag_future:
259
        continue
260
      random.shuffle(ag_future)
261
      for single_g in ag_future:
262
        if instruction_type(single_g) != 'unary':
263
          discount = self.cfg.discount**(future - current_t)
264
          if self.cfg.paraphrase:
265
            single_g = paraphrase_sentence(
266
                single_g, delete_color=self.cfg.diverse_scene_content)
267
          replay_buffer.add((s, a, discount * env.reward_scale, s_tp1,
268
                             self.encode_fn(single_g)))
269
          goal_count += 1
270
          break
271

272
  def update_epsilon(self, step):
273
    new_epsilon = self.cfg.epsilon_decay**(step // self.cfg.num_episode)
274
    self.epsilon = max(new_epsilon, self.cfg.min_epsilon)
275

276
  def is_warming_up(self, step):
277
    return step <= self.cfg.collect_cycle * self.cfg.num_episode
278

279
  def rollout(self,
280
              env,
281
              agent,
282
              directory,
283
              record_video=False,
284
              timeout=8,
285
              num_episode=10,
286
              record_trajectory=False):
287
    """Rollout and save.
288

289
    Args:
290
      env: the RL environment
291
      agent: the RL agent
292
      directory: directory where the output of the rollout is saved
293
      record_video: record the video
294
      timeout: timeout step if the agent is stuck
295
      num_episode: number of rollout episode
296
      record_trajectory: record the ground truth trajectory
297

298
    Returns:
299
      percentage of success during this rollout
300
    """
301
    print('#######################################')
302
    print('Rolling out...')
303
    print('#######################################')
304
    all_frames = []
305
    ep_observation, ep_action, ep_agn = [], [], []
306
    black_frame = pad_image(env.render(mode='rgb_array')) * 0.0
307
    goal_sampled = 0
308
    timeout_count, success = 0, 0
309
    for ep in range(num_episode):
310
      s = self.reset(env, agent, self.cfg.diverse_scene_content)
311
      all_frames += [black_frame] * 10
312
      g_text, p = env.sample_goal()
313
      if env.all_goals_satisfied:
314
        s = self.reset(env, agent, True)
315
        g, p = env.sample_goal()
316
      goal_sampled += 1
317
      g = np.squeeze(self.encode_fn(g_text))
318
      current_goal_repetition = 0
319
      for t in range(self.cfg.max_episode_length):
320
        prob = self.epsilon if record_trajectory else 0.0
321
        action = agent.step(s, g, env, explore_prob=prob)
322
        s_tp1, r, _, _ = env.step(
323
            action,
324
            record_achieved_goal=True,
325
            goal=p,
326
            atomic_goal=self.cfg.record_atomic_instruction)
327
        ag = env.get_achieved_goals()
328
        s = s_tp1
329
        all_frames.append(
330
            add_text(pad_image(env.render(mode='rgb_array')), g_text))
331
        current_goal_repetition += 1
332

333
        if record_trajectory:
334
          ep_observation.append(env.get_direct_obs().tolist())
335
          ep_action.append(action)
336
          ep_agn.append(len(ag))
337

338
        sample_new_goal = False
339
        if r > env.shape_val:
340
          for _ in range(5):
341
            all_frames.append(
342
                add_text(
343
                    pad_image(env.render(mode='rgb_array')),
344
                    g_text,
345
                    color='green'))
346
          success += 1
347
          sample_new_goal = True
348

349
        if current_goal_repetition >= timeout:
350
          all_frames.append(
351
              add_text(pad_image(env.render(mode='rgb_array')), 'time out :('))
352
          timeout_count += 1
353
          sample_new_goal = True
354

355
        if sample_new_goal:
356
          g, p = env.sample_goal()
357
          if env.all_goals_satisfied:
358
            break
359
          g_text = g
360
          g = np.squeeze(self.encode_fn(g))
361
          current_goal_repetition = 0
362
          goal_sampled += 1
363

364
    print('Rollout finished')
365
    print('{} instrutctions tried given'.format(goal_sampled))
366
    print('{} instructions timed out'.format(timeout_count))
367
    if record_video:
368
      save_video(np.uint8(all_frames), directory, fps=5)
369
      print('Video saved...')
370
    if record_trajectory:
371
      print('Recording trajectory...')
372
      datum = {
373
          'obs': ep_observation,
374
          'action': ep_action,
375
          'achieved goal': ep_agn,
376
      }
377
      save_json(datum, directory[:-4] + '_trajectory.json')
378
    return success / float(num_episode)
379

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.