google-research
320 строк · 11.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Helper functions for C-learning."""
17
18import gin
19import tensorflow as tf
20from tf_agents.agents.ddpg import critic_network
21from tf_agents.metrics import tf_metric
22from tf_agents.metrics import tf_metrics
23from tf_agents.utils import common
24
25
26def truncated_geometric(horizon, gamma):
27"""Generates sampled from a truncated geometric distribution.
28
29Args:
30horizon: A 1-d tensor of horizon lengths for each element in the batch.
31The returned samples will be less than the corresponding horizon.
32gamma: The discount factor. Importantly, we sample from a Geom(1 - gamma)
33distribution.
34Returns:
35indices: A 1-d tensor of integers, one for each element of the batch.
36"""
37max_horizon = tf.reduce_max(horizon)
38batch_size = tf.shape(horizon)[0]
39indices = tf.tile(
40tf.range(max_horizon, dtype=tf.float32)[None], (batch_size, 1))
41probs = tf.where(indices < horizon[:, None], gamma**indices,
42tf.zeros_like(indices))
43probs = probs / tf.reduce_sum(probs, axis=1)[:, None]
44indices = tf.random.categorical(tf.math.log(probs), 1, dtype=tf.int32)
45return indices[:, 0] # Remove the extra dimension.
46
47
48def get_future_goals(observation, discount, gamma):
49"""Samples future goals according to a geometric distribution."""
50num_obs = observation.shape[0]
51traj_len = observation.shape[1]
52first_terminal_or_zero = tf.argmax(
53discount == 0, axis=1, output_type=tf.int32)
54any_terminal = tf.reduce_any(discount == 0, axis=1)
55first_terminal = tf.where(any_terminal, first_terminal_or_zero, traj_len)
56first_terminal = tf.cast(first_terminal, tf.float32)
57if num_obs == 0:
58# The truncated_geometric function breaks if called on an empty list.
59# In that case, we manually create an empty list of future goals.
60indices = tf.zeros((0,), dtype=tf.int32)
61else:
62indices = truncated_geometric(first_terminal, gamma)
63stacked_indices = tf.stack([tf.range(num_obs), indices], axis=1)
64return tf.gather_nd(observation, stacked_indices)
65
66
67def get_last_goals(observation, discount):
68"""Extracts that final observation before termination.
69
70Args:
71observation: a B x T x D tensor storing the next T time steps. These time
72steps may be part of a new trajectory. This function will only consider
73observations that occur before the first terminal.
74discount: a B x T tensor indicating whether the episode has terminated.
75Returns:
76last_obs: a B x D tensor storing the last observation in each trajectory
77that occurs before the first terminal.
78"""
79num_obs = observation.shape[0]
80traj_len = observation.shape[1]
81first_terminal_or_zero = tf.argmax(
82discount == 0, axis=1, output_type=tf.int32)
83any_terminal = tf.reduce_any(discount == 0, axis=1)
84first_terminal = tf.where(any_terminal, first_terminal_or_zero, traj_len)
85# If the first state is terminal then first_terminal - 1 = -1. In this case we
86# use the state itself as the goal.
87last_nonterminal = tf.clip_by_value(first_terminal - 1, 0, traj_len)
88stacked_indices = tf.stack([tf.range(num_obs), last_nonterminal], axis=1)
89last_obs = tf.gather_nd(observation, stacked_indices)
90return last_obs
91
92
93@gin.configurable
94def obs_to_goal(obs, start_index=0, end_index=None):
95if end_index is None:
96return obs[:, start_index:]
97else:
98return obs[:, start_index:end_index]
99
100
101@gin.configurable
102def goal_fn(experience,
103buffer_info,
104relabel_orig_prob=0.0,
105relabel_next_prob=0.5,
106relabel_future_prob=0.0,
107relabel_last_prob=0.0,
108batch_size=None,
109obs_dim=None,
110gamma=None):
111"""Given experience, sample goals in three ways.
112
113The three ways are using the next state, an arbitrary future state, or a
114random state. For the future state relabeling, care must be taken to ensure
115that we don't sample experience across the episode boundary. We automatically
116set relabel_random_prob = (1 - relabel_next_prob - relabel_future_prob).
117
118Args:
119experience: The experience that we aim to relabel.
120buffer_info: Information about the replay buffer. We will not change this.
121relabel_orig_prob: (float) Fraction of experience to not relabel.
122relabel_next_prob: (float) Fraction of experience to relabel with the next
123state.
124relabel_future_prob: (float) Fraction of experience to relabel with a future
125state.
126relabel_last_prob: (float) Fraction of experience to relabel with the
127final state.
128batch_size: (int) The size of the batch.
129obs_dim: (int) The dimension of the observation.
130gamma: (float) The discount factor. Future states are sampled according to
131a Geom(1 - gamma) distribution.
132Returns:
133experience: A modified version of the input experience where the goals
134have been changed and the rewards and terminal flags are recomputed.
135buffer_info: Information about the replay buffer.
136
137"""
138assert batch_size is not None
139assert obs_dim is not None
140assert gamma is not None
141relabel_orig_num = int(relabel_orig_prob * batch_size)
142relabel_next_num = int(relabel_next_prob * batch_size)
143relabel_future_num = int(relabel_future_prob * batch_size)
144relabel_last_num = int(relabel_last_prob * batch_size)
145relabel_random_num = batch_size - (
146relabel_orig_num + relabel_next_num + relabel_future_num +
147relabel_last_num)
148assert relabel_random_num >= 0
149
150orig_goals = experience.observation[:relabel_orig_num, 0, obs_dim:]
151
152index = relabel_orig_num
153next_goals = experience.observation[index:index + relabel_next_num,
1541, :obs_dim]
155
156index = relabel_orig_num + relabel_next_num
157future_goals = get_future_goals(
158experience.observation[index:index + relabel_future_num, :, :obs_dim],
159experience.discount[index:index + relabel_future_num], gamma)
160
161index = relabel_orig_num + relabel_next_num + relabel_future_num
162last_goals = get_last_goals(
163experience.observation[index:index + relabel_last_num, :, :obs_dim],
164experience.discount[index:index + relabel_last_num])
165
166# For random goals we take other states from the same batch.
167random_goals = tf.random.shuffle(experience.observation[:relabel_random_num,
1680, :obs_dim])
169new_goals = obs_to_goal(tf.concat([next_goals, future_goals,
170last_goals, random_goals], axis=0))
171goals = tf.concat([orig_goals, new_goals], axis=0)
172
173obs = experience.observation[:, :2, :obs_dim]
174reward = tf.reduce_all(obs_to_goal(obs[:, 1]) == goals, axis=-1)
175reward = tf.cast(reward, tf.float32)
176reward = tf.tile(reward[:, None], [1, 2])
177new_obs = tf.concat([obs, tf.tile(goals[:, None, :], [1, 2, 1])], axis=2)
178experience = experience.replace(
179observation=new_obs, # [B x 2 x 2 * obs_dim]
180action=experience.action[:, :2],
181step_type=experience.step_type[:, :2],
182next_step_type=experience.next_step_type[:, :2],
183discount=experience.discount[:, :2],
184reward=reward,
185)
186return experience, buffer_info
187
188
189@gin.configurable
190class ClassifierCriticNetwork(critic_network.CriticNetwork):
191"""Creates a critic network."""
192
193def __init__(self,
194input_tensor_spec,
195observation_fc_layer_params=None,
196action_fc_layer_params=None,
197joint_fc_layer_params=None,
198kernel_initializer=None,
199last_kernel_initializer=None,
200name='ClassifierCriticNetwork'):
201super(ClassifierCriticNetwork, self).__init__(
202input_tensor_spec,
203observation_fc_layer_params=observation_fc_layer_params,
204action_fc_layer_params=action_fc_layer_params,
205joint_fc_layer_params=joint_fc_layer_params,
206kernel_initializer=kernel_initializer,
207last_kernel_initializer=last_kernel_initializer,
208name=name,
209)
210
211last_layers = [
212tf.keras.layers.Dense(
2131,
214activation=tf.math.sigmoid,
215kernel_initializer=last_kernel_initializer,
216name='value')
217]
218self._joint_layers = self._joint_layers[:-1] + last_layers
219
220
221class BaseDistanceMetric(tf_metric.TFStepMetric):
222"""Computes the initial distance to the goal."""
223
224def __init__(self,
225prefix='Metrics',
226dtype=tf.float32,
227batch_size=1,
228buffer_size=10,
229obs_dim=None,
230start_index=0,
231end_index=None,
232name=None):
233assert obs_dim is not None
234self._start_index = start_index
235self._end_index = end_index
236self._obs_dim = obs_dim
237name = self.NAME if name is None else name
238super(BaseDistanceMetric, self).__init__(name=name, prefix=prefix)
239self._buffer = tf_metrics.TFDeque(buffer_size, dtype)
240self._dist_buffer = tf_metrics.TFDeque(
2411000, dtype) # Episodes should have length less than 1k
242self.dtype = dtype
243
244@common.function(autograph=True)
245def call(self, trajectory):
246obs = trajectory.observation
247s = obs[:, :self._obs_dim]
248g = obs[:, self._obs_dim:]
249dist_to_goal = tf.norm(
250obs_to_goal(obs_to_goal(s), self._start_index, self._end_index) -
251obs_to_goal(g, self._start_index, self._end_index),
252axis=1)
253tf.assert_equal(tf.shape(obs)[0], 1)
254if trajectory.is_mid():
255self._dist_buffer.extend(dist_to_goal)
256if trajectory.is_last()[0] and self._dist_buffer.length > 0:
257self._update_buffer()
258self._dist_buffer.clear()
259return trajectory
260
261def result(self):
262return self._buffer.mean()
263
264@common.function
265def reset(self):
266self._buffer.clear()
267
268def _update_buffer(self):
269raise NotImplementedError
270
271
272class InitialDistance(BaseDistanceMetric):
273"""Computes the initial distance to the goal."""
274NAME = 'InitialDistance'
275
276def _update_buffer(self):
277initial_dist = self._dist_buffer.data[0]
278self._buffer.add(initial_dist)
279
280
281class FinalDistance(BaseDistanceMetric):
282"""Computes the final distance to the goal."""
283NAME = 'FinalDistance'
284
285def _update_buffer(self):
286final_dist = self._dist_buffer.data[-1]
287self._buffer.add(final_dist)
288
289
290class AverageDistance(BaseDistanceMetric):
291"""Computes the average distance to the goal."""
292NAME = 'AverageDistance'
293
294def _update_buffer(self):
295avg_dist = self._dist_buffer.mean()
296self._buffer.add(avg_dist)
297
298
299class MinimumDistance(BaseDistanceMetric):
300"""Computes the minimum distance to the goal."""
301NAME = 'MinimumDistance'
302
303def _update_buffer(self):
304min_dist = self._dist_buffer.min()
305tf.Assert(
306tf.math.is_finite(min_dist), [
307min_dist, self._dist_buffer.length, self._dist_buffer._head, # pylint: disable=protected-access
308self._dist_buffer.data
309],
310summarize=1000)
311self._buffer.add(min_dist)
312
313
314class DeltaDistance(BaseDistanceMetric):
315"""Computes the net distance traveled towards the goal. Positive is good."""
316NAME = 'DeltaDistance'
317
318def _update_buffer(self):
319delta_dist = self._dist_buffer.data[0] - self._dist_buffer.data[-1]
320self._buffer.add(delta_dist)
321