google-research

Форк
0
344 строки · 11.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Program definition for a distributed layout based on a builder."""
17

18
import dataclasses
19
import logging
20
from typing import Any, Callable, Optional, Sequence
21

22
from acme import core
23
from acme import environment_loop
24
from acme import specs
25
from acme.agents.jax import builders
26
from acme.jax import networks as networks_lib
27
from acme.jax import savers
28
from acme.jax import types
29
from acme.jax import utils
30
from acme.utils import counting
31
from acme.utils import loggers
32
from acme.utils import lp_utils
33
from acme.utils import observers as observers_lib
34
import dm_env
35
import jax
36
import launchpad as lp
37
import numpy as np
38
import reverb
39
import tqdm
40

41

42
ActorId = int
43
AgentNetwork = Any
44
PolicyNetwork = Any
45
NetworkFactory = Callable[[specs.EnvironmentSpec], AgentNetwork]
46
PolicyFactory = Callable[[AgentNetwork], PolicyNetwork]
47
Seed = int
48
EnvironmentFactory = Callable[[Seed], dm_env.Environment]
49
MakeActorFn = Callable[[types.PRNGKey, PolicyNetwork, core.VariableSource],
50
                       core.Actor]
51
LoggerLabel = str
52
LoggerStepsKey = str
53
LoggerFn = Callable[[LoggerLabel, LoggerStepsKey], loggers.Logger]
54
EvaluatorFactory = Callable[[
55
    types.PRNGKey,
56
    core.VariableSource,
57
    counting.Counter,
58
    MakeActorFn,
59
], core.Worker]
60

61

62
def get_default_logger_fn(
63
    log_to_bigtable = False,
64
    log_every = 10):
65
  """Creates an actor logger."""
66

67
  def create_logger(actor_id):
68
    return loggers.make_default_logger(
69
        'actor',
70
        save_data=(log_to_bigtable and actor_id == 0),
71
        time_delta=log_every,
72
        steps_key='actor_steps')
73
  return create_logger
74

75

76
def default_evaluator_factory(
77
    environment_factory,
78
    network_factory,
79
    policy_factory,
80
    observers = (),
81
    log_to_bigtable = False):
82
  """Returns a default evaluator process."""
83
  def evaluator(
84
      random_key,
85
      variable_source,
86
      counter,
87
      make_actor,
88
  ):
89
    """The evaluation process."""
90

91
    # Create environment and evaluator networks
92
    environment_key, actor_key = jax.random.split(random_key)
93
    # Environments normally require uint32 as a seed.
94
    environment = environment_factory(utils.sample_uint32(environment_key))
95
    networks = network_factory(specs.make_environment_spec(environment))
96

97
    actor = make_actor(actor_key, policy_factory(networks), variable_source)
98

99
    # Create logger and counter.
100
    counter = counting.Counter(counter, 'evaluator')
101
    logger = loggers.make_default_logger('evaluator', log_to_bigtable,
102
                                         steps_key='actor_steps')
103

104
    # Create the run loop and return it.
105
    return environment_loop.EnvironmentLoop(environment, actor, counter,
106
                                            logger, observers=observers)
107
  return evaluator
108

109

110
@dataclasses.dataclass
111
class CheckpointingConfig:
112
  """Configuration options for learner checkpointer."""
113
  # The maximum number of checkpoints to keep.
114
  max_to_keep: int = 1
115
  # Which directory to put the checkpoint in.
116
  directory: str = '~/acme'
117
  # If True adds a UID to the checkpoint path, see
118
  # `paths.get_unique_id()` for how this UID is generated.
119
  add_uid: bool = True
120

121

122
class DistributedLayout:
123
  """Program definition for a distributed agent based on a builder."""
124

125
  def __init__(
126
      self,
127
      seed,
128
      environment_factory,
129
      network_factory,
130
      builder,
131
      policy_network,
132
      num_actors,
133
      environment_spec = None,
134
      actor_logger_fn = None,
135
      evaluator_factories = (),
136
      device_prefetch = True,
137
      prefetch_size = 1,
138
      log_to_bigtable = False,
139
      max_number_of_steps = None,
140
      observers = (),
141
      multithreading_colocate_learner_and_reverb = False,
142
      checkpointing_config = None):
143

144
    if prefetch_size < 0:
145
      raise ValueError(f'Prefetch size={prefetch_size} should be non negative')
146

147
    actor_logger_fn = actor_logger_fn or get_default_logger_fn(log_to_bigtable)
148

149
    self._seed = seed
150
    self._builder = builder
151
    self._environment_factory = environment_factory
152
    self._network_factory = network_factory
153
    self._policy_network = policy_network
154
    self._environment_spec = environment_spec
155
    self._num_actors = num_actors
156
    self._device_prefetch = device_prefetch
157
    self._log_to_bigtable = log_to_bigtable
158
    self._prefetch_size = prefetch_size
159
    self._max_number_of_steps = max_number_of_steps
160
    self._actor_logger_fn = actor_logger_fn
161
    self._evaluator_factories = evaluator_factories
162
    self._observers = observers
163
    self._multithreading_colocate_learner_and_reverb = (
164
        multithreading_colocate_learner_and_reverb)
165
    self._checkpointing_config = checkpointing_config
166

167
  def replay(self):
168
    """The replay storage."""
169
    dummy_seed = 1
170
    environment_spec = (
171
        self._environment_spec or
172
        specs.make_environment_spec(self._environment_factory(dummy_seed)))
173
    return self._builder.make_replay_tables(environment_spec)
174

175
  def counter(self):
176
    kwargs = {}
177
    if self._checkpointing_config:
178
      kwargs = vars(self._checkpointing_config)
179
    return savers.CheckpointingRunner(
180
        counting.Counter(),
181
        key='counter',
182
        subdirectory='counter',
183
        time_delta_minutes=5,
184
        **kwargs)
185

186
  def learner(
187
      self,
188
      random_key,
189
      replay,
190
      counter,
191
  ):
192
    """The Learning part of the agent."""
193

194
    if self._builder._config.env_name.startswith('offline_ant'):  # pytype: disable=attribute-error, pylint: disable=protected-access
195
      adder = self._builder.make_adder(replay)
196
      env = self._environment_factory(0)
197
      dataset = env.get_dataset()  # pytype: disable=attribute-error
198
      for t in tqdm.trange(dataset['observations'].shape[0]):
199
        discount = 1.0
200
        if t == 0 or dataset['timeouts'][t - 1]:
201
          step_type = dm_env.StepType.FIRST
202
        elif dataset['timeouts'][t]:
203
          step_type = dm_env.StepType.LAST
204
          discount = 0.0
205
        else:
206
          step_type = dm_env.StepType.MID
207

208
        ts = dm_env.TimeStep(
209
            step_type=step_type,
210
            reward=dataset['rewards'][t],
211
            discount=discount,
212
            observation=np.concatenate([dataset['observations'][t],
213
                                        dataset['infos/goal'][t]]),
214
        )
215
        if t == 0 or dataset['timeouts'][t - 1]:
216
          adder.add_first(ts)  # pytype: disable=attribute-error
217
        else:
218
          adder.add(action=dataset['actions'][t-1], next_timestep=ts)  # pytype: disable=attribute-error
219

220
        if self._builder._config.local and t > 10_000:  # pytype: disable=attribute-error, pylint: disable=protected-access
221
          break
222

223
    iterator = self._builder.make_dataset_iterator(replay)
224

225
    dummy_seed = 1
226
    environment_spec = (
227
        self._environment_spec or
228
        specs.make_environment_spec(self._environment_factory(dummy_seed)))
229

230
    # Creates the networks to optimize (online) and target networks.
231
    networks = self._network_factory(environment_spec)
232

233
    if self._prefetch_size > 1:
234
      # When working with single GPU we should prefetch to device for
235
      # efficiency. If running on TPU this isn't necessary as the computation
236
      # and input placement can be done automatically. For multi-gpu currently
237
      # the best solution is to pre-fetch to host although this may change in
238
      # the future.
239
      device = jax.devices()[0] if self._device_prefetch else None
240
      iterator = utils.prefetch(
241
          iterator, buffer_size=self._prefetch_size, device=device)
242
    else:
243
      logging.info('Not prefetching the iterator.')
244

245
    counter = counting.Counter(counter, 'learner')
246

247
    learner = self._builder.make_learner(random_key, networks, iterator, replay,
248
                                         counter)
249
    kwargs = {}
250
    if self._checkpointing_config:
251
      kwargs = vars(self._checkpointing_config)
252
    # Return the learning agent.
253
    return savers.CheckpointingRunner(
254
        learner,
255
        key='learner',
256
        subdirectory='learner',
257
        time_delta_minutes=5,
258
        **kwargs)
259

260
  def actor(self, random_key, replay,
261
            variable_source, counter,
262
            actor_id):
263
    """The actor process."""
264
    adder = self._builder.make_adder(replay)
265

266
    environment_key, actor_key = jax.random.split(random_key)
267
    # Create environment and policy core.
268

269
    # Environments normally require uint32 as a seed.
270
    environment = self._environment_factory(
271
        utils.sample_uint32(environment_key))
272

273
    networks = self._network_factory(specs.make_environment_spec(environment))
274
    policy_network = self._policy_network(networks)
275
    actor = self._builder.make_actor(actor_key, policy_network, adder,
276
                                     variable_source)
277

278
    # Create logger and counter.
279
    counter = counting.Counter(counter, 'actor')
280
    # Only actor #0 will write to bigtable in order not to spam it too much.
281
    logger = self._actor_logger_fn(actor_id)
282
    # Create the loop to connect environment and agent.
283
    return environment_loop.EnvironmentLoop(environment, actor, counter,
284
                                            logger, observers=self._observers)
285

286
  def coordinator(self, counter, max_actor_steps):
287
    if self._builder._config.env_name.startswith('offline_ant'):  # pytype: disable=attribute-error, pylint: disable=protected-access
288
      steps_key = 'learner_steps'
289
    else:
290
      steps_key = 'actor_steps'
291
    return lp_utils.StepsLimiter(counter, max_actor_steps, steps_key=steps_key)
292

293
  def build(self, name='agent', program = None):
294
    """Build the distributed agent topology."""
295
    if not program:
296
      program = lp.Program(name=name)
297

298
    key = jax.random.PRNGKey(self._seed)
299

300
    replay_node = lp.ReverbNode(self.replay)
301
    with program.group('replay'):
302
      if self._multithreading_colocate_learner_and_reverb:
303
        replay = replay_node.create_handle()
304
      else:
305
        replay = program.add_node(replay_node)
306

307
    with program.group('counter'):
308
      counter = program.add_node(lp.CourierNode(self.counter))
309
      if self._max_number_of_steps is not None:
310
        _ = program.add_node(
311
            lp.CourierNode(self.coordinator, counter,
312
                           self._max_number_of_steps))
313

314
    learner_key, key = jax.random.split(key)
315
    learner_node = lp.CourierNode(self.learner, learner_key, replay, counter)
316
    with program.group('learner'):
317
      if self._multithreading_colocate_learner_and_reverb:
318
        learner = learner_node.create_handle()
319
        program.add_node(
320
            lp.MultiThreadingColocation([learner_node, replay_node]))
321
      else:
322
        learner = program.add_node(learner_node)
323

324
    def make_actor(random_key,
325
                   policy_network,
326
                   variable_source):
327
      return self._builder.make_actor(
328
          random_key, policy_network, variable_source=variable_source)
329

330
    with program.group('evaluator'):
331
      for evaluator in self._evaluator_factories:
332
        evaluator_key, key = jax.random.split(key)
333
        program.add_node(
334
            lp.CourierNode(evaluator, evaluator_key, learner, counter,
335
                           make_actor))
336

337
    with program.group('actor'):
338
      for actor_id in range(self._num_actors):
339
        actor_key, key = jax.random.split(key)
340
        program.add_node(
341
            lp.CourierNode(self.actor, actor_key, replay, learner, counter,
342
                           actor_id))
343

344
    return program
345

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.