google-research

Форк
0
/
inference_evaluater.py 
251 строка · 7.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""This file coordinates with all other files to get the evaluation metrics."""
17

18
import functools
19

20
from absl import logging
21
from clu import metric_writers
22
from flax import jax_utils
23
from flax import linen as nn
24
from flax.training import checkpoints
25
import jax
26
from jax import random
27
import jax.numpy as jnp
28
import matplotlib.pyplot as plt
29
import numpy as np
30
import tensorflow as tf
31

32
from sudoku_gpt import model
33
from sudoku_gpt import trainer
34
from sudoku_gpt.inference import data
35
from sudoku_gpt.inference import inference_eval_utils
36

37

38
def log_hyperparams_tb(
39
    config, model_config, initial_variables, tf_summary_writer
40
    ):
41
  """Log hyperparameters to TensorBoard."""
42
  config.num_model_parameters = sum(
43
      x.size for x in jax.tree_util.tree_leaves(initial_variables)
44
  )
45

46
  config_hyperparameters = [
47
      tf.convert_to_tensor([k, str(v)]) for k, v in config.items()
48
  ]
49
  model_config_hyperparameters = [
50
      tf.convert_to_tensor([k, str(v)])
51
      for k, v in model_config.__dict__.items()
52
  ]
53

54
  with tf_summary_writer.as_default():
55
    tf.summary.text(
56
        'Model hyperparameters', tf.stack(model_config_hyperparameters), step=0
57
    )
58
    tf.summary.text(
59
        'Config hyperparameters', tf.stack(config_hyperparameters), step=0
60
    )
61

62
  return tf_summary_writer, config
63

64

65
def inference_evaluate(config, workdir, ckpt_loc):
66
  """Perform inference evaluation on a checkpoint."""
67
  logging.info('Creating training dataset iterator')
68
  eval_data_iter = data.create_iter(
69
      config.dataset_path, config, config.minibatch_size, eval=True
70
  )
71

72
  logging.info('Finished creating training dataset iterator')
73

74
  model_config = model.TransformerConfig(
75
      dataset_fn=config.dataset,
76
      dtype=config.dtype,
77
      vocab_size=config.vocab_size,
78
      seq_len=config.seq_len,
79
      num_heads=config.num_heads,
80
      num_layers=config.num_layers,
81
      emb_dim=config.emb_dim,
82
      qkv_dim=config.qkv_dim,
83
      mlp_dim=config.mlp_dim,
84
      dropout_rate=config.dropout_rate,
85
      attention_dropout_rate=config.attention_dropout_rate,
86
      deterministic=False,
87
      kernel_init=nn.initializers.xavier_uniform(),
88
      bias_init=nn.initializers.normal(stddev=1e-6),
89
  )
90

91
  logging.info('Starting print training config')
92
  logging.info('train_config: %s', str(model_config.__dict__))
93
  print(str(model_config.__dict__), flush=True)
94

95
  rng = jax.random.PRNGKey(config.seed)
96
  rng, init_rng, inference_rng = random.split(rng, num=3)
97

98
  rng, dropout_rng = jax.random.split(rng)
99
  input_shape = (config.minibatch_size, config.seq_len)
100
  net = model.TransformerLMHeadModel(model_config)
101
  rng_keys = {'params': init_rng, 'dropout': dropout_rng}
102
  _, initial_variables = jax.jit(net.init_with_output)(
103
      rng_keys, jnp.ones(input_shape, jnp.int32)
104
  )
105

106
  ### Defines optimizer and learning rate function
107
  state, _ = trainer.get_state(config, net, initial_variables)
108
  state = checkpoints.restore_checkpoint(ckpt_loc, state)
109

110
  writer = metric_writers.create_default_writer(
111
      workdir, asynchronous=False, just_logging=(jax.process_index() > 0)
112
  )
113
  tf_summary_writer = tf.summary.create_file_writer(workdir)
114

115
  logging.info('config: %s', str(config.__dict__))
116
  state = jax_utils.replicate(state)
117

118
  _ = jax.random.split(rng, jax.local_device_count())
119

120
  p_eval_step = jax.pmap(
121
      functools.partial(
122
          inference_eval_utils.eval_step,
123
          config=model_config.replace(deterministic=True),  # pylint: disable=attribute-error
124
          rng=inference_rng,
125
      ),
126
      axis_name='batch',
127
      donate_argnums=(0,),
128
  )
129

130
  _ = trainer.get_metrics_report_progress(
131
      config, workdir, writer)
132

133
  tf_summary_writer, config = log_hyperparams_tb(
134
      config, model_config, initial_variables, tf_summary_writer
135
  )
136

137
  step = 0
138
  with metric_writers.ensure_flushes(writer):
139
    eval_metrics, mistakes_metrics = inference_eval_utils.get_eval_metrics(
140
        step,
141
        state,
142
        eval_data_iter,
143
        p_eval_step,
144
        config,
145
    )
146

147
    with tf_summary_writer.as_default():
148
      for key in eval_metrics.keys():
149
        tf.summary.scalar(
150
            'eval_' + key, np.array(eval_metrics[key]).mean(), step=step
151
        )
152

153
      fig_mistake_pos = plt.figure()
154
      plt.plot(np.arange(81), mistakes_metrics['mistake_pos'])
155
      tf.summary.image(
156
          'mistakes position',
157
          inference_eval_utils.plot_to_image(fig_mistake_pos),
158
          step=step,
159
      )
160

161
      fig_first_mistake_pos = plt.figure()
162
      plt.plot(np.arange(81), mistakes_metrics['first_mistake_pos'])
163
      tf.summary.image(
164
          'first mistakes position',
165
          inference_eval_utils.plot_to_image(fig_first_mistake_pos),
166
          step=step,
167
      )
168

169
      fig_first_mistake_strategies = plt.figure()
170
      fig_label = [
171
          str(element)
172
          for element in mistakes_metrics['first_mistake_strategies']
173
      ]
174
      fig_color = [
175
          'red',
176
          'tan',
177
          'lime',
178
          'lightblue',
179
          'blue',
180
          'purple',
181
          'darkred',
182
          'orange',
183
      ]
184

185
      plt.bar(
186
          np.arange(8),
187
          mistakes_metrics['first_mistake_strategies'],
188
          label=fig_label,
189
          color=fig_color,
190
      )
191

192
      plt.legend()
193
      tf.summary.image(
194
          'first mistakes strategies',
195
          inference_eval_utils.plot_to_image(fig_first_mistake_strategies),
196
          step=step,
197
      )
198

199
      fig_strategies_list = plt.figure()
200
      fig_label = [
201
          str(element) for element in mistakes_metrics['total_strategies']
202
      ]
203

204
      plt.bar(
205
          np.arange(8),
206
          mistakes_metrics['total_strategies'],
207
          label=fig_label,
208
          color=fig_color,
209
      )
210

211
      plt.legend()
212
      tf.summary.image(
213
          'Total strategies',
214
          inference_eval_utils.plot_to_image(fig_strategies_list),
215
          step=step,
216
      )
217

218
      for eid in range(config.num_examples):
219
        fig, axs = plt.subplots(1, 2, figsize=(20, 10))
220
        cur_board = np.zeros((9, 9))
221
        puzzle_sol_board = np.zeros((9, 9))
222

223
        for j in range(0, 3 * 81, 3):
224
          row_num = mistakes_metrics['mistakes'][eid][0][j] - 1
225
          col_num = mistakes_metrics['mistakes'][eid][0][j + 1] - 1
226
          val = mistakes_metrics['mistakes'][eid][0][j + 2]
227
          cur_board[row_num, col_num] = val
228

229
        for j in range(9):
230
          for k in range(9):
231
            puzzle_sol_board[k, j] = mistakes_metrics['mistakes'][eid][1][
232
                9 * j + k
233
            ]
234

235
        wr, wc = 0, 0
236
        for j in range(9):
237
          for k in range(9):
238
            if cur_board[j, k] == 0:
239
              continue
240
            if cur_board[j, k] != puzzle_sol_board[k, j]:
241
              wr = j
242
              wc = k
243

244
        inference_eval_utils.plot_ax(axs[0], cur_board, wr, wc)
245
        inference_eval_utils.plot_ax(axs[1], puzzle_sol_board.T, wr, wc)
246

247
        tf.summary.image(
248
            'Mistakes ' + str(eid),
249
            inference_eval_utils.plot_to_image(fig),
250
            step=step,
251
        )
252

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.