google-research
586 строк · 25.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""C-learning.
17
18Implements the off-policy goal-conditioned C-learning algorithm from
19"C-Learning: Learning to Achieve Goals via Recursive Classification" by
20Eysenbach et al (2020): https://arxiv.org/abs/2011.08909
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27from typing import Callable, Optional, Text
28
29import gin
30from six.moves import zip
31import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
32import tensorflow_probability as tfp
33
34from tf_agents.agents import tf_agent
35from tf_agents.networks import network
36from tf_agents.policies import actor_policy
37from tf_agents.policies import tf_policy
38from tf_agents.trajectories import time_step as ts
39from tf_agents.trajectories import trajectory
40from tf_agents.typing import types
41from tf_agents.utils import common
42from tf_agents.utils import eager_utils
43from tf_agents.utils import nest_utils
44from tf_agents.utils import object_identity
45
46
47CLearningLossInfo = collections.namedtuple(
48'LossInfo', ('critic_loss', 'actor_loss'))
49
50
51@gin.configurable
52class CLearningAgent(tf_agent.TFAgent):
53"""A C-learning Agent."""
54
55def __init__(self,
56time_step_spec,
57action_spec,
58critic_network,
59actor_network,
60actor_optimizer,
61critic_optimizer,
62actor_loss_weight = 1.0,
63critic_loss_weight = 0.5,
64actor_policy_ctor = actor_policy.ActorPolicy,
65critic_network_2 = None,
66target_critic_network = None,
67target_critic_network_2 = None,
68target_update_tau = 1.0,
69target_update_period = 1,
70td_errors_loss_fn = tf.math.squared_difference,
71gamma = 1.0,
72gradient_clipping = None,
73debug_summaries = False,
74summarize_grads_and_vars = False,
75train_step_counter = None,
76name = None):
77"""Creates a C-learning Agent.
78
79By default, the environment observation contains the current state and goal
80state. By setting the obs_to_goal gin config in c_learning_utils, the user
81can specify that the agent should only look at certain subsets of the goal
82state.
83
84Args:
85time_step_spec: A `TimeStep` spec of the expected time_steps.
86action_spec: A nest of BoundedTensorSpec representing the actions.
87critic_network: A function critic_network((observations, actions)) that
88returns the q_values for each observation and action.
89actor_network: A function actor_network(observation, action_spec) that
90returns action distribution.
91actor_optimizer: The optimizer to use for the actor network.
92critic_optimizer: The default optimizer to use for the critic network.
93actor_loss_weight: The weight on actor loss.
94critic_loss_weight: The weight on critic loss.
95actor_policy_ctor: The policy class to use.
96critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as
97the second critic network during Q learning. The weights from
98`critic_network` are copied if this is not provided.
99target_critic_network: (Optional.) A `tf_agents.network.Network` to be
100used as the target critic network during Q learning. Every
101`target_update_period` train steps, the weights from `critic_network`
102are copied (possibly withsmoothing via `target_update_tau`) to `
103target_critic_network`. If `target_critic_network` is not provided, it
104is created by making a copy of `critic_network`, which initializes a new
105network with the same structure and its own layers and weights.
106Performing a `Network.copy` does not work when the network instance
107already has trainable parameters (e.g., has already been built, or when
108the network is sharing layers with another). In these cases, it is up
109to you to build a copy having weights that are not shared with the
110original `critic_network`, so that this can be used as a target network.
111If you provide a `target_critic_network` that shares any weights with
112`critic_network`, a warning will be logged but no exception is thrown.
113target_critic_network_2: (Optional.) Similar network as
114target_critic_network but for the critic_network_2. See documentation
115for target_critic_network. Will only be used if 'critic_network_2' is
116also specified.
117target_update_tau: Factor for soft update of the target networks.
118target_update_period: Period for soft update of the target networks.
119td_errors_loss_fn: A function for computing the elementwise TD errors
120loss.
121gamma: A discount factor for future rewards.
122gradient_clipping: Norm length to clip gradients.
123debug_summaries: A bool to gather debug summaries.
124summarize_grads_and_vars: If True, gradient and network variable summaries
125will be written during training.
126train_step_counter: An optional counter to increment every time the train
127op is run. Defaults to the global_step.
128name: The name of this agent. All variables in this module will fall under
129that name. Defaults to the class name.
130"""
131tf.Module.__init__(self, name=name)
132
133self._check_action_spec(action_spec)
134
135self._critic_network_1 = critic_network
136self._critic_network_1.create_variables()
137if target_critic_network:
138target_critic_network.create_variables()
139self._target_critic_network_1 = (
140common.maybe_copy_target_network_with_checks(self._critic_network_1,
141target_critic_network,
142'TargetCriticNetwork1'))
143
144if critic_network_2 is not None:
145self._critic_network_2 = critic_network_2
146else:
147self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
148# Do not use target_critic_network_2 if critic_network_2 is None.
149target_critic_network_2 = None
150self._critic_network_2.create_variables()
151if target_critic_network_2:
152target_critic_network_2.create_variables()
153self._target_critic_network_2 = (
154common.maybe_copy_target_network_with_checks(self._critic_network_2,
155target_critic_network_2,
156'TargetCriticNetwork2'))
157
158if actor_network:
159actor_network.create_variables()
160self._actor_network = actor_network
161
162policy = actor_policy_ctor(
163time_step_spec=time_step_spec,
164action_spec=action_spec,
165actor_network=self._actor_network,
166training=False)
167
168self._train_policy = actor_policy_ctor(
169time_step_spec=time_step_spec,
170action_spec=action_spec,
171actor_network=self._actor_network,
172training=True)
173
174self._target_update_tau = target_update_tau
175self._target_update_period = target_update_period
176self._actor_optimizer = actor_optimizer
177self._critic_optimizer = critic_optimizer
178self._actor_loss_weight = actor_loss_weight
179self._critic_loss_weight = critic_loss_weight
180self._td_errors_loss_fn = td_errors_loss_fn
181self._gamma = gamma
182self._gradient_clipping = gradient_clipping
183self._debug_summaries = debug_summaries
184self._summarize_grads_and_vars = summarize_grads_and_vars
185self._update_target = self._get_target_updater(
186tau=self._target_update_tau, period=self._target_update_period)
187
188train_sequence_length = 2 if not critic_network.state_spec else None
189
190super(CLearningAgent, self).__init__(
191time_step_spec,
192action_spec,
193policy=policy,
194collect_policy=policy,
195train_sequence_length=train_sequence_length,
196debug_summaries=debug_summaries,
197summarize_grads_and_vars=summarize_grads_and_vars,
198train_step_counter=train_step_counter)
199
200def _check_action_spec(self, action_spec):
201flat_action_spec = tf.nest.flatten(action_spec)
202for spec in flat_action_spec:
203if spec.dtype.is_integer:
204raise NotImplementedError(
205'CLearningAgent does not currently support discrete actions. '
206'Action spec: {}'.format(action_spec))
207
208def _initialize(self):
209"""Returns an op to initialize the agent.
210
211Copies weights from the Q networks to the target Q network.
212"""
213common.soft_variables_update(
214self._critic_network_1.variables,
215self._target_critic_network_1.variables,
216tau=1.0)
217common.soft_variables_update(
218self._critic_network_2.variables,
219self._target_critic_network_2.variables,
220tau=1.0)
221
222def _train(self, experience, weights):
223"""Returns a train op to update the agent's networks.
224
225This method trains with the provided batched experience.
226
227Args:
228experience: A time-stacked trajectory object.
229weights: Optional scalar or elementwise (per-batch-entry) importance
230weights.
231
232Returns:
233A train_op.
234
235Raises:
236ValueError: If optimizers are None and no default value was provided to
237the constructor.
238"""
239squeeze_time_dim = not self._critic_network_1.state_spec
240time_steps, policy_steps, next_time_steps = (
241trajectory.experience_to_transitions(experience, squeeze_time_dim))
242actions = policy_steps.action
243
244trainable_critic_variables = list(object_identity.ObjectIdentitySet(
245self._critic_network_1.trainable_variables +
246self._critic_network_2.trainable_variables))
247
248with tf.GradientTape(watch_accessed_variables=False) as tape:
249assert trainable_critic_variables, ('No trainable critic variables to '
250'optimize.')
251tape.watch(trainable_critic_variables)
252critic_loss = self._critic_loss_weight*self.critic_loss(
253time_steps,
254actions,
255next_time_steps,
256td_errors_loss_fn=self._td_errors_loss_fn,
257gamma=self._gamma,
258weights=weights,
259training=True)
260
261tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
262critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
263self._apply_gradients(critic_grads, trainable_critic_variables,
264self._critic_optimizer)
265
266trainable_actor_variables = self._actor_network.trainable_variables
267with tf.GradientTape(watch_accessed_variables=False) as tape:
268assert trainable_actor_variables, ('No trainable actor variables to '
269'optimize.')
270tape.watch(trainable_actor_variables)
271actor_loss = self._actor_loss_weight*self.actor_loss(
272time_steps, actions, weights=weights)
273tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
274actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
275self._apply_gradients(actor_grads, trainable_actor_variables,
276self._actor_optimizer)
277
278with tf.name_scope('Losses'):
279tf.compat.v2.summary.scalar(
280name='critic_loss', data=critic_loss, step=self.train_step_counter)
281tf.compat.v2.summary.scalar(
282name='actor_loss', data=actor_loss, step=self.train_step_counter)
283
284self.train_step_counter.assign_add(1)
285self._update_target()
286
287total_loss = critic_loss + actor_loss
288
289extra = CLearningLossInfo(
290critic_loss=critic_loss, actor_loss=actor_loss)
291
292return tf_agent.LossInfo(loss=total_loss, extra=extra)
293
294def _apply_gradients(self, gradients, variables, optimizer):
295# list(...) is required for Python3.
296grads_and_vars = list(zip(gradients, variables))
297if self._gradient_clipping is not None:
298grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars,
299self._gradient_clipping)
300
301if self._summarize_grads_and_vars:
302eager_utils.add_variables_summaries(grads_and_vars,
303self.train_step_counter)
304eager_utils.add_gradients_summaries(grads_and_vars,
305self.train_step_counter)
306
307optimizer.apply_gradients(grads_and_vars)
308
309def _get_target_updater(self, tau=1.0, period=1):
310"""Performs a soft update of the target network parameters.
311
312For each weight w_s in the original network, and its corresponding
313weight w_t in the target network, a soft update is:
314w_t = (1- tau) x w_t + tau x ws
315
316Args:
317tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
318period: Step interval at which the target network is updated.
319
320Returns:
321A callable that performs a soft update of the target network parameters.
322"""
323with tf.name_scope('update_target'):
324
325def update():
326"""Update target network."""
327critic_update_1 = common.soft_variables_update(
328self._critic_network_1.variables,
329self._target_critic_network_1.variables,
330tau,
331tau_non_trainable=1.0)
332
333critic_2_update_vars = common.deduped_network_variables(
334self._critic_network_2, self._critic_network_1)
335
336target_critic_2_update_vars = common.deduped_network_variables(
337self._target_critic_network_2, self._target_critic_network_1)
338
339critic_update_2 = common.soft_variables_update(
340critic_2_update_vars,
341target_critic_2_update_vars,
342tau,
343tau_non_trainable=1.0)
344
345return tf.group(critic_update_1, critic_update_2)
346
347return common.Periodically(update, period, 'update_targets')
348
349def _actions_and_log_probs(self, time_steps):
350"""Get actions and corresponding log probabilities from policy."""
351# Get raw action distribution from policy, and initialize bijectors list.
352batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
353policy_state = self._train_policy.get_initial_state(batch_size)
354action_distribution = self._train_policy.distribution(
355time_steps, policy_state=policy_state).action
356
357# Sample actions and log_pis from transformed distribution.
358actions = tf.nest.map_structure(lambda d: d.sample(), action_distribution)
359log_pi = common.log_probability(action_distribution, actions,
360self.action_spec)
361
362return actions, log_pi
363
364@gin.configurable
365def critic_loss(self,
366time_steps,
367actions,
368next_time_steps,
369td_errors_loss_fn,
370gamma = 1.0,
371weights = None,
372training = False,
373w_clipping = 20.0,
374self_normalized = False,
375lambda_fix = False,
376):
377"""Computes the critic loss for C-learning training.
378
379Args:
380time_steps: A batch of timesteps.
381actions: A batch of actions.
382next_time_steps: A batch of next timesteps.
383td_errors_loss_fn: A function(td_targets, predictions) to compute
384elementwise (per-batch-entry) loss.
385gamma: Discount for future rewards.
386weights: Optional scalar or elementwise (per-batch-entry) importance
387weights.
388training: Whether this loss is being used for training.
389w_clipping: Maximum value used for clipping the weights. Use -1 to do no
390clipping; use None to use the recommended value of 1 / (1 - gamma).
391self_normalized: Whether to normalize the weights to the average is 1.
392Empirically this usually hurts performance.
393lambda_fix: Whether to include the adjustment when using future positives.
394Empirically this has little effect.
395
396Returns:
397critic_loss: A scalar critic loss.
398"""
399del weights
400if w_clipping is None:
401w_clipping = 1 / (1 - gamma)
402rfp = gin.query_parameter('goal_fn.relabel_future_prob')
403rnp = gin.query_parameter('goal_fn.relabel_next_prob')
404assert rfp + rnp == 0.5
405with tf.name_scope('critic_loss'):
406nest_utils.assert_same_structure(actions, self.action_spec)
407nest_utils.assert_same_structure(time_steps, self.time_step_spec)
408nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)
409
410next_actions, _ = self._actions_and_log_probs(next_time_steps)
411target_input = (next_time_steps.observation, next_actions)
412target_q_values1, unused_network_state1 = self._target_critic_network_1(
413target_input, next_time_steps.step_type, training=False)
414target_q_values2, unused_network_state2 = self._target_critic_network_2(
415target_input, next_time_steps.step_type, training=False)
416target_q_values = tf.minimum(target_q_values1, target_q_values2)
417
418w = tf.stop_gradient(target_q_values / (1 - target_q_values))
419if w_clipping >= 0:
420w = tf.clip_by_value(w, 0, w_clipping)
421tf.debugging.assert_all_finite(w, 'Not all elements of w are finite')
422if self_normalized:
423w = w / tf.reduce_mean(w)
424
425batch_size = nest_utils.get_outer_shape(time_steps,
426self._time_step_spec)[0]
427half_batch = batch_size // 2
428float_batch_size = tf.cast(batch_size, float)
429num_next = tf.cast(tf.round(float_batch_size * rnp), tf.int32)
430num_future = tf.cast(tf.round(float_batch_size * rfp), tf.int32)
431if lambda_fix:
432lambda_coef = 2 * rnp
433weights = tf.concat([tf.fill((num_next,), (1 - gamma)),
434tf.fill((num_future,), 1.0),
435(1 + lambda_coef * gamma * w)[half_batch:]],
436axis=0)
437else:
438weights = tf.concat([tf.fill((num_next,), (1 - gamma)),
439tf.fill((num_future,), 1.0),
440(1 + gamma * w)[half_batch:]],
441axis=0)
442
443# Note that we assume that episodes never terminate. If they do, then
444# we need to include next_time_steps.discount in the (negative) TD target.
445# We exclude the termination here so that we can use termination to
446# indicate task success during evaluation. In the evaluation setting,
447# task success depends on the task, but we don't want the termination
448# here to depend on the task. Hence, we ignored it.
449if lambda_fix:
450lambda_coef = 2 * rnp
451y = lambda_coef * gamma * w / (1 + lambda_coef * gamma * w)
452else:
453y = gamma * w / (1 + gamma * w)
454td_targets = tf.stop_gradient(next_time_steps.reward +
455(1 - next_time_steps.reward) * y)
456if rfp > 0:
457td_targets = tf.concat([tf.ones(half_batch),
458td_targets[half_batch:]], axis=0)
459
460observation = time_steps.observation
461pred_input = (observation, actions)
462pred_td_targets1, _ = self._critic_network_1(
463pred_input, time_steps.step_type, training=training)
464pred_td_targets2, _ = self._critic_network_2(
465pred_input, time_steps.step_type, training=training)
466
467critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
468critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
469critic_loss = critic_loss1 + critic_loss2
470
471if critic_loss.shape.rank > 1:
472# Sum over the time dimension.
473critic_loss = tf.reduce_sum(
474critic_loss, axis=range(1, critic_loss.shape.rank))
475
476agg_loss = common.aggregate_losses(
477per_example_loss=critic_loss,
478sample_weight=weights,
479regularization_loss=(self._critic_network_1.losses +
480self._critic_network_2.losses))
481critic_loss = agg_loss.total_loss
482self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
483pred_td_targets2, weights)
484
485return critic_loss
486
487@gin.configurable
488def actor_loss(self,
489time_steps,
490actions,
491weights = None,
492ce_loss = False):
493"""Computes the actor_loss for C-learning training.
494
495Args:
496time_steps: A batch of timesteps.
497actions: A batch of actions.
498weights: Optional scalar or elementwise (per-batch-entry) importance
499weights.
500ce_loss: (bool) Whether to update the actor using the cross entropy loss,
501which corresponds to using the log C-value. The default actor loss
502differs by not including the log. Empirically we observed no difference.
503
504Returns:
505actor_loss: A scalar actor loss.
506"""
507with tf.name_scope('actor_loss'):
508nest_utils.assert_same_structure(time_steps, self.time_step_spec)
509
510sampled_actions, log_pi = self._actions_and_log_probs(time_steps)
511target_input = (time_steps.observation, sampled_actions)
512target_q_values1, _ = self._critic_network_1(
513target_input, time_steps.step_type, training=False)
514target_q_values2, _ = self._critic_network_2(
515target_input, time_steps.step_type, training=False)
516target_q_values = tf.minimum(target_q_values1, target_q_values2)
517if ce_loss:
518actor_loss = tf.keras.losses.binary_crossentropy(
519tf.ones_like(target_q_values), target_q_values)
520else:
521actor_loss = -1.0 * target_q_values
522
523if actor_loss.shape.rank > 1:
524# Sum over the time dimension.
525actor_loss = tf.reduce_sum(
526actor_loss, axis=range(1, actor_loss.shape.rank))
527reg_loss = self._actor_network.losses if self._actor_network else None
528agg_loss = common.aggregate_losses(
529per_example_loss=actor_loss,
530sample_weight=weights,
531regularization_loss=reg_loss)
532actor_loss = agg_loss.total_loss
533self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
534target_q_values, time_steps)
535
536return actor_loss
537
538def _critic_loss_debug_summaries(self, td_targets, pred_td_targets1,
539pred_td_targets2, weights):
540if self._debug_summaries:
541td_errors1 = td_targets - pred_td_targets1
542td_errors2 = td_targets - pred_td_targets2
543td_errors = tf.concat([td_errors1, td_errors2], axis=0)
544common.generate_tensor_summaries('td_errors', td_errors,
545self.train_step_counter)
546common.generate_tensor_summaries('td_targets', td_targets,
547self.train_step_counter)
548common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1,
549self.train_step_counter)
550common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2,
551self.train_step_counter)
552common.generate_tensor_summaries('weights', weights,
553self.train_step_counter)
554
555def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi,
556target_q_values, time_steps):
557if self._debug_summaries:
558common.generate_tensor_summaries('actor_loss', actor_loss,
559self.train_step_counter)
560common.generate_tensor_summaries('actions', actions,
561self.train_step_counter)
562common.generate_tensor_summaries('log_pi', log_pi,
563self.train_step_counter)
564tf.compat.v2.summary.scalar(
565name='entropy_avg',
566data=-tf.reduce_mean(input_tensor=log_pi),
567step=self.train_step_counter)
568common.generate_tensor_summaries('target_q_values', target_q_values,
569self.train_step_counter)
570batch_size = nest_utils.get_outer_shape(time_steps,
571self._time_step_spec)[0]
572policy_state = self._train_policy.get_initial_state(batch_size)
573action_distribution = self._train_policy.distribution(
574time_steps, policy_state).action
575if isinstance(action_distribution, tfp.distributions.Normal):
576common.generate_tensor_summaries('act_mean', action_distribution.loc,
577self.train_step_counter)
578common.generate_tensor_summaries('act_stddev',
579action_distribution.scale,
580self.train_step_counter)
581elif isinstance(action_distribution, tfp.distributions.Categorical):
582common.generate_tensor_summaries('act_mode', action_distribution.mode(),
583self.train_step_counter)
584common.generate_tensor_summaries('entropy_action',
585action_distribution.entropy(),
586self.train_step_counter)
587