google-research

Форк
0
650 строк · 21.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=missing-docstring
17
# pylint: disable=g-complex-comprehension
18
"""MuZero Core.
19

20
Based partially on https://arxiv.org/src/1911.08265v1/anc/pseudocode.py
21
"""
22

23
import collections
24
import logging
25
import math
26
from typing import List, Optional, Dict, Any, Tuple
27

28
from absl import flags
29
import attr
30
import gym
31
import numpy as np
32
import tensorflow as tf
33

34
FLAGS = flags.FLAGS
35
MAXIMUM_FLOAT_VALUE = float('inf')
36

37
KnownBounds = collections.namedtuple('KnownBounds', 'min max')
38

39
NetworkOutput = collections.namedtuple(
40
    'NetworkOutput',
41
    'value value_logits reward reward_logits policy_logits hidden_state')
42

43
Prediction = collections.namedtuple(
44
    'Prediction',
45
    'gradient_scale value value_logits reward reward_logits policy_logits')
46

47
Target = collections.namedtuple(
48
    'Target', 'value_mask reward_mask policy_mask value reward visits')
49

50
Range = collections.namedtuple('Range', 'low high')
51

52

53
class RLEnvironmentError(Exception):
54
  pass
55

56

57
class BadSupervisedEpisodeError(Exception):
58
  pass
59

60

61
class SkipEpisode(Exception):  # pylint: disable=g-bad-exception-name
62
  pass
63

64

65
class MinMaxStats(object):
66
  """A class that holds the min-max values of the tree."""
67

68
  def __init__(self, known_bounds):
69
    self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE
70
    self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE
71

72
  def update(self, value):
73
    self.maximum = max(self.maximum, value)
74
    self.minimum = min(self.minimum, value)
75

76
  def normalize(self, value):
77
    if self.maximum > self.minimum:
78
      # We normalize only when we have set the maximum and minimum values.
79
      value = (value - self.minimum) / (self.maximum - self.minimum)
80
    value = max(min(1.0, value), 0.0)
81
    return value
82

83

84
class MuZeroConfig:
85
  """Config object for MuZero."""
86

87
  def __init__(self,
88
               action_space_size,
89
               max_moves,
90
               recurrent_inference_batch_size,
91
               initial_inference_batch_size,
92
               train_batch_size,
93
               discount = 0.99,
94
               dirichlet_alpha = 0.25,
95
               root_exploration_fraction = 0.25,
96
               num_simulations = 11,
97
               td_steps = 5,
98
               num_unroll_steps = 5,
99
               pb_c_base = 19652,
100
               pb_c_init = 1.25,
101
               visit_softmax_temperature_fn=None,
102
               known_bounds = None,
103
               use_softmax_for_action_selection = False,
104
               parent_base_visit_count=1,
105
               max_num_action_expansion = 0):
106

107
    ### Play
108
    self.action_space_size = action_space_size
109

110
    self.visit_softmax_temperature_fn = (visit_softmax_temperature_fn
111
                                         if visit_softmax_temperature_fn
112
                                         is not None else lambda a, b, c: 1.0)
113
    self.max_moves = max_moves
114
    self.num_simulations = num_simulations
115
    self.discount = discount
116

117
    # Root prior exploration noise.
118
    self.root_dirichlet_alpha = dirichlet_alpha
119
    self.root_exploration_fraction = root_exploration_fraction
120

121
    # UCB formula
122
    self.pb_c_base = pb_c_base
123
    self.pb_c_init = pb_c_init
124

125
    # If we already have some information about which values occur in the
126
    # environment, we can use them to initialize the rescaling.
127
    # This is not strictly necessary, but establishes identical behaviour to
128
    # AlphaZero in board games.
129
    self.known_bounds = known_bounds
130

131
    ### Training
132
    self.recurrent_inference_batch_size = recurrent_inference_batch_size
133
    self.initial_inference_batch_size = initial_inference_batch_size
134
    self.train_batch_size = train_batch_size
135
    self.num_unroll_steps = num_unroll_steps
136
    self.td_steps = td_steps
137

138
    self.use_softmax_for_action_selection = use_softmax_for_action_selection
139

140
    # This is 0 in the MuZero paper.
141
    self.parent_base_visit_count = parent_base_visit_count
142
    self.max_num_action_expansion = max_num_action_expansion
143

144
  def new_episode(self, environment, index=None):
145
    return Episode(
146
        environment, self.action_space_size, self.discount, index=index)
147

148

149
Action = np.int64  # pylint: disable=invalid-name
150

151

152
class TransitionModel:
153
  """Transition model providing additional information for MCTS transitions.
154

155
  An environment can provide a specialized version of a transition model via the
156
  info dict. This model then provides additional information, e.g. on the legal
157
  actions, between transitions in the MCTS.
158
  """
159

160
  def __init__(self, full_action_space_size):
161
    self.full_action_space_size = full_action_space_size
162

163
  def legal_actions_after_sequence(self,
164
                                   actions_sequence):  # pylint: disable=unused-argument
165
    """Returns the legal action space after a sequence of actions."""
166
    return range(self.full_action_space_size)
167

168
  def full_action_space(self):
169
    return range(self.full_action_space_size)
170

171
  def legal_actions_mask_after_sequence(self,
172
                                        actions_sequence):
173
    """Returns the legal action space after a sequence of actions as a mask."""
174
    mask = np.zeros(self.full_action_space_size, dtype=np.int64)
175
    for action in self.legal_actions_after_sequence(actions_sequence):
176
      mask[action] = 1
177
    return mask
178

179

180
class Node:
181
  """Node for MCTS."""
182

183
  def __init__(self, prior, config, is_root=False):
184
    self.visit_count = 0
185
    self.prior = prior
186
    self.is_root = is_root
187
    self.value_sum = 0
188
    self.children = {}
189
    self.hidden_state = None
190
    self.reward = 0
191
    self.discount = config.discount
192

193
  def expanded(self):
194
    return len(self.children) > 0  # pylint: disable=g-explicit-length-test
195

196
  def value(self):
197
    if self.visit_count == 0:
198
      return 0
199
    return self.value_sum / self.visit_count
200

201
  def qvalue(self):
202
    return self.discount * self.value() + self.reward
203

204

205
class ActionHistory:
206
  """Simple history container used inside the search.
207

208
  Only used to keep track of the actions executed.
209
  """
210

211
  def __init__(self, history, action_space_size):
212
    self.history = list(history)
213
    self.action_space_size = action_space_size
214

215
  def clone(self):
216
    return ActionHistory(self.history, self.action_space_size)
217

218
  def add_action(self, action):
219
    self.history.append(Action(action))
220

221
  def last_action(self):
222
    return self.history[-1]
223

224
  def action_space(self):
225
    return [Action(i) for i in range(self.action_space_size)]
226

227

228
class Episode:
229
  """A single episode of interaction with the environment."""
230

231
  def __init__(self,
232
               environment,
233
               action_space_size,
234
               discount,
235
               index=None):
236
    self.environment = environment
237
    self.history = []
238
    self.observations = []
239
    self.rewards = []
240
    self.child_visits = []
241
    self.root_values = []
242
    self.mcts_visualizations = []
243
    self.action_space_size = action_space_size
244
    self.discount = discount
245
    self.failed = False
246

247
    if index is None:
248
      self._observation, self._info = self.environment.reset()
249
    else:
250
      self._observation, self._info = self.environment.reset(index)
251
    self.observations.append(self._observation)
252
    self._reward = None
253
    self._done = False
254

255
  def terminal(self):
256
    return self._done
257

258
  def get_info(self, kword):
259
    if not self._info:
260
      return None
261
    return self._info.get(kword, None)
262

263
  def total_reward(self):
264
    return sum(self.rewards)
265

266
  def __len__(self):
267
    return len(self.history)
268

269
  def special_statistics(self):
270
    try:
271
      return self.environment.special_episode_statistics()
272
    except AttributeError:
273
      return {}
274

275
  def special_statistics_learner(self):
276
    try:
277
      return self.environment.special_episode_statistics_learner()
278
    except AttributeError:
279
      return ()
280

281
  def visualize_mcts(self, root):
282
    history = self.action_history().history
283
    try:
284
      treestr = self.environment.visualize_mcts(root, history)
285
    except AttributeError:
286
      treestr = ''
287
    if treestr:
288
      self.mcts_visualizations.append(treestr)
289

290
  def legal_actions(self,
291
                    actions_sequence = None
292
                   ):
293
    """Returns the legal actions after an actions sequence.
294

295
    Args:
296
      actions_sequence: Past sequence of actions.
297

298
    Returns:
299
      A list of full_action_space size. At each index a 1 corresponds to a legal
300
      action and a 0 to an illegal action.
301
    """
302
    transition_model = self.get_info('transition_model') or TransitionModel(
303
        self.action_space_size)
304
    actions_sequence = tuple(actions_sequence or [])
305
    return transition_model.legal_actions_mask_after_sequence(actions_sequence)
306

307
  def apply(self, action, training_steps = 0):
308
    (self._observation, self._reward, self._done,
309
     self._info) = self.environment.step(
310
         action, training_steps=training_steps)
311
    self.rewards.append(self._reward)
312
    self.history.append(action)
313
    self.observations.append(self._observation)
314

315
  def history_range(self, start, end):
316
    rng = []
317
    for i in range(start, end):
318
      if i < len(self.history):
319
        rng.append(self.history[i])
320
      else:
321
        rng.append(0)
322
    return np.array(rng, np.int64)
323

324
  def store_search_statistics(self, root, use_softmax=False):
325
    sum_visits = sum(child.visit_count for child in root.children.values())
326
    sum_visits = max(sum_visits, 1)
327
    action_space = (Action(index) for index in range(self.action_space_size))
328
    if use_softmax:
329
      child_visits, mask = zip(*[(root.children[a].visit_count,
330
                                  1) if a in root.children else (0, 0)
331
                                 for a in action_space])
332
      child_visits_distribution = masked_softmax(child_visits, mask)
333
    else:
334
      child_visits_distribution = [
335
          root.children[a].visit_count / sum_visits if a in root.children else 0
336
          for a in action_space
337
      ]
338

339
    self.child_visits.append(child_visits_distribution)
340
    self.root_values.append(root.value())
341

342
  def make_image(self, state_index):
343
    if state_index == -1 or state_index < len(self.observations):
344
      return self.observations[state_index]
345
    return self._observation
346

347
  @staticmethod
348
  def make_target(state_index,
349
                  num_unroll_steps,
350
                  td_steps,
351
                  rewards,
352
                  policy_distributions,
353
                  discount,
354
                  value_approximations = None):
355
    num_steps = len(rewards)
356
    if td_steps == -1:
357
      td_steps = num_steps  # for sure go to the end of the game
358

359
    # The value target is the discounted root value of the search tree N steps
360
    # into the future, plus the discounted sum of all rewards until then.
361
    targets = []
362
    for current_index in range(state_index, state_index + num_unroll_steps + 1):
363
      bootstrap_index = current_index + td_steps
364
      if bootstrap_index < num_steps and value_approximations is not None:
365
        value = value_approximations[bootstrap_index] * discount**td_steps
366
      else:
367
        value = 0
368

369
      for i, reward in enumerate(rewards[current_index:bootstrap_index]):
370
        value += reward * discount**i  # pytype: disable=unsupported-operands
371

372
      reward_mask = 1.0 if current_index > state_index else 0.0
373
      if current_index < num_steps:
374
        targets.append(
375
            (1.0, reward_mask, 1.0, value, rewards[current_index - 1],
376
             policy_distributions[current_index]))
377
      elif current_index == num_steps:
378
        targets.append((1.0, reward_mask, 0.0, 0.0, rewards[current_index - 1],
379
                        policy_distributions[0]))
380
      else:
381
        # States past the end of games are treated as absorbing states.
382
        targets.append((1.0, 0.0, 0.0, 0.0, 0.0, policy_distributions[0]))
383
    target = Target(*zip(*targets))
384
    return target
385

386
  def action_history(self):
387
    return ActionHistory(self.history, self.action_space_size)
388

389

390
def prepare_root_node(config, legal_actions,
391
                      initial_inference_output):
392
  root = Node(0, config, is_root=True)
393
  expand_node(root, legal_actions, initial_inference_output, config)
394
  add_exploration_noise(config, root)
395
  return root
396

397

398
# Core Monte Carlo Tree Search algorithm.
399
# To decide on an action, we run N simulations, always starting at the root of
400
# the search tree and traversing the tree according to the UCB formula until we
401
# reach a leaf node.
402
def run_mcts(config,
403
             root,
404
             action_history,
405
             legal_actions_fn,
406
             recurrent_inference_fn,
407
             visualization_fn=None):
408
  min_max_stats = MinMaxStats(config.known_bounds)
409

410
  for _ in range(config.num_simulations):
411
    history = action_history.clone()
412
    node = root
413
    search_path = [node]
414

415
    while node.expanded():
416
      action, node = select_child(config, node, min_max_stats)
417
      history.add_action(action)
418
      search_path.append(node)
419

420
    # Inside the search tree we use the dynamics function to obtain the next
421
    # hidden state given an action and the previous hidden state.
422
    parent = search_path[-2]
423
    network_output = recurrent_inference_fn(parent.hidden_state,
424
                                            history.last_action())
425
    legal_actions = legal_actions_fn(
426
        history.history[len(action_history.history):])
427
    expand_node(node, legal_actions, network_output, config)
428

429
    backpropagate(search_path, network_output.value, config.discount,
430
                  min_max_stats)
431

432
  if visualization_fn:
433
    visualization_fn(root)
434

435

436
def masked_distribution(x,
437
                        use_exp,
438
                        mask = None):
439
  if mask is None:
440
    mask = [1] * len(x)
441
  assert sum(mask) > 0, 'Not all values can be masked.'
442
  assert len(mask) == len(x), (
443
      'The dimensions of the mask and x need to be the same.')
444
  x = np.exp(x) if use_exp else np.array(x, dtype=np.float64)
445
  mask = np.array(mask, dtype=np.float64)
446
  x *= mask
447
  if sum(x) == 0:
448
    # No unmasked value has any weight. Use uniform distribution over unmasked
449
    # tokens.
450
    x = mask
451
  return x / np.sum(x, keepdims=True)
452

453

454
def masked_softmax(x, mask=None):
455
  x = np.array(x) - np.max(x, axis=-1)  # to avoid overflow
456
  return masked_distribution(x, use_exp=True, mask=mask)
457

458

459
def masked_count_distribution(x, mask=None):
460
  return masked_distribution(x, use_exp=False, mask=mask)
461

462

463
def histogram_sample(distribution,
464
                     temperature,
465
                     use_softmax=False,
466
                     mask=None):
467
  actions = [d[1] for d in distribution]
468
  visit_counts = np.array([d[0] for d in distribution], dtype=np.float64)
469
  if temperature == 0.:
470
    probs = masked_count_distribution(visit_counts, mask=mask)
471
    return actions[np.argmax(probs)]
472
  if use_softmax:
473
    logits = visit_counts / temperature
474
    probs = masked_softmax(logits, mask)
475
  else:
476
    logits = visit_counts**(1. / temperature)
477
    probs = masked_count_distribution(logits, mask)
478
  return np.random.choice(actions, p=probs)
479

480

481
def select_action(config,
482
                  num_moves,
483
                  node,
484
                  train_step,
485
                  use_softmax=False,
486
                  is_training=True):
487
  visit_counts = [
488
      (child.visit_count, action) for action, child in node.children.items()
489
  ]
490
  t = config.visit_softmax_temperature_fn(
491
      num_moves=num_moves, training_steps=train_step, is_training=is_training)
492
  action = histogram_sample(visit_counts, t, use_softmax=use_softmax)
493
  return action
494

495

496
# Select the child with the highest UCB score.
497
def select_child(config, node, min_max_stats):
498
  ucb_scores = [(ucb_score(config, node, child, min_max_stats), action, child)
499
                for action, child in node.children.items()]
500
  _, action, child = max(ucb_scores)
501
  return action, child
502

503

504
# The score for a node is based on its value, plus an exploration bonus based on
505
# the prior.
506
def ucb_score(config, parent, child,
507
              min_max_stats):
508
  pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
509
                  config.pb_c_base) + config.pb_c_init
510
  pb_c *= math.sqrt(parent.visit_count + config.parent_base_visit_count) / (
511
      child.visit_count + 1)
512

513
  prior_score = pb_c * child.prior
514
  if child.visit_count > 0:
515
    value_score = min_max_stats.normalize(child.qvalue())
516
  else:
517
    value_score = 0.
518
  return prior_score + value_score
519

520

521
# We expand a node using the value, reward and policy prediction obtained from
522
# the neural network.
523
def expand_node(node, legal_actions,
524
                network_output, config):
525
  node.hidden_state = network_output.hidden_state
526
  node.reward = network_output.reward
527
  policy_probs = masked_softmax(
528
      network_output.policy_logits, mask=legal_actions.astype(np.float64))
529
  actions = np.where(legal_actions == 1)[0]
530

531
  if (config.max_num_action_expansion > 0 and
532
      len(actions) > config.max_num_action_expansion):
533
    # get indices of the max_num_action_expansion largest probabilities
534
    actions = np.argpartition(
535
        policy_probs,
536
        -config.max_num_action_expansion)[-config.max_num_action_expansion:]
537

538
  policy = {a: policy_probs[a] for a in actions}
539
  for action, p in policy.items():
540
    node.children[action] = Node(p, config)
541

542

543
# At the end of a simulation, we propagate the evaluation all the way up the
544
# tree to the root.
545
def backpropagate(search_path, value, discount,
546
                  min_max_stats):
547
  for node in search_path[::-1]:
548
    node.value_sum += value
549
    node.visit_count += 1
550
    min_max_stats.update(node.qvalue())
551
    value = node.reward + discount * value
552

553

554
# At the start of each search, we add dirichlet noise to the prior of the root
555
# to encourage the search to explore new actions.
556
def add_exploration_noise(config, node):
557
  actions = list(node.children.keys())
558
  noise = np.random.dirichlet([config.root_dirichlet_alpha] * len(actions))
559
  frac = config.root_exploration_fraction
560
  for a, n in zip(actions, noise):
561
    node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac
562

563

564
class ValueEncoder:
565
  """Encoder for reward and value targets from Appendix of MuZero Paper."""
566

567
  def __init__(self,
568
               min_value,
569
               max_value,
570
               num_steps,
571
               use_contractive_mapping=True):
572
    if not max_value > min_value:
573
      raise ValueError('max_value must be > min_value')
574
    min_value = float(min_value)
575
    max_value = float(max_value)
576
    if use_contractive_mapping:
577
      max_value = contractive_mapping(max_value)
578
      min_value = contractive_mapping(min_value)
579
    if num_steps <= 0:
580
      num_steps = int(math.ceil(max_value) + 1 - math.floor(min_value))
581
    logging.info('Initializing ValueEncoder with range (%d, %d) and %d steps',
582
                 min_value, max_value, num_steps)
583
    self.min_value = min_value
584
    self.max_value = max_value
585
    self.value_range = max_value - min_value
586
    self.num_steps = num_steps
587
    self.step_size = self.value_range / (num_steps - 1)
588
    self.step_range_int = tf.range(self.num_steps, dtype=tf.int32)
589
    self.step_range_float = tf.cast(self.step_range_int, tf.float32)
590
    self.use_contractive_mapping = use_contractive_mapping
591

592
  def encode(self, value):
593
    if len(value.shape) != 1:
594
      raise ValueError(
595
          'Expected value to be 1D Tensor [batch_size], but got {}.'.format(
596
              value.shape))
597
    if self.use_contractive_mapping:
598
      value = contractive_mapping(value)
599
    value = tf.expand_dims(value, -1)
600
    clipped_value = tf.clip_by_value(value, self.min_value, self.max_value)
601
    above_min = clipped_value - self.min_value
602
    num_steps = above_min / self.step_size
603
    lower_step = tf.math.floor(num_steps)
604
    upper_mod = num_steps - lower_step
605
    lower_step = tf.cast(lower_step, tf.int32)
606
    upper_step = lower_step + 1
607
    lower_mod = 1.0 - upper_mod
608
    lower_encoding, upper_encoding = (
609
        tf.cast(tf.math.equal(step, self.step_range_int), tf.float32) * mod
610
        for step, mod in (
611
            (lower_step, lower_mod),
612
            (upper_step, upper_mod),
613
        ))
614
    return lower_encoding + upper_encoding
615

616
  def decode(self, logits):
617
    if len(logits.shape) != 2:
618
      raise ValueError(
619
          'Expected logits to be 2D Tensor [batch_size, steps], but got {}.'
620
          .format(logits.shape))
621
    num_steps = tf.reduce_sum(logits * self.step_range_float, -1)
622
    above_min = num_steps * self.step_size
623
    value = above_min + self.min_value
624
    if self.use_contractive_mapping:
625
      value = inverse_contractive_mapping(value)
626
    return value
627

628

629
# From the MuZero paper.
630
def contractive_mapping(x, eps=0.001):
631
  return tf.math.sign(x) * (tf.math.sqrt(tf.math.abs(x) + 1.) - 1.) + eps * x
632

633

634
# From the MuZero paper.
635
def inverse_contractive_mapping(x, eps=0.001):
636
  return tf.math.sign(x) * (
637
      tf.math.square(
638
          (tf.sqrt(4 * eps *
639
                   (tf.math.abs(x) + 1. + eps) + 1.) - 1.) / (2. * eps)) - 1.)
640

641

642
@attr.s(auto_attribs=True)
643
class EnvironmentDescriptor:
644
  """Descriptor for Environments."""
645
  observation_space: gym.spaces.Space
646
  action_space: gym.spaces.Space
647
  reward_range: Range
648
  value_range: Range
649
  pretraining_space: gym.spaces.Space = None
650
  extras: Dict[str, Any] = None
651

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.