google-research
374 строки · 12.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=g-complex-comprehension
17"""Encoder+LSTM network for use with MuZero."""
18
19import os20
21import tensorflow as tf22import tensorflow_addons as tfa23
24from muzero import core as mzcore25
26
27
28class AbstractEncoderandLSTM(tf.Module):29"""Encoder+stacked LSTM Agent.30
31When using this, implement the `_encode_observation` method.
32"""
33
34def __init__(self,35parametric_action_distribution,36rnn_sizes,37head_hidden_sizes,38reward_encoder,39value_encoder,40normalize_hidden_state=False,41rnn_cell_type='lstm_norm',42recurrent_activation='sigmoid',43head_relu_before_norm=False,44nonlinear_to_hidden=False,45embed_actions=True):46"""Creates an Encoder followed by a stacked LSTM agent.47
48Args:
49parametric_action_distribution: an object of ParametricDistribution class
50specifing a parametric distribution over actions to be used
51rnn_sizes: list of integers with sizes of LSTM layers
52head_hidden_sizes: list of integers with sizes of head layers
53reward_encoder: a value encoder for the reward
54value_encoder: a value encoder for the value
55normalize_hidden_state: boolean to normalize the hidden state each step
56rnn_cell_type: 'gru', 'simple', 'lstm_norm' or 'lstm'
57recurrent_activation: a keras activation function
58head_relu_before_norm: put ReLU before the normalization
59nonlinear_to_hidden: appends recurrent_activation to the latent projection
60embed_actions: use action embeddings instead of one hot encodings
61"""
62super().__init__(name='MuZeroAgent')63self._parametric_action_distribution = parametric_action_distribution64self.reward_encoder = reward_encoder65self.value_encoder = value_encoder66self.head_hidden_sizes = head_hidden_sizes67self._normalize_hidden_state = normalize_hidden_state68self._rnn_cell_type = rnn_cell_type69self._head_relu_before_norm = head_relu_before_norm70
71# LSTMs pass on 2x their state size72self._rnn_sizes = rnn_sizes73if rnn_cell_type in ('gru', 'simple'):74self.hidden_state_size = sum(rnn_sizes)75elif rnn_cell_type in ('lstm', 'lstm_norm'):76self.hidden_state_size = sum(rnn_sizes) * 277
78self._to_hidden = tf.keras.Sequential(79[80# flattening the representation81tf.keras.layers.Flatten(),82# mapping it to the size and domain of the hidden state83tf.keras.layers.Dense(84self.hidden_state_size,85activation=(recurrent_activation86if nonlinear_to_hidden else None),87name='final')88],89name='to_hidden')90
91self._embed_actions = embed_actions92if self._embed_actions:93self._action_embeddings = tf.keras.layers.Dense(self.hidden_state_size)94
95# RNNs are a convenient choice for muzero, because they can take the action96# as input and compute the reward from the output, while computing97# value and policy from the hidden states.98
99rnn_cell_cls = {100'gru': tf.keras.layers.GRUCell,101'lstm': tf.keras.layers.LSTMCell,102'lstm_norm': tfa.rnn.LayerNormLSTMCell,103'simple': tf.keras.layers.SimpleRNNCell,104}[rnn_cell_type]105
106rnn_cells = [107rnn_cell_cls(108size,109recurrent_activation=recurrent_activation,110name='cell_{}'.format(idx)) for idx, size in enumerate(rnn_sizes)111]112self._core = tf.keras.layers.StackedRNNCells(113rnn_cells, name='recurrent_core')114
115self._policy_head = tf.keras.Sequential(116self._head_hidden_layers() + [117tf.keras.layers.Dense(118parametric_action_distribution.param_size, name='output')119],120name='policy_logits')121
122# Note that value and reward are logits, because their values are binned.123# See utils.ValueEncoder for details.124self._value_head = tf.keras.Sequential(125self._head_hidden_layers() + [126tf.keras.layers.Dense(self.value_encoder.num_steps, name='output'),127],128name='value_logits')129self._reward_head = tf.keras.Sequential(130self._head_hidden_layers() + [131tf.keras.layers.Dense(self.reward_encoder.num_steps, name='output'),132],133name='reward_logits')134
135# Each head can have its own hidden layers.136def _head_hidden_layers(self):137
138def _make_layer(size):139if self._head_relu_before_norm:140return [141tf.keras.layers.Dense(size, 'relu'),142tf.keras.layers.LayerNormalization(),143]144else:145return [146tf.keras.layers.Dense(size, use_bias=False),147tf.keras.layers.LayerNormalization(),148tf.keras.layers.ReLU(),149]150
151layers = [152tf.keras.Sequential(153_make_layer(size), name='intermediate_{}'.format(idx))154for idx, size in enumerate(self.head_hidden_sizes)155]156return layers157
158@staticmethod159def _rnn_to_flat(state):160"""Maps LSTM state to flat vector."""161states = []162for cell_state in state:163if not (isinstance(cell_state, list) or isinstance(cell_state, tuple)):164# This is a GRU or SimpleRNNCell165cell_state = (cell_state,)166states.extend(cell_state)167return tf.concat(states, -1)168
169def _flat_to_rnn(self, state):170"""Maps flat vector to LSTM state."""171tensors = []172cur_idx = 0173for size in self._rnn_sizes:174if self._rnn_cell_type in ('gru', 'simple'):175states = (state[Ellipsis, cur_idx:cur_idx + size],)176cur_idx += size177elif self._rnn_cell_type in ('lstm', 'lstm_norm'):178states = (state[Ellipsis, cur_idx:cur_idx + size],179state[Ellipsis, cur_idx + size:cur_idx + 2 * size])180cur_idx += 2 * size181tensors.append(states)182assert cur_idx == state.shape[-1]183return tensors184
185def initial_state(self, batch_size):186return tf.zeros((batch_size, self.hidden_state_size))187
188def _encode_observation(self, observation, training=True):189raise NotImplementedError()190
191def pretraining_loss(self, sample, training=True):192raise NotImplementedError()193
194def get_pretraining_trainable_variables(self):195return self.trainable_variables196
197def get_rl_trainable_variables(self):198return self.trainable_variables199
200def get_trainable_variables(self, pretraining=False):201if pretraining:202return self.get_pretraining_trainable_variables()203else:204return self.get_rl_trainable_variables()205
206def initial_inference(self, observation, training=True):207encoded_observation = self._encode_observation(208observation, training=training)209hidden_state = self._to_hidden(encoded_observation, training=training)210
211value_logits = self._value_head(hidden_state, training=training)212value = self.value_encoder.decode(tf.nn.softmax(value_logits))213
214# Rewards are only calculated in recurrent_inference.215reward = tf.zeros_like(value)216reward_logits = self.reward_encoder.encode(reward)217
218policy_logits = self._policy_head(hidden_state, training=training)219
220outputs = mzcore.NetworkOutput(221value_logits=value_logits,222value=value,223reward_logits=reward_logits,224reward=reward,225policy_logits=policy_logits,226hidden_state=hidden_state)227return outputs228
229def _maybe_normalize_hidden_state(self, hidden_state):230if self._normalize_hidden_state:231# This is in the paper, but probably unnecessary.232max_hidden_state = tf.reduce_max(hidden_state, -1, keepdims=True)233min_hidden_state = tf.reduce_min(hidden_state, -1, keepdims=True)234hidden_state_range = max_hidden_state - min_hidden_state235hidden_state = hidden_state - min_hidden_state236hidden_state = tf.math.divide_no_nan(hidden_state, hidden_state_range)237hidden_state = hidden_state * 2. - 1.238return hidden_state239
240def recurrent_inference(self, hidden_state, action, training=True):241if self._embed_actions:242one_hot_action = tf.one_hot(243action, self._parametric_action_distribution.param_size)244embedded_action = self._action_embeddings(one_hot_action)245else:246one_hot_action = tf.one_hot(247action, self._parametric_action_distribution.param_size, 1., -1.)248embedded_action = one_hot_action249hidden_state = self._maybe_normalize_hidden_state(hidden_state)250
251rnn_state = self._flat_to_rnn(hidden_state)252rnn_output, next_rnn_state = self._core(embedded_action, rnn_state)253next_hidden_state = self._rnn_to_flat(next_rnn_state)254
255value_logits = self._value_head(next_hidden_state, training=training)256value = self.value_encoder.decode(tf.nn.softmax(value_logits))257
258reward_logits = self._reward_head(rnn_output, training=training)259reward = self.reward_encoder.decode(tf.nn.softmax(reward_logits))260
261policy_logits = self._policy_head(next_hidden_state, training=training)262
263output = mzcore.NetworkOutput(264value=value,265value_logits=value_logits,266reward=reward,267reward_logits=reward_logits,268policy_logits=policy_logits,269hidden_state=next_hidden_state)270return output271
272
273class ExportedAgent(tf.Module):274"""Wraps an Agent for export."""275
276def __init__(self, agent_module):277self._agent = agent_module278
279def initial_inference(self, input_ids, segment_ids, features, action_history):280output = self._agent.initial_inference(281observation=(input_ids, segment_ids, features, action_history),282training=False)283return [284output.value,285output.value_logits,286output.reward,287output.reward_logits,288output.policy_logits,289output.hidden_state,290]291
292def recurrent_inference(self, hidden_state, action):293output = self._agent.recurrent_inference(294hidden_state=hidden_state, action=action, training=False)295return [296output.value,297output.value_logits,298output.reward,299output.reward_logits,300output.policy_logits,301output.hidden_state,302]303
304
305def export_agent_for_initial_inference(agent,306model_dir):307"""Export `agent` as a TPU servable model for initial inference."""308
309def get_initial_inference_fn(model):310
311@tf.function(input_signature=[312tf.TensorSpec(shape=(None, 512), dtype=tf.int32, name='input_ids'),313tf.TensorSpec(shape=(None, 512, 1), dtype=tf.int32, name='segment_ids'),314tf.TensorSpec(shape=(None, 512, 2), dtype=tf.float32, name='features'),315tf.TensorSpec(shape=(None, 10), dtype=tf.int32, name='action_history'),316])317def serve_fn(input_ids, segment_ids, features, action_history):318return model.initial_inference(319input_ids=input_ids,320segment_ids=segment_ids,321features=features,322action_history=action_history)323
324return serve_fn325
326exported_agent = ExportedAgent(agent_module=agent)327initial_fn = get_initial_inference_fn(exported_agent)328
329# Export.330save_options = tf.saved_model.SaveOptions(function_aliases={331'initial_inference': initial_fn,332})333# Saves the CPU model, which will be rewritten to a TPU model.334tf.saved_model.save(335obj=exported_agent,336export_dir=model_dir,337signatures={338'initial_inference': initial_fn,339},340options=save_options)341
342
343
344def export_agent_for_recurrent_inference(agent,345model_dir):346"""Export `agent` as a TPU servable model for recurrent inference."""347
348def get_recurrent_inference_fn(model):349
350@tf.function(input_signature=[351tf.TensorSpec(352shape=(None, 1024), dtype=tf.float32, name='hidden_state'),353tf.TensorSpec(shape=(None,), dtype=tf.int32, name='action')354])355def serve_fn(hidden_state, action):356return model.recurrent_inference(hidden_state=hidden_state, action=action)357
358return serve_fn359
360exported_agent = ExportedAgent(agent_module=agent)361recurrent_fn = get_recurrent_inference_fn(exported_agent)362
363# Export.364save_options = tf.saved_model.SaveOptions(function_aliases={365'recurrent_inference': recurrent_fn,366})367# Saves the CPU model, which will be rewritten to a TPU model.368tf.saved_model.save(369obj=exported_agent,370export_dir=model_dir,371signatures={372'recurrent_inference': recurrent_fn,373},374options=save_options)375
376