google-research

Форк
0
/
rce_agent.py 
710 строк · 30.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implements the RCE Agent."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23
from typing import Callable, Optional, Text
24

25
import gin
26
from six.moves import zip
27
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
28
import tensorflow_probability as tfp
29

30
from tf_agents.agents import data_converter
31
from tf_agents.agents import tf_agent
32
from tf_agents.networks import network
33
from tf_agents.policies import actor_policy
34
from tf_agents.policies import tf_policy
35
from tf_agents.trajectories import time_step as ts
36
from tf_agents.typing import types
37
from tf_agents.utils import common
38
from tf_agents.utils import eager_utils
39
from tf_agents.utils import nest_utils
40
from tf_agents.utils import object_identity
41

42

43
RceLossInfo = collections.namedtuple(
44
    'RceLossInfo', ('critic_loss', 'actor_loss'))
45

46

47
@gin.configurable
48
class RceAgent(tf_agent.TFAgent):
49
  """An agent for Recursive Classification of Examples."""
50

51
  def __init__(self,
52
               time_step_spec,
53
               action_spec,
54
               critic_network,
55
               actor_network,
56
               actor_optimizer,
57
               critic_optimizer,
58
               actor_loss_weight = 1.0,
59
               critic_loss_weight = 0.5,
60
               actor_policy_ctor = actor_policy.ActorPolicy,
61
               critic_network_2 = None,
62
               target_critic_network = None,
63
               target_critic_network_2 = None,
64
               target_update_tau = 1.0,
65
               target_update_period = 1,
66
               td_errors_loss_fn = tf.math.squared_difference,
67
               gamma = 1.0,
68
               reward_scale_factor = 1.0,
69
               gradient_clipping = None,
70
               debug_summaries = False,
71
               summarize_grads_and_vars = False,
72
               train_step_counter = None,
73
               name = None,
74
               n_step = None,
75
               use_behavior_policy = False):
76
    """Creates a RCE Agent.
77

78
    Args:
79
      time_step_spec: A `TimeStep` spec of the expected time_steps.
80
      action_spec: A nest of BoundedTensorSpec representing the actions.
81
      critic_network: A function critic_network((observations, actions)) that
82
        returns the q_values for each observation and action.
83
      actor_network: A function actor_network(observation, action_spec) that
84
        returns action distribution.
85
      actor_optimizer: The optimizer to use for the actor network.
86
      critic_optimizer: The default optimizer to use for the critic network.
87
      actor_loss_weight: The weight on actor loss.
88
      critic_loss_weight: The weight on critic loss.
89
      actor_policy_ctor: The policy class to use.
90
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
91
        the second critic network during Q learning.  The weights from
92
        `critic_network` are copied if this is not provided.
93
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
94
        used as the target critic network during Q learning. Every
95
        `target_update_period` train steps, the weights from `critic_network`
96
        are copied (possibly withsmoothing via `target_update_tau`) to `
97
        target_critic_network`.  If `target_critic_network` is not provided, it
98
        is created by making a copy of `critic_network`, which initializes a new
99
        network with the same structure and its own layers and weights.
100
        Performing a `Network.copy` does not work when the network instance
101
        already has trainable parameters (e.g., has already been built, or when
102
        the network is sharing layers with another).  In these cases, it is up
103
        to you to build a copy having weights that are not shared with the
104
        original `critic_network`, so that this can be used as a target network.
105
        If you provide a `target_critic_network` that shares any weights with
106
        `critic_network`, a warning will be logged but no exception is thrown.
107
      target_critic_network_2: (Optional.) Similar network as
108
        target_critic_network but for the critic_network_2. See documentation
109
        for target_critic_network. Will only be used if 'critic_network_2' is
110
        also specified.
111
      target_update_tau: Factor for soft update of the target networks.
112
      target_update_period: Period for soft update of the target networks.
113
      td_errors_loss_fn:  A function for computing the elementwise TD errors
114
        loss.
115
      gamma: A discount factor for future rewards.
116
      reward_scale_factor: Multiplicative scale for the reward.
117
      gradient_clipping: Norm length to clip gradients.
118
      debug_summaries: A bool to gather debug summaries.
119
      summarize_grads_and_vars: If True, gradient and network variable summaries
120
        will be written during training.
121
      train_step_counter: An optional counter to increment every time the train
122
        op is run.  Defaults to the global_step.
123
      name: The name of this agent. All variables in this module will fall under
124
        that name. Defaults to the class name.
125
      n_step: An integer specifying whether to use n-step returns. Empirically,
126
        a value of 10 works well for most tasks. Use None to disable n-step
127
        returns.
128
      use_behavior_policy: A boolean indicating how to sample actions for the
129
        success states. When use_behavior_policy=True, we use the historical
130
        average policy; otherwise, we use the current policy.
131
    """
132
    tf.Module.__init__(self, name=name)
133

134
    self._check_action_spec(action_spec)
135

136
    self._critic_network_1 = critic_network
137
    self._critic_network_1.create_variables(
138
        (time_step_spec.observation, action_spec))
139
    if target_critic_network:
140
      target_critic_network.create_variables(
141
          (time_step_spec.observation, action_spec))
142
      self._target_critic_network_1 = target_critic_network
143
    else:
144
      self._target_critic_network_1 = (
145
          common.maybe_copy_target_network_with_checks(self._critic_network_1,
146
                                                       None,
147
                                                       'TargetCriticNetwork1'))
148

149
    if critic_network_2 is not None:
150
      self._critic_network_2 = critic_network_2
151
    else:
152
      self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
153
      # Do not use target_critic_network_2 if critic_network_2 is None.
154
      target_critic_network_2 = None
155
    self._critic_network_2.create_variables(
156
        (time_step_spec.observation, action_spec))
157

158
    if target_critic_network_2:
159
      target_critic_network_2.create_variables(
160
          (time_step_spec.observation, action_spec))
161
      self._target_critic_network_2 = target_critic_network
162
    else:
163
      self._target_critic_network_2 = (
164
          common.maybe_copy_target_network_with_checks(self._critic_network_2,
165
                                                       None,
166
                                                       'TargetCriticNetwork2'))
167

168
    if actor_network:
169
      actor_network.create_variables(time_step_spec.observation)
170
    self._actor_network = actor_network
171

172
    self._use_behavior_policy = use_behavior_policy
173
    if use_behavior_policy:
174
      self._behavior_actor_network = actor_network.copy(
175
          name='BehaviorActorNetwork')
176
      self._behavior_policy = actor_policy_ctor(
177
          time_step_spec=time_step_spec,
178
          action_spec=action_spec,
179
          actor_network=self._behavior_actor_network,
180
          training=True)
181

182
    policy = actor_policy_ctor(
183
        time_step_spec=time_step_spec,
184
        action_spec=action_spec,
185
        actor_network=self._actor_network,
186
        training=False)
187

188
    self._train_policy = actor_policy_ctor(
189
        time_step_spec=time_step_spec,
190
        action_spec=action_spec,
191
        actor_network=self._actor_network,
192
        training=True)
193

194
    self._target_update_tau = target_update_tau
195
    self._target_update_period = target_update_period
196
    self._actor_optimizer = actor_optimizer
197
    self._critic_optimizer = critic_optimizer
198
    self._actor_loss_weight = actor_loss_weight
199
    self._critic_loss_weight = critic_loss_weight
200
    self._td_errors_loss_fn = td_errors_loss_fn
201
    self._gamma = gamma
202
    self._reward_scale_factor = reward_scale_factor
203
    self._gradient_clipping = gradient_clipping
204
    self._debug_summaries = debug_summaries
205
    self._summarize_grads_and_vars = summarize_grads_and_vars
206
    self._update_target = self._get_target_updater(
207
        tau=self._target_update_tau, period=self._target_update_period)
208
    self._n_step = n_step
209

210
    train_sequence_length = 2 if not critic_network.state_spec else None
211

212
    super(RceAgent, self).__init__(
213
        time_step_spec,
214
        action_spec,
215
        policy=policy,
216
        collect_policy=policy,
217
        train_sequence_length=train_sequence_length,
218
        debug_summaries=debug_summaries,
219
        summarize_grads_and_vars=summarize_grads_and_vars,
220
        train_step_counter=train_step_counter,
221
        validate_args=False
222
    )
223

224
    self._as_transition = data_converter.AsTransition(
225
        self.data_context, squeeze_time_dim=(train_sequence_length == 2))
226

227
  def _check_action_spec(self, action_spec):
228
    flat_action_spec = tf.nest.flatten(action_spec)
229
    for spec in flat_action_spec:
230
      if spec.dtype.is_integer:
231
        raise NotImplementedError(
232
            'RceAgent does not currently support discrete actions. '
233
            'Action spec: {}'.format(action_spec))
234

235
  def _initialize(self):
236
    """Returns an op to initialize the agent.
237

238
    Copies weights from the Q networks to the target Q network.
239
    """
240
    common.soft_variables_update(
241
        self._critic_network_1.variables,
242
        self._target_critic_network_1.variables,
243
        tau=1.0)
244
    common.soft_variables_update(
245
        self._critic_network_2.variables,
246
        self._target_critic_network_2.variables,
247
        tau=1.0)
248

249
  def _train(self, experience, weights):
250
    """Returns a train op to update the agent's networks.
251

252
    This method trains with the provided batched experience.
253

254
    Args:
255
      experience: A time-stacked trajectory object.
256
      weights: Optional scalar or elementwise (per-batch-entry) importance
257
        weights.
258

259
    Returns:
260
      A train_op.
261

262
    Raises:
263
      ValueError: If optimizers are None and no default value was provided to
264
        the constructor.
265
    """
266
    experience, expert_experience = experience
267

268
    if self._n_step is None:
269
      transition = self._as_transition(experience)
270
      time_steps, policy_steps, next_time_steps = transition
271
      future_time_steps = next_time_steps
272
    else:
273
      experience_1 = experience._replace(
274
          observation=experience.observation[:, :2],
275
          action=experience.action[:, :2],
276
          discount=experience.discount[:, :2],
277
          reward=experience.reward[:, :2],
278
          step_type=experience.step_type[:, :2],
279
          next_step_type=experience.next_step_type[:, :2],
280
          )
281
      obs_2 = tf.stack([experience.observation[:, 0],
282
                        experience.observation[:, -1],], axis=1)
283
      action_2 = tf.stack([experience.action[:, 0],
284
                           experience.action[:, -1],], axis=1)
285
      discount_2 = tf.stack([experience.discount[:, 0],
286
                             experience.discount[:, -1],], axis=1)
287
      step_type_2 = tf.stack([experience.step_type[:, 0],
288
                              experience.step_type[:, -1],], axis=1)
289
      next_step_type_2 = tf.stack([experience.next_step_type[:, 0],
290
                                   experience.next_step_type[:, -1],], axis=1)
291
      reward_2 = tf.stack([experience.reward[:, 0],
292
                           experience.reward[:, -1],], axis=1)
293
      experience_2 = experience._replace(
294
          observation=obs_2,
295
          action=action_2,
296
          discount=discount_2,
297
          step_type=step_type_2,
298
          next_step_type=next_step_type_2,
299
          reward=reward_2)
300
      time_steps, policy_steps, next_time_steps = self._as_transition(
301
          experience_1)
302
      _, _, future_time_steps = self._as_transition(experience_2)
303

304
    actions = policy_steps.action
305

306
    trainable_critic_variables = list(object_identity.ObjectIdentitySet(
307
        self._critic_network_1.trainable_variables +
308
        self._critic_network_2.trainable_variables))
309

310
    with tf.GradientTape(watch_accessed_variables=False) as tape:
311
      assert trainable_critic_variables, ('No trainable critic variables to '
312
                                          'optimize.')
313
      tape.watch(trainable_critic_variables)
314
      critic_loss = self._critic_loss_weight*self.critic_loss(
315
          time_steps,
316
          expert_experience,
317
          actions,
318
          next_time_steps,
319
          future_time_steps,
320
          td_errors_loss_fn=self._td_errors_loss_fn,
321
          gamma=self._gamma,
322
          reward_scale_factor=self._reward_scale_factor,
323
          weights=weights,
324
          training=True)
325

326
    tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
327
    critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
328
    self._apply_gradients(critic_grads, trainable_critic_variables,
329
                          self._critic_optimizer)
330

331
    trainable_actor_variables = self._actor_network.trainable_variables
332
    with tf.GradientTape(watch_accessed_variables=False) as tape:
333
      assert trainable_actor_variables, ('No trainable actor variables to '
334
                                         'optimize.')
335
      tape.watch(trainable_actor_variables)
336
      actor_loss = self._actor_loss_weight*self.actor_loss(
337
          time_steps, actions, weights=weights)
338
    tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
339
    actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
340
    self._apply_gradients(actor_grads, trainable_actor_variables,
341
                          self._actor_optimizer)
342

343
    # Train the behavior policy
344
    if self._use_behavior_policy:
345
      trainable_behavior_variables = self._behavior_actor_network.trainable_variables
346
      with tf.GradientTape(watch_accessed_variables=False) as tape:
347
        assert trainable_behavior_variables, ('No trainable behavior variables '
348
                                              'to optimize.')
349
        tape.watch(trainable_behavior_variables)
350
        behavior_loss = self._actor_loss_weight*self.behavior_loss(
351
            time_steps, actions, weights=weights)
352
      tf.debugging.check_numerics(behavior_loss, 'Behavior loss is inf or nan.')
353
      behavior_grads = tape.gradient(behavior_loss,
354
                                     trainable_behavior_variables)
355
      self._apply_gradients(behavior_grads, trainable_behavior_variables,
356
                            self._actor_optimizer)
357
    else:
358
      behavior_loss = 0.0
359

360
    with tf.name_scope('Losses'):
361
      tf.compat.v2.summary.scalar(
362
          name='critic_loss', data=critic_loss, step=self.train_step_counter)
363
      tf.compat.v2.summary.scalar(
364
          name='actor_loss', data=actor_loss, step=self.train_step_counter)
365
      tf.compat.v2.summary.scalar(name='behavior_loss', data=behavior_loss,
366
                                  step=self.train_step_counter)
367

368
    self.train_step_counter.assign_add(1)
369
    self._update_target()
370

371
    total_loss = critic_loss + actor_loss
372

373
    extra = RceLossInfo(
374
        critic_loss=critic_loss, actor_loss=actor_loss)
375

376
    return tf_agent.LossInfo(loss=total_loss, extra=extra)
377

378
  def _apply_gradients(self, gradients, variables, optimizer):
379
    # list(...) is required for Python3.
380
    grads_and_vars = list(zip(gradients, variables))
381
    if self._gradient_clipping is not None:
382
      grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars,
383
                                                       self._gradient_clipping)
384

385
    if self._summarize_grads_and_vars:
386
      eager_utils.add_variables_summaries(grads_and_vars,
387
                                          self.train_step_counter)
388
      eager_utils.add_gradients_summaries(grads_and_vars,
389
                                          self.train_step_counter)
390

391
    optimizer.apply_gradients(grads_and_vars)
392

393
  def _get_target_updater(self, tau=1.0, period=1):
394
    """Performs a soft update of the target network parameters.
395

396
    For each weight w_s in the original network, and its corresponding
397
    weight w_t in the target network, a soft update is:
398
    w_t = (1- tau) x w_t + tau x ws
399

400
    Args:
401
      tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
402
      period: Step interval at which the target network is updated.
403

404
    Returns:
405
      A callable that performs a soft update of the target network parameters.
406
    """
407
    with tf.name_scope('update_target'):
408

409
      def update():
410
        """Update target network."""
411
        critic_update_1 = common.soft_variables_update(
412
            self._critic_network_1.variables,
413
            self._target_critic_network_1.variables,
414
            tau,
415
            tau_non_trainable=1.0)
416

417
        critic_2_update_vars = common.deduped_network_variables(
418
            self._critic_network_2, self._critic_network_1)
419

420
        target_critic_2_update_vars = common.deduped_network_variables(
421
            self._target_critic_network_2, self._target_critic_network_1)
422

423
        critic_update_2 = common.soft_variables_update(
424
            critic_2_update_vars,
425
            target_critic_2_update_vars,
426
            tau,
427
            tau_non_trainable=1.0)
428

429
        return tf.group(critic_update_1, critic_update_2)
430

431
      return common.Periodically(update, period, 'update_targets')
432

433
  def _actions_and_log_probs(self, time_steps):
434
    """Get actions and corresponding log probabilities from policy."""
435
    # Get raw action distribution from policy, and initialize bijectors list.
436
    batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
437
    policy_state = self._train_policy.get_initial_state(batch_size)
438
    action_distribution = self._train_policy.distribution(
439
        time_steps, policy_state=policy_state).action
440

441
    # Sample actions and log_pis from transformed distribution.
442
    actions = tf.nest.map_structure(lambda d: d.sample(), action_distribution)
443
    log_pi = common.log_probability(action_distribution, actions,
444
                                    self.action_spec)
445

446
    return actions, log_pi
447

448
  @gin.configurable
449
  def critic_loss(self,
450
                  time_steps,
451
                  expert_experience,
452
                  actions,
453
                  next_time_steps,
454
                  future_time_steps,
455
                  td_errors_loss_fn,
456
                  gamma = 1.0,
457
                  reward_scale_factor = 1.0,
458
                  weights = None,
459
                  training = False,
460
                  loss_name='c',
461
                  use_done=False,
462
                  q_combinator='min'):
463
    """Computes the critic loss for SAC training.
464

465
    Args:
466
      time_steps: A batch of timesteps.
467
      expert_experience: An array of success examples.
468
      actions: A batch of actions.
469
      next_time_steps: A batch of next timesteps.
470
      future_time_steps: A batch of future timesteps, used for n-step returns.
471
      td_errors_loss_fn: A function(td_targets, predictions) to compute
472
        elementwise (per-batch-entry) loss.
473
      gamma: Discount for future rewards.
474
      reward_scale_factor: Multiplicative factor to scale rewards.
475
      weights: Optional scalar or elementwise (per-batch-entry) importance
476
        weights.
477
      training: Whether this loss is being used for training.
478
      loss_name: Which loss function to use. Use 'c' for RCE and 'q' for SQIL.
479
      use_done: Whether to use the terminal flag from the environment in the
480
        Bellman backup. We found that omitting it led to better results.
481
      q_combinator: Whether to combine the two Q-functions by taking the 'min'
482
        (as in TD3) or the 'max'.
483

484
    Returns:
485
      critic_loss: A scalar critic loss.
486
    """
487
    assert weights is None
488
    with tf.name_scope('critic_loss'):
489
      nest_utils.assert_same_structure(actions, self.action_spec)
490
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
491
      nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)
492

493
      next_actions, _ = self._actions_and_log_probs(next_time_steps)
494
      target_input = (next_time_steps.observation, next_actions)
495
      target_q_values1, unused_network_state1 = self._target_critic_network_1(
496
          target_input, next_time_steps.step_type, training=False)
497
      target_q_values2, unused_network_state2 = self._target_critic_network_2(
498
          target_input, next_time_steps.step_type, training=False)
499
      if self._n_step is not None:
500
        future_actions, _ = self._actions_and_log_probs(future_time_steps)
501
        future_input = (future_time_steps.observation, future_actions)
502
        future_q_values1, _ = self._target_critic_network_1(
503
            future_input, future_time_steps.step_type, training=False)
504
        future_q_values2, _ = self._target_critic_network_2(
505
            future_input, future_time_steps.step_type, training=False)
506

507
        gamma_n = gamma**self._n_step  # Discount for n-step returns
508
        target_q_values1 = (target_q_values1 + gamma_n * future_q_values1) / 2.0
509
        target_q_values2 = (target_q_values2 + gamma_n * future_q_values2) / 2.0
510

511
      if q_combinator == 'min':
512
        target_q_values = tf.minimum(target_q_values1, target_q_values2)
513
      else:
514
        assert q_combinator == 'max'
515
        target_q_values = tf.maximum(target_q_values1, target_q_values2)
516

517
      batch_size = time_steps.observation.shape[0]
518
      if loss_name == 'q':
519
        if use_done:
520
          td_targets = gamma * next_time_steps.discount * target_q_values
521
        else:
522
          td_targets = gamma * target_q_values
523
      else:
524
        assert loss_name == 'c'
525
        w = target_q_values / (1 - target_q_values)
526
        td_targets = gamma * w / (gamma * w + 1)
527
        if use_done:
528
          td_targets = next_time_steps.discount * td_targets
529
        weights = tf.concat([1 + gamma * w, (1 - gamma) * tf.ones(batch_size)],
530
                            axis=0)
531

532
      td_targets = tf.stop_gradient(td_targets)
533
      td_targets = tf.concat([td_targets, tf.ones(batch_size)], axis=0)
534

535
      # Note that the actions only depend on the observations. We create the
536
      # expert_time_steps object simply to make this look like a time step
537
      # object.
538
      expert_time_steps = time_steps._replace(observation=expert_experience)
539
      if self._use_behavior_policy:
540
        policy_state = self._train_policy.get_initial_state(batch_size)
541
        action_distribution = self._behavior_policy.distribution(
542
            time_steps, policy_state=policy_state).action
543
        # Sample actions and log_pis from transformed distribution.
544
        expert_actions = tf.nest.map_structure(lambda d: d.sample(),
545
                                               action_distribution)
546
      else:
547
        expert_actions, _ = self._actions_and_log_probs(expert_time_steps)
548

549
      observation = time_steps.observation
550
      pred_input = (tf.concat([observation, expert_experience], axis=0),
551
                    tf.concat([actions, expert_actions], axis=0))
552

553
      pred_td_targets1, _ = self._critic_network_1(
554
          pred_input, time_steps.step_type, training=training)
555
      pred_td_targets2, _ = self._critic_network_2(
556
          pred_input, time_steps.step_type, training=training)
557

558
      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
559
                                        pred_td_targets2)
560

561
      critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
562
      critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
563
      critic_loss = critic_loss1 + critic_loss2
564

565
      if critic_loss.shape.rank > 1:
566
        # Sum over the time dimension.
567
        critic_loss = tf.reduce_sum(
568
            critic_loss, axis=range(1, critic_loss.shape.rank))
569

570
      agg_loss = common.aggregate_losses(
571
          per_example_loss=critic_loss,
572
          sample_weight=weights,
573
          regularization_loss=(self._critic_network_1.losses +
574
                               self._critic_network_2.losses))
575
      critic_loss = agg_loss.total_loss
576

577
      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
578
                                        pred_td_targets2)
579

580
      return critic_loss
581

582
  @gin.configurable
583
  def actor_loss(self,
584
                 time_steps,
585
                 rb_actions=None,
586
                 weights = None,
587
                 q_combinator='min',
588
                 entropy_coef=1e-4):
589
    """Computes the actor_loss for SAC training.
590

591
    Args:
592
      time_steps: A batch of timesteps.
593
      rb_actions: Actions from the replay buffer. While not used in the main RCE
594
        method, we used these actions to train a behavior policy for the
595
        ablation experiment studying how to sample actions for the success
596
        examples.
597
      weights: Optional scalar or elementwise (per-batch-entry) importance
598
        weights.
599
      q_combinator: Whether to combine the two Q-functions by taking the 'min'
600
        (as in TD3) or the 'max'.
601
      entropy_coef: Coefficient for entropy regularization term. We found that
602
        1e-4 worked well for all environments.
603
    Returns:
604
      actor_loss: A scalar actor loss.
605
    """
606
    with tf.name_scope('actor_loss'):
607
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
608

609
      actions, log_pi = self._actions_and_log_probs(time_steps)
610
      target_input = (time_steps.observation, actions)
611

612
      target_q_values1, _ = self._critic_network_1(
613
          target_input, time_steps.step_type, training=False)
614
      target_q_values2, _ = self._critic_network_2(
615
          target_input, time_steps.step_type, training=False)
616
      if q_combinator == 'min':
617
        target_q_values = tf.minimum(target_q_values1, target_q_values2)
618
      else:
619
        assert q_combinator == 'max'
620
        target_q_values = tf.maximum(target_q_values1, target_q_values2)
621
      if entropy_coef == 0:
622
        actor_loss = - target_q_values
623
      else:
624
        actor_loss = entropy_coef * log_pi - target_q_values
625
      if actor_loss.shape.rank > 1:
626
        # Sum over the time dimension.
627
        actor_loss = tf.reduce_sum(
628
            actor_loss, axis=range(1, actor_loss.shape.rank))
629
      reg_loss = self._actor_network.losses if self._actor_network else None
630
      agg_loss = common.aggregate_losses(
631
          per_example_loss=actor_loss,
632
          sample_weight=weights,
633
          regularization_loss=reg_loss)
634
      actor_loss = agg_loss.total_loss
635
      self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
636
                                       target_q_values, time_steps)
637

638
      return actor_loss
639

640
  @gin.configurable
641
  def behavior_loss(self,
642
                    time_steps,
643
                    actions,
644
                    weights = None):
645
    with tf.name_scope('behavior_loss'):
646
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
647
      batch_size = nest_utils.get_outer_shape(time_steps,
648
                                              self._time_step_spec)[0]
649
      policy_state = self._behavior_policy.get_initial_state(batch_size)
650
      action_distribution = self._behavior_policy.distribution(
651
          time_steps, policy_state=policy_state).action
652
      log_pi = common.log_probability(action_distribution, actions,
653
                                      self.action_spec)
654
      return -1.0 * tf.reduce_mean(log_pi)
655

656
  def _critic_loss_debug_summaries(self, td_targets, pred_td_targets1,
657
                                   pred_td_targets2):
658
    if self._debug_summaries:
659
      td_errors1 = td_targets - pred_td_targets1
660
      td_errors2 = td_targets - pred_td_targets2
661
      td_errors = tf.concat([td_errors1, td_errors2], axis=0)
662
      common.generate_tensor_summaries('td_errors', td_errors,
663
                                       self.train_step_counter)
664
      common.generate_tensor_summaries('td_targets', td_targets,
665
                                       self.train_step_counter)
666
      common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1,
667
                                       self.train_step_counter)
668
      common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2,
669
                                       self.train_step_counter)
670

671
  def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi,
672
                                  target_q_values, time_steps):
673
    if self._debug_summaries:
674
      common.generate_tensor_summaries('actor_loss', actor_loss,
675
                                       self.train_step_counter)
676
      try:
677
        common.generate_tensor_summaries('actions', actions,
678
                                         self.train_step_counter)
679
      except ValueError:
680
        pass  # Guard against internal SAC variants that do not directly
681
        # generate actions.
682

683
      common.generate_tensor_summaries('log_pi', log_pi,
684
                                       self.train_step_counter)
685
      tf.compat.v2.summary.scalar(
686
          name='entropy_avg',
687
          data=-tf.reduce_mean(input_tensor=log_pi),
688
          step=self.train_step_counter)
689
      common.generate_tensor_summaries('target_q_values', target_q_values,
690
                                       self.train_step_counter)
691
      batch_size = nest_utils.get_outer_shape(time_steps,
692
                                              self._time_step_spec)[0]
693
      policy_state = self._train_policy.get_initial_state(batch_size)
694
      action_distribution = self._train_policy.distribution(
695
          time_steps, policy_state).action
696
      if isinstance(action_distribution, tfp.distributions.Normal):
697
        common.generate_tensor_summaries('act_mean', action_distribution.loc,
698
                                         self.train_step_counter)
699
        common.generate_tensor_summaries('act_stddev',
700
                                         action_distribution.scale,
701
                                         self.train_step_counter)
702
      elif isinstance(action_distribution, tfp.distributions.Categorical):
703
        common.generate_tensor_summaries('act_mode', action_distribution.mode(),
704
                                         self.train_step_counter)
705
      try:
706
        common.generate_tensor_summaries('entropy_action',
707
                                         action_distribution.entropy(),
708
                                         self.train_step_counter)
709
      except NotImplementedError:
710
        pass  # Some distributions do not have an analytic entropy.
711

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.