google-research
710 строк · 30.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implements the RCE Agent."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23from typing import Callable, Optional, Text
24
25import gin
26from six.moves import zip
27import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
28import tensorflow_probability as tfp
29
30from tf_agents.agents import data_converter
31from tf_agents.agents import tf_agent
32from tf_agents.networks import network
33from tf_agents.policies import actor_policy
34from tf_agents.policies import tf_policy
35from tf_agents.trajectories import time_step as ts
36from tf_agents.typing import types
37from tf_agents.utils import common
38from tf_agents.utils import eager_utils
39from tf_agents.utils import nest_utils
40from tf_agents.utils import object_identity
41
42
43RceLossInfo = collections.namedtuple(
44'RceLossInfo', ('critic_loss', 'actor_loss'))
45
46
47@gin.configurable
48class RceAgent(tf_agent.TFAgent):
49"""An agent for Recursive Classification of Examples."""
50
51def __init__(self,
52time_step_spec,
53action_spec,
54critic_network,
55actor_network,
56actor_optimizer,
57critic_optimizer,
58actor_loss_weight = 1.0,
59critic_loss_weight = 0.5,
60actor_policy_ctor = actor_policy.ActorPolicy,
61critic_network_2 = None,
62target_critic_network = None,
63target_critic_network_2 = None,
64target_update_tau = 1.0,
65target_update_period = 1,
66td_errors_loss_fn = tf.math.squared_difference,
67gamma = 1.0,
68reward_scale_factor = 1.0,
69gradient_clipping = None,
70debug_summaries = False,
71summarize_grads_and_vars = False,
72train_step_counter = None,
73name = None,
74n_step = None,
75use_behavior_policy = False):
76"""Creates a RCE Agent.
77
78Args:
79time_step_spec: A `TimeStep` spec of the expected time_steps.
80action_spec: A nest of BoundedTensorSpec representing the actions.
81critic_network: A function critic_network((observations, actions)) that
82returns the q_values for each observation and action.
83actor_network: A function actor_network(observation, action_spec) that
84returns action distribution.
85actor_optimizer: The optimizer to use for the actor network.
86critic_optimizer: The default optimizer to use for the critic network.
87actor_loss_weight: The weight on actor loss.
88critic_loss_weight: The weight on critic loss.
89actor_policy_ctor: The policy class to use.
90critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as
91the second critic network during Q learning. The weights from
92`critic_network` are copied if this is not provided.
93target_critic_network: (Optional.) A `tf_agents.network.Network` to be
94used as the target critic network during Q learning. Every
95`target_update_period` train steps, the weights from `critic_network`
96are copied (possibly withsmoothing via `target_update_tau`) to `
97target_critic_network`. If `target_critic_network` is not provided, it
98is created by making a copy of `critic_network`, which initializes a new
99network with the same structure and its own layers and weights.
100Performing a `Network.copy` does not work when the network instance
101already has trainable parameters (e.g., has already been built, or when
102the network is sharing layers with another). In these cases, it is up
103to you to build a copy having weights that are not shared with the
104original `critic_network`, so that this can be used as a target network.
105If you provide a `target_critic_network` that shares any weights with
106`critic_network`, a warning will be logged but no exception is thrown.
107target_critic_network_2: (Optional.) Similar network as
108target_critic_network but for the critic_network_2. See documentation
109for target_critic_network. Will only be used if 'critic_network_2' is
110also specified.
111target_update_tau: Factor for soft update of the target networks.
112target_update_period: Period for soft update of the target networks.
113td_errors_loss_fn: A function for computing the elementwise TD errors
114loss.
115gamma: A discount factor for future rewards.
116reward_scale_factor: Multiplicative scale for the reward.
117gradient_clipping: Norm length to clip gradients.
118debug_summaries: A bool to gather debug summaries.
119summarize_grads_and_vars: If True, gradient and network variable summaries
120will be written during training.
121train_step_counter: An optional counter to increment every time the train
122op is run. Defaults to the global_step.
123name: The name of this agent. All variables in this module will fall under
124that name. Defaults to the class name.
125n_step: An integer specifying whether to use n-step returns. Empirically,
126a value of 10 works well for most tasks. Use None to disable n-step
127returns.
128use_behavior_policy: A boolean indicating how to sample actions for the
129success states. When use_behavior_policy=True, we use the historical
130average policy; otherwise, we use the current policy.
131"""
132tf.Module.__init__(self, name=name)
133
134self._check_action_spec(action_spec)
135
136self._critic_network_1 = critic_network
137self._critic_network_1.create_variables(
138(time_step_spec.observation, action_spec))
139if target_critic_network:
140target_critic_network.create_variables(
141(time_step_spec.observation, action_spec))
142self._target_critic_network_1 = target_critic_network
143else:
144self._target_critic_network_1 = (
145common.maybe_copy_target_network_with_checks(self._critic_network_1,
146None,
147'TargetCriticNetwork1'))
148
149if critic_network_2 is not None:
150self._critic_network_2 = critic_network_2
151else:
152self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
153# Do not use target_critic_network_2 if critic_network_2 is None.
154target_critic_network_2 = None
155self._critic_network_2.create_variables(
156(time_step_spec.observation, action_spec))
157
158if target_critic_network_2:
159target_critic_network_2.create_variables(
160(time_step_spec.observation, action_spec))
161self._target_critic_network_2 = target_critic_network
162else:
163self._target_critic_network_2 = (
164common.maybe_copy_target_network_with_checks(self._critic_network_2,
165None,
166'TargetCriticNetwork2'))
167
168if actor_network:
169actor_network.create_variables(time_step_spec.observation)
170self._actor_network = actor_network
171
172self._use_behavior_policy = use_behavior_policy
173if use_behavior_policy:
174self._behavior_actor_network = actor_network.copy(
175name='BehaviorActorNetwork')
176self._behavior_policy = actor_policy_ctor(
177time_step_spec=time_step_spec,
178action_spec=action_spec,
179actor_network=self._behavior_actor_network,
180training=True)
181
182policy = actor_policy_ctor(
183time_step_spec=time_step_spec,
184action_spec=action_spec,
185actor_network=self._actor_network,
186training=False)
187
188self._train_policy = actor_policy_ctor(
189time_step_spec=time_step_spec,
190action_spec=action_spec,
191actor_network=self._actor_network,
192training=True)
193
194self._target_update_tau = target_update_tau
195self._target_update_period = target_update_period
196self._actor_optimizer = actor_optimizer
197self._critic_optimizer = critic_optimizer
198self._actor_loss_weight = actor_loss_weight
199self._critic_loss_weight = critic_loss_weight
200self._td_errors_loss_fn = td_errors_loss_fn
201self._gamma = gamma
202self._reward_scale_factor = reward_scale_factor
203self._gradient_clipping = gradient_clipping
204self._debug_summaries = debug_summaries
205self._summarize_grads_and_vars = summarize_grads_and_vars
206self._update_target = self._get_target_updater(
207tau=self._target_update_tau, period=self._target_update_period)
208self._n_step = n_step
209
210train_sequence_length = 2 if not critic_network.state_spec else None
211
212super(RceAgent, self).__init__(
213time_step_spec,
214action_spec,
215policy=policy,
216collect_policy=policy,
217train_sequence_length=train_sequence_length,
218debug_summaries=debug_summaries,
219summarize_grads_and_vars=summarize_grads_and_vars,
220train_step_counter=train_step_counter,
221validate_args=False
222)
223
224self._as_transition = data_converter.AsTransition(
225self.data_context, squeeze_time_dim=(train_sequence_length == 2))
226
227def _check_action_spec(self, action_spec):
228flat_action_spec = tf.nest.flatten(action_spec)
229for spec in flat_action_spec:
230if spec.dtype.is_integer:
231raise NotImplementedError(
232'RceAgent does not currently support discrete actions. '
233'Action spec: {}'.format(action_spec))
234
235def _initialize(self):
236"""Returns an op to initialize the agent.
237
238Copies weights from the Q networks to the target Q network.
239"""
240common.soft_variables_update(
241self._critic_network_1.variables,
242self._target_critic_network_1.variables,
243tau=1.0)
244common.soft_variables_update(
245self._critic_network_2.variables,
246self._target_critic_network_2.variables,
247tau=1.0)
248
249def _train(self, experience, weights):
250"""Returns a train op to update the agent's networks.
251
252This method trains with the provided batched experience.
253
254Args:
255experience: A time-stacked trajectory object.
256weights: Optional scalar or elementwise (per-batch-entry) importance
257weights.
258
259Returns:
260A train_op.
261
262Raises:
263ValueError: If optimizers are None and no default value was provided to
264the constructor.
265"""
266experience, expert_experience = experience
267
268if self._n_step is None:
269transition = self._as_transition(experience)
270time_steps, policy_steps, next_time_steps = transition
271future_time_steps = next_time_steps
272else:
273experience_1 = experience._replace(
274observation=experience.observation[:, :2],
275action=experience.action[:, :2],
276discount=experience.discount[:, :2],
277reward=experience.reward[:, :2],
278step_type=experience.step_type[:, :2],
279next_step_type=experience.next_step_type[:, :2],
280)
281obs_2 = tf.stack([experience.observation[:, 0],
282experience.observation[:, -1],], axis=1)
283action_2 = tf.stack([experience.action[:, 0],
284experience.action[:, -1],], axis=1)
285discount_2 = tf.stack([experience.discount[:, 0],
286experience.discount[:, -1],], axis=1)
287step_type_2 = tf.stack([experience.step_type[:, 0],
288experience.step_type[:, -1],], axis=1)
289next_step_type_2 = tf.stack([experience.next_step_type[:, 0],
290experience.next_step_type[:, -1],], axis=1)
291reward_2 = tf.stack([experience.reward[:, 0],
292experience.reward[:, -1],], axis=1)
293experience_2 = experience._replace(
294observation=obs_2,
295action=action_2,
296discount=discount_2,
297step_type=step_type_2,
298next_step_type=next_step_type_2,
299reward=reward_2)
300time_steps, policy_steps, next_time_steps = self._as_transition(
301experience_1)
302_, _, future_time_steps = self._as_transition(experience_2)
303
304actions = policy_steps.action
305
306trainable_critic_variables = list(object_identity.ObjectIdentitySet(
307self._critic_network_1.trainable_variables +
308self._critic_network_2.trainable_variables))
309
310with tf.GradientTape(watch_accessed_variables=False) as tape:
311assert trainable_critic_variables, ('No trainable critic variables to '
312'optimize.')
313tape.watch(trainable_critic_variables)
314critic_loss = self._critic_loss_weight*self.critic_loss(
315time_steps,
316expert_experience,
317actions,
318next_time_steps,
319future_time_steps,
320td_errors_loss_fn=self._td_errors_loss_fn,
321gamma=self._gamma,
322reward_scale_factor=self._reward_scale_factor,
323weights=weights,
324training=True)
325
326tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
327critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
328self._apply_gradients(critic_grads, trainable_critic_variables,
329self._critic_optimizer)
330
331trainable_actor_variables = self._actor_network.trainable_variables
332with tf.GradientTape(watch_accessed_variables=False) as tape:
333assert trainable_actor_variables, ('No trainable actor variables to '
334'optimize.')
335tape.watch(trainable_actor_variables)
336actor_loss = self._actor_loss_weight*self.actor_loss(
337time_steps, actions, weights=weights)
338tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
339actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
340self._apply_gradients(actor_grads, trainable_actor_variables,
341self._actor_optimizer)
342
343# Train the behavior policy
344if self._use_behavior_policy:
345trainable_behavior_variables = self._behavior_actor_network.trainable_variables
346with tf.GradientTape(watch_accessed_variables=False) as tape:
347assert trainable_behavior_variables, ('No trainable behavior variables '
348'to optimize.')
349tape.watch(trainable_behavior_variables)
350behavior_loss = self._actor_loss_weight*self.behavior_loss(
351time_steps, actions, weights=weights)
352tf.debugging.check_numerics(behavior_loss, 'Behavior loss is inf or nan.')
353behavior_grads = tape.gradient(behavior_loss,
354trainable_behavior_variables)
355self._apply_gradients(behavior_grads, trainable_behavior_variables,
356self._actor_optimizer)
357else:
358behavior_loss = 0.0
359
360with tf.name_scope('Losses'):
361tf.compat.v2.summary.scalar(
362name='critic_loss', data=critic_loss, step=self.train_step_counter)
363tf.compat.v2.summary.scalar(
364name='actor_loss', data=actor_loss, step=self.train_step_counter)
365tf.compat.v2.summary.scalar(name='behavior_loss', data=behavior_loss,
366step=self.train_step_counter)
367
368self.train_step_counter.assign_add(1)
369self._update_target()
370
371total_loss = critic_loss + actor_loss
372
373extra = RceLossInfo(
374critic_loss=critic_loss, actor_loss=actor_loss)
375
376return tf_agent.LossInfo(loss=total_loss, extra=extra)
377
378def _apply_gradients(self, gradients, variables, optimizer):
379# list(...) is required for Python3.
380grads_and_vars = list(zip(gradients, variables))
381if self._gradient_clipping is not None:
382grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars,
383self._gradient_clipping)
384
385if self._summarize_grads_and_vars:
386eager_utils.add_variables_summaries(grads_and_vars,
387self.train_step_counter)
388eager_utils.add_gradients_summaries(grads_and_vars,
389self.train_step_counter)
390
391optimizer.apply_gradients(grads_and_vars)
392
393def _get_target_updater(self, tau=1.0, period=1):
394"""Performs a soft update of the target network parameters.
395
396For each weight w_s in the original network, and its corresponding
397weight w_t in the target network, a soft update is:
398w_t = (1- tau) x w_t + tau x ws
399
400Args:
401tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
402period: Step interval at which the target network is updated.
403
404Returns:
405A callable that performs a soft update of the target network parameters.
406"""
407with tf.name_scope('update_target'):
408
409def update():
410"""Update target network."""
411critic_update_1 = common.soft_variables_update(
412self._critic_network_1.variables,
413self._target_critic_network_1.variables,
414tau,
415tau_non_trainable=1.0)
416
417critic_2_update_vars = common.deduped_network_variables(
418self._critic_network_2, self._critic_network_1)
419
420target_critic_2_update_vars = common.deduped_network_variables(
421self._target_critic_network_2, self._target_critic_network_1)
422
423critic_update_2 = common.soft_variables_update(
424critic_2_update_vars,
425target_critic_2_update_vars,
426tau,
427tau_non_trainable=1.0)
428
429return tf.group(critic_update_1, critic_update_2)
430
431return common.Periodically(update, period, 'update_targets')
432
433def _actions_and_log_probs(self, time_steps):
434"""Get actions and corresponding log probabilities from policy."""
435# Get raw action distribution from policy, and initialize bijectors list.
436batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
437policy_state = self._train_policy.get_initial_state(batch_size)
438action_distribution = self._train_policy.distribution(
439time_steps, policy_state=policy_state).action
440
441# Sample actions and log_pis from transformed distribution.
442actions = tf.nest.map_structure(lambda d: d.sample(), action_distribution)
443log_pi = common.log_probability(action_distribution, actions,
444self.action_spec)
445
446return actions, log_pi
447
448@gin.configurable
449def critic_loss(self,
450time_steps,
451expert_experience,
452actions,
453next_time_steps,
454future_time_steps,
455td_errors_loss_fn,
456gamma = 1.0,
457reward_scale_factor = 1.0,
458weights = None,
459training = False,
460loss_name='c',
461use_done=False,
462q_combinator='min'):
463"""Computes the critic loss for SAC training.
464
465Args:
466time_steps: A batch of timesteps.
467expert_experience: An array of success examples.
468actions: A batch of actions.
469next_time_steps: A batch of next timesteps.
470future_time_steps: A batch of future timesteps, used for n-step returns.
471td_errors_loss_fn: A function(td_targets, predictions) to compute
472elementwise (per-batch-entry) loss.
473gamma: Discount for future rewards.
474reward_scale_factor: Multiplicative factor to scale rewards.
475weights: Optional scalar or elementwise (per-batch-entry) importance
476weights.
477training: Whether this loss is being used for training.
478loss_name: Which loss function to use. Use 'c' for RCE and 'q' for SQIL.
479use_done: Whether to use the terminal flag from the environment in the
480Bellman backup. We found that omitting it led to better results.
481q_combinator: Whether to combine the two Q-functions by taking the 'min'
482(as in TD3) or the 'max'.
483
484Returns:
485critic_loss: A scalar critic loss.
486"""
487assert weights is None
488with tf.name_scope('critic_loss'):
489nest_utils.assert_same_structure(actions, self.action_spec)
490nest_utils.assert_same_structure(time_steps, self.time_step_spec)
491nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)
492
493next_actions, _ = self._actions_and_log_probs(next_time_steps)
494target_input = (next_time_steps.observation, next_actions)
495target_q_values1, unused_network_state1 = self._target_critic_network_1(
496target_input, next_time_steps.step_type, training=False)
497target_q_values2, unused_network_state2 = self._target_critic_network_2(
498target_input, next_time_steps.step_type, training=False)
499if self._n_step is not None:
500future_actions, _ = self._actions_and_log_probs(future_time_steps)
501future_input = (future_time_steps.observation, future_actions)
502future_q_values1, _ = self._target_critic_network_1(
503future_input, future_time_steps.step_type, training=False)
504future_q_values2, _ = self._target_critic_network_2(
505future_input, future_time_steps.step_type, training=False)
506
507gamma_n = gamma**self._n_step # Discount for n-step returns
508target_q_values1 = (target_q_values1 + gamma_n * future_q_values1) / 2.0
509target_q_values2 = (target_q_values2 + gamma_n * future_q_values2) / 2.0
510
511if q_combinator == 'min':
512target_q_values = tf.minimum(target_q_values1, target_q_values2)
513else:
514assert q_combinator == 'max'
515target_q_values = tf.maximum(target_q_values1, target_q_values2)
516
517batch_size = time_steps.observation.shape[0]
518if loss_name == 'q':
519if use_done:
520td_targets = gamma * next_time_steps.discount * target_q_values
521else:
522td_targets = gamma * target_q_values
523else:
524assert loss_name == 'c'
525w = target_q_values / (1 - target_q_values)
526td_targets = gamma * w / (gamma * w + 1)
527if use_done:
528td_targets = next_time_steps.discount * td_targets
529weights = tf.concat([1 + gamma * w, (1 - gamma) * tf.ones(batch_size)],
530axis=0)
531
532td_targets = tf.stop_gradient(td_targets)
533td_targets = tf.concat([td_targets, tf.ones(batch_size)], axis=0)
534
535# Note that the actions only depend on the observations. We create the
536# expert_time_steps object simply to make this look like a time step
537# object.
538expert_time_steps = time_steps._replace(observation=expert_experience)
539if self._use_behavior_policy:
540policy_state = self._train_policy.get_initial_state(batch_size)
541action_distribution = self._behavior_policy.distribution(
542time_steps, policy_state=policy_state).action
543# Sample actions and log_pis from transformed distribution.
544expert_actions = tf.nest.map_structure(lambda d: d.sample(),
545action_distribution)
546else:
547expert_actions, _ = self._actions_and_log_probs(expert_time_steps)
548
549observation = time_steps.observation
550pred_input = (tf.concat([observation, expert_experience], axis=0),
551tf.concat([actions, expert_actions], axis=0))
552
553pred_td_targets1, _ = self._critic_network_1(
554pred_input, time_steps.step_type, training=training)
555pred_td_targets2, _ = self._critic_network_2(
556pred_input, time_steps.step_type, training=training)
557
558self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
559pred_td_targets2)
560
561critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
562critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
563critic_loss = critic_loss1 + critic_loss2
564
565if critic_loss.shape.rank > 1:
566# Sum over the time dimension.
567critic_loss = tf.reduce_sum(
568critic_loss, axis=range(1, critic_loss.shape.rank))
569
570agg_loss = common.aggregate_losses(
571per_example_loss=critic_loss,
572sample_weight=weights,
573regularization_loss=(self._critic_network_1.losses +
574self._critic_network_2.losses))
575critic_loss = agg_loss.total_loss
576
577self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
578pred_td_targets2)
579
580return critic_loss
581
582@gin.configurable
583def actor_loss(self,
584time_steps,
585rb_actions=None,
586weights = None,
587q_combinator='min',
588entropy_coef=1e-4):
589"""Computes the actor_loss for SAC training.
590
591Args:
592time_steps: A batch of timesteps.
593rb_actions: Actions from the replay buffer. While not used in the main RCE
594method, we used these actions to train a behavior policy for the
595ablation experiment studying how to sample actions for the success
596examples.
597weights: Optional scalar or elementwise (per-batch-entry) importance
598weights.
599q_combinator: Whether to combine the two Q-functions by taking the 'min'
600(as in TD3) or the 'max'.
601entropy_coef: Coefficient for entropy regularization term. We found that
6021e-4 worked well for all environments.
603Returns:
604actor_loss: A scalar actor loss.
605"""
606with tf.name_scope('actor_loss'):
607nest_utils.assert_same_structure(time_steps, self.time_step_spec)
608
609actions, log_pi = self._actions_and_log_probs(time_steps)
610target_input = (time_steps.observation, actions)
611
612target_q_values1, _ = self._critic_network_1(
613target_input, time_steps.step_type, training=False)
614target_q_values2, _ = self._critic_network_2(
615target_input, time_steps.step_type, training=False)
616if q_combinator == 'min':
617target_q_values = tf.minimum(target_q_values1, target_q_values2)
618else:
619assert q_combinator == 'max'
620target_q_values = tf.maximum(target_q_values1, target_q_values2)
621if entropy_coef == 0:
622actor_loss = - target_q_values
623else:
624actor_loss = entropy_coef * log_pi - target_q_values
625if actor_loss.shape.rank > 1:
626# Sum over the time dimension.
627actor_loss = tf.reduce_sum(
628actor_loss, axis=range(1, actor_loss.shape.rank))
629reg_loss = self._actor_network.losses if self._actor_network else None
630agg_loss = common.aggregate_losses(
631per_example_loss=actor_loss,
632sample_weight=weights,
633regularization_loss=reg_loss)
634actor_loss = agg_loss.total_loss
635self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
636target_q_values, time_steps)
637
638return actor_loss
639
640@gin.configurable
641def behavior_loss(self,
642time_steps,
643actions,
644weights = None):
645with tf.name_scope('behavior_loss'):
646nest_utils.assert_same_structure(time_steps, self.time_step_spec)
647batch_size = nest_utils.get_outer_shape(time_steps,
648self._time_step_spec)[0]
649policy_state = self._behavior_policy.get_initial_state(batch_size)
650action_distribution = self._behavior_policy.distribution(
651time_steps, policy_state=policy_state).action
652log_pi = common.log_probability(action_distribution, actions,
653self.action_spec)
654return -1.0 * tf.reduce_mean(log_pi)
655
656def _critic_loss_debug_summaries(self, td_targets, pred_td_targets1,
657pred_td_targets2):
658if self._debug_summaries:
659td_errors1 = td_targets - pred_td_targets1
660td_errors2 = td_targets - pred_td_targets2
661td_errors = tf.concat([td_errors1, td_errors2], axis=0)
662common.generate_tensor_summaries('td_errors', td_errors,
663self.train_step_counter)
664common.generate_tensor_summaries('td_targets', td_targets,
665self.train_step_counter)
666common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1,
667self.train_step_counter)
668common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2,
669self.train_step_counter)
670
671def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi,
672target_q_values, time_steps):
673if self._debug_summaries:
674common.generate_tensor_summaries('actor_loss', actor_loss,
675self.train_step_counter)
676try:
677common.generate_tensor_summaries('actions', actions,
678self.train_step_counter)
679except ValueError:
680pass # Guard against internal SAC variants that do not directly
681# generate actions.
682
683common.generate_tensor_summaries('log_pi', log_pi,
684self.train_step_counter)
685tf.compat.v2.summary.scalar(
686name='entropy_avg',
687data=-tf.reduce_mean(input_tensor=log_pi),
688step=self.train_step_counter)
689common.generate_tensor_summaries('target_q_values', target_q_values,
690self.train_step_counter)
691batch_size = nest_utils.get_outer_shape(time_steps,
692self._time_step_spec)[0]
693policy_state = self._train_policy.get_initial_state(batch_size)
694action_distribution = self._train_policy.distribution(
695time_steps, policy_state).action
696if isinstance(action_distribution, tfp.distributions.Normal):
697common.generate_tensor_summaries('act_mean', action_distribution.loc,
698self.train_step_counter)
699common.generate_tensor_summaries('act_stddev',
700action_distribution.scale,
701self.train_step_counter)
702elif isinstance(action_distribution, tfp.distributions.Categorical):
703common.generate_tensor_summaries('act_mode', action_distribution.mode(),
704self.train_step_counter)
705try:
706common.generate_tensor_summaries('entropy_action',
707action_distribution.entropy(),
708self.train_step_counter)
709except NotImplementedError:
710pass # Some distributions do not have an analytic entropy.
711