google-research

Форк
0
221 строка · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contrastive Value Learning builder."""
17
from typing import Callable, Iterator, List, Optional
18

19
import acme
20
from acme import adders
21
from acme import core
22
from acme import specs
23
from acme import types
24
from acme.adders import reverb as adders_reverb
25
from acme.agents.jax import actor_core as actor_core_lib
26
from acme.agents.jax import actors
27
from acme.agents.jax import builders
28
from acme.jax import networks as networks_lib
29
from acme.jax import variable_utils
30
from acme.utils import counting
31
from acme.utils import loggers
32
import optax
33
import reverb
34
from reverb import rate_limiters
35
import tensorflow as tf
36
import tree
37

38
from cvl_public import config as contrastive_config
39
from cvl_public import learning
40
from cvl_public import networks as contrastive_networks
41
from cvl_public import utils as contrastive_utils
42

43

44
class ContrastiveBuilder(builders.ActorLearnerBuilder):
45
  """Contrastive RL builder."""
46

47
  def __init__(
48
      self,
49
      config,
50
      logger_fn = lambda: None,
51
  ):
52
    """Creates a contrastive RL learner, a behavior policy and an eval actor.
53

54
    Args:
55
      config: a config with contrastive RL hyperparameters
56
      logger_fn: a logger factory for the learner
57
    """
58
    self._config = config
59
    self._logger_fn = logger_fn
60

61
  def make_learner(
62
      self,
63
      random_key,
64
      networks,
65
      dataset,
66
      replay_client = None,
67
      counter = None,
68
  ):
69
    # Create optimizers
70
    policy_optimizer = optax.adam(
71
        learning_rate=self._config.actor_learning_rate, eps=1e-7)
72
    q_optimizer = optax.adam(learning_rate=self._config.learning_rate, eps=1e-7)
73
    return learning.ContrastiveLearner(
74
        networks=networks,
75
        rng=random_key,
76
        policy_optimizer=policy_optimizer,
77
        q_optimizer=q_optimizer,
78
        iterator=dataset,
79
        counter=counter,
80
        logger=self._logger_fn(),
81
        config=self._config)
82

83
  def make_actor(
84
      self,
85
      random_key,
86
      policy_network,
87
      adder = None,
88
      variable_source = None):
89
    assert variable_source is not None
90
    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
91
        policy_network)
92
    variable_client = variable_utils.VariableClient(variable_source, 'policy',
93
                                                    device='cpu')
94
    if self._config.use_random_actor:
95
      ACTOR = contrastive_utils.InitiallyRandomActor  # pylint: disable=invalid-name
96
    else:
97
      ACTOR = actors.GenericActor  # pylint: disable=invalid-name
98
    return ACTOR(
99
        actor_core, random_key, variable_client, adder, backend='cpu')
100

101
  def make_replay_tables(
102
      self,
103
      environment_spec,
104
  ):
105
    """Create tables to insert data into."""
106
    samples_per_insert_tolerance = (
107
        self._config.samples_per_insert_tolerance_rate
108
        * self._config.samples_per_insert)
109
    min_replay_traj = self._config.min_replay_size  // self._config.max_episode_steps  # pylint: disable=line-too-long
110
    max_replay_traj = self._config.max_replay_size  // self._config.max_episode_steps  # pylint: disable=line-too-long
111
    error_buffer = min_replay_traj * samples_per_insert_tolerance
112
    limiter = rate_limiters.SampleToInsertRatio(
113
        min_size_to_sample=min_replay_traj,
114
        samples_per_insert=self._config.samples_per_insert,
115
        error_buffer=error_buffer)
116
    return [
117
        reverb.Table(
118
            name=self._config.replay_table_name,
119
            sampler=reverb.selectors.Uniform(),
120
            remover=reverb.selectors.Fifo(),
121
            max_size=max_replay_traj,
122
            rate_limiter=limiter,
123
            signature=adders_reverb.EpisodeAdder.signature(environment_spec, {}))  # pylint: disable=line-too-long
124
    ]
125

126
  def make_dataset_iterator(
127
      self, replay_client):
128
    """Create a dataset iterator to use for learning/updating the agent."""
129
    @tf.function
130
    def flatten_fn(sample):
131
      seq_len = tf.shape(sample.data.observation)[0]
132
      arange = tf.range(seq_len)
133
      is_future_mask = tf.cast(arange[:, None] < arange[None], tf.float32)
134
      discount = self._config.discount ** tf.cast(arange[None] - arange[:, None], tf.float32)  # pylint: disable=line-too-long
135
      probs = is_future_mask * discount
136
      # The indexing changes the shape from [seq_len, 1] to [seq_len]
137
      goal_index = tf.random.categorical(logits=tf.math.log(probs),
138
                                         num_samples=1)[:, 0]
139
      state = sample.data.observation[:-1]
140
      next_state = sample.data.observation[1:]
141

142
      # Create the goal observations in three steps.
143
      # 1. Take all future states (not future goals).
144
      # 2. Apply obs_to_goal.
145
      # 3. Sample one of the future states. Note that we don't look for a goal
146
      # for the final state, because there are no future states.
147
      goal = tf.gather(sample.data.observation, goal_index[:-1])
148
      goal_reward = tf.gather(sample.data.reward, goal_index[:-1])
149
      # new_obs = tf.concat([state, goal], axis=1)
150
      # new_next_obs = tf.concat([next_state, goal], axis=1)
151
      transition = types.Transition(
152
          observation=state,
153
          action=sample.data.action[:-1],
154
          reward=sample.data.reward[:-1],
155
          discount=sample.data.discount[:-1],
156
          next_observation=next_state,
157
          extras={
158
              'next_action': sample.data.action[1:],
159
              'goal': goal,
160
              'goal_reward': goal_reward
161
          })
162
      # Shift for the transpose_shuffle.
163
      shift = tf.random.uniform((), 0, seq_len, tf.int32)
164
      transition = tree.map_structure(lambda t: tf.roll(t, shift, axis=0),
165
                                      transition)
166
      return transition
167

168
    if self._config.num_parallel_calls:
169
      num_parallel_calls = self._config.num_parallel_calls
170
    else:
171
      num_parallel_calls = tf.data.AUTOTUNE
172

173
    def _make_dataset(unused_idx):
174
      dataset = reverb.TrajectoryDataset.from_table_signature(
175
          server_address=replay_client.server_address,
176
          table=self._config.replay_table_name,
177
          max_in_flight_samples_per_worker=100)
178
      dataset = dataset.map(flatten_fn)
179
      # transpose_shuffle
180
      def _transpose_fn(t):
181
        dims = tf.range(tf.shape(tf.shape(t))[0])
182
        perm = tf.concat([[1, 0], dims[2:]], axis=0)
183
        return tf.transpose(t, perm)
184
      dataset = dataset.batch(self._config.batch_size, drop_remainder=True)
185
      dataset = dataset.map(
186
          lambda transition: tree.map_structure(_transpose_fn, transition))
187
      dataset = dataset.unbatch()
188
      # end transpose_shuffle
189

190
      dataset = dataset.unbatch()
191
      return dataset
192
    dataset = tf.data.Dataset.from_tensors(0).repeat()
193
    dataset = dataset.interleave(
194
        map_func=_make_dataset,
195
        cycle_length=num_parallel_calls,
196
        num_parallel_calls=num_parallel_calls,
197
        deterministic=False)
198

199
    dataset = dataset.batch(
200
        self._config.batch_size * self._config.num_sgd_steps_per_step,
201
        drop_remainder=True)
202
    @tf.function
203
    def add_info_fn(data):
204
      info = reverb.SampleInfo(key=0,
205
                               probability=0.0,
206
                               table_size=0,
207
                               priority=0.0,
208
                               times_sampled=0)
209
      return reverb.ReplaySample(info=info, data=data)
210
    dataset = dataset.map(add_info_fn, num_parallel_calls=tf.data.AUTOTUNE,
211
                          deterministic=False)
212
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
213
    return dataset.as_numpy_iterator()
214

215
  def make_adder(self,
216
                 replay_client):
217
    """Create an adder to record data generated by the actor/environment."""
218
    return adders_reverb.EpisodeAdder(
219
        client=replay_client,
220
        priority_fns={self._config.replay_table_name: None},
221
        max_sequence_length=self._config.max_episode_steps + 1)
222

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.