google-research

Форк
0
428 строк · 16.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contrastive RL learner implementation."""
17
import time
18
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Callable
19

20
import acme
21
from acme import types
22
from acme.jax import networks as networks_lib
23
from acme.jax import utils
24
from acme.utils import counting
25
from acme.utils import loggers
26
from contrastive import config as contrastive_config
27
from contrastive import networks as contrastive_networks
28
import jax
29
import jax.numpy as jnp
30
import optax
31
import reverb
32

33

34
class TrainingState(NamedTuple):
35
  """Contains training state for the learner."""
36
  policy_optimizer_state: optax.OptState
37
  q_optimizer_state: optax.OptState
38
  policy_params: networks_lib.Params
39
  q_params: networks_lib.Params
40
  target_q_params: networks_lib.Params
41
  key: networks_lib.PRNGKey
42
  alpha_optimizer_state: Optional[optax.OptState] = None
43
  alpha_params: Optional[networks_lib.Params] = None
44

45

46
class ContrastiveLearner(acme.Learner):
47
  """Contrastive RL learner."""
48

49
  _state: TrainingState
50

51
  def __init__(
52
      self,
53
      networks,
54
      rng,
55
      policy_optimizer,
56
      q_optimizer,
57
      iterator,
58
      counter,
59
      logger,
60
      obs_to_goal,
61
      config):
62
    """Initialize the Contrastive RL learner.
63

64
    Args:
65
      networks: Contrastive RL networks.
66
      rng: a key for random number generation.
67
      policy_optimizer: the policy optimizer.
68
      q_optimizer: the Q-function optimizer.
69
      iterator: an iterator over training data.
70
      counter: counter object used to keep track of steps.
71
      logger: logger object to be used by learner.
72
      obs_to_goal: a function for extracting the goal coordinates.
73
      config: the experiment config file.
74
    """
75
    if config.add_mc_to_td:
76
      assert config.use_td
77
    adaptive_entropy_coefficient = config.entropy_coefficient is None
78
    self._num_sgd_steps_per_step = config.num_sgd_steps_per_step
79
    self._obs_dim = config.obs_dim
80
    self._use_td = config.use_td
81
    if adaptive_entropy_coefficient:
82
      # alpha is the temperature parameter that determines the relative
83
      # importance of the entropy term versus the reward.
84
      log_alpha = jnp.asarray(0., dtype=jnp.float32)
85
      alpha_optimizer = optax.adam(learning_rate=3e-4)
86
      alpha_optimizer_state = alpha_optimizer.init(log_alpha)
87
    else:
88
      if config.target_entropy:
89
        raise ValueError('target_entropy should not be set when '
90
                         'entropy_coefficient is provided')
91

92
    def alpha_loss(log_alpha,
93
                   policy_params,
94
                   transitions,
95
                   key):
96
      """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""
97
      dist_params = networks.policy_network.apply(
98
          policy_params, transitions.observation)
99
      action = networks.sample(dist_params, key)
100
      log_prob = networks.log_prob(dist_params, action)
101
      alpha = jnp.exp(log_alpha)
102
      alpha_loss = alpha * jax.lax.stop_gradient(
103
          -log_prob - config.target_entropy)
104
      return jnp.mean(alpha_loss)
105

106
    def critic_loss(q_params,
107
                    policy_params,
108
                    target_q_params,
109
                    transitions,
110
                    key):
111
      batch_size = transitions.observation.shape[0]
112
      # Note: We might be able to speed up the computation for some of the
113
      # baselines to making a single network that returns all the values. This
114
      # avoids computing some of the underlying representations multiple times.
115
      if config.use_td:
116
        # For TD learning, the diagonal elements are the immediate next state.
117
        s, g = jnp.split(transitions.observation, [config.obs_dim], axis=1)
118
        next_s, _ = jnp.split(transitions.next_observation, [config.obs_dim],
119
                              axis=1)
120
        if config.add_mc_to_td:
121
          next_fraction = (1 - config.discount) / ((1 - config.discount) + 1)
122
          num_next = int(batch_size * next_fraction)
123
          new_g = jnp.concatenate([
124
              obs_to_goal(next_s[:num_next]),
125
              g[num_next:],
126
          ], axis=0)
127
        else:
128
          new_g = obs_to_goal(next_s)
129
        obs = jnp.concatenate([s, new_g], axis=1)
130
        transitions = transitions._replace(observation=obs)
131
      I = jnp.eye(batch_size)  # pylint: disable=invalid-name
132
      logits = networks.q_network.apply(
133
          q_params, transitions.observation, transitions.action)
134

135
      if config.use_td:
136
        # Make sure to use the twin Q trick.
137
        assert len(logits.shape) == 3
138

139
        # We evaluate the next-state Q function using random goals
140
        s, g = jnp.split(transitions.observation, [config.obs_dim], axis=1)
141
        del s
142
        next_s = transitions.next_observation[:, :config.obs_dim]
143
        goal_indices = jnp.roll(jnp.arange(batch_size, dtype=jnp.int32), -1)
144
        g = g[goal_indices]
145
        transitions = transitions._replace(
146
            next_observation=jnp.concatenate([next_s, g], axis=1))
147
        next_dist_params = networks.policy_network.apply(
148
            policy_params, transitions.next_observation)
149
        next_action = networks.sample(next_dist_params, key)
150
        next_q = networks.q_network.apply(target_q_params,
151
                                          transitions.next_observation,
152
                                          next_action)  # This outputs logits.
153
        next_q = jax.nn.sigmoid(next_q)
154
        next_v = jnp.min(next_q, axis=-1)
155
        next_v = jax.lax.stop_gradient(next_v)
156
        next_v = jnp.diag(next_v)
157
        # diag(logits) are predictions for future states.
158
        # diag(next_q) are predictions for random states, which correspond to
159
        # the predictions logits[range(B), goal_indices].
160
        # So, the only thing that's meaningful for next_q is the diagonal. Off
161
        # diagonal entries are meaningless and shouldn't be used.
162
        w = next_v / (1 - next_v)
163
        w_clipping = 20.0
164
        w = jnp.clip(w, 0, w_clipping)
165
        # (B, B, 2) --> (B, 2), computes diagonal of each twin Q.
166
        pos_logits = jax.vmap(jnp.diag, -1, -1)(logits)
167
        loss_pos = optax.sigmoid_binary_cross_entropy(
168
            logits=pos_logits, labels=1)  # [B, 2]
169

170
        neg_logits = logits[jnp.arange(batch_size), goal_indices]
171
        loss_neg1 = w[:, None] * optax.sigmoid_binary_cross_entropy(
172
            logits=neg_logits, labels=1)  # [B, 2]
173
        loss_neg2 = optax.sigmoid_binary_cross_entropy(
174
            logits=neg_logits, labels=0)  # [B, 2]
175

176
        if config.add_mc_to_td:
177
          loss = ((1 + (1 - config.discount)) * loss_pos
178
                  + config.discount * loss_neg1 + 2 * loss_neg2)
179
        else:
180
          loss = ((1 - config.discount) * loss_pos
181
                  + config.discount * loss_neg1 + loss_neg2)
182
        # Take the mean here so that we can compute the accuracy.
183
        logits = jnp.mean(logits, axis=-1)
184

185
      else:  # For the MC losses.
186
        def loss_fn(_logits):  # pylint: disable=invalid-name
187
          if config.use_cpc:
188
            return (optax.softmax_cross_entropy(logits=_logits, labels=I)
189
                    + 0.01 * jax.nn.logsumexp(_logits, axis=1)**2)
190
          else:
191
            return optax.sigmoid_binary_cross_entropy(logits=_logits, labels=I)
192
        if len(logits.shape) == 3:  # twin q
193
          # loss.shape = [.., num_q]
194
          loss = jax.vmap(loss_fn, in_axes=2, out_axes=-1)(logits)
195
          loss = jnp.mean(loss, axis=-1)
196
          # Take the mean here so that we can compute the accuracy.
197
          logits = jnp.mean(logits, axis=-1)
198
        else:
199
          loss = loss_fn(logits)
200

201
      loss = jnp.mean(loss)
202
      correct = (jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1))
203
      logits_pos = jnp.sum(logits * I) / jnp.sum(I)
204
      logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I)
205
      if len(logits.shape) == 3:
206
        logsumexp = jax.nn.logsumexp(logits[:, :, 0], axis=1)**2
207
      else:
208
        logsumexp = jax.nn.logsumexp(logits, axis=1)**2
209
      metrics = {
210
          'binary_accuracy': jnp.mean((logits > 0) == I),
211
          'categorical_accuracy': jnp.mean(correct),
212
          'logits_pos': logits_pos,
213
          'logits_neg': logits_neg,
214
          'logsumexp': logsumexp.mean(),
215
      }
216

217
      return loss, metrics
218

219
    def actor_loss(policy_params,
220
                   q_params,
221
                   alpha,
222
                   transitions,
223
                   key,
224
                   ):
225
      obs = transitions.observation
226
      if config.use_gcbc:
227
        dist_params = networks.policy_network.apply(
228
            policy_params, obs)
229
        log_prob = networks.log_prob(dist_params, transitions.action)
230
        actor_loss = -1.0 * jnp.mean(log_prob)
231
      else:
232
        state = obs[:, :config.obs_dim]
233
        goal = obs[:, config.obs_dim:]
234

235
        if config.random_goals == 0.0:
236
          new_state = state
237
          new_goal = goal
238
        elif config.random_goals == 0.5:
239
          new_state = jnp.concatenate([state, state], axis=0)
240
          new_goal = jnp.concatenate([goal, jnp.roll(goal, 1, axis=0)], axis=0)
241
        else:
242
          assert config.random_goals == 1.0
243
          new_state = state
244
          new_goal = jnp.roll(goal, 1, axis=0)
245

246
        new_obs = jnp.concatenate([new_state, new_goal], axis=1)
247
        dist_params = networks.policy_network.apply(
248
            policy_params, new_obs)
249
        action = networks.sample(dist_params, key)
250
        log_prob = networks.log_prob(dist_params, action)
251
        q_action = networks.q_network.apply(
252
            q_params, new_obs, action)
253
        if len(q_action.shape) == 3:  # twin q trick
254
          assert q_action.shape[2] == 2
255
          q_action = jnp.min(q_action, axis=-1)
256
        actor_loss = alpha * log_prob - jnp.diag(q_action)
257

258
        assert 0.0 <= config.bc_coef <= 1.0
259
        if config.bc_coef > 0:
260
          orig_action = transitions.action
261
          if config.random_goals == 0.5:
262
            orig_action = jnp.concatenate([orig_action, orig_action], axis=0)
263

264
          bc_loss = -1.0 * networks.log_prob(dist_params, orig_action)
265
          actor_loss = (config.bc_coef * bc_loss
266
                        + (1 - config.bc_coef) * actor_loss)
267

268
      return jnp.mean(actor_loss)
269

270
    alpha_grad = jax.value_and_grad(alpha_loss)
271
    critic_grad = jax.value_and_grad(critic_loss, has_aux=True)
272
    actor_grad = jax.value_and_grad(actor_loss)
273

274
    def update_step(
275
        state,
276
        transitions,
277
    ):
278

279
      key, key_alpha, key_critic, key_actor = jax.random.split(state.key, 4)
280
      if adaptive_entropy_coefficient:
281
        alpha_loss, alpha_grads = alpha_grad(state.alpha_params,
282
                                             state.policy_params, transitions,
283
                                             key_alpha)
284
        alpha = jnp.exp(state.alpha_params)
285
      else:
286
        alpha = config.entropy_coefficient
287

288
      if not config.use_gcbc:
289
        (critic_loss, critic_metrics), critic_grads = critic_grad(
290
            state.q_params, state.policy_params, state.target_q_params,
291
            transitions, key_critic)
292

293
      actor_loss, actor_grads = actor_grad(state.policy_params, state.q_params,
294
                                           alpha, transitions, key_actor)
295

296
      # Apply policy gradients
297
      actor_update, policy_optimizer_state = policy_optimizer.update(
298
          actor_grads, state.policy_optimizer_state)
299
      policy_params = optax.apply_updates(state.policy_params, actor_update)
300

301
      # Apply critic gradients
302
      if config.use_gcbc:
303
        metrics = {}
304
        critic_loss = 0.0
305
        q_params = state.q_params
306
        q_optimizer_state = state.q_optimizer_state
307
        new_target_q_params = state.target_q_params
308
      else:
309
        critic_update, q_optimizer_state = q_optimizer.update(
310
            critic_grads, state.q_optimizer_state)
311

312
        q_params = optax.apply_updates(state.q_params, critic_update)
313

314
        new_target_q_params = jax.tree_map(
315
            lambda x, y: x * (1 - config.tau) + y * config.tau,
316
            state.target_q_params, q_params)
317
        metrics = critic_metrics
318

319
      metrics.update({
320
          'critic_loss': critic_loss,
321
          'actor_loss': actor_loss,
322
      })
323

324
      new_state = TrainingState(
325
          policy_optimizer_state=policy_optimizer_state,
326
          q_optimizer_state=q_optimizer_state,
327
          policy_params=policy_params,
328
          q_params=q_params,
329
          target_q_params=new_target_q_params,
330
          key=key,
331
      )
332
      if adaptive_entropy_coefficient:
333
        # Apply alpha gradients
334
        alpha_update, alpha_optimizer_state = alpha_optimizer.update(
335
            alpha_grads, state.alpha_optimizer_state)
336
        alpha_params = optax.apply_updates(state.alpha_params, alpha_update)
337
        metrics.update({
338
            'alpha_loss': alpha_loss,
339
            'alpha': jnp.exp(alpha_params),
340
        })
341
        new_state = new_state._replace(
342
            alpha_optimizer_state=alpha_optimizer_state,
343
            alpha_params=alpha_params)
344

345
      return new_state, metrics
346

347
    # General learner book-keeping and loggers.
348
    self._counter = counter or counting.Counter()
349
    self._logger = logger or loggers.make_default_logger(
350
        'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray,
351
        time_delta=10.0)
352

353
    # Iterator on demonstration transitions.
354
    self._iterator = iterator
355

356
    update_step = utils.process_multiple_batches(update_step,
357
                                                 config.num_sgd_steps_per_step)
358
    # Use the JIT compiler.
359
    if config.jit:
360
      self._update_step = jax.jit(update_step)
361
    else:
362
      self._update_step = update_step
363

364
    def make_initial_state(key):
365
      """Initialises the training state (parameters and optimiser state)."""
366
      key_policy, key_q, key = jax.random.split(key, 3)
367

368
      policy_params = networks.policy_network.init(key_policy)
369
      policy_optimizer_state = policy_optimizer.init(policy_params)
370

371
      q_params = networks.q_network.init(key_q)
372
      q_optimizer_state = q_optimizer.init(q_params)
373

374
      state = TrainingState(
375
          policy_optimizer_state=policy_optimizer_state,
376
          q_optimizer_state=q_optimizer_state,
377
          policy_params=policy_params,
378
          q_params=q_params,
379
          target_q_params=q_params,
380
          key=key)
381

382
      if adaptive_entropy_coefficient:
383
        state = state._replace(alpha_optimizer_state=alpha_optimizer_state,
384
                               alpha_params=log_alpha)
385
      return state
386

387
    # Create initial state.
388
    self._state = make_initial_state(rng)
389

390
    # Do not record timestamps until after the first learning step is done.
391
    # This is to avoid including the time it takes for actors to come online
392
    # and fill the replay buffer.
393
    self._timestamp = None
394

395
  def step(self):
396
    with jax.profiler.StepTraceAnnotation('step', step_num=self._counter):
397
      sample = next(self._iterator)
398
      transitions = types.Transition(*sample.data)
399
      self._state, metrics = self._update_step(self._state, transitions)
400

401
    # Compute elapsed time.
402
    timestamp = time.time()
403
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
404
    self._timestamp = timestamp
405

406
    # Increment counts and record the current time
407
    counts = self._counter.increment(steps=1, walltime=elapsed_time)
408
    if elapsed_time > 0:
409
      metrics['steps_per_second'] = (
410
          self._num_sgd_steps_per_step / elapsed_time)
411
    else:
412
      metrics['steps_per_second'] = 0.
413

414
    # Attempts to write the logs.
415
    self._logger.write({**metrics, **counts})
416

417
  def get_variables(self, names):
418
    variables = {
419
        'policy': self._state.policy_params,
420
        'critic': self._state.q_params,
421
    }
422
    return [variables[name] for name in names]
423

424
  def save(self):
425
    return self._state
426

427
  def restore(self, state):
428
    self._state = state
429

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.