google-research

Форк
0
363 строки · 15.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=logging-format-interpolation
17
# pylint: disable=g-complex-comprehension
18
r"""SEED actor."""
19

20
import collections
21
import os
22
import random
23

24
from absl import flags
25
from absl import logging
26
import numpy as np
27
from seed_rl import grpc
28
from seed_rl.common import common_flags  # pylint: disable=unused-import
29
from seed_rl.common import profiling
30
from seed_rl.common import utils
31
import tensorflow as tf
32

33
from muzero import core
34
from muzero import utils as mzutils
35

36
FLAGS = flags.FLAGS
37

38
TASK = flags.DEFINE_integer('task', 0, 'Task id.')
39
USE_SOFTMAX_FOR_TARGET = flags.DEFINE_integer(
40
    'use_softmax_for_target', 0,
41
    'If True (1), use a softmax for the child_visit count distribution that '
42
    'is used as a target for the policy.')
43
NUM_TEST_ACTORS = flags.DEFINE_integer(
44
    'num_test_actors', 2, 'Number of actors that are used for testing.')
45
NUM_ACTORS_WITH_SUMMARIES = flags.DEFINE_integer(
46
    'num_actors_with_summaries', 1,
47
    'Number of actors that will log debug/profiling TF '
48
    'summaries.')
49
ACTOR_LOG_FREQUENCY = flags.DEFINE_integer('actor_log_frequency', 10,
50
                                           'in number of training steps')
51
MCTS_VIS_FILE = flags.DEFINE_string(
52
    'mcts_vis_file', None, 'File in which to log the mcts visualizations.')
53
FLAG_FILE = flags.DEFINE_string('flag_file', None,
54
                                'File in which to log the parameters.')
55
ENABLE_ACTOR_LOGGING = flags.DEFINE_boolean('enable_actor_logging', True,
56
                                            'Verbose logging for the actor.')
57
MAX_NUM_ACTION_EXPANSION = flags.DEFINE_integer(
58
    'max_num_action_expansion', 0,
59
    'Maximum number of new nodes for a node expansion. 0 for no limit. '
60
    'This is important for the full vocabulary.')
61
ACTOR_ENQUEUE_EVERY = flags.DEFINE_integer(
62
    'actor_enqueue_every', 0,
63
    'After how many steps the actor enqueues samples. 0 for at episode end.')
64
ACTOR_SKIP = flags.DEFINE_integer('actor_skip', 0,
65
                                  'How many target samples the actor skips.')
66

67

68
def is_training_actor():
69
  return TASK.value >= NUM_TEST_ACTORS.value
70

71

72
def are_summaries_enabled():
73
  return is_training_actor(
74
  ) and TASK.value < NUM_TEST_ACTORS.value + NUM_ACTORS_WITH_SUMMARIES.value
75

76

77
def actor_loop(create_env_fn,
78
               mzconfig,
79
               share_of_supervised_episodes_fn=lambda _: 0.):
80
  """Main actor loop.
81

82
  Args:
83
    create_env_fn: Callable (taking the task ID as argument) that must return a
84
      newly created environment.
85
    mzconfig: MuZeroConfig instance.
86
    share_of_supervised_episodes_fn: Function that specifies the share of
87
      episodes that should be supervised based on the learner iteration.
88
  """
89

90
  logging.info('Starting actor loop')
91

92
  actor_log_dir = os.path.join(FLAGS.logdir, 'actor_{}'.format(TASK.value))
93
  if are_summaries_enabled():
94
    summary_writer = tf.summary.create_file_writer(
95
        actor_log_dir, flush_millis=20000, max_queue=1000)
96
    timer_cls = profiling.ExportingTimer
97
    if FLAG_FILE.value:
98
      mzutils.write_flags(FLAGS.__flags, FLAG_FILE.value)  # pylint: disable=protected-access
99
  else:
100
    summary_writer = tf.summary.create_noop_writer()
101
    timer_cls = utils.nullcontext
102

103
  batch_queue = collections.deque()
104

105
  actor_step = tf.Variable(0, dtype=tf.int64)
106
  num_episodes = tf.Variable(0, dtype=tf.int64)
107

108
  # We use the checkpoint to keep track of the actor_step and num_episodes.
109
  actor_checkpoint = tf.train.Checkpoint(
110
      actor_step=actor_step, num_episodes=num_episodes)
111
  ckpt_manager = tf.train.CheckpointManager(
112
      checkpoint=actor_checkpoint, directory=actor_log_dir, max_to_keep=1)
113
  if ckpt_manager.latest_checkpoint:
114
    logging.info('Restoring actor checkpoint: %s',
115
                 ckpt_manager.latest_checkpoint)
116
    actor_checkpoint.restore(ckpt_manager.latest_checkpoint).assert_consumed()
117

118
  reward_agg, length_agg = profiling.Aggregator(), profiling.Aggregator()
119
  with summary_writer.as_default():
120
    tf.summary.experimental.set_step(actor_step)
121
    while True:
122
      try:
123
        # Client to communicate with the learner.
124
        client = grpc.Client(FLAGS.server_address)
125

126
        def _create_training_samples(episode, start_idx=0):
127
          start_idx += random.choice(range(ACTOR_SKIP.value + 1))
128
          for i in range(start_idx, len(episode.history), ACTOR_SKIP.value + 1):
129
            target = episode.make_target(
130
                state_index=i,
131
                num_unroll_steps=mzconfig.num_unroll_steps,
132
                td_steps=mzconfig.td_steps,
133
                rewards=episode.rewards,
134
                policy_distributions=episode.child_visits,
135
                discount=episode.discount,
136
                value_approximations=episode.root_values)
137
            priority = np.float32(1e-2)  # preventing all zero priorities
138
            if len(episode) > 0:  # pylint: disable=g-explicit-length-test
139
              last_value_idx = min(len(episode) - 1 - i, len(target.value) - 1)
140
              priority = np.maximum(
141
                  priority,
142
                  np.float32(
143
                      np.abs(episode.root_values[i + last_value_idx] -
144
                             target.value[last_value_idx])))
145

146
            # This will be batched and given to add_to_replay_buffer on the
147
            # learner.
148
            sample = (
149
                priority,
150
                episode.make_image(i),
151
                tf.stack(
152
                    episode.history_range(i, i + mzconfig.num_unroll_steps)),
153
            ) + tuple(map(lambda x: tf.cast(tf.stack(x), tf.float32), target))
154
            batch_queue.append(sample)
155
          if ENABLE_ACTOR_LOGGING.value:
156
            logging.info(
157
                'Added %d samples to the batch_queue. Size: %d of needed %d',
158
                len(episode.history) - start_idx, len(batch_queue),
159
                mzconfig.train_batch_size)
160

161
        def _add_queue_to_replay_buffer():
162
          with timer_cls('actor/elapsed_add_to_buffer_s',
163
                         10 * ACTOR_LOG_FREQUENCY.value):
164
            while len(batch_queue) >= mzconfig.train_batch_size:
165
              batch = [
166
                  batch_queue.popleft()
167
                  for _ in range(mzconfig.train_batch_size)
168
              ]
169
              flat_batch = [tf.nest.flatten(b) for b in batch]
170
              stacked_batch = list(map(tf.stack, zip(*flat_batch)))
171
              structured_batch = tf.nest.pack_sequence_as(
172
                  batch[0], stacked_batch)
173
              client.add_to_replay_buffer(*structured_batch)
174
              if ENABLE_ACTOR_LOGGING.value:
175
                logging.info('Added batch of size %d into replay_buffer.',
176
                             len(batch))
177

178
        env = create_env_fn(TASK.value, training=is_training_actor())
179

180
        def recurrent_inference_fn(*args, **kwargs):
181
          with timer_cls('actor/elapsed_recurrent_inference_s',
182
                         100 * ACTOR_LOG_FREQUENCY.value):
183
            output = client.recurrent_inference(*args, **kwargs)
184
            output = tf.nest.map_structure(lambda t: t.numpy(), output)
185
          return output
186

187
        def get_legal_actions_fn(episode):
188

189
          def legal_actions_fn(*args, **kwargs):
190
            with timer_cls('actor/elapsed_get_legal_actions_s',
191
                           100 * ACTOR_LOG_FREQUENCY.value):
192
              output = episode.legal_actions(*args, **kwargs)
193
            return output
194

195
          return legal_actions_fn
196

197
        while True:
198
          episode = mzconfig.new_episode(env)
199
          is_supervised_episode = is_training_actor() and \
200
              random.random() < share_of_supervised_episodes_fn(
201
                  client.learning_iteration().numpy())
202

203
          if is_supervised_episode:
204
            if ENABLE_ACTOR_LOGGING.value:
205
              logging.info('Supervised Episode.')
206
            try:
207
              with timer_cls('actor/elapsed_load_supervised_episode_s',
208
                             ACTOR_LOG_FREQUENCY.value):
209
                episode_example = env.load_supervised_episode()
210
              with timer_cls('actor/elapsed_run_supervised_episode_s',
211
                             ACTOR_LOG_FREQUENCY.value):
212
                targets, samples = env.run_supervised_episode(episode_example)
213
              episode.rewards = samples['reward']
214
              episode.history = samples['to_predict']
215
              for target in targets:
216
                batch_queue.append(target)
217
            except core.RLEnvironmentError as e:
218
              logging.warning('Environment not ready %s', str(e))
219
              # restart episode
220
              continue
221
            except core.BadSupervisedEpisodeError as e:
222
              logging.warning('Abort supervised episode: %s', str(e))
223
              # restart episode
224
              continue
225
          else:
226
            if ENABLE_ACTOR_LOGGING.value:
227
              logging.info('RL Episode.')
228
            try:
229
              last_enqueued_idx = 0
230
              legal_actions_fn = get_legal_actions_fn(episode)
231
            except core.RLEnvironmentError as e:
232
              logging.warning('Environment not ready: %s', str(e))
233
              # restart episode
234
              continue
235
            except core.SkipEpisode as e:
236
              logging.warning('Episode is skipped due to: %s', str(e))
237
              # restart episode
238
              continue
239
            while (not episode.terminal() and
240
                   len(episode.history) < mzconfig.max_moves):
241
              # This loop is the agent playing the episode.
242
              current_observation = episode.make_image(-1)
243

244
              # Map the observation to hidden space.
245
              with timer_cls('actor/elapsed_initial_inference_s',
246
                             10 * ACTOR_LOG_FREQUENCY.value):
247
                initial_inference_output = client.initial_inference(
248
                    current_observation)
249
                initial_inference_output = tf.nest.map_structure(
250
                    lambda t: t.numpy(), initial_inference_output)
251

252
              # Run MCTS using recurrent_inference_fn.
253
              with timer_cls('actor/elapsed_mcts_s',
254
                             10 * ACTOR_LOG_FREQUENCY.value):
255
                legal_actions = legal_actions_fn()
256
                root = core.prepare_root_node(mzconfig, legal_actions,
257
                                              initial_inference_output)
258
                with timer_cls('actor/elapsed_run_mcts_s',
259
                               10 * ACTOR_LOG_FREQUENCY.value):
260
                  core.run_mcts(mzconfig, root, episode.action_history(),
261
                                legal_actions_fn, recurrent_inference_fn,
262
                                episode.visualize_mcts)
263
                action = core.select_action(
264
                    mzconfig,
265
                    len(episode.history),
266
                    root,
267
                    train_step=actor_step.numpy(),
268
                    use_softmax=mzconfig.use_softmax_for_action_selection,
269
                    is_training=is_training_actor())
270

271
              try:
272
                # Perform chosen action.
273
                with timer_cls('actor/elapsed_env_step_s',
274
                               10 * ACTOR_LOG_FREQUENCY.value):
275
                  training_steps = client.learning_iteration().numpy()
276
                  episode.apply(action=action, training_steps=training_steps)
277
              except core.RLEnvironmentError as env_error:
278
                logging.warning('Environment failed: %s', str(env_error))
279
                episode.failed = True
280
                # terminate episode
281
                break
282

283
              episode.store_search_statistics(
284
                  root, use_softmax=(USE_SOFTMAX_FOR_TARGET.value == 1))
285
              actor_step.assign_add(delta=1)
286
              if is_training_actor() and ACTOR_ENQUEUE_EVERY.value > 0 and (
287
                  len(episode.history) -
288
                  last_enqueued_idx) >= ACTOR_ENQUEUE_EVERY.value:
289
                _create_training_samples(episode, start_idx=last_enqueued_idx)
290
                last_enqueued_idx = len(episode.history)
291
                _add_queue_to_replay_buffer()
292

293
            if episode.failed:
294
              # restart episode
295
              logging.warning('Episode failed, restarting.')
296
              continue
297
            # Post-episode stats
298
            num_episodes.assign_add(delta=1)
299
            reward_agg.add(episode.total_reward())
300
            length_agg.add(len(episode))
301
            if ENABLE_ACTOR_LOGGING.value:
302
              logging.info(
303
                  'Episode done. Length: %d, '
304
                  'Total Reward: %d, Min Reward: %d, Max Reward: %d',
305
                  len(episode), episode.total_reward(), min(episode.rewards),
306
                  max(episode.rewards))
307
            if reward_agg.count % ACTOR_LOG_FREQUENCY.value == 0:
308
              tf.summary.experimental.set_step(actor_step)
309
              tf.summary.scalar('actor/total_reward', reward_agg.average())
310
              tf.summary.scalar('actor/episode_length', length_agg.average())
311
              tf.summary.scalar('actor/num_episodes', num_episodes)
312
              tf.summary.scalar('actor/step', actor_step)
313
              tf.summary.scalar(
314
                  'actor/share_of_supervised_episodes',
315
                  share_of_supervised_episodes_fn(
316
                      client.learning_iteration().numpy()))
317
              if episode.mcts_visualizations:
318
                tf.summary.text('actor/mcts_vis',
319
                                '\n\n'.join(episode.mcts_visualizations))
320
                if are_summaries_enabled() and MCTS_VIS_FILE.value is not None:
321
                  # write it also into a txt file
322
                  with tf.io.gfile.GFile(MCTS_VIS_FILE.value, 'a') as f:
323
                    f.write('Step {}\n{}\n\n\n\n'.format(
324
                        actor_step, '\n\n'.join(episode.mcts_visualizations)))
325

326
              special_stats = episode.special_statistics()
327
              for stat_name, stat_value in special_stats.items():
328
                if isinstance(stat_value, float) or isinstance(stat_value, int):
329
                  tf.summary.scalar('actor/{}'.format(stat_name), stat_value)
330
                elif isinstance(stat_value, str):
331
                  tf.summary.text('actor/{}'.format(stat_name), stat_value)
332
                else:
333
                  logging.warning(
334
                      'Special statistic %s could not be tracked. '
335
                      'Type %s is not supported.', stat_name, type(stat_value))
336

337
              ckpt_manager.save()
338
              reward_agg.reset()
339
              length_agg.reset()
340

341
            if is_training_actor():
342
              # Create samples for training.
343
              _create_training_samples(episode, start_idx=last_enqueued_idx)
344

345
          # Send training samples to the learner after the episode is finished
346
          if is_training_actor():
347
            _add_queue_to_replay_buffer()
348

349
          summary_name = 'train' if is_training_actor() else 'test'
350
          if is_supervised_episode:
351
            summary_name += ' (supervised)'
352
          with timer_cls('actor/elapsed_add_to_reward_s',
353
                         10 * ACTOR_LOG_FREQUENCY.value):
354
            # This is just for statistics.
355
            client.add_to_reward_queue(summary_name,
356
                                       np.float32(episode.total_reward()),
357
                                       np.int64(len(episode)),
358
                                       *episode.special_statistics_learner())
359
          del episode
360

361
      except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
362
        logging.exception(e)
363
        env.close()
364

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.