google-research

Форк
0
/
c_learning_agent.py 
586 строк · 25.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""C-learning.
17

18
Implements the off-policy goal-conditioned C-learning algorithm from
19
"C-Learning: Learning to Achieve Goals via Recursive Classification" by
20
Eysenbach et al (2020): https://arxiv.org/abs/2011.08909
21
"""
22
from __future__ import absolute_import
23
from __future__ import division
24
from __future__ import print_function
25

26
import collections
27
from typing import Callable, Optional, Text
28

29
import gin
30
from six.moves import zip
31
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
32
import tensorflow_probability as tfp
33

34
from tf_agents.agents import tf_agent
35
from tf_agents.networks import network
36
from tf_agents.policies import actor_policy
37
from tf_agents.policies import tf_policy
38
from tf_agents.trajectories import time_step as ts
39
from tf_agents.trajectories import trajectory
40
from tf_agents.typing import types
41
from tf_agents.utils import common
42
from tf_agents.utils import eager_utils
43
from tf_agents.utils import nest_utils
44
from tf_agents.utils import object_identity
45

46

47
CLearningLossInfo = collections.namedtuple(
48
    'LossInfo', ('critic_loss', 'actor_loss'))
49

50

51
@gin.configurable
52
class CLearningAgent(tf_agent.TFAgent):
53
  """A C-learning Agent."""
54

55
  def __init__(self,
56
               time_step_spec,
57
               action_spec,
58
               critic_network,
59
               actor_network,
60
               actor_optimizer,
61
               critic_optimizer,
62
               actor_loss_weight = 1.0,
63
               critic_loss_weight = 0.5,
64
               actor_policy_ctor = actor_policy.ActorPolicy,
65
               critic_network_2 = None,
66
               target_critic_network = None,
67
               target_critic_network_2 = None,
68
               target_update_tau = 1.0,
69
               target_update_period = 1,
70
               td_errors_loss_fn = tf.math.squared_difference,
71
               gamma = 1.0,
72
               gradient_clipping = None,
73
               debug_summaries = False,
74
               summarize_grads_and_vars = False,
75
               train_step_counter = None,
76
               name = None):
77
    """Creates a C-learning Agent.
78

79
    By default, the environment observation contains the current state and goal
80
    state. By setting the obs_to_goal gin config in c_learning_utils, the user
81
    can specify that the agent should only look at certain subsets of the goal
82
    state.
83

84
    Args:
85
      time_step_spec: A `TimeStep` spec of the expected time_steps.
86
      action_spec: A nest of BoundedTensorSpec representing the actions.
87
      critic_network: A function critic_network((observations, actions)) that
88
        returns the q_values for each observation and action.
89
      actor_network: A function actor_network(observation, action_spec) that
90
        returns action distribution.
91
      actor_optimizer: The optimizer to use for the actor network.
92
      critic_optimizer: The default optimizer to use for the critic network.
93
      actor_loss_weight: The weight on actor loss.
94
      critic_loss_weight: The weight on critic loss.
95
      actor_policy_ctor: The policy class to use.
96
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
97
        the second critic network during Q learning.  The weights from
98
        `critic_network` are copied if this is not provided.
99
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
100
        used as the target critic network during Q learning. Every
101
        `target_update_period` train steps, the weights from `critic_network`
102
        are copied (possibly withsmoothing via `target_update_tau`) to `
103
        target_critic_network`.  If `target_critic_network` is not provided, it
104
        is created by making a copy of `critic_network`, which initializes a new
105
        network with the same structure and its own layers and weights.
106
        Performing a `Network.copy` does not work when the network instance
107
        already has trainable parameters (e.g., has already been built, or when
108
        the network is sharing layers with another).  In these cases, it is up
109
        to you to build a copy having weights that are not shared with the
110
        original `critic_network`, so that this can be used as a target network.
111
        If you provide a `target_critic_network` that shares any weights with
112
        `critic_network`, a warning will be logged but no exception is thrown.
113
      target_critic_network_2: (Optional.) Similar network as
114
        target_critic_network but for the critic_network_2. See documentation
115
        for target_critic_network. Will only be used if 'critic_network_2' is
116
        also specified.
117
      target_update_tau: Factor for soft update of the target networks.
118
      target_update_period: Period for soft update of the target networks.
119
      td_errors_loss_fn:  A function for computing the elementwise TD errors
120
        loss.
121
      gamma: A discount factor for future rewards.
122
      gradient_clipping: Norm length to clip gradients.
123
      debug_summaries: A bool to gather debug summaries.
124
      summarize_grads_and_vars: If True, gradient and network variable summaries
125
        will be written during training.
126
      train_step_counter: An optional counter to increment every time the train
127
        op is run.  Defaults to the global_step.
128
      name: The name of this agent. All variables in this module will fall under
129
        that name. Defaults to the class name.
130
    """
131
    tf.Module.__init__(self, name=name)
132

133
    self._check_action_spec(action_spec)
134

135
    self._critic_network_1 = critic_network
136
    self._critic_network_1.create_variables()
137
    if target_critic_network:
138
      target_critic_network.create_variables()
139
    self._target_critic_network_1 = (
140
        common.maybe_copy_target_network_with_checks(self._critic_network_1,
141
                                                     target_critic_network,
142
                                                     'TargetCriticNetwork1'))
143

144
    if critic_network_2 is not None:
145
      self._critic_network_2 = critic_network_2
146
    else:
147
      self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
148
      # Do not use target_critic_network_2 if critic_network_2 is None.
149
      target_critic_network_2 = None
150
    self._critic_network_2.create_variables()
151
    if target_critic_network_2:
152
      target_critic_network_2.create_variables()
153
    self._target_critic_network_2 = (
154
        common.maybe_copy_target_network_with_checks(self._critic_network_2,
155
                                                     target_critic_network_2,
156
                                                     'TargetCriticNetwork2'))
157

158
    if actor_network:
159
      actor_network.create_variables()
160
    self._actor_network = actor_network
161

162
    policy = actor_policy_ctor(
163
        time_step_spec=time_step_spec,
164
        action_spec=action_spec,
165
        actor_network=self._actor_network,
166
        training=False)
167

168
    self._train_policy = actor_policy_ctor(
169
        time_step_spec=time_step_spec,
170
        action_spec=action_spec,
171
        actor_network=self._actor_network,
172
        training=True)
173

174
    self._target_update_tau = target_update_tau
175
    self._target_update_period = target_update_period
176
    self._actor_optimizer = actor_optimizer
177
    self._critic_optimizer = critic_optimizer
178
    self._actor_loss_weight = actor_loss_weight
179
    self._critic_loss_weight = critic_loss_weight
180
    self._td_errors_loss_fn = td_errors_loss_fn
181
    self._gamma = gamma
182
    self._gradient_clipping = gradient_clipping
183
    self._debug_summaries = debug_summaries
184
    self._summarize_grads_and_vars = summarize_grads_and_vars
185
    self._update_target = self._get_target_updater(
186
        tau=self._target_update_tau, period=self._target_update_period)
187

188
    train_sequence_length = 2 if not critic_network.state_spec else None
189

190
    super(CLearningAgent, self).__init__(
191
        time_step_spec,
192
        action_spec,
193
        policy=policy,
194
        collect_policy=policy,
195
        train_sequence_length=train_sequence_length,
196
        debug_summaries=debug_summaries,
197
        summarize_grads_and_vars=summarize_grads_and_vars,
198
        train_step_counter=train_step_counter)
199

200
  def _check_action_spec(self, action_spec):
201
    flat_action_spec = tf.nest.flatten(action_spec)
202
    for spec in flat_action_spec:
203
      if spec.dtype.is_integer:
204
        raise NotImplementedError(
205
            'CLearningAgent does not currently support discrete actions. '
206
            'Action spec: {}'.format(action_spec))
207

208
  def _initialize(self):
209
    """Returns an op to initialize the agent.
210

211
    Copies weights from the Q networks to the target Q network.
212
    """
213
    common.soft_variables_update(
214
        self._critic_network_1.variables,
215
        self._target_critic_network_1.variables,
216
        tau=1.0)
217
    common.soft_variables_update(
218
        self._critic_network_2.variables,
219
        self._target_critic_network_2.variables,
220
        tau=1.0)
221

222
  def _train(self, experience, weights):
223
    """Returns a train op to update the agent's networks.
224

225
    This method trains with the provided batched experience.
226

227
    Args:
228
      experience: A time-stacked trajectory object.
229
      weights: Optional scalar or elementwise (per-batch-entry) importance
230
        weights.
231

232
    Returns:
233
      A train_op.
234

235
    Raises:
236
      ValueError: If optimizers are None and no default value was provided to
237
        the constructor.
238
    """
239
    squeeze_time_dim = not self._critic_network_1.state_spec
240
    time_steps, policy_steps, next_time_steps = (
241
        trajectory.experience_to_transitions(experience, squeeze_time_dim))
242
    actions = policy_steps.action
243

244
    trainable_critic_variables = list(object_identity.ObjectIdentitySet(
245
        self._critic_network_1.trainable_variables +
246
        self._critic_network_2.trainable_variables))
247

248
    with tf.GradientTape(watch_accessed_variables=False) as tape:
249
      assert trainable_critic_variables, ('No trainable critic variables to '
250
                                          'optimize.')
251
      tape.watch(trainable_critic_variables)
252
      critic_loss = self._critic_loss_weight*self.critic_loss(
253
          time_steps,
254
          actions,
255
          next_time_steps,
256
          td_errors_loss_fn=self._td_errors_loss_fn,
257
          gamma=self._gamma,
258
          weights=weights,
259
          training=True)
260

261
    tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
262
    critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
263
    self._apply_gradients(critic_grads, trainable_critic_variables,
264
                          self._critic_optimizer)
265

266
    trainable_actor_variables = self._actor_network.trainable_variables
267
    with tf.GradientTape(watch_accessed_variables=False) as tape:
268
      assert trainable_actor_variables, ('No trainable actor variables to '
269
                                         'optimize.')
270
      tape.watch(trainable_actor_variables)
271
      actor_loss = self._actor_loss_weight*self.actor_loss(
272
          time_steps, actions, weights=weights)
273
    tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
274
    actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
275
    self._apply_gradients(actor_grads, trainable_actor_variables,
276
                          self._actor_optimizer)
277

278
    with tf.name_scope('Losses'):
279
      tf.compat.v2.summary.scalar(
280
          name='critic_loss', data=critic_loss, step=self.train_step_counter)
281
      tf.compat.v2.summary.scalar(
282
          name='actor_loss', data=actor_loss, step=self.train_step_counter)
283

284
    self.train_step_counter.assign_add(1)
285
    self._update_target()
286

287
    total_loss = critic_loss + actor_loss
288

289
    extra = CLearningLossInfo(
290
        critic_loss=critic_loss, actor_loss=actor_loss)
291

292
    return tf_agent.LossInfo(loss=total_loss, extra=extra)
293

294
  def _apply_gradients(self, gradients, variables, optimizer):
295
    # list(...) is required for Python3.
296
    grads_and_vars = list(zip(gradients, variables))
297
    if self._gradient_clipping is not None:
298
      grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars,
299
                                                       self._gradient_clipping)
300

301
    if self._summarize_grads_and_vars:
302
      eager_utils.add_variables_summaries(grads_and_vars,
303
                                          self.train_step_counter)
304
      eager_utils.add_gradients_summaries(grads_and_vars,
305
                                          self.train_step_counter)
306

307
    optimizer.apply_gradients(grads_and_vars)
308

309
  def _get_target_updater(self, tau=1.0, period=1):
310
    """Performs a soft update of the target network parameters.
311

312
    For each weight w_s in the original network, and its corresponding
313
    weight w_t in the target network, a soft update is:
314
    w_t = (1- tau) x w_t + tau x ws
315

316
    Args:
317
      tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
318
      period: Step interval at which the target network is updated.
319

320
    Returns:
321
      A callable that performs a soft update of the target network parameters.
322
    """
323
    with tf.name_scope('update_target'):
324

325
      def update():
326
        """Update target network."""
327
        critic_update_1 = common.soft_variables_update(
328
            self._critic_network_1.variables,
329
            self._target_critic_network_1.variables,
330
            tau,
331
            tau_non_trainable=1.0)
332

333
        critic_2_update_vars = common.deduped_network_variables(
334
            self._critic_network_2, self._critic_network_1)
335

336
        target_critic_2_update_vars = common.deduped_network_variables(
337
            self._target_critic_network_2, self._target_critic_network_1)
338

339
        critic_update_2 = common.soft_variables_update(
340
            critic_2_update_vars,
341
            target_critic_2_update_vars,
342
            tau,
343
            tau_non_trainable=1.0)
344

345
        return tf.group(critic_update_1, critic_update_2)
346

347
      return common.Periodically(update, period, 'update_targets')
348

349
  def _actions_and_log_probs(self, time_steps):
350
    """Get actions and corresponding log probabilities from policy."""
351
    # Get raw action distribution from policy, and initialize bijectors list.
352
    batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
353
    policy_state = self._train_policy.get_initial_state(batch_size)
354
    action_distribution = self._train_policy.distribution(
355
        time_steps, policy_state=policy_state).action
356

357
    # Sample actions and log_pis from transformed distribution.
358
    actions = tf.nest.map_structure(lambda d: d.sample(), action_distribution)
359
    log_pi = common.log_probability(action_distribution, actions,
360
                                    self.action_spec)
361

362
    return actions, log_pi
363

364
  @gin.configurable
365
  def critic_loss(self,
366
                  time_steps,
367
                  actions,
368
                  next_time_steps,
369
                  td_errors_loss_fn,
370
                  gamma = 1.0,
371
                  weights = None,
372
                  training = False,
373
                  w_clipping = 20.0,
374
                  self_normalized = False,
375
                  lambda_fix = False,
376
                  ):
377
    """Computes the critic loss for C-learning training.
378

379
    Args:
380
      time_steps: A batch of timesteps.
381
      actions: A batch of actions.
382
      next_time_steps: A batch of next timesteps.
383
      td_errors_loss_fn: A function(td_targets, predictions) to compute
384
        elementwise (per-batch-entry) loss.
385
      gamma: Discount for future rewards.
386
      weights: Optional scalar or elementwise (per-batch-entry) importance
387
        weights.
388
      training: Whether this loss is being used for training.
389
      w_clipping: Maximum value used for clipping the weights. Use -1 to do no
390
        clipping; use None to use the recommended value of 1 / (1 - gamma).
391
      self_normalized: Whether to normalize the weights to the average is 1.
392
        Empirically this usually hurts performance.
393
      lambda_fix: Whether to include the adjustment when using future positives.
394
        Empirically this has little effect.
395

396
    Returns:
397
      critic_loss: A scalar critic loss.
398
    """
399
    del weights
400
    if w_clipping is None:
401
      w_clipping = 1 / (1 - gamma)
402
    rfp = gin.query_parameter('goal_fn.relabel_future_prob')
403
    rnp = gin.query_parameter('goal_fn.relabel_next_prob')
404
    assert rfp + rnp == 0.5
405
    with tf.name_scope('critic_loss'):
406
      nest_utils.assert_same_structure(actions, self.action_spec)
407
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
408
      nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)
409

410
      next_actions, _ = self._actions_and_log_probs(next_time_steps)
411
      target_input = (next_time_steps.observation, next_actions)
412
      target_q_values1, unused_network_state1 = self._target_critic_network_1(
413
          target_input, next_time_steps.step_type, training=False)
414
      target_q_values2, unused_network_state2 = self._target_critic_network_2(
415
          target_input, next_time_steps.step_type, training=False)
416
      target_q_values = tf.minimum(target_q_values1, target_q_values2)
417

418
      w = tf.stop_gradient(target_q_values / (1 - target_q_values))
419
      if w_clipping >= 0:
420
        w = tf.clip_by_value(w, 0, w_clipping)
421
      tf.debugging.assert_all_finite(w, 'Not all elements of w are finite')
422
      if self_normalized:
423
        w = w / tf.reduce_mean(w)
424

425
      batch_size = nest_utils.get_outer_shape(time_steps,
426
                                              self._time_step_spec)[0]
427
      half_batch = batch_size // 2
428
      float_batch_size = tf.cast(batch_size, float)
429
      num_next = tf.cast(tf.round(float_batch_size * rnp), tf.int32)
430
      num_future = tf.cast(tf.round(float_batch_size * rfp), tf.int32)
431
      if lambda_fix:
432
        lambda_coef = 2 * rnp
433
        weights = tf.concat([tf.fill((num_next,), (1 - gamma)),
434
                             tf.fill((num_future,), 1.0),
435
                             (1 + lambda_coef * gamma * w)[half_batch:]],
436
                            axis=0)
437
      else:
438
        weights = tf.concat([tf.fill((num_next,), (1 - gamma)),
439
                             tf.fill((num_future,), 1.0),
440
                             (1 + gamma * w)[half_batch:]],
441
                            axis=0)
442

443
      # Note that we assume that episodes never terminate. If they do, then
444
      # we need to include next_time_steps.discount in the (negative) TD target.
445
      # We exclude the termination here so that we can use termination to
446
      # indicate task success during evaluation. In the evaluation setting,
447
      # task success depends on the task, but we don't want the termination
448
      # here to depend on the task. Hence, we ignored it.
449
      if lambda_fix:
450
        lambda_coef = 2 * rnp
451
        y = lambda_coef * gamma * w / (1 + lambda_coef * gamma * w)
452
      else:
453
        y = gamma * w / (1 + gamma * w)
454
      td_targets = tf.stop_gradient(next_time_steps.reward +
455
                                    (1 - next_time_steps.reward) * y)
456
      if rfp > 0:
457
        td_targets = tf.concat([tf.ones(half_batch),
458
                                td_targets[half_batch:]], axis=0)
459

460
      observation = time_steps.observation
461
      pred_input = (observation, actions)
462
      pred_td_targets1, _ = self._critic_network_1(
463
          pred_input, time_steps.step_type, training=training)
464
      pred_td_targets2, _ = self._critic_network_2(
465
          pred_input, time_steps.step_type, training=training)
466

467
      critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
468
      critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
469
      critic_loss = critic_loss1 + critic_loss2
470

471
      if critic_loss.shape.rank > 1:
472
        # Sum over the time dimension.
473
        critic_loss = tf.reduce_sum(
474
            critic_loss, axis=range(1, critic_loss.shape.rank))
475

476
      agg_loss = common.aggregate_losses(
477
          per_example_loss=critic_loss,
478
          sample_weight=weights,
479
          regularization_loss=(self._critic_network_1.losses +
480
                               self._critic_network_2.losses))
481
      critic_loss = agg_loss.total_loss
482
      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
483
                                        pred_td_targets2, weights)
484

485
      return critic_loss
486

487
  @gin.configurable
488
  def actor_loss(self,
489
                 time_steps,
490
                 actions,
491
                 weights = None,
492
                 ce_loss = False):
493
    """Computes the actor_loss for C-learning training.
494

495
    Args:
496
      time_steps: A batch of timesteps.
497
      actions: A batch of actions.
498
      weights: Optional scalar or elementwise (per-batch-entry) importance
499
        weights.
500
      ce_loss: (bool) Whether to update the actor using the cross entropy loss,
501
        which corresponds to using the log C-value. The default actor loss
502
        differs by not including the log. Empirically we observed no difference.
503

504
    Returns:
505
      actor_loss: A scalar actor loss.
506
    """
507
    with tf.name_scope('actor_loss'):
508
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
509

510
      sampled_actions, log_pi = self._actions_and_log_probs(time_steps)
511
      target_input = (time_steps.observation, sampled_actions)
512
      target_q_values1, _ = self._critic_network_1(
513
          target_input, time_steps.step_type, training=False)
514
      target_q_values2, _ = self._critic_network_2(
515
          target_input, time_steps.step_type, training=False)
516
      target_q_values = tf.minimum(target_q_values1, target_q_values2)
517
      if ce_loss:
518
        actor_loss = tf.keras.losses.binary_crossentropy(
519
            tf.ones_like(target_q_values), target_q_values)
520
      else:
521
        actor_loss = -1.0 * target_q_values
522

523
      if actor_loss.shape.rank > 1:
524
        # Sum over the time dimension.
525
        actor_loss = tf.reduce_sum(
526
            actor_loss, axis=range(1, actor_loss.shape.rank))
527
      reg_loss = self._actor_network.losses if self._actor_network else None
528
      agg_loss = common.aggregate_losses(
529
          per_example_loss=actor_loss,
530
          sample_weight=weights,
531
          regularization_loss=reg_loss)
532
      actor_loss = agg_loss.total_loss
533
      self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
534
                                       target_q_values, time_steps)
535

536
      return actor_loss
537

538
  def _critic_loss_debug_summaries(self, td_targets, pred_td_targets1,
539
                                   pred_td_targets2, weights):
540
    if self._debug_summaries:
541
      td_errors1 = td_targets - pred_td_targets1
542
      td_errors2 = td_targets - pred_td_targets2
543
      td_errors = tf.concat([td_errors1, td_errors2], axis=0)
544
      common.generate_tensor_summaries('td_errors', td_errors,
545
                                       self.train_step_counter)
546
      common.generate_tensor_summaries('td_targets', td_targets,
547
                                       self.train_step_counter)
548
      common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1,
549
                                       self.train_step_counter)
550
      common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2,
551
                                       self.train_step_counter)
552
      common.generate_tensor_summaries('weights', weights,
553
                                       self.train_step_counter)
554

555
  def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi,
556
                                  target_q_values, time_steps):
557
    if self._debug_summaries:
558
      common.generate_tensor_summaries('actor_loss', actor_loss,
559
                                       self.train_step_counter)
560
      common.generate_tensor_summaries('actions', actions,
561
                                       self.train_step_counter)
562
      common.generate_tensor_summaries('log_pi', log_pi,
563
                                       self.train_step_counter)
564
      tf.compat.v2.summary.scalar(
565
          name='entropy_avg',
566
          data=-tf.reduce_mean(input_tensor=log_pi),
567
          step=self.train_step_counter)
568
      common.generate_tensor_summaries('target_q_values', target_q_values,
569
                                       self.train_step_counter)
570
      batch_size = nest_utils.get_outer_shape(time_steps,
571
                                              self._time_step_spec)[0]
572
      policy_state = self._train_policy.get_initial_state(batch_size)
573
      action_distribution = self._train_policy.distribution(
574
          time_steps, policy_state).action
575
      if isinstance(action_distribution, tfp.distributions.Normal):
576
        common.generate_tensor_summaries('act_mean', action_distribution.loc,
577
                                         self.train_step_counter)
578
        common.generate_tensor_summaries('act_stddev',
579
                                         action_distribution.scale,
580
                                         self.train_step_counter)
581
      elif isinstance(action_distribution, tfp.distributions.Categorical):
582
        common.generate_tensor_summaries('act_mode', action_distribution.mode(),
583
                                         self.train_step_counter)
584
      common.generate_tensor_summaries('entropy_action',
585
                                       action_distribution.entropy(),
586
                                       self.train_step_counter)
587

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.