google-research
344 строки · 11.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Program definition for a distributed layout based on a builder."""
17
18import dataclasses19import logging20from typing import Any, Callable, Optional, Sequence21
22from acme import core23from acme import environment_loop24from acme import specs25from acme.agents.jax import builders26from acme.jax import networks as networks_lib27from acme.jax import savers28from acme.jax import types29from acme.jax import utils30from acme.utils import counting31from acme.utils import loggers32from acme.utils import lp_utils33from acme.utils import observers as observers_lib34import dm_env35import jax36import launchpad as lp37import numpy as np38import reverb39import tqdm40
41
42ActorId = int43AgentNetwork = Any44PolicyNetwork = Any45NetworkFactory = Callable[[specs.EnvironmentSpec], AgentNetwork]46PolicyFactory = Callable[[AgentNetwork], PolicyNetwork]47Seed = int48EnvironmentFactory = Callable[[Seed], dm_env.Environment]49MakeActorFn = Callable[[types.PRNGKey, PolicyNetwork, core.VariableSource],50core.Actor]51LoggerLabel = str52LoggerStepsKey = str53LoggerFn = Callable[[LoggerLabel, LoggerStepsKey], loggers.Logger]54EvaluatorFactory = Callable[[55types.PRNGKey,56core.VariableSource,57counting.Counter,58MakeActorFn,59], core.Worker]60
61
62def get_default_logger_fn(63log_to_bigtable = False,64log_every = 10):65"""Creates an actor logger."""66
67def create_logger(actor_id):68return loggers.make_default_logger(69'actor',70save_data=(log_to_bigtable and actor_id == 0),71time_delta=log_every,72steps_key='actor_steps')73return create_logger74
75
76def default_evaluator_factory(77environment_factory,78network_factory,79policy_factory,80observers = (),81log_to_bigtable = False):82"""Returns a default evaluator process."""83def evaluator(84random_key,85variable_source,86counter,87make_actor,88):89"""The evaluation process."""90
91# Create environment and evaluator networks92environment_key, actor_key = jax.random.split(random_key)93# Environments normally require uint32 as a seed.94environment = environment_factory(utils.sample_uint32(environment_key))95networks = network_factory(specs.make_environment_spec(environment))96
97actor = make_actor(actor_key, policy_factory(networks), variable_source)98
99# Create logger and counter.100counter = counting.Counter(counter, 'evaluator')101logger = loggers.make_default_logger('evaluator', log_to_bigtable,102steps_key='actor_steps')103
104# Create the run loop and return it.105return environment_loop.EnvironmentLoop(environment, actor, counter,106logger, observers=observers)107return evaluator108
109
110@dataclasses.dataclass111class CheckpointingConfig:112"""Configuration options for learner checkpointer."""113# The maximum number of checkpoints to keep.114max_to_keep: int = 1115# Which directory to put the checkpoint in.116directory: str = '~/acme'117# If True adds a UID to the checkpoint path, see118# `paths.get_unique_id()` for how this UID is generated.119add_uid: bool = True120
121
122class DistributedLayout:123"""Program definition for a distributed agent based on a builder."""124
125def __init__(126self,127seed,128environment_factory,129network_factory,130builder,131policy_network,132num_actors,133environment_spec = None,134actor_logger_fn = None,135evaluator_factories = (),136device_prefetch = True,137prefetch_size = 1,138log_to_bigtable = False,139max_number_of_steps = None,140observers = (),141multithreading_colocate_learner_and_reverb = False,142checkpointing_config = None):143
144if prefetch_size < 0:145raise ValueError(f'Prefetch size={prefetch_size} should be non negative')146
147actor_logger_fn = actor_logger_fn or get_default_logger_fn(log_to_bigtable)148
149self._seed = seed150self._builder = builder151self._environment_factory = environment_factory152self._network_factory = network_factory153self._policy_network = policy_network154self._environment_spec = environment_spec155self._num_actors = num_actors156self._device_prefetch = device_prefetch157self._log_to_bigtable = log_to_bigtable158self._prefetch_size = prefetch_size159self._max_number_of_steps = max_number_of_steps160self._actor_logger_fn = actor_logger_fn161self._evaluator_factories = evaluator_factories162self._observers = observers163self._multithreading_colocate_learner_and_reverb = (164multithreading_colocate_learner_and_reverb)165self._checkpointing_config = checkpointing_config166
167def replay(self):168"""The replay storage."""169dummy_seed = 1170environment_spec = (171self._environment_spec or172specs.make_environment_spec(self._environment_factory(dummy_seed)))173return self._builder.make_replay_tables(environment_spec)174
175def counter(self):176kwargs = {}177if self._checkpointing_config:178kwargs = vars(self._checkpointing_config)179return savers.CheckpointingRunner(180counting.Counter(),181key='counter',182subdirectory='counter',183time_delta_minutes=5,184**kwargs)185
186def learner(187self,188random_key,189replay,190counter,191):192"""The Learning part of the agent."""193
194if self._builder._config.env_name.startswith('offline_ant'): # pytype: disable=attribute-error, pylint: disable=protected-access195adder = self._builder.make_adder(replay)196env = self._environment_factory(0)197dataset = env.get_dataset() # pytype: disable=attribute-error198for t in tqdm.trange(dataset['observations'].shape[0]):199discount = 1.0200if t == 0 or dataset['timeouts'][t - 1]:201step_type = dm_env.StepType.FIRST202elif dataset['timeouts'][t]:203step_type = dm_env.StepType.LAST204discount = 0.0205else:206step_type = dm_env.StepType.MID207
208ts = dm_env.TimeStep(209step_type=step_type,210reward=dataset['rewards'][t],211discount=discount,212observation=np.concatenate([dataset['observations'][t],213dataset['infos/goal'][t]]),214)215if t == 0 or dataset['timeouts'][t - 1]:216adder.add_first(ts) # pytype: disable=attribute-error217else:218adder.add(action=dataset['actions'][t-1], next_timestep=ts) # pytype: disable=attribute-error219
220if self._builder._config.local and t > 10_000: # pytype: disable=attribute-error, pylint: disable=protected-access221break222
223iterator = self._builder.make_dataset_iterator(replay)224
225dummy_seed = 1226environment_spec = (227self._environment_spec or228specs.make_environment_spec(self._environment_factory(dummy_seed)))229
230# Creates the networks to optimize (online) and target networks.231networks = self._network_factory(environment_spec)232
233if self._prefetch_size > 1:234# When working with single GPU we should prefetch to device for235# efficiency. If running on TPU this isn't necessary as the computation236# and input placement can be done automatically. For multi-gpu currently237# the best solution is to pre-fetch to host although this may change in238# the future.239device = jax.devices()[0] if self._device_prefetch else None240iterator = utils.prefetch(241iterator, buffer_size=self._prefetch_size, device=device)242else:243logging.info('Not prefetching the iterator.')244
245counter = counting.Counter(counter, 'learner')246
247learner = self._builder.make_learner(random_key, networks, iterator, replay,248counter)249kwargs = {}250if self._checkpointing_config:251kwargs = vars(self._checkpointing_config)252# Return the learning agent.253return savers.CheckpointingRunner(254learner,255key='learner',256subdirectory='learner',257time_delta_minutes=5,258**kwargs)259
260def actor(self, random_key, replay,261variable_source, counter,262actor_id):263"""The actor process."""264adder = self._builder.make_adder(replay)265
266environment_key, actor_key = jax.random.split(random_key)267# Create environment and policy core.268
269# Environments normally require uint32 as a seed.270environment = self._environment_factory(271utils.sample_uint32(environment_key))272
273networks = self._network_factory(specs.make_environment_spec(environment))274policy_network = self._policy_network(networks)275actor = self._builder.make_actor(actor_key, policy_network, adder,276variable_source)277
278# Create logger and counter.279counter = counting.Counter(counter, 'actor')280# Only actor #0 will write to bigtable in order not to spam it too much.281logger = self._actor_logger_fn(actor_id)282# Create the loop to connect environment and agent.283return environment_loop.EnvironmentLoop(environment, actor, counter,284logger, observers=self._observers)285
286def coordinator(self, counter, max_actor_steps):287if self._builder._config.env_name.startswith('offline_ant'): # pytype: disable=attribute-error, pylint: disable=protected-access288steps_key = 'learner_steps'289else:290steps_key = 'actor_steps'291return lp_utils.StepsLimiter(counter, max_actor_steps, steps_key=steps_key)292
293def build(self, name='agent', program = None):294"""Build the distributed agent topology."""295if not program:296program = lp.Program(name=name)297
298key = jax.random.PRNGKey(self._seed)299
300replay_node = lp.ReverbNode(self.replay)301with program.group('replay'):302if self._multithreading_colocate_learner_and_reverb:303replay = replay_node.create_handle()304else:305replay = program.add_node(replay_node)306
307with program.group('counter'):308counter = program.add_node(lp.CourierNode(self.counter))309if self._max_number_of_steps is not None:310_ = program.add_node(311lp.CourierNode(self.coordinator, counter,312self._max_number_of_steps))313
314learner_key, key = jax.random.split(key)315learner_node = lp.CourierNode(self.learner, learner_key, replay, counter)316with program.group('learner'):317if self._multithreading_colocate_learner_and_reverb:318learner = learner_node.create_handle()319program.add_node(320lp.MultiThreadingColocation([learner_node, replay_node]))321else:322learner = program.add_node(learner_node)323
324def make_actor(random_key,325policy_network,326variable_source):327return self._builder.make_actor(328random_key, policy_network, variable_source=variable_source)329
330with program.group('evaluator'):331for evaluator in self._evaluator_factories:332evaluator_key, key = jax.random.split(key)333program.add_node(334lp.CourierNode(evaluator, evaluator_key, learner, counter,335make_actor))336
337with program.group('actor'):338for actor_id in range(self._num_actors):339actor_key, key = jax.random.split(key)340program.add_node(341lp.CourierNode(self.actor, actor_key, replay, learner, counter,342actor_id))343
344return program345