google-research
650 строк · 21.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=missing-docstring
17# pylint: disable=g-complex-comprehension
18"""MuZero Core.
19
20Based partially on https://arxiv.org/src/1911.08265v1/anc/pseudocode.py
21"""
22
23import collections24import logging25import math26from typing import List, Optional, Dict, Any, Tuple27
28from absl import flags29import attr30import gym31import numpy as np32import tensorflow as tf33
34FLAGS = flags.FLAGS35MAXIMUM_FLOAT_VALUE = float('inf')36
37KnownBounds = collections.namedtuple('KnownBounds', 'min max')38
39NetworkOutput = collections.namedtuple(40'NetworkOutput',41'value value_logits reward reward_logits policy_logits hidden_state')42
43Prediction = collections.namedtuple(44'Prediction',45'gradient_scale value value_logits reward reward_logits policy_logits')46
47Target = collections.namedtuple(48'Target', 'value_mask reward_mask policy_mask value reward visits')49
50Range = collections.namedtuple('Range', 'low high')51
52
53class RLEnvironmentError(Exception):54pass55
56
57class BadSupervisedEpisodeError(Exception):58pass59
60
61class SkipEpisode(Exception): # pylint: disable=g-bad-exception-name62pass63
64
65class MinMaxStats(object):66"""A class that holds the min-max values of the tree."""67
68def __init__(self, known_bounds):69self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE70self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE71
72def update(self, value):73self.maximum = max(self.maximum, value)74self.minimum = min(self.minimum, value)75
76def normalize(self, value):77if self.maximum > self.minimum:78# We normalize only when we have set the maximum and minimum values.79value = (value - self.minimum) / (self.maximum - self.minimum)80value = max(min(1.0, value), 0.0)81return value82
83
84class MuZeroConfig:85"""Config object for MuZero."""86
87def __init__(self,88action_space_size,89max_moves,90recurrent_inference_batch_size,91initial_inference_batch_size,92train_batch_size,93discount = 0.99,94dirichlet_alpha = 0.25,95root_exploration_fraction = 0.25,96num_simulations = 11,97td_steps = 5,98num_unroll_steps = 5,99pb_c_base = 19652,100pb_c_init = 1.25,101visit_softmax_temperature_fn=None,102known_bounds = None,103use_softmax_for_action_selection = False,104parent_base_visit_count=1,105max_num_action_expansion = 0):106
107### Play108self.action_space_size = action_space_size109
110self.visit_softmax_temperature_fn = (visit_softmax_temperature_fn111if visit_softmax_temperature_fn112is not None else lambda a, b, c: 1.0)113self.max_moves = max_moves114self.num_simulations = num_simulations115self.discount = discount116
117# Root prior exploration noise.118self.root_dirichlet_alpha = dirichlet_alpha119self.root_exploration_fraction = root_exploration_fraction120
121# UCB formula122self.pb_c_base = pb_c_base123self.pb_c_init = pb_c_init124
125# If we already have some information about which values occur in the126# environment, we can use them to initialize the rescaling.127# This is not strictly necessary, but establishes identical behaviour to128# AlphaZero in board games.129self.known_bounds = known_bounds130
131### Training132self.recurrent_inference_batch_size = recurrent_inference_batch_size133self.initial_inference_batch_size = initial_inference_batch_size134self.train_batch_size = train_batch_size135self.num_unroll_steps = num_unroll_steps136self.td_steps = td_steps137
138self.use_softmax_for_action_selection = use_softmax_for_action_selection139
140# This is 0 in the MuZero paper.141self.parent_base_visit_count = parent_base_visit_count142self.max_num_action_expansion = max_num_action_expansion143
144def new_episode(self, environment, index=None):145return Episode(146environment, self.action_space_size, self.discount, index=index)147
148
149Action = np.int64 # pylint: disable=invalid-name150
151
152class TransitionModel:153"""Transition model providing additional information for MCTS transitions.154
155An environment can provide a specialized version of a transition model via the
156info dict. This model then provides additional information, e.g. on the legal
157actions, between transitions in the MCTS.
158"""
159
160def __init__(self, full_action_space_size):161self.full_action_space_size = full_action_space_size162
163def legal_actions_after_sequence(self,164actions_sequence): # pylint: disable=unused-argument165"""Returns the legal action space after a sequence of actions."""166return range(self.full_action_space_size)167
168def full_action_space(self):169return range(self.full_action_space_size)170
171def legal_actions_mask_after_sequence(self,172actions_sequence):173"""Returns the legal action space after a sequence of actions as a mask."""174mask = np.zeros(self.full_action_space_size, dtype=np.int64)175for action in self.legal_actions_after_sequence(actions_sequence):176mask[action] = 1177return mask178
179
180class Node:181"""Node for MCTS."""182
183def __init__(self, prior, config, is_root=False):184self.visit_count = 0185self.prior = prior186self.is_root = is_root187self.value_sum = 0188self.children = {}189self.hidden_state = None190self.reward = 0191self.discount = config.discount192
193def expanded(self):194return len(self.children) > 0 # pylint: disable=g-explicit-length-test195
196def value(self):197if self.visit_count == 0:198return 0199return self.value_sum / self.visit_count200
201def qvalue(self):202return self.discount * self.value() + self.reward203
204
205class ActionHistory:206"""Simple history container used inside the search.207
208Only used to keep track of the actions executed.
209"""
210
211def __init__(self, history, action_space_size):212self.history = list(history)213self.action_space_size = action_space_size214
215def clone(self):216return ActionHistory(self.history, self.action_space_size)217
218def add_action(self, action):219self.history.append(Action(action))220
221def last_action(self):222return self.history[-1]223
224def action_space(self):225return [Action(i) for i in range(self.action_space_size)]226
227
228class Episode:229"""A single episode of interaction with the environment."""230
231def __init__(self,232environment,233action_space_size,234discount,235index=None):236self.environment = environment237self.history = []238self.observations = []239self.rewards = []240self.child_visits = []241self.root_values = []242self.mcts_visualizations = []243self.action_space_size = action_space_size244self.discount = discount245self.failed = False246
247if index is None:248self._observation, self._info = self.environment.reset()249else:250self._observation, self._info = self.environment.reset(index)251self.observations.append(self._observation)252self._reward = None253self._done = False254
255def terminal(self):256return self._done257
258def get_info(self, kword):259if not self._info:260return None261return self._info.get(kword, None)262
263def total_reward(self):264return sum(self.rewards)265
266def __len__(self):267return len(self.history)268
269def special_statistics(self):270try:271return self.environment.special_episode_statistics()272except AttributeError:273return {}274
275def special_statistics_learner(self):276try:277return self.environment.special_episode_statistics_learner()278except AttributeError:279return ()280
281def visualize_mcts(self, root):282history = self.action_history().history283try:284treestr = self.environment.visualize_mcts(root, history)285except AttributeError:286treestr = ''287if treestr:288self.mcts_visualizations.append(treestr)289
290def legal_actions(self,291actions_sequence = None292):293"""Returns the legal actions after an actions sequence.294
295Args:
296actions_sequence: Past sequence of actions.
297
298Returns:
299A list of full_action_space size. At each index a 1 corresponds to a legal
300action and a 0 to an illegal action.
301"""
302transition_model = self.get_info('transition_model') or TransitionModel(303self.action_space_size)304actions_sequence = tuple(actions_sequence or [])305return transition_model.legal_actions_mask_after_sequence(actions_sequence)306
307def apply(self, action, training_steps = 0):308(self._observation, self._reward, self._done,309self._info) = self.environment.step(310action, training_steps=training_steps)311self.rewards.append(self._reward)312self.history.append(action)313self.observations.append(self._observation)314
315def history_range(self, start, end):316rng = []317for i in range(start, end):318if i < len(self.history):319rng.append(self.history[i])320else:321rng.append(0)322return np.array(rng, np.int64)323
324def store_search_statistics(self, root, use_softmax=False):325sum_visits = sum(child.visit_count for child in root.children.values())326sum_visits = max(sum_visits, 1)327action_space = (Action(index) for index in range(self.action_space_size))328if use_softmax:329child_visits, mask = zip(*[(root.children[a].visit_count,3301) if a in root.children else (0, 0)331for a in action_space])332child_visits_distribution = masked_softmax(child_visits, mask)333else:334child_visits_distribution = [335root.children[a].visit_count / sum_visits if a in root.children else 0336for a in action_space337]338
339self.child_visits.append(child_visits_distribution)340self.root_values.append(root.value())341
342def make_image(self, state_index):343if state_index == -1 or state_index < len(self.observations):344return self.observations[state_index]345return self._observation346
347@staticmethod348def make_target(state_index,349num_unroll_steps,350td_steps,351rewards,352policy_distributions,353discount,354value_approximations = None):355num_steps = len(rewards)356if td_steps == -1:357td_steps = num_steps # for sure go to the end of the game358
359# The value target is the discounted root value of the search tree N steps360# into the future, plus the discounted sum of all rewards until then.361targets = []362for current_index in range(state_index, state_index + num_unroll_steps + 1):363bootstrap_index = current_index + td_steps364if bootstrap_index < num_steps and value_approximations is not None:365value = value_approximations[bootstrap_index] * discount**td_steps366else:367value = 0368
369for i, reward in enumerate(rewards[current_index:bootstrap_index]):370value += reward * discount**i # pytype: disable=unsupported-operands371
372reward_mask = 1.0 if current_index > state_index else 0.0373if current_index < num_steps:374targets.append(375(1.0, reward_mask, 1.0, value, rewards[current_index - 1],376policy_distributions[current_index]))377elif current_index == num_steps:378targets.append((1.0, reward_mask, 0.0, 0.0, rewards[current_index - 1],379policy_distributions[0]))380else:381# States past the end of games are treated as absorbing states.382targets.append((1.0, 0.0, 0.0, 0.0, 0.0, policy_distributions[0]))383target = Target(*zip(*targets))384return target385
386def action_history(self):387return ActionHistory(self.history, self.action_space_size)388
389
390def prepare_root_node(config, legal_actions,391initial_inference_output):392root = Node(0, config, is_root=True)393expand_node(root, legal_actions, initial_inference_output, config)394add_exploration_noise(config, root)395return root396
397
398# Core Monte Carlo Tree Search algorithm.
399# To decide on an action, we run N simulations, always starting at the root of
400# the search tree and traversing the tree according to the UCB formula until we
401# reach a leaf node.
402def run_mcts(config,403root,404action_history,405legal_actions_fn,406recurrent_inference_fn,407visualization_fn=None):408min_max_stats = MinMaxStats(config.known_bounds)409
410for _ in range(config.num_simulations):411history = action_history.clone()412node = root413search_path = [node]414
415while node.expanded():416action, node = select_child(config, node, min_max_stats)417history.add_action(action)418search_path.append(node)419
420# Inside the search tree we use the dynamics function to obtain the next421# hidden state given an action and the previous hidden state.422parent = search_path[-2]423network_output = recurrent_inference_fn(parent.hidden_state,424history.last_action())425legal_actions = legal_actions_fn(426history.history[len(action_history.history):])427expand_node(node, legal_actions, network_output, config)428
429backpropagate(search_path, network_output.value, config.discount,430min_max_stats)431
432if visualization_fn:433visualization_fn(root)434
435
436def masked_distribution(x,437use_exp,438mask = None):439if mask is None:440mask = [1] * len(x)441assert sum(mask) > 0, 'Not all values can be masked.'442assert len(mask) == len(x), (443'The dimensions of the mask and x need to be the same.')444x = np.exp(x) if use_exp else np.array(x, dtype=np.float64)445mask = np.array(mask, dtype=np.float64)446x *= mask447if sum(x) == 0:448# No unmasked value has any weight. Use uniform distribution over unmasked449# tokens.450x = mask451return x / np.sum(x, keepdims=True)452
453
454def masked_softmax(x, mask=None):455x = np.array(x) - np.max(x, axis=-1) # to avoid overflow456return masked_distribution(x, use_exp=True, mask=mask)457
458
459def masked_count_distribution(x, mask=None):460return masked_distribution(x, use_exp=False, mask=mask)461
462
463def histogram_sample(distribution,464temperature,465use_softmax=False,466mask=None):467actions = [d[1] for d in distribution]468visit_counts = np.array([d[0] for d in distribution], dtype=np.float64)469if temperature == 0.:470probs = masked_count_distribution(visit_counts, mask=mask)471return actions[np.argmax(probs)]472if use_softmax:473logits = visit_counts / temperature474probs = masked_softmax(logits, mask)475else:476logits = visit_counts**(1. / temperature)477probs = masked_count_distribution(logits, mask)478return np.random.choice(actions, p=probs)479
480
481def select_action(config,482num_moves,483node,484train_step,485use_softmax=False,486is_training=True):487visit_counts = [488(child.visit_count, action) for action, child in node.children.items()489]490t = config.visit_softmax_temperature_fn(491num_moves=num_moves, training_steps=train_step, is_training=is_training)492action = histogram_sample(visit_counts, t, use_softmax=use_softmax)493return action494
495
496# Select the child with the highest UCB score.
497def select_child(config, node, min_max_stats):498ucb_scores = [(ucb_score(config, node, child, min_max_stats), action, child)499for action, child in node.children.items()]500_, action, child = max(ucb_scores)501return action, child502
503
504# The score for a node is based on its value, plus an exploration bonus based on
505# the prior.
506def ucb_score(config, parent, child,507min_max_stats):508pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /509config.pb_c_base) + config.pb_c_init510pb_c *= math.sqrt(parent.visit_count + config.parent_base_visit_count) / (511child.visit_count + 1)512
513prior_score = pb_c * child.prior514if child.visit_count > 0:515value_score = min_max_stats.normalize(child.qvalue())516else:517value_score = 0.518return prior_score + value_score519
520
521# We expand a node using the value, reward and policy prediction obtained from
522# the neural network.
523def expand_node(node, legal_actions,524network_output, config):525node.hidden_state = network_output.hidden_state526node.reward = network_output.reward527policy_probs = masked_softmax(528network_output.policy_logits, mask=legal_actions.astype(np.float64))529actions = np.where(legal_actions == 1)[0]530
531if (config.max_num_action_expansion > 0 and532len(actions) > config.max_num_action_expansion):533# get indices of the max_num_action_expansion largest probabilities534actions = np.argpartition(535policy_probs,536-config.max_num_action_expansion)[-config.max_num_action_expansion:]537
538policy = {a: policy_probs[a] for a in actions}539for action, p in policy.items():540node.children[action] = Node(p, config)541
542
543# At the end of a simulation, we propagate the evaluation all the way up the
544# tree to the root.
545def backpropagate(search_path, value, discount,546min_max_stats):547for node in search_path[::-1]:548node.value_sum += value549node.visit_count += 1550min_max_stats.update(node.qvalue())551value = node.reward + discount * value552
553
554# At the start of each search, we add dirichlet noise to the prior of the root
555# to encourage the search to explore new actions.
556def add_exploration_noise(config, node):557actions = list(node.children.keys())558noise = np.random.dirichlet([config.root_dirichlet_alpha] * len(actions))559frac = config.root_exploration_fraction560for a, n in zip(actions, noise):561node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac562
563
564class ValueEncoder:565"""Encoder for reward and value targets from Appendix of MuZero Paper."""566
567def __init__(self,568min_value,569max_value,570num_steps,571use_contractive_mapping=True):572if not max_value > min_value:573raise ValueError('max_value must be > min_value')574min_value = float(min_value)575max_value = float(max_value)576if use_contractive_mapping:577max_value = contractive_mapping(max_value)578min_value = contractive_mapping(min_value)579if num_steps <= 0:580num_steps = int(math.ceil(max_value) + 1 - math.floor(min_value))581logging.info('Initializing ValueEncoder with range (%d, %d) and %d steps',582min_value, max_value, num_steps)583self.min_value = min_value584self.max_value = max_value585self.value_range = max_value - min_value586self.num_steps = num_steps587self.step_size = self.value_range / (num_steps - 1)588self.step_range_int = tf.range(self.num_steps, dtype=tf.int32)589self.step_range_float = tf.cast(self.step_range_int, tf.float32)590self.use_contractive_mapping = use_contractive_mapping591
592def encode(self, value):593if len(value.shape) != 1:594raise ValueError(595'Expected value to be 1D Tensor [batch_size], but got {}.'.format(596value.shape))597if self.use_contractive_mapping:598value = contractive_mapping(value)599value = tf.expand_dims(value, -1)600clipped_value = tf.clip_by_value(value, self.min_value, self.max_value)601above_min = clipped_value - self.min_value602num_steps = above_min / self.step_size603lower_step = tf.math.floor(num_steps)604upper_mod = num_steps - lower_step605lower_step = tf.cast(lower_step, tf.int32)606upper_step = lower_step + 1607lower_mod = 1.0 - upper_mod608lower_encoding, upper_encoding = (609tf.cast(tf.math.equal(step, self.step_range_int), tf.float32) * mod610for step, mod in (611(lower_step, lower_mod),612(upper_step, upper_mod),613))614return lower_encoding + upper_encoding615
616def decode(self, logits):617if len(logits.shape) != 2:618raise ValueError(619'Expected logits to be 2D Tensor [batch_size, steps], but got {}.'620.format(logits.shape))621num_steps = tf.reduce_sum(logits * self.step_range_float, -1)622above_min = num_steps * self.step_size623value = above_min + self.min_value624if self.use_contractive_mapping:625value = inverse_contractive_mapping(value)626return value627
628
629# From the MuZero paper.
630def contractive_mapping(x, eps=0.001):631return tf.math.sign(x) * (tf.math.sqrt(tf.math.abs(x) + 1.) - 1.) + eps * x632
633
634# From the MuZero paper.
635def inverse_contractive_mapping(x, eps=0.001):636return tf.math.sign(x) * (637tf.math.square(638(tf.sqrt(4 * eps *639(tf.math.abs(x) + 1. + eps) + 1.) - 1.) / (2. * eps)) - 1.)640
641
642@attr.s(auto_attribs=True)643class EnvironmentDescriptor:644"""Descriptor for Environments."""645observation_space: gym.spaces.Space646action_space: gym.spaces.Space647reward_range: Range648value_range: Range649pretraining_space: gym.spaces.Space = None650extras: Dict[str, Any] = None651