google-research
251 строка · 7.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""This file coordinates with all other files to get the evaluation metrics."""
17
18import functools
19
20from absl import logging
21from clu import metric_writers
22from flax import jax_utils
23from flax import linen as nn
24from flax.training import checkpoints
25import jax
26from jax import random
27import jax.numpy as jnp
28import matplotlib.pyplot as plt
29import numpy as np
30import tensorflow as tf
31
32from sudoku_gpt import model
33from sudoku_gpt import trainer
34from sudoku_gpt.inference import data
35from sudoku_gpt.inference import inference_eval_utils
36
37
38def log_hyperparams_tb(
39config, model_config, initial_variables, tf_summary_writer
40):
41"""Log hyperparameters to TensorBoard."""
42config.num_model_parameters = sum(
43x.size for x in jax.tree_util.tree_leaves(initial_variables)
44)
45
46config_hyperparameters = [
47tf.convert_to_tensor([k, str(v)]) for k, v in config.items()
48]
49model_config_hyperparameters = [
50tf.convert_to_tensor([k, str(v)])
51for k, v in model_config.__dict__.items()
52]
53
54with tf_summary_writer.as_default():
55tf.summary.text(
56'Model hyperparameters', tf.stack(model_config_hyperparameters), step=0
57)
58tf.summary.text(
59'Config hyperparameters', tf.stack(config_hyperparameters), step=0
60)
61
62return tf_summary_writer, config
63
64
65def inference_evaluate(config, workdir, ckpt_loc):
66"""Perform inference evaluation on a checkpoint."""
67logging.info('Creating training dataset iterator')
68eval_data_iter = data.create_iter(
69config.dataset_path, config, config.minibatch_size, eval=True
70)
71
72logging.info('Finished creating training dataset iterator')
73
74model_config = model.TransformerConfig(
75dataset_fn=config.dataset,
76dtype=config.dtype,
77vocab_size=config.vocab_size,
78seq_len=config.seq_len,
79num_heads=config.num_heads,
80num_layers=config.num_layers,
81emb_dim=config.emb_dim,
82qkv_dim=config.qkv_dim,
83mlp_dim=config.mlp_dim,
84dropout_rate=config.dropout_rate,
85attention_dropout_rate=config.attention_dropout_rate,
86deterministic=False,
87kernel_init=nn.initializers.xavier_uniform(),
88bias_init=nn.initializers.normal(stddev=1e-6),
89)
90
91logging.info('Starting print training config')
92logging.info('train_config: %s', str(model_config.__dict__))
93print(str(model_config.__dict__), flush=True)
94
95rng = jax.random.PRNGKey(config.seed)
96rng, init_rng, inference_rng = random.split(rng, num=3)
97
98rng, dropout_rng = jax.random.split(rng)
99input_shape = (config.minibatch_size, config.seq_len)
100net = model.TransformerLMHeadModel(model_config)
101rng_keys = {'params': init_rng, 'dropout': dropout_rng}
102_, initial_variables = jax.jit(net.init_with_output)(
103rng_keys, jnp.ones(input_shape, jnp.int32)
104)
105
106### Defines optimizer and learning rate function
107state, _ = trainer.get_state(config, net, initial_variables)
108state = checkpoints.restore_checkpoint(ckpt_loc, state)
109
110writer = metric_writers.create_default_writer(
111workdir, asynchronous=False, just_logging=(jax.process_index() > 0)
112)
113tf_summary_writer = tf.summary.create_file_writer(workdir)
114
115logging.info('config: %s', str(config.__dict__))
116state = jax_utils.replicate(state)
117
118_ = jax.random.split(rng, jax.local_device_count())
119
120p_eval_step = jax.pmap(
121functools.partial(
122inference_eval_utils.eval_step,
123config=model_config.replace(deterministic=True), # pylint: disable=attribute-error
124rng=inference_rng,
125),
126axis_name='batch',
127donate_argnums=(0,),
128)
129
130_ = trainer.get_metrics_report_progress(
131config, workdir, writer)
132
133tf_summary_writer, config = log_hyperparams_tb(
134config, model_config, initial_variables, tf_summary_writer
135)
136
137step = 0
138with metric_writers.ensure_flushes(writer):
139eval_metrics, mistakes_metrics = inference_eval_utils.get_eval_metrics(
140step,
141state,
142eval_data_iter,
143p_eval_step,
144config,
145)
146
147with tf_summary_writer.as_default():
148for key in eval_metrics.keys():
149tf.summary.scalar(
150'eval_' + key, np.array(eval_metrics[key]).mean(), step=step
151)
152
153fig_mistake_pos = plt.figure()
154plt.plot(np.arange(81), mistakes_metrics['mistake_pos'])
155tf.summary.image(
156'mistakes position',
157inference_eval_utils.plot_to_image(fig_mistake_pos),
158step=step,
159)
160
161fig_first_mistake_pos = plt.figure()
162plt.plot(np.arange(81), mistakes_metrics['first_mistake_pos'])
163tf.summary.image(
164'first mistakes position',
165inference_eval_utils.plot_to_image(fig_first_mistake_pos),
166step=step,
167)
168
169fig_first_mistake_strategies = plt.figure()
170fig_label = [
171str(element)
172for element in mistakes_metrics['first_mistake_strategies']
173]
174fig_color = [
175'red',
176'tan',
177'lime',
178'lightblue',
179'blue',
180'purple',
181'darkred',
182'orange',
183]
184
185plt.bar(
186np.arange(8),
187mistakes_metrics['first_mistake_strategies'],
188label=fig_label,
189color=fig_color,
190)
191
192plt.legend()
193tf.summary.image(
194'first mistakes strategies',
195inference_eval_utils.plot_to_image(fig_first_mistake_strategies),
196step=step,
197)
198
199fig_strategies_list = plt.figure()
200fig_label = [
201str(element) for element in mistakes_metrics['total_strategies']
202]
203
204plt.bar(
205np.arange(8),
206mistakes_metrics['total_strategies'],
207label=fig_label,
208color=fig_color,
209)
210
211plt.legend()
212tf.summary.image(
213'Total strategies',
214inference_eval_utils.plot_to_image(fig_strategies_list),
215step=step,
216)
217
218for eid in range(config.num_examples):
219fig, axs = plt.subplots(1, 2, figsize=(20, 10))
220cur_board = np.zeros((9, 9))
221puzzle_sol_board = np.zeros((9, 9))
222
223for j in range(0, 3 * 81, 3):
224row_num = mistakes_metrics['mistakes'][eid][0][j] - 1
225col_num = mistakes_metrics['mistakes'][eid][0][j + 1] - 1
226val = mistakes_metrics['mistakes'][eid][0][j + 2]
227cur_board[row_num, col_num] = val
228
229for j in range(9):
230for k in range(9):
231puzzle_sol_board[k, j] = mistakes_metrics['mistakes'][eid][1][
2329 * j + k
233]
234
235wr, wc = 0, 0
236for j in range(9):
237for k in range(9):
238if cur_board[j, k] == 0:
239continue
240if cur_board[j, k] != puzzle_sol_board[k, j]:
241wr = j
242wc = k
243
244inference_eval_utils.plot_ax(axs[0], cur_board, wr, wc)
245inference_eval_utils.plot_ax(axs[1], puzzle_sol_board.T, wr, wc)
246
247tf.summary.image(
248'Mistakes ' + str(eid),
249inference_eval_utils.plot_to_image(fig),
250step=step,
251)
252