google-research

Форк
0
/
network.py 
374 строки · 12.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=g-complex-comprehension
17
"""Encoder+LSTM network for use with MuZero."""
18

19
import os
20

21
import tensorflow as tf
22
import tensorflow_addons as tfa
23

24
from muzero import core as mzcore
25

26

27

28
class AbstractEncoderandLSTM(tf.Module):
29
  """Encoder+stacked LSTM Agent.
30

31
  When using this, implement the `_encode_observation` method.
32
  """
33

34
  def __init__(self,
35
               parametric_action_distribution,
36
               rnn_sizes,
37
               head_hidden_sizes,
38
               reward_encoder,
39
               value_encoder,
40
               normalize_hidden_state=False,
41
               rnn_cell_type='lstm_norm',
42
               recurrent_activation='sigmoid',
43
               head_relu_before_norm=False,
44
               nonlinear_to_hidden=False,
45
               embed_actions=True):
46
    """Creates an Encoder followed by a stacked LSTM agent.
47

48
    Args:
49
      parametric_action_distribution: an object of ParametricDistribution class
50
        specifing a parametric distribution over actions to be used
51
      rnn_sizes: list of integers with sizes of LSTM layers
52
      head_hidden_sizes: list of integers with sizes of head layers
53
      reward_encoder: a value encoder for the reward
54
      value_encoder: a value encoder for the value
55
      normalize_hidden_state: boolean to normalize the hidden state each step
56
      rnn_cell_type: 'gru', 'simple', 'lstm_norm' or 'lstm'
57
      recurrent_activation: a keras activation function
58
      head_relu_before_norm: put ReLU before the normalization
59
      nonlinear_to_hidden: appends recurrent_activation to the latent projection
60
      embed_actions: use action embeddings instead of one hot encodings
61
    """
62
    super().__init__(name='MuZeroAgent')
63
    self._parametric_action_distribution = parametric_action_distribution
64
    self.reward_encoder = reward_encoder
65
    self.value_encoder = value_encoder
66
    self.head_hidden_sizes = head_hidden_sizes
67
    self._normalize_hidden_state = normalize_hidden_state
68
    self._rnn_cell_type = rnn_cell_type
69
    self._head_relu_before_norm = head_relu_before_norm
70

71
    # LSTMs pass on 2x their state size
72
    self._rnn_sizes = rnn_sizes
73
    if rnn_cell_type in ('gru', 'simple'):
74
      self.hidden_state_size = sum(rnn_sizes)
75
    elif rnn_cell_type in ('lstm', 'lstm_norm'):
76
      self.hidden_state_size = sum(rnn_sizes) * 2
77

78
    self._to_hidden = tf.keras.Sequential(
79
        [
80
            # flattening the representation
81
            tf.keras.layers.Flatten(),
82
            # mapping it to the size and domain of the hidden state
83
            tf.keras.layers.Dense(
84
                self.hidden_state_size,
85
                activation=(recurrent_activation
86
                            if nonlinear_to_hidden else None),
87
                name='final')
88
        ],
89
        name='to_hidden')
90

91
    self._embed_actions = embed_actions
92
    if self._embed_actions:
93
      self._action_embeddings = tf.keras.layers.Dense(self.hidden_state_size)
94

95
    # RNNs are a convenient choice for muzero, because they can take the action
96
    # as input and compute the reward from the output, while computing
97
    # value and policy from the hidden states.
98

99
    rnn_cell_cls = {
100
        'gru': tf.keras.layers.GRUCell,
101
        'lstm': tf.keras.layers.LSTMCell,
102
        'lstm_norm': tfa.rnn.LayerNormLSTMCell,
103
        'simple': tf.keras.layers.SimpleRNNCell,
104
    }[rnn_cell_type]
105

106
    rnn_cells = [
107
        rnn_cell_cls(
108
            size,
109
            recurrent_activation=recurrent_activation,
110
            name='cell_{}'.format(idx)) for idx, size in enumerate(rnn_sizes)
111
    ]
112
    self._core = tf.keras.layers.StackedRNNCells(
113
        rnn_cells, name='recurrent_core')
114

115
    self._policy_head = tf.keras.Sequential(
116
        self._head_hidden_layers() + [
117
            tf.keras.layers.Dense(
118
                parametric_action_distribution.param_size, name='output')
119
        ],
120
        name='policy_logits')
121

122
    # Note that value and reward are logits, because their values are binned.
123
    # See utils.ValueEncoder for details.
124
    self._value_head = tf.keras.Sequential(
125
        self._head_hidden_layers() + [
126
            tf.keras.layers.Dense(self.value_encoder.num_steps, name='output'),
127
        ],
128
        name='value_logits')
129
    self._reward_head = tf.keras.Sequential(
130
        self._head_hidden_layers() + [
131
            tf.keras.layers.Dense(self.reward_encoder.num_steps, name='output'),
132
        ],
133
        name='reward_logits')
134

135
  # Each head can have its own hidden layers.
136
  def _head_hidden_layers(self):
137

138
    def _make_layer(size):
139
      if self._head_relu_before_norm:
140
        return [
141
            tf.keras.layers.Dense(size, 'relu'),
142
            tf.keras.layers.LayerNormalization(),
143
        ]
144
      else:
145
        return [
146
            tf.keras.layers.Dense(size, use_bias=False),
147
            tf.keras.layers.LayerNormalization(),
148
            tf.keras.layers.ReLU(),
149
        ]
150

151
    layers = [
152
        tf.keras.Sequential(
153
            _make_layer(size), name='intermediate_{}'.format(idx))
154
        for idx, size in enumerate(self.head_hidden_sizes)
155
    ]
156
    return layers
157

158
  @staticmethod
159
  def _rnn_to_flat(state):
160
    """Maps LSTM state to flat vector."""
161
    states = []
162
    for cell_state in state:
163
      if not (isinstance(cell_state, list) or isinstance(cell_state, tuple)):
164
        # This is a GRU or SimpleRNNCell
165
        cell_state = (cell_state,)
166
      states.extend(cell_state)
167
    return tf.concat(states, -1)
168

169
  def _flat_to_rnn(self, state):
170
    """Maps flat vector to LSTM state."""
171
    tensors = []
172
    cur_idx = 0
173
    for size in self._rnn_sizes:
174
      if self._rnn_cell_type in ('gru', 'simple'):
175
        states = (state[Ellipsis, cur_idx:cur_idx + size],)
176
        cur_idx += size
177
      elif self._rnn_cell_type in ('lstm', 'lstm_norm'):
178
        states = (state[Ellipsis, cur_idx:cur_idx + size],
179
                  state[Ellipsis, cur_idx + size:cur_idx + 2 * size])
180
        cur_idx += 2 * size
181
      tensors.append(states)
182
    assert cur_idx == state.shape[-1]
183
    return tensors
184

185
  def initial_state(self, batch_size):
186
    return tf.zeros((batch_size, self.hidden_state_size))
187

188
  def _encode_observation(self, observation, training=True):
189
    raise NotImplementedError()
190

191
  def pretraining_loss(self, sample, training=True):
192
    raise NotImplementedError()
193

194
  def get_pretraining_trainable_variables(self):
195
    return self.trainable_variables
196

197
  def get_rl_trainable_variables(self):
198
    return self.trainable_variables
199

200
  def get_trainable_variables(self, pretraining=False):
201
    if pretraining:
202
      return self.get_pretraining_trainable_variables()
203
    else:
204
      return self.get_rl_trainable_variables()
205

206
  def initial_inference(self, observation, training=True):
207
    encoded_observation = self._encode_observation(
208
        observation, training=training)
209
    hidden_state = self._to_hidden(encoded_observation, training=training)
210

211
    value_logits = self._value_head(hidden_state, training=training)
212
    value = self.value_encoder.decode(tf.nn.softmax(value_logits))
213

214
    # Rewards are only calculated in recurrent_inference.
215
    reward = tf.zeros_like(value)
216
    reward_logits = self.reward_encoder.encode(reward)
217

218
    policy_logits = self._policy_head(hidden_state, training=training)
219

220
    outputs = mzcore.NetworkOutput(
221
        value_logits=value_logits,
222
        value=value,
223
        reward_logits=reward_logits,
224
        reward=reward,
225
        policy_logits=policy_logits,
226
        hidden_state=hidden_state)
227
    return outputs
228

229
  def _maybe_normalize_hidden_state(self, hidden_state):
230
    if self._normalize_hidden_state:
231
      # This is in the paper, but probably unnecessary.
232
      max_hidden_state = tf.reduce_max(hidden_state, -1, keepdims=True)
233
      min_hidden_state = tf.reduce_min(hidden_state, -1, keepdims=True)
234
      hidden_state_range = max_hidden_state - min_hidden_state
235
      hidden_state = hidden_state - min_hidden_state
236
      hidden_state = tf.math.divide_no_nan(hidden_state, hidden_state_range)
237
      hidden_state = hidden_state * 2. - 1.
238
    return hidden_state
239

240
  def recurrent_inference(self, hidden_state, action, training=True):
241
    if self._embed_actions:
242
      one_hot_action = tf.one_hot(
243
          action, self._parametric_action_distribution.param_size)
244
      embedded_action = self._action_embeddings(one_hot_action)
245
    else:
246
      one_hot_action = tf.one_hot(
247
          action, self._parametric_action_distribution.param_size, 1., -1.)
248
      embedded_action = one_hot_action
249
    hidden_state = self._maybe_normalize_hidden_state(hidden_state)
250

251
    rnn_state = self._flat_to_rnn(hidden_state)
252
    rnn_output, next_rnn_state = self._core(embedded_action, rnn_state)
253
    next_hidden_state = self._rnn_to_flat(next_rnn_state)
254

255
    value_logits = self._value_head(next_hidden_state, training=training)
256
    value = self.value_encoder.decode(tf.nn.softmax(value_logits))
257

258
    reward_logits = self._reward_head(rnn_output, training=training)
259
    reward = self.reward_encoder.decode(tf.nn.softmax(reward_logits))
260

261
    policy_logits = self._policy_head(next_hidden_state, training=training)
262

263
    output = mzcore.NetworkOutput(
264
        value=value,
265
        value_logits=value_logits,
266
        reward=reward,
267
        reward_logits=reward_logits,
268
        policy_logits=policy_logits,
269
        hidden_state=next_hidden_state)
270
    return output
271

272

273
class ExportedAgent(tf.Module):
274
  """Wraps an Agent for export."""
275

276
  def __init__(self, agent_module):
277
    self._agent = agent_module
278

279
  def initial_inference(self, input_ids, segment_ids, features, action_history):
280
    output = self._agent.initial_inference(
281
        observation=(input_ids, segment_ids, features, action_history),
282
        training=False)
283
    return [
284
        output.value,
285
        output.value_logits,
286
        output.reward,
287
        output.reward_logits,
288
        output.policy_logits,
289
        output.hidden_state,
290
    ]
291

292
  def recurrent_inference(self, hidden_state, action):
293
    output = self._agent.recurrent_inference(
294
        hidden_state=hidden_state, action=action, training=False)
295
    return [
296
        output.value,
297
        output.value_logits,
298
        output.reward,
299
        output.reward_logits,
300
        output.policy_logits,
301
        output.hidden_state,
302
    ]
303

304

305
def export_agent_for_initial_inference(agent,
306
                                       model_dir):
307
  """Export `agent` as a TPU servable model for initial inference."""
308

309
  def get_initial_inference_fn(model):
310

311
    @tf.function(input_signature=[
312
        tf.TensorSpec(shape=(None, 512), dtype=tf.int32, name='input_ids'),
313
        tf.TensorSpec(shape=(None, 512, 1), dtype=tf.int32, name='segment_ids'),
314
        tf.TensorSpec(shape=(None, 512, 2), dtype=tf.float32, name='features'),
315
        tf.TensorSpec(shape=(None, 10), dtype=tf.int32, name='action_history'),
316
    ])
317
    def serve_fn(input_ids, segment_ids, features, action_history):
318
      return model.initial_inference(
319
          input_ids=input_ids,
320
          segment_ids=segment_ids,
321
          features=features,
322
          action_history=action_history)
323

324
    return serve_fn
325

326
  exported_agent = ExportedAgent(agent_module=agent)
327
  initial_fn = get_initial_inference_fn(exported_agent)
328

329
  # Export.
330
  save_options = tf.saved_model.SaveOptions(function_aliases={
331
      'initial_inference': initial_fn,
332
  })
333
  # Saves the CPU model, which will be rewritten to a TPU model.
334
  tf.saved_model.save(
335
      obj=exported_agent,
336
      export_dir=model_dir,
337
      signatures={
338
          'initial_inference': initial_fn,
339
      },
340
      options=save_options)
341

342

343

344
def export_agent_for_recurrent_inference(agent,
345
                                         model_dir):
346
  """Export `agent` as a TPU servable model for recurrent inference."""
347

348
  def get_recurrent_inference_fn(model):
349

350
    @tf.function(input_signature=[
351
        tf.TensorSpec(
352
            shape=(None, 1024), dtype=tf.float32, name='hidden_state'),
353
        tf.TensorSpec(shape=(None,), dtype=tf.int32, name='action')
354
    ])
355
    def serve_fn(hidden_state, action):
356
      return model.recurrent_inference(hidden_state=hidden_state, action=action)
357

358
    return serve_fn
359

360
  exported_agent = ExportedAgent(agent_module=agent)
361
  recurrent_fn = get_recurrent_inference_fn(exported_agent)
362

363
  # Export.
364
  save_options = tf.saved_model.SaveOptions(function_aliases={
365
      'recurrent_inference': recurrent_fn,
366
  })
367
  # Saves the CPU model, which will be rewritten to a TPU model.
368
  tf.saved_model.save(
369
      obj=exported_agent,
370
      export_dir=model_dir,
371
      signatures={
372
          'recurrent_inference': recurrent_fn,
373
      },
374
      options=save_options)
375

376

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.