google-research
363 строки · 15.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=logging-format-interpolation
17# pylint: disable=g-complex-comprehension
18r"""SEED actor."""
19
20import collections
21import os
22import random
23
24from absl import flags
25from absl import logging
26import numpy as np
27from seed_rl import grpc
28from seed_rl.common import common_flags # pylint: disable=unused-import
29from seed_rl.common import profiling
30from seed_rl.common import utils
31import tensorflow as tf
32
33from muzero import core
34from muzero import utils as mzutils
35
36FLAGS = flags.FLAGS
37
38TASK = flags.DEFINE_integer('task', 0, 'Task id.')
39USE_SOFTMAX_FOR_TARGET = flags.DEFINE_integer(
40'use_softmax_for_target', 0,
41'If True (1), use a softmax for the child_visit count distribution that '
42'is used as a target for the policy.')
43NUM_TEST_ACTORS = flags.DEFINE_integer(
44'num_test_actors', 2, 'Number of actors that are used for testing.')
45NUM_ACTORS_WITH_SUMMARIES = flags.DEFINE_integer(
46'num_actors_with_summaries', 1,
47'Number of actors that will log debug/profiling TF '
48'summaries.')
49ACTOR_LOG_FREQUENCY = flags.DEFINE_integer('actor_log_frequency', 10,
50'in number of training steps')
51MCTS_VIS_FILE = flags.DEFINE_string(
52'mcts_vis_file', None, 'File in which to log the mcts visualizations.')
53FLAG_FILE = flags.DEFINE_string('flag_file', None,
54'File in which to log the parameters.')
55ENABLE_ACTOR_LOGGING = flags.DEFINE_boolean('enable_actor_logging', True,
56'Verbose logging for the actor.')
57MAX_NUM_ACTION_EXPANSION = flags.DEFINE_integer(
58'max_num_action_expansion', 0,
59'Maximum number of new nodes for a node expansion. 0 for no limit. '
60'This is important for the full vocabulary.')
61ACTOR_ENQUEUE_EVERY = flags.DEFINE_integer(
62'actor_enqueue_every', 0,
63'After how many steps the actor enqueues samples. 0 for at episode end.')
64ACTOR_SKIP = flags.DEFINE_integer('actor_skip', 0,
65'How many target samples the actor skips.')
66
67
68def is_training_actor():
69return TASK.value >= NUM_TEST_ACTORS.value
70
71
72def are_summaries_enabled():
73return is_training_actor(
74) and TASK.value < NUM_TEST_ACTORS.value + NUM_ACTORS_WITH_SUMMARIES.value
75
76
77def actor_loop(create_env_fn,
78mzconfig,
79share_of_supervised_episodes_fn=lambda _: 0.):
80"""Main actor loop.
81
82Args:
83create_env_fn: Callable (taking the task ID as argument) that must return a
84newly created environment.
85mzconfig: MuZeroConfig instance.
86share_of_supervised_episodes_fn: Function that specifies the share of
87episodes that should be supervised based on the learner iteration.
88"""
89
90logging.info('Starting actor loop')
91
92actor_log_dir = os.path.join(FLAGS.logdir, 'actor_{}'.format(TASK.value))
93if are_summaries_enabled():
94summary_writer = tf.summary.create_file_writer(
95actor_log_dir, flush_millis=20000, max_queue=1000)
96timer_cls = profiling.ExportingTimer
97if FLAG_FILE.value:
98mzutils.write_flags(FLAGS.__flags, FLAG_FILE.value) # pylint: disable=protected-access
99else:
100summary_writer = tf.summary.create_noop_writer()
101timer_cls = utils.nullcontext
102
103batch_queue = collections.deque()
104
105actor_step = tf.Variable(0, dtype=tf.int64)
106num_episodes = tf.Variable(0, dtype=tf.int64)
107
108# We use the checkpoint to keep track of the actor_step and num_episodes.
109actor_checkpoint = tf.train.Checkpoint(
110actor_step=actor_step, num_episodes=num_episodes)
111ckpt_manager = tf.train.CheckpointManager(
112checkpoint=actor_checkpoint, directory=actor_log_dir, max_to_keep=1)
113if ckpt_manager.latest_checkpoint:
114logging.info('Restoring actor checkpoint: %s',
115ckpt_manager.latest_checkpoint)
116actor_checkpoint.restore(ckpt_manager.latest_checkpoint).assert_consumed()
117
118reward_agg, length_agg = profiling.Aggregator(), profiling.Aggregator()
119with summary_writer.as_default():
120tf.summary.experimental.set_step(actor_step)
121while True:
122try:
123# Client to communicate with the learner.
124client = grpc.Client(FLAGS.server_address)
125
126def _create_training_samples(episode, start_idx=0):
127start_idx += random.choice(range(ACTOR_SKIP.value + 1))
128for i in range(start_idx, len(episode.history), ACTOR_SKIP.value + 1):
129target = episode.make_target(
130state_index=i,
131num_unroll_steps=mzconfig.num_unroll_steps,
132td_steps=mzconfig.td_steps,
133rewards=episode.rewards,
134policy_distributions=episode.child_visits,
135discount=episode.discount,
136value_approximations=episode.root_values)
137priority = np.float32(1e-2) # preventing all zero priorities
138if len(episode) > 0: # pylint: disable=g-explicit-length-test
139last_value_idx = min(len(episode) - 1 - i, len(target.value) - 1)
140priority = np.maximum(
141priority,
142np.float32(
143np.abs(episode.root_values[i + last_value_idx] -
144target.value[last_value_idx])))
145
146# This will be batched and given to add_to_replay_buffer on the
147# learner.
148sample = (
149priority,
150episode.make_image(i),
151tf.stack(
152episode.history_range(i, i + mzconfig.num_unroll_steps)),
153) + tuple(map(lambda x: tf.cast(tf.stack(x), tf.float32), target))
154batch_queue.append(sample)
155if ENABLE_ACTOR_LOGGING.value:
156logging.info(
157'Added %d samples to the batch_queue. Size: %d of needed %d',
158len(episode.history) - start_idx, len(batch_queue),
159mzconfig.train_batch_size)
160
161def _add_queue_to_replay_buffer():
162with timer_cls('actor/elapsed_add_to_buffer_s',
16310 * ACTOR_LOG_FREQUENCY.value):
164while len(batch_queue) >= mzconfig.train_batch_size:
165batch = [
166batch_queue.popleft()
167for _ in range(mzconfig.train_batch_size)
168]
169flat_batch = [tf.nest.flatten(b) for b in batch]
170stacked_batch = list(map(tf.stack, zip(*flat_batch)))
171structured_batch = tf.nest.pack_sequence_as(
172batch[0], stacked_batch)
173client.add_to_replay_buffer(*structured_batch)
174if ENABLE_ACTOR_LOGGING.value:
175logging.info('Added batch of size %d into replay_buffer.',
176len(batch))
177
178env = create_env_fn(TASK.value, training=is_training_actor())
179
180def recurrent_inference_fn(*args, **kwargs):
181with timer_cls('actor/elapsed_recurrent_inference_s',
182100 * ACTOR_LOG_FREQUENCY.value):
183output = client.recurrent_inference(*args, **kwargs)
184output = tf.nest.map_structure(lambda t: t.numpy(), output)
185return output
186
187def get_legal_actions_fn(episode):
188
189def legal_actions_fn(*args, **kwargs):
190with timer_cls('actor/elapsed_get_legal_actions_s',
191100 * ACTOR_LOG_FREQUENCY.value):
192output = episode.legal_actions(*args, **kwargs)
193return output
194
195return legal_actions_fn
196
197while True:
198episode = mzconfig.new_episode(env)
199is_supervised_episode = is_training_actor() and \
200random.random() < share_of_supervised_episodes_fn(
201client.learning_iteration().numpy())
202
203if is_supervised_episode:
204if ENABLE_ACTOR_LOGGING.value:
205logging.info('Supervised Episode.')
206try:
207with timer_cls('actor/elapsed_load_supervised_episode_s',
208ACTOR_LOG_FREQUENCY.value):
209episode_example = env.load_supervised_episode()
210with timer_cls('actor/elapsed_run_supervised_episode_s',
211ACTOR_LOG_FREQUENCY.value):
212targets, samples = env.run_supervised_episode(episode_example)
213episode.rewards = samples['reward']
214episode.history = samples['to_predict']
215for target in targets:
216batch_queue.append(target)
217except core.RLEnvironmentError as e:
218logging.warning('Environment not ready %s', str(e))
219# restart episode
220continue
221except core.BadSupervisedEpisodeError as e:
222logging.warning('Abort supervised episode: %s', str(e))
223# restart episode
224continue
225else:
226if ENABLE_ACTOR_LOGGING.value:
227logging.info('RL Episode.')
228try:
229last_enqueued_idx = 0
230legal_actions_fn = get_legal_actions_fn(episode)
231except core.RLEnvironmentError as e:
232logging.warning('Environment not ready: %s', str(e))
233# restart episode
234continue
235except core.SkipEpisode as e:
236logging.warning('Episode is skipped due to: %s', str(e))
237# restart episode
238continue
239while (not episode.terminal() and
240len(episode.history) < mzconfig.max_moves):
241# This loop is the agent playing the episode.
242current_observation = episode.make_image(-1)
243
244# Map the observation to hidden space.
245with timer_cls('actor/elapsed_initial_inference_s',
24610 * ACTOR_LOG_FREQUENCY.value):
247initial_inference_output = client.initial_inference(
248current_observation)
249initial_inference_output = tf.nest.map_structure(
250lambda t: t.numpy(), initial_inference_output)
251
252# Run MCTS using recurrent_inference_fn.
253with timer_cls('actor/elapsed_mcts_s',
25410 * ACTOR_LOG_FREQUENCY.value):
255legal_actions = legal_actions_fn()
256root = core.prepare_root_node(mzconfig, legal_actions,
257initial_inference_output)
258with timer_cls('actor/elapsed_run_mcts_s',
25910 * ACTOR_LOG_FREQUENCY.value):
260core.run_mcts(mzconfig, root, episode.action_history(),
261legal_actions_fn, recurrent_inference_fn,
262episode.visualize_mcts)
263action = core.select_action(
264mzconfig,
265len(episode.history),
266root,
267train_step=actor_step.numpy(),
268use_softmax=mzconfig.use_softmax_for_action_selection,
269is_training=is_training_actor())
270
271try:
272# Perform chosen action.
273with timer_cls('actor/elapsed_env_step_s',
27410 * ACTOR_LOG_FREQUENCY.value):
275training_steps = client.learning_iteration().numpy()
276episode.apply(action=action, training_steps=training_steps)
277except core.RLEnvironmentError as env_error:
278logging.warning('Environment failed: %s', str(env_error))
279episode.failed = True
280# terminate episode
281break
282
283episode.store_search_statistics(
284root, use_softmax=(USE_SOFTMAX_FOR_TARGET.value == 1))
285actor_step.assign_add(delta=1)
286if is_training_actor() and ACTOR_ENQUEUE_EVERY.value > 0 and (
287len(episode.history) -
288last_enqueued_idx) >= ACTOR_ENQUEUE_EVERY.value:
289_create_training_samples(episode, start_idx=last_enqueued_idx)
290last_enqueued_idx = len(episode.history)
291_add_queue_to_replay_buffer()
292
293if episode.failed:
294# restart episode
295logging.warning('Episode failed, restarting.')
296continue
297# Post-episode stats
298num_episodes.assign_add(delta=1)
299reward_agg.add(episode.total_reward())
300length_agg.add(len(episode))
301if ENABLE_ACTOR_LOGGING.value:
302logging.info(
303'Episode done. Length: %d, '
304'Total Reward: %d, Min Reward: %d, Max Reward: %d',
305len(episode), episode.total_reward(), min(episode.rewards),
306max(episode.rewards))
307if reward_agg.count % ACTOR_LOG_FREQUENCY.value == 0:
308tf.summary.experimental.set_step(actor_step)
309tf.summary.scalar('actor/total_reward', reward_agg.average())
310tf.summary.scalar('actor/episode_length', length_agg.average())
311tf.summary.scalar('actor/num_episodes', num_episodes)
312tf.summary.scalar('actor/step', actor_step)
313tf.summary.scalar(
314'actor/share_of_supervised_episodes',
315share_of_supervised_episodes_fn(
316client.learning_iteration().numpy()))
317if episode.mcts_visualizations:
318tf.summary.text('actor/mcts_vis',
319'\n\n'.join(episode.mcts_visualizations))
320if are_summaries_enabled() and MCTS_VIS_FILE.value is not None:
321# write it also into a txt file
322with tf.io.gfile.GFile(MCTS_VIS_FILE.value, 'a') as f:
323f.write('Step {}\n{}\n\n\n\n'.format(
324actor_step, '\n\n'.join(episode.mcts_visualizations)))
325
326special_stats = episode.special_statistics()
327for stat_name, stat_value in special_stats.items():
328if isinstance(stat_value, float) or isinstance(stat_value, int):
329tf.summary.scalar('actor/{}'.format(stat_name), stat_value)
330elif isinstance(stat_value, str):
331tf.summary.text('actor/{}'.format(stat_name), stat_value)
332else:
333logging.warning(
334'Special statistic %s could not be tracked. '
335'Type %s is not supported.', stat_name, type(stat_value))
336
337ckpt_manager.save()
338reward_agg.reset()
339length_agg.reset()
340
341if is_training_actor():
342# Create samples for training.
343_create_training_samples(episode, start_idx=last_enqueued_idx)
344
345# Send training samples to the learner after the episode is finished
346if is_training_actor():
347_add_queue_to_replay_buffer()
348
349summary_name = 'train' if is_training_actor() else 'test'
350if is_supervised_episode:
351summary_name += ' (supervised)'
352with timer_cls('actor/elapsed_add_to_reward_s',
35310 * ACTOR_LOG_FREQUENCY.value):
354# This is just for statistics.
355client.add_to_reward_queue(summary_name,
356np.float32(episode.total_reward()),
357np.int64(len(episode)),
358*episode.special_statistics_learner())
359del episode
360
361except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
362logging.exception(e)
363env.close()
364