google-research

Форк
0
/
c_learning_utils.py 
320 строк · 11.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Helper functions for C-learning."""
17

18
import gin
19
import tensorflow as tf
20
from tf_agents.agents.ddpg import critic_network
21
from tf_agents.metrics import tf_metric
22
from tf_agents.metrics import tf_metrics
23
from tf_agents.utils import common
24

25

26
def truncated_geometric(horizon, gamma):
27
  """Generates sampled from a truncated geometric distribution.
28

29
  Args:
30
    horizon: A 1-d tensor of horizon lengths for each element in the batch.
31
      The returned samples will be less than the corresponding horizon.
32
    gamma: The discount factor. Importantly, we sample from a Geom(1 - gamma)
33
      distribution.
34
  Returns:
35
    indices: A 1-d tensor of integers, one for each element of the batch.
36
  """
37
  max_horizon = tf.reduce_max(horizon)
38
  batch_size = tf.shape(horizon)[0]
39
  indices = tf.tile(
40
      tf.range(max_horizon, dtype=tf.float32)[None], (batch_size, 1))
41
  probs = tf.where(indices < horizon[:, None], gamma**indices,
42
                   tf.zeros_like(indices))
43
  probs = probs / tf.reduce_sum(probs, axis=1)[:, None]
44
  indices = tf.random.categorical(tf.math.log(probs), 1, dtype=tf.int32)
45
  return indices[:, 0]  # Remove the extra dimension.
46

47

48
def get_future_goals(observation, discount, gamma):
49
  """Samples future goals according to a geometric distribution."""
50
  num_obs = observation.shape[0]
51
  traj_len = observation.shape[1]
52
  first_terminal_or_zero = tf.argmax(
53
      discount == 0, axis=1, output_type=tf.int32)
54
  any_terminal = tf.reduce_any(discount == 0, axis=1)
55
  first_terminal = tf.where(any_terminal, first_terminal_or_zero, traj_len)
56
  first_terminal = tf.cast(first_terminal, tf.float32)
57
  if num_obs == 0:
58
    # The truncated_geometric function breaks if called on an empty list.
59
    # In that case, we manually create an empty list of future goals.
60
    indices = tf.zeros((0,), dtype=tf.int32)
61
  else:
62
    indices = truncated_geometric(first_terminal, gamma)
63
  stacked_indices = tf.stack([tf.range(num_obs), indices], axis=1)
64
  return tf.gather_nd(observation, stacked_indices)
65

66

67
def get_last_goals(observation, discount):
68
  """Extracts that final observation before termination.
69

70
  Args:
71
    observation: a B x T x D tensor storing the next T time steps. These time
72
      steps may be part of a new trajectory. This function will only consider
73
      observations that occur before the first terminal.
74
    discount: a B x T tensor indicating whether the episode has terminated.
75
  Returns:
76
    last_obs: a B x D tensor storing the last observation in each trajectory
77
      that occurs before the first terminal.
78
  """
79
  num_obs = observation.shape[0]
80
  traj_len = observation.shape[1]
81
  first_terminal_or_zero = tf.argmax(
82
      discount == 0, axis=1, output_type=tf.int32)
83
  any_terminal = tf.reduce_any(discount == 0, axis=1)
84
  first_terminal = tf.where(any_terminal, first_terminal_or_zero, traj_len)
85
  # If the first state is terminal then first_terminal - 1 = -1. In this case we
86
  # use the state itself as the goal.
87
  last_nonterminal = tf.clip_by_value(first_terminal - 1, 0, traj_len)
88
  stacked_indices = tf.stack([tf.range(num_obs), last_nonterminal], axis=1)
89
  last_obs = tf.gather_nd(observation, stacked_indices)
90
  return last_obs
91

92

93
@gin.configurable
94
def obs_to_goal(obs, start_index=0, end_index=None):
95
  if end_index is None:
96
    return obs[:, start_index:]
97
  else:
98
    return obs[:, start_index:end_index]
99

100

101
@gin.configurable
102
def goal_fn(experience,
103
            buffer_info,
104
            relabel_orig_prob=0.0,
105
            relabel_next_prob=0.5,
106
            relabel_future_prob=0.0,
107
            relabel_last_prob=0.0,
108
            batch_size=None,
109
            obs_dim=None,
110
            gamma=None):
111
  """Given experience, sample goals in three ways.
112

113
  The three ways are using the next state, an arbitrary future state, or a
114
  random state. For the future state relabeling, care must be taken to ensure
115
  that we don't sample experience across the episode boundary. We automatically
116
  set relabel_random_prob = (1 - relabel_next_prob - relabel_future_prob).
117

118
  Args:
119
    experience: The experience that we aim to relabel.
120
    buffer_info: Information about the replay buffer. We will not change this.
121
    relabel_orig_prob: (float) Fraction of experience to not relabel.
122
    relabel_next_prob: (float) Fraction of experience to relabel with the next
123
      state.
124
    relabel_future_prob: (float) Fraction of experience to relabel with a future
125
      state.
126
    relabel_last_prob: (float) Fraction of experience to relabel with the
127
      final state.
128
    batch_size: (int) The size of the batch.
129
    obs_dim: (int) The dimension of the observation.
130
    gamma: (float) The discount factor. Future states are sampled according to
131
      a Geom(1 - gamma) distribution.
132
  Returns:
133
    experience: A modified version of the input experience where the goals
134
      have been changed and the rewards and terminal flags are recomputed.
135
    buffer_info: Information about the replay buffer.
136

137
  """
138
  assert batch_size is not None
139
  assert obs_dim is not None
140
  assert gamma is not None
141
  relabel_orig_num = int(relabel_orig_prob * batch_size)
142
  relabel_next_num = int(relabel_next_prob * batch_size)
143
  relabel_future_num = int(relabel_future_prob * batch_size)
144
  relabel_last_num = int(relabel_last_prob * batch_size)
145
  relabel_random_num = batch_size - (
146
      relabel_orig_num + relabel_next_num + relabel_future_num +
147
      relabel_last_num)
148
  assert relabel_random_num >= 0
149

150
  orig_goals = experience.observation[:relabel_orig_num, 0, obs_dim:]
151

152
  index = relabel_orig_num
153
  next_goals = experience.observation[index:index + relabel_next_num,
154
                                      1, :obs_dim]
155

156
  index = relabel_orig_num + relabel_next_num
157
  future_goals = get_future_goals(
158
      experience.observation[index:index + relabel_future_num, :, :obs_dim],
159
      experience.discount[index:index + relabel_future_num], gamma)
160

161
  index = relabel_orig_num + relabel_next_num + relabel_future_num
162
  last_goals = get_last_goals(
163
      experience.observation[index:index + relabel_last_num, :, :obs_dim],
164
      experience.discount[index:index + relabel_last_num])
165

166
  # For random goals we take other states from the same batch.
167
  random_goals = tf.random.shuffle(experience.observation[:relabel_random_num,
168
                                                          0, :obs_dim])
169
  new_goals = obs_to_goal(tf.concat([next_goals, future_goals,
170
                                     last_goals, random_goals], axis=0))
171
  goals = tf.concat([orig_goals, new_goals], axis=0)
172

173
  obs = experience.observation[:, :2, :obs_dim]
174
  reward = tf.reduce_all(obs_to_goal(obs[:, 1]) == goals, axis=-1)
175
  reward = tf.cast(reward, tf.float32)
176
  reward = tf.tile(reward[:, None], [1, 2])
177
  new_obs = tf.concat([obs, tf.tile(goals[:, None, :], [1, 2, 1])], axis=2)
178
  experience = experience.replace(
179
      observation=new_obs,  # [B x 2 x 2 * obs_dim]
180
      action=experience.action[:, :2],
181
      step_type=experience.step_type[:, :2],
182
      next_step_type=experience.next_step_type[:, :2],
183
      discount=experience.discount[:, :2],
184
      reward=reward,
185
  )
186
  return experience, buffer_info
187

188

189
@gin.configurable
190
class ClassifierCriticNetwork(critic_network.CriticNetwork):
191
  """Creates a critic network."""
192

193
  def __init__(self,
194
               input_tensor_spec,
195
               observation_fc_layer_params=None,
196
               action_fc_layer_params=None,
197
               joint_fc_layer_params=None,
198
               kernel_initializer=None,
199
               last_kernel_initializer=None,
200
               name='ClassifierCriticNetwork'):
201
    super(ClassifierCriticNetwork, self).__init__(
202
        input_tensor_spec,
203
        observation_fc_layer_params=observation_fc_layer_params,
204
        action_fc_layer_params=action_fc_layer_params,
205
        joint_fc_layer_params=joint_fc_layer_params,
206
        kernel_initializer=kernel_initializer,
207
        last_kernel_initializer=last_kernel_initializer,
208
        name=name,
209
    )
210

211
    last_layers = [
212
        tf.keras.layers.Dense(
213
            1,
214
            activation=tf.math.sigmoid,
215
            kernel_initializer=last_kernel_initializer,
216
            name='value')
217
    ]
218
    self._joint_layers = self._joint_layers[:-1] + last_layers
219

220

221
class BaseDistanceMetric(tf_metric.TFStepMetric):
222
  """Computes the initial distance to the goal."""
223

224
  def __init__(self,
225
               prefix='Metrics',
226
               dtype=tf.float32,
227
               batch_size=1,
228
               buffer_size=10,
229
               obs_dim=None,
230
               start_index=0,
231
               end_index=None,
232
               name=None):
233
    assert obs_dim is not None
234
    self._start_index = start_index
235
    self._end_index = end_index
236
    self._obs_dim = obs_dim
237
    name = self.NAME if name is None else name
238
    super(BaseDistanceMetric, self).__init__(name=name, prefix=prefix)
239
    self._buffer = tf_metrics.TFDeque(buffer_size, dtype)
240
    self._dist_buffer = tf_metrics.TFDeque(
241
        1000, dtype)  # Episodes should have length less than 1k
242
    self.dtype = dtype
243

244
  @common.function(autograph=True)
245
  def call(self, trajectory):
246
    obs = trajectory.observation
247
    s = obs[:, :self._obs_dim]
248
    g = obs[:, self._obs_dim:]
249
    dist_to_goal = tf.norm(
250
        obs_to_goal(obs_to_goal(s), self._start_index, self._end_index) -
251
        obs_to_goal(g, self._start_index, self._end_index),
252
        axis=1)
253
    tf.assert_equal(tf.shape(obs)[0], 1)
254
    if trajectory.is_mid():
255
      self._dist_buffer.extend(dist_to_goal)
256
    if trajectory.is_last()[0] and self._dist_buffer.length > 0:
257
      self._update_buffer()
258
      self._dist_buffer.clear()
259
    return trajectory
260

261
  def result(self):
262
    return self._buffer.mean()
263

264
  @common.function
265
  def reset(self):
266
    self._buffer.clear()
267

268
  def _update_buffer(self):
269
    raise NotImplementedError
270

271

272
class InitialDistance(BaseDistanceMetric):
273
  """Computes the initial distance to the goal."""
274
  NAME = 'InitialDistance'
275

276
  def _update_buffer(self):
277
    initial_dist = self._dist_buffer.data[0]
278
    self._buffer.add(initial_dist)
279

280

281
class FinalDistance(BaseDistanceMetric):
282
  """Computes the final distance to the goal."""
283
  NAME = 'FinalDistance'
284

285
  def _update_buffer(self):
286
    final_dist = self._dist_buffer.data[-1]
287
    self._buffer.add(final_dist)
288

289

290
class AverageDistance(BaseDistanceMetric):
291
  """Computes the average distance to the goal."""
292
  NAME = 'AverageDistance'
293

294
  def _update_buffer(self):
295
    avg_dist = self._dist_buffer.mean()
296
    self._buffer.add(avg_dist)
297

298

299
class MinimumDistance(BaseDistanceMetric):
300
  """Computes the minimum distance to the goal."""
301
  NAME = 'MinimumDistance'
302

303
  def _update_buffer(self):
304
    min_dist = self._dist_buffer.min()
305
    tf.Assert(
306
        tf.math.is_finite(min_dist), [
307
            min_dist, self._dist_buffer.length, self._dist_buffer._head,  # pylint: disable=protected-access
308
            self._dist_buffer.data
309
        ],
310
        summarize=1000)
311
    self._buffer.add(min_dist)
312

313

314
class DeltaDistance(BaseDistanceMetric):
315
  """Computes the net distance traveled towards the goal. Positive is good."""
316
  NAME = 'DeltaDistance'
317

318
  def _update_buffer(self):
319
    delta_dist = self._dist_buffer.data[0] - self._dist_buffer.data[-1]
320
    self._buffer.add(delta_dist)
321

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.