google-research
41 строка · 1.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=g-complex-comprehension
17# pylint: disable=missing-docstring
18"""MLP+LSTM network for use with MuZero."""
19
20import tensorflow as tf
21
22from muzero import network
23
24
25class MLPandLSTM(network.AbstractEncoderandLSTM):
26
27def __init__(self, mlp_sizes, *args, **kwargs):
28super().__init__(*args, **kwargs)
29mlp_layers = [
30tf.keras.Sequential([
31tf.keras.layers.Dense(size, activation='relu', use_bias=False),
32tf.keras.layers.LayerNormalization(),
33],
34name='intermediate_{}'.format(idx))
35for idx, size in enumerate(mlp_sizes)
36]
37self._observation_encoder = tf.keras.Sequential(
38mlp_layers, name='observation_encoder')
39
40def _encode_observation(self, observation, training=True):
41return self._observation_encoder(observation, training=training)
42