google-research
930 строк · 28.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""This file contains functions for evaluation of the trained model."""
17
18import io
19
20from flax.training import common_utils
21import jax
22from jax import numpy as jnp
23import matplotlib
24import matplotlib.pyplot as plt
25import matplotlib.ticker as mticker
26import numpy as np
27import tensorflow as tf
28
29from sudoku_gpt import model
30from sudoku_gpt import othello
31from sudoku_gpt import sudoku
32
33
34def valid_solution(output_seq):
35"""Checks if the output sequence is valid."""
36## returns 1 if correct solution, otherwise returns 0
37rows = np.zeros((9, 9))
38cols = np.zeros((9, 9))
39boxes = np.zeros((9, 9))
40
41for j in range(81):
42if int(output_seq[3 * j + 2] - 1) > 8:
43return False
44if int(output_seq[3 * j] - 1) > 8:
45return False
46if int(output_seq[3 * j + 1] - 1) > 8:
47return False
48row_num = int(output_seq[3 * j] - 1)
49col_num = int(output_seq[3 * j + 1] - 1)
50rows[row_num, int(output_seq[3 * j + 2] - 1)] += 1
51cols[col_num, int(output_seq[3 * j + 2] - 1)] += 1
52boxes[
53int(3 * (row_num // 3) + (col_num // 3)), int(output_seq[3 * j + 2] - 1)
54] += 1
55
56if np.all(rows) and np.all(cols) and np.all(boxes):
57return True
58else:
59return False
60
61
62def eval_step(state, batch, config):
63pred_logits = model.TransformerLMHeadModel(config).apply(
64{"params": state.params}, batch
65)
66return pred_logits
67
68
69def get_othello_eval_metrics(
70state, eval_data_iter, p_eval_step, config
71):
72"""Get eval metrics for Othello dataset."""
73eval_metrics = {"acc": []}
74for eval_epoch in range(config.eval_epochs):
75with jax.profiler.StepTraceAnnotation("eval", step_num=eval_epoch):
76
77batch = np.array(next(eval_data_iter))
78total_pred, sucess_pred = 0, 0
79
80for i in range(config.seq_len):
81padding = np.zeros(
82(batch.shape[0], config.seq_len - (i + 1)), dtype=np.int32
83)
84concat_batch = np.hstack((batch[:, : (i + 1)], padding))
85concat_batch = common_utils.shard(
86jax.tree_util.tree_map(np.asarray, concat_batch)
87)
88pred_logits = p_eval_step(state, concat_batch)
89
90max_action = pred_logits[:, :, i, :].argmax(axis=-1)
91pred_seq = np.hstack(
92(batch[:, : (i + 1)], jnp.reshape(max_action, newshape=(-1, 1)))
93)
94
95for j in range(pred_seq.shape[0]):
96## When length of the game is small, then the model can simply keep
97## predicting the next token which will increase the accuracy
98total_pred += 1
99try:
100othello.OthelloBoardState().update(pred_seq[j], prt=False)
101except AssertionError:
102### Wrong prediction
103pass
104else:
105sucess_pred += 1
106
107eval_metrics["acc"].append(sucess_pred * 1.0 / total_pred)
108
109return eval_metrics
110
111
112def get_edit_distance(config, generated_input_seq, original_input_seq):
113"""Get the edit distance between model's output and solver's output."""
114total_distance = 0
115
116for i in range(config.start_index, config.block_size):
117# Iterate through model's output
118flg = False
119for j in range(config.start_index, config.block_size):
120
121# Iterate through solver's output to find the location of model's output
122same_row = generated_input_seq[3 * i] == original_input_seq[3 * j]
123same_col = (
124generated_input_seq[3 * i + 1] == original_input_seq[3 * j + 1]
125)
126if same_row and same_col:
127
128# When model's output cell location and solver's output cell location
129# matches, then calculate edit distance.
130total_distance += abs(j - i)
131flg = True
132break
133
134if not flg:
135total_distance += abs(config.block_size - i)
136
137return total_distance
138
139
140def get_set_accuracy_for_pairs(
141pairs,
142state,
143p_eval_step,
144input_seq,
145possible_vals,
146given_vals,
147config,
148):
149"""Get the accuracy of the set of possible values for different cell positions."""
150correct_cnt = np.zeros(9)
151total_cnt = np.zeros(9)
152
153for i in range(len(pairs)):
154cur_input_seq = np.hstack(
155(input_seq[:, : (config.start_index * 3)], pairs[i] + 1)
156)
157
158padding = np.zeros(
159(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
160dtype=np.int32,
161)
162concat_batch = np.hstack((cur_input_seq, padding))
163
164concat_batch = common_utils.shard(
165jax.tree_util.tree_map(np.asarray, concat_batch)
166)
167
168pred_logits = p_eval_step(state, concat_batch)
169
170cur_pred_logits = pred_logits[:, :, 3 * config.start_index + 1, :].reshape(
171(-1, pred_logits.shape[-1])
172)
173
174for k in range(input_seq.shape[0]):
175if given_vals[k, pairs[i, k, 0], pairs[i, k, 1]] == 1:
176continue
177total_possible_vals = np.int32(
178np.sum(possible_vals[k, pairs[i, k, 0], pairs[i, k, 1], :])
179)
180ordering_ind = np.argsort(cur_pred_logits[np.int32(k), :])[::-1][
181:total_possible_vals
182]
183
184assert len(ordering_ind) <= 9
185
186for t, ind in enumerate(ordering_ind):
187if ind <= 9 and ind >= 1:
188correct_cnt[t] += (
189possible_vals[k, pairs[i, k, 0], pairs[i, k, 1], ind - 1] == 1
190)
191
192total_cnt[t] += 1
193
194accuracy = np.ones(9)
195for i in range(9):
196if total_cnt[i] > 0:
197accuracy[i] = correct_cnt[i] / total_cnt[i]
198
199return accuracy, correct_cnt, total_cnt
200
201
202def get_sampled_pairs(input_seq, pred_logits, state, p_eval_step, config, key):
203"""Get sampled pairs in a sequence."""
204pairs_set = []
205for _ in range(input_seq.shape[0]):
206pairs_set.append(set())
207
208pairs = np.zeros(
209(config.set_accuracy_top_k, input_seq.shape[0], 2), dtype=np.int32
210)
211flag = True # Denotes if we want to sample next time or not.
212
213while flag:
214pred_logits_row = pred_logits[:, :, 3 * config.start_index - 1, :].reshape(
215(-1, pred_logits.shape[-1])
216)
217rkey, key = jax.random.split(key, 2)
218
219pair_row = jax.random.categorical(rkey, pred_logits_row)
220
221assert len(pair_row) == input_seq.shape[0] and pair_row.ndim == 1
222
223cur_input_seq = np.hstack(
224(input_seq[:, : (config.start_index * 3)], pair_row.reshape(-1, 1))
225)
226
227padding = np.zeros(
228(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
229dtype=np.int32,
230)
231
232concat_batch = np.hstack((cur_input_seq, padding))
233
234concat_batch = common_utils.shard(
235jax.tree_util.tree_map(np.asarray, concat_batch)
236)
237
238pred_logits = p_eval_step(state, concat_batch)
239pred_logits_col = pred_logits[:, :, 3 * config.start_index, :].reshape(
240(-1, pred_logits.shape[-1])
241)
242
243rkey, key = jax.random.split(key, 2)
244pair_col = jax.random.categorical(rkey, pred_logits_col)
245
246assert len(pair_col) == input_seq.shape[0] and pair_col.ndim == 1
247
248flag = False
249for i in range(input_seq.shape[0]):
250if pair_row[i] < 1 or pair_row[i] > 9:
251continue
252
253if pair_col[i] < 1 or pair_col[i] > 9:
254continue
255pairs_set[i].add(tuple((int(pair_row[i]), int(pair_col[i]))))
256if len(pairs_set[i]) < config.set_accuracy_top_k:
257flag = True
258
259for i in range(input_seq.shape[0]):
260j = 0
261for a_pair in pairs_set[i]:
262pairs[j, i, 0] = int(a_pair[0] - 1)
263pairs[j, i, 1] = int(a_pair[1] - 1)
264j += 1
265
266if j == config.set_accuracy_top_k:
267break
268
269return pairs
270
271
272def get_topk_probability_pairs(
273input_seq, pred_logits, state, p_eval_step, config
274):
275"""Get topk probability pairs in a sequence."""
276pairs = np.zeros(
277(config.set_accuracy_top_k, input_seq.shape[0], 2), dtype=np.int32
278)
279pred_logits_row = pred_logits[:, :, 3 * config.start_index - 1, :].reshape(
280(-1, pred_logits.shape[-1])
281)
282row_log_prob = jax.nn.log_softmax(pred_logits_row[:, 1:10])
283
284pairs_log_prob = np.zeros((input_seq.shape[0], 81))
285
286for i in range(9):
287row_num = np.ones((input_seq.shape[0], 1), dtype=np.int32) * (i + 1)
288cur_input_seq = np.hstack(
289(input_seq[:, : (config.start_index * 3)], row_num)
290)
291
292padding = np.zeros(
293(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
294dtype=np.int32,
295)
296
297concat_batch = np.hstack((cur_input_seq, padding))
298
299concat_batch = common_utils.shard(
300jax.tree_util.tree_map(np.asarray, concat_batch)
301)
302
303pred_logits_col = p_eval_step(state, concat_batch)
304pred_logits_col = pred_logits_col[:, :, 3 * config.start_index, :].reshape(
305(-1, pred_logits.shape[-1])
306)
307
308col_log_prob = jax.nn.log_softmax(pred_logits_col[:, 1:10])
309
310for j in range(input_seq.shape[0]):
311for k in range(9):
312pairs_log_prob[j, i * 9 + k] = col_log_prob[j, k] + row_log_prob[j, i]
313
314for i in range(input_seq.shape[0]):
315topk_indices = np.argsort(pairs_log_prob[i, :])[::-1][
316: config.set_accuracy_top_k
317]
318for j, ind in enumerate(topk_indices):
319pairs[j, i, 0] = ind // 9
320pairs[j, i, 1] = ind % 9
321
322return pairs
323
324
325def get_set_accuracies(state, p_eval_step, input_seq, config):
326"""Get set accuracies in a sequence."""
327possible_vals = np.ones((input_seq.shape[0], 9, 9, 9))
328given_vals = np.zeros((input_seq.shape[0], 9, 9))
329
330for i in range(input_seq.shape[0]):
331for j in range(config.start_index):
332row_num = input_seq[i, 3 * j] - 1
333col_num = input_seq[i, 3 * j + 1] - 1
334val = input_seq[i, 3 * j + 2] - 1
335
336possible_vals[i, row_num, :, val] = 0
337possible_vals[i, :, col_num, val] = 0
338
339given_vals[i, row_num, col_num] = 1
340
341if config.set_accuracy == "top-k":
342cur_input_seq = input_seq[:, : (config.start_index * 3)]
343padding = np.zeros(
344(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
345dtype=np.int32,
346)
347
348concat_batch = np.hstack((cur_input_seq, padding))
349
350concat_batch = common_utils.shard(
351jax.tree_util.tree_map(np.asarray, concat_batch)
352)
353
354_ = jax.random.PRNGKey(98)
355pred_logits = p_eval_step(state, concat_batch)
356
357print("get_topk_probability_pairs", flush=True)
358pairs = get_topk_probability_pairs(
359input_seq, pred_logits, state, p_eval_step, config
360)
361print("got_topk_probability_pairs", flush=True)
362return get_set_accuracy_for_pairs(
363pairs,
364state,
365p_eval_step,
366input_seq,
367possible_vals,
368given_vals,
369config,
370)
371
372elif config.set_accuracy == "all":
373pairs = np.zeros((81, input_seq.shape[0], 2), dtype=np.int32)
374for i in range(81):
375pairs[i, :, 0] = np.ones(input_seq.shape[0], dtype=np.int32) * (i // 9)
376pairs[i, :, 1] = np.ones(input_seq.shape[0], dtype=np.int32) * (i % 9)
377
378return get_set_accuracy_for_pairs(
379pairs,
380state,
381p_eval_step,
382input_seq,
383possible_vals,
384given_vals,
385config,
386)
387
388
389def get_pred_logits(cur_input_seq, input_seq, state, p_eval_step, config):
390padding = np.zeros(
391(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
392dtype=np.int32,
393)
394concat_batch = np.hstack((cur_input_seq, padding))
395concat_batch = common_utils.shard(
396jax.tree_util.tree_map(np.asarray, concat_batch)
397)
398
399pred_logits = p_eval_step(state, concat_batch)
400return pred_logits
401
402
403def get_beam_search_candidates(
404input_seq, beam_search_candidates, state, p_eval_step, pos, config
405):
406"""Get beam search candidates for decoding."""
407new_beam_candidate_list = []
408new_beam_candidate_likelihood_list = []
409for i in range(len(beam_search_candidates)):
410### Iterate through all the beam search candidates
411
412# predict the logits for row/column/value
413pred_logits = get_pred_logits(
414beam_search_candidates[i][0], input_seq, state, p_eval_step, config
415)
416
417# Choose top beam_search_n most probable predictions
418max_pos = (
419pred_logits[:, :, pos, :]
420.argpartition(-config.beam_search_n, axis=-1)[
421:, :, -config.beam_search_n :
422]
423.reshape((-1, config.beam_search_n))
424)
425log_likelihood = jax.nn.log_softmax(pred_logits[:, :, pos, :]).reshape(
426(-1, pred_logits.shape[-1])
427)
428log_likelihood = np.take_along_axis(log_likelihood, max_pos, 1)
429
430# Append all of the candidates in new_beam_candidate_list
431for j in range(config.beam_search_n):
432cur_candidate = beam_search_candidates[i]
433new_beam_candidate = np.hstack(
434(cur_candidate[0], jnp.reshape(max_pos[:, j], newshape=(-1, 1)))
435)
436new_beam_candidate_likelihood = cur_candidate[1] + log_likelihood[:, j]
437new_beam_candidate_likelihood_list.append(new_beam_candidate_likelihood)
438new_beam_candidate_list.append(
439(new_beam_candidate, new_beam_candidate_likelihood, cur_candidate[2])
440)
441
442# Likelihood list for new candidates
443new_beam_candidate_likelihood_list = np.stack(
444new_beam_candidate_likelihood_list, axis=0
445)
446assert new_beam_candidate_likelihood_list.shape == (
447len(beam_search_candidates) * config.beam_search_n,
448config.minibatch_size,
449), new_beam_candidate_likelihood_list.shape
450
451# Find index of top beam_search_n in new candidates
452new_beam_candidate_ind = new_beam_candidate_likelihood_list.argpartition(
453-config.beam_search_n, axis=0
454)[-config.beam_search_n :, :]
455assert new_beam_candidate_ind.shape == (
456config.beam_search_n,
457config.minibatch_size,
458), new_beam_candidate_ind.shape
459
460# Create the new list by truncating to top beam_search_n candidate
461truncated_candidate_list = []
462for i in range(config.beam_search_n):
463new_candidate = np.zeros_like(new_beam_candidate_list[0][0])
464new_candidate_likelihood = np.zeros_like(new_beam_candidate_list[0][1])
465new_candidate_success_pred = np.zeros_like(new_beam_candidate_list[0][2])
466
467for j in range(config.minibatch_size):
468index = new_beam_candidate_ind[i, j]
469
470new_candidate[j] = new_beam_candidate_list[index][0][j]
471new_candidate_likelihood[j] = new_beam_candidate_list[index][1][j]
472new_candidate_success_pred[j] = new_beam_candidate_list[index][2][j]
473
474truncated_candidate_list.append(
475(new_candidate, new_candidate_likelihood, new_candidate_success_pred)
476)
477
478return truncated_candidate_list
479
480
481def get_greedy_row_col(
482beam_search_candidates, pos, input_seq, state, p_eval_step, config
483):
484"""Perform greedy row and column decoding using beam search candidates."""
485
486### Get beam search candidates for row
487beam_search_candidates = get_beam_search_candidates(
488input_seq, beam_search_candidates, state, p_eval_step, pos - 3, config
489)
490
491### Get beam search candidates for column
492beam_search_candidates = get_beam_search_candidates(
493input_seq, beam_search_candidates, state, p_eval_step, pos - 2, config
494)
495### Predict most confident column according to row
496# pred_logits = get_pred_logits(cur_input_seq, input_seq,
497# state, p_eval_step, config)
498
499# max_pos = pred_logits[:, :, pos-2, :].argmax(axis=-1).flatten()
500# cur_input_seq = np.hstack((
501# cur_input_seq, jnp.reshape(max_pos, newshape=(-1, 1))))
502return beam_search_candidates
503
504
505def get_greedy_pair(cur_input_seq, pos, input_seq, state, p_eval_step, config):
506"""Get greedy pair decoding."""
507pred_logits = get_pred_logits(
508cur_input_seq, input_seq, state, p_eval_step, config
509)
510
511row_pred_logits = pred_logits[:, :, pos - 3, :].reshape(
512(-1, pred_logits.shape[-1])
513)
514row_log_prob = jax.nn.log_softmax(row_pred_logits[:, 1:10])
515
516pairs_log_prob = np.zeros((input_seq.shape[0], 81))
517
518for i in range(9):
519row_num = np.ones((input_seq.shape[0], 1), dtype=np.int32) * (i + 1)
520cur_input_seq = np.hstack((cur_input_seq, row_num))
521
522pred_logits_col = get_pred_logits(
523cur_input_seq, input_seq, state, p_eval_step, config
524)
525pred_logits_col = pred_logits_col[:, :, pos - 2, :].reshape(
526(-1, pred_logits.shape[-1])
527)
528
529col_log_prob = jax.nn.log_softmax(pred_logits_col[:, 1:10])
530
531for j in range(input_seq.shape[0]):
532for k in range(9):
533pairs_log_prob[j, i * 9 + k] = col_log_prob[j, k] + row_log_prob[j, i]
534
535pair = np.hstack((
536pairs_log_prob.argmax(axis=-1, keepdims=True) // 9,
537pairs_log_prob.argmax(axis=-1, keepdims=True) % 9,
538))
539return np.hstack((cur_input_seq, pair))
540
541
542def get_accuracy(
543cur_input_seq,
544state,
545p_eval_step,
546input_seq,
547puzzle_sol,
548config,
549eval_metrics,
550mistakes_metrics,
551):
552"""Get accuracy of a decoding sequence."""
553total_pred, _ = 0, 0
554
555### Keeps tuple of best n sequences, log probability and correct pred for it
556beam_search_candidates = [(
557cur_input_seq,
558np.zeros(len(cur_input_seq)),
559np.zeros(len(cur_input_seq)),
560)]
561
562for i in range(config.start_index * 3 + 2, config.seq_len, 3):
563if config.sampling_method == "greedy-row-col":
564# greedy-row-col: selects first max probability row and
565# then max probability column.
566beam_search_candidates = get_greedy_row_col(
567beam_search_candidates, i, input_seq, state, p_eval_step, config
568)
569
570elif config.sampling_method == "greedy-pair":
571# greedy-pair: selects max probability (row, column) pair
572cur_input_seq = get_greedy_pair(
573cur_input_seq, i, input_seq, state, p_eval_step, config
574)
575
576beam_search_candidates = get_beam_search_candidates(
577input_seq, beam_search_candidates, state, p_eval_step, i - 1, config
578)
579
580total_pred += len(beam_search_candidates[0][0])
581for candidate in beam_search_candidates:
582for j in range(
583len(candidate[0])
584): ## Iterate through all examples in batch
585try:
586sudoku.SudokuBoardStateUpdate(
587puzzle_sol[j],
588candidate[0][j][i - 2],
589candidate[0][j][i - 1],
590candidate[0][j][i],
591)
592
593# row_num = cur_input_seq[j, i-2] - 1
594# col_num = cur_input_seq[j, i-1] - 1
595
596# strategy_id = input_seq_strategies[ j, row_num * 9 + col_num ]
597# mistakes_metrics['total_strategies'][ strategy_id ] += 1
598
599except AssertionError:
600# mistakes_metrics['mistakes'].append((
601# concat_batch[j], puzzle_sol[j]))
602# # if i < 81:
603# mistakes_metrics['mistake_pos'][i // 3] += 1
604# if first_mistake_ind[j] == 0:
605# mistakes_metrics['first_mistake_pos'][i // 3] += 1
606
607# row_num = cur_input_seq[j, i-2] - 1
608# col_num = cur_input_seq[j, i-1] - 1
609# strategy_id = input_seq_strategies[j, row_num * 9 + col_num ]
610# mistakes_metrics['first_mistake_strategies'][ strategy_id ] += 1
611
612# g3pdb.set_trace()
613# if strategy_id == 0:
614# g3pdb.set_trace()
615
616# first_mistake_ind[j] = 1
617pass
618else:
619candidate[2][j] += 1
620
621# cur_input_seq = input_seq[:, :(i+1)]
622
623max_prob_seq = np.zeros_like(beam_search_candidates[0][0])
624max_prob = np.zeros(
625(len(beam_search_candidates), beam_search_candidates[0][1].shape[0])
626)
627
628for j, candidate in enumerate(beam_search_candidates):
629max_prob[j, :] = candidate[1]
630
631max_prob_seq_ind = max_prob.argmax(axis=0)
632sucess_pred = np.zeros(len(max_prob_seq))
633
634for i in range(len(max_prob_seq)):
635max_prob_seq[i] = beam_search_candidates[max_prob_seq_ind[i]][0][i]
636sucess_pred[i] = beam_search_candidates[max_prob_seq_ind[i]][2][i]
637
638eval_metrics["acc"].append(sucess_pred.sum() * 1.0 / total_pred)
639return eval_metrics, mistakes_metrics, max_prob_seq
640
641
642def set_set_accuracies(eval_metrics, set_acc, correct_cnt):
643eval_metrics["set_acc1"].append(set_acc[0])
644eval_metrics["set_acc2"].append(set_acc[1])
645eval_metrics["set_acc3"].append(set_acc[2])
646
647eval_metrics["correct_cnt1"].append(correct_cnt[0])
648eval_metrics["correct_cnt2"].append(correct_cnt[1])
649eval_metrics["correct_cnt3"].append(correct_cnt[2])
650
651return eval_metrics
652
653
654def get_position_hinted_eval_acc(
655input_seq, puzzle_sol, state, p_eval_step, eval_metrics, config
656):
657"""This function computes the accuracy of the position hinted decoding model."""
658
659total_pred, sucess_pred = 0, 0
660
661cur_input_seq = input_seq[:, : (config.start_index * 3)]
662for i in range(config.start_index, config.block_size):
663### i^th cell in sequence will predict
664
665# Append the row number from the ground truth sequence
666cur_input_seq = np.hstack(
667(cur_input_seq, jnp.reshape(input_seq[:, 3 * i], newshape=(-1, 1)))
668)
669
670# Append the column number from the ground truth sequence
671cur_input_seq = np.hstack(
672(cur_input_seq, jnp.reshape(input_seq[:, 3 * i + 1], newshape=(-1, 1)))
673)
674
675padding = np.zeros(
676(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
677dtype=np.int32,
678)
679concat_batch = np.hstack((cur_input_seq, padding))
680concat_batch = common_utils.shard(
681jax.tree_util.tree_map(np.asarray, concat_batch)
682)
683
684# Predict and append value at the pos chosen by the ground truth sequence
685pred_logits = p_eval_step(state, concat_batch)
686max_number = pred_logits[:, :, (3 * i + 1), :].argmax(axis=-1).flatten()
687cur_input_seq = np.hstack(
688(cur_input_seq, jnp.reshape(max_number, newshape=(-1, 1)))
689)
690for j in range(
691len(cur_input_seq)
692): ## Iterate through all examples in batch
693total_pred += 1
694try:
695sudoku.SudokuBoardStateUpdate(
696puzzle_sol[j],
697cur_input_seq[j, -3],
698cur_input_seq[j, -2],
699cur_input_seq[j, -1],
700)
701except AssertionError:
702pass
703else:
704sucess_pred += 1
705
706eval_metrics["hinted_acc"].append(sucess_pred * 1.0 / total_pred)
707return eval_metrics
708
709
710def get_internal_model_stats(
711cur_input_seq,
712state,
713p_eval_step,
714input_seq,
715candidate_list,
716config,
717eval_metrics,
718):
719"""This function computes the internal model stats."""
720
721for i in range(10): ### Checks internal model stats at [35, 40, 45,..., 80]
722## Find already filled cell upto 35th position
723filled_cells = np.zeros((len(cur_input_seq), 81), dtype=np.int8)
724
725for i1 in range(len(cur_input_seq)):
726for j1 in range(5 * i + 35):
727cell_pos = int(
728(cur_input_seq[i1, 3 * j1] - 1) * 9
729+ (cur_input_seq[i1, 3 * j1 + 1] - 1)
730)
731filled_cells[i1, cell_pos] = 1
732
733cur_board_state = cur_input_seq[:, : (3 * (5 * i + 35))]
734correct_pred = 0
735total_pred = 0
736
737for j in range(81):
738row = (j // 9) + 1
739col = (j % 9) + 1
740test_board_state = np.hstack((
741cur_board_state,
742np.ones((len(cur_input_seq), 1), dtype=np.int8) * row,
743))
744test_board_state = np.hstack((
745test_board_state,
746np.ones((len(cur_input_seq), 1), dtype=np.int8) * col,
747))
748
749pred_logits = get_pred_logits(
750test_board_state, input_seq, state, p_eval_step, config
751)
752
753pos = 3 * (5 * i + 35) + 1
754pred_logits = pred_logits[:, :, pos, :].reshape(
755(len(cur_input_seq), pred_logits.shape[-1])
756)
757
758for k in range(len(cur_input_seq)):
759
760num_candidates = np.sum(candidate_list[k, i, j])
761if filled_cells[k, j] == 1 or num_candidates == 0:
762continue
763
764model_candidates = pred_logits[k].argpartition(
765-num_candidates, axis=-1
766)[-num_candidates:]
767correct_pred += np.sum(candidate_list[k, i, j][model_candidates - 1])
768total_pred += num_candidates
769
770eval_metrics["intermediate_calc_acc" + str(5 * i + 35)].append(
771correct_pred * 1.0 / total_pred
772)
773return eval_metrics
774
775
776def get_sudoku_eval_metrics(
777state, eval_data_iter, p_eval_step, config
778):
779"""This function computes given evaluation metrics (e.g, accuracy).
780
781Args:
782state: contains model parameters, optimizer, etc.
783eval_data_iter: data iterator for evaluation dataset
784p_eval_step: pmap function for forward pass of model for evaluation
785config: general config file
786
787Returns:
788eval_metrics: contains list of evaluation metrics for each batch
789"""
790
791eval_metrics = {
792"acc": [],
793"hinted_acc": [],
794"acc_complete_puzzle": [],
795"edit_distance": [],
796"set_acc1": [],
797"set_acc2": [],
798"set_acc3": [],
799"correct_cnt1": [],
800"correct_cnt2": [],
801"correct_cnt3": [],
802}
803
804eval_metrics.update(
805{"intermediate_calc_acc" + str(5 * i + 35): [] for i in range(10)}
806)
807
808mistakes = []
809mistake_pos = np.zeros(81, dtype=np.int32)
810first_mistake_pos = np.zeros(81, dtype=np.int32)
811first_mistake_strategies = np.zeros(8, dtype=np.int32)
812total_strategies = np.zeros(8, dtype=np.int32)
813mistakes_metrics = {
814"mistakes": mistakes,
815"mistake_pos": mistake_pos,
816"first_mistake_pos": first_mistake_pos,
817"first_mistake_strategies": first_mistake_strategies,
818"total_strategies": total_strategies,
819}
820
821for eval_epoch in range(config.eval_epochs):
822with jax.profiler.StepTraceAnnotation("eval", step_num=eval_epoch):
823
824batch_tuple = next(eval_data_iter)
825
826# Input seq is of the shape (batchsize, 3*81) and 3*81 because row, column
827# and value for each cell. Row, column and value all are in {1, ..., 9}
828input_seq = np.array(batch_tuple[0])
829
830# Puzzle solution is of the shape (batchsize, 81). Each pos in {0,.., 80}
831# for each puzzle contains value at cell (pos//9+1, pos%9 + 1)
832puzzle_sol = np.array(batch_tuple[1])
833
834cur_input_seq = input_seq[:, : (config.start_index * 3)]
835
836set_acc, correct_cnt, _ = get_set_accuracies(
837state, p_eval_step, input_seq, config
838)
839
840eval_metrics = set_set_accuracies(eval_metrics, set_acc, correct_cnt)
841
842eval_metrics, mistakes_metrics, cur_input_seq = get_accuracy(
843cur_input_seq,
844state,
845p_eval_step,
846input_seq,
847puzzle_sol,
848config,
849eval_metrics,
850mistakes_metrics,
851)
852
853eval_metrics = get_position_hinted_eval_acc(
854input_seq, puzzle_sol, state, p_eval_step, eval_metrics, config
855)
856
857correct_eval_sudoku_puzzle = 0
858solution_edit_distance = 0.0
859
860for i, _ in enumerate(cur_input_seq):
861correct_eval_sudoku_puzzle += valid_solution(cur_input_seq[i])
862solution_edit_distance += get_edit_distance(
863config, cur_input_seq[i], input_seq[i]
864)
865
866eval_metrics["acc_complete_puzzle"].append(
867correct_eval_sudoku_puzzle * 1.0 / len(cur_input_seq)
868)
869
870eval_metrics["edit_distance"].append(
871solution_edit_distance * 1.0 / len(cur_input_seq)
872)
873return eval_metrics, mistakes_metrics
874
875
876def get_eval_metrics(
877step, state, eval_data_iter, p_eval_step, config
878):
879if config.dataset == "othello":
880return get_othello_eval_metrics(
881state, eval_data_iter, p_eval_step, config
882)
883elif "sudoku" in config.dataset:
884return get_sudoku_eval_metrics(
885step, state, eval_data_iter, p_eval_step, config
886)
887
888
889def plot_to_image(figure):
890"""Converts the matplotlib plot specified by 'figure' to a PNG image and returns it."""
891# The supplied figure is closed and inaccessible after this call.
892buf = io.BytesIO()
893plt.savefig(buf, format="png")
894plt.close(figure)
895buf.seek(0)
896
897image = tf.image.decode_png(buf.getvalue(), channels=4)
898image = tf.expand_dims(image, 0)
899return image
900
901
902def plot_ax(ax, num, wr, wc):
903"""Plots the given axis with the given number of values."""
904for i in range(9):
905for j in range(9):
906if num[i, j] == 0:
907continue
908ax.text(
909i + 0.5, (8 - j) + 0.5, str(int(num[i, j])), ha="center", va="center"
910)
911
912ax.axis([0, 9, 0, 9])
913
914rect = matplotlib.patches.Rectangle((wr, 8 - wc), 1, 1, color="red")
915ax.add_patch(rect)
916
917for axis in [ax.xaxis, ax.yaxis]:
918axis.set_minor_locator(mticker.MultipleLocator(1))
919axis.set_major_locator(mticker.MultipleLocator(3))
920# axis.set_ticks(np.arange(maxnum) + 0.5)
921# axis.set_ticklabels(range(maxnum))
922
923ax.grid(which="minor")
924# ax.axis('off')
925ax.xaxis.set_ticks_position("top")
926
927ax.hlines(y=3, xmin=0, xmax=10, color="0")
928ax.hlines(y=6, xmin=0, xmax=10, color="0")
929ax.vlines(x=6, ymin=0, ymax=10, color="0")
930ax.vlines(x=3, ymin=0, ymax=10, color="0")
931