google-research
555 строк · 17.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluation related functions."""
17
18from flax.training import common_utils
19import jax
20from jax import numpy as jnp
21import numpy as np
22
23from sudoku_gpt import model
24from sudoku_gpt import othello
25from sudoku_gpt import sudoku
26
27
28def valid_solution(output_seq):
29"""Checks if the output sequence is a valid solution for the sudoku puzzle."""
30## returns 1 if correct solution, otherwise returns 0
31rows = np.zeros((9, 9))
32cols = np.zeros((9, 9))
33boxes = np.zeros((9, 9))
34
35for j in range(81):
36if int(output_seq[3 * j + 2] - 1) > 8:
37return False
38if int(output_seq[3 * j] - 1) > 8:
39return False
40if int(output_seq[3 * j + 1] - 1) > 8:
41return False
42row_num = int(output_seq[3 * j] - 1)
43col_num = int(output_seq[3 * j + 1] - 1)
44rows[row_num, int(output_seq[3 * j + 2] - 1)] += 1
45cols[col_num, int(output_seq[3 * j + 2] - 1)] += 1
46boxes[
47int(3 * (row_num // 3) + (col_num // 3)), int(output_seq[3 * j + 2] - 1)
48] += 1
49
50if np.all(rows) and np.all(cols) and np.all(boxes):
51return True
52else:
53return False
54
55
56def eval_step(state, batch, config):
57pred_logits = model.TransformerLMHeadModel(config).apply(
58{"params": state.params}, batch)
59return pred_logits
60
61
62def get_othello_eval_metrics(state, eval_data_iter, p_eval_step, config):
63"""Get evaluation metrics for Othello game.
64
65Args:
66state:
67eval_data_iter: Iterator for evaluation dataset.
68p_eval_step: Function to compute forward pass on a single evaluation batch.
69config: The config for the experiment.
70
71Returns:
72
73"""
74eval_metrics = {"acc": []}
75for eval_epoch in range(config.eval_epochs):
76with jax.profiler.StepTraceAnnotation("eval", step_num=eval_epoch):
77
78batch = np.array(next(eval_data_iter))
79total_pred, sucess_pred = 0, 0
80
81for i in range(config.seq_len):
82padding = np.zeros((batch.shape[0],
83config.seq_len - (i+1)), dtype=np.int32)
84concat_batch = np.hstack((batch[:, :(i + 1)], padding))
85concat_batch = common_utils.shard(
86jax.tree_util.tree_map(np.asarray, concat_batch)
87)
88pred_logits = p_eval_step(state, concat_batch)
89
90max_action = pred_logits[:, :, i, :].argmax(axis=-1)
91pred_seq = np.hstack((batch[:, :(i + 1)],
92jnp.reshape(max_action, newshape=(-1, 1))))
93
94for j in range(pred_seq.shape[0]):
95## When length of the game is small, then the model can simply keep
96## predicting the next token which will increase the accuracy
97total_pred += 1
98try:
99othello.OthelloBoardState().update(pred_seq[j], prt=False)
100except AssertionError:
101### Wrong prediction
102pass
103else:
104sucess_pred += 1
105
106eval_metrics["acc"].append(sucess_pred * 1.0/ total_pred)
107
108return eval_metrics
109
110
111def get_edit_distance(config, generated_input_seq, original_input_seq):
112"""Get edit distance between generated input and original input."""
113total_distance = 0
114for i in range(config.start_index, config.block_size):
115# Iterate through model's output
116flg = False
117for j in range(config.start_index, config.block_size):
118
119# Iterate through solver's output to find the location of model's output
120same_row = generated_input_seq[3 * i] == original_input_seq[3 * j]
121same_col = (
122generated_input_seq[3 * i + 1] == original_input_seq[3 * j + 1]
123)
124if same_row and same_col:
125
126# When model's output cell location and solver's output cell location
127# matches, then calculate edit distance.
128total_distance += abs(j - i)
129flg = True
130break
131
132if not flg:
133total_distance += abs(config.block_size - i)
134
135return total_distance
136
137
138def get_set_accuracy_for_pairs(
139pairs,
140state,
141p_eval_step,
142input_seq,
143possible_vals,
144given_vals,
145config,
146):
147"""Computes accuracy of set of possible values."""
148correct_cnt = np.zeros(9)
149total_cnt = np.zeros(9)
150
151min_start_index = 31
152for i in range(len(pairs)):
153cur_input_seq = np.hstack(
154(input_seq[:, : (min_start_index * 3)], pairs[i] + 1)
155)
156
157padding = np.zeros(
158(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
159dtype=np.int32,
160)
161concat_batch = np.hstack((cur_input_seq, padding))
162
163concat_batch = common_utils.shard(
164jax.tree_util.tree_map(np.asarray, concat_batch)
165)
166
167pred_logits = p_eval_step(state, concat_batch)
168
169cur_pred_logits = pred_logits[:, :, 3 * min_start_index + 1, :].reshape(
170(-1, pred_logits.shape[-1])
171)
172
173for k in range(input_seq.shape[0]):
174if given_vals[k, pairs[i, k, 0], pairs[i, k, 1]] == 1:
175continue
176total_possible_vals = np.int32(
177np.sum(possible_vals[k, pairs[i, k, 0], pairs[i, k, 1], :])
178)
179ordering_ind = np.argsort(cur_pred_logits[np.int32(k), :])[::-1][
180:total_possible_vals
181]
182
183assert len(ordering_ind) <= 9
184
185for t, ind in enumerate(ordering_ind):
186if ind <= 9 and ind >= 1:
187correct_cnt[t] += (
188possible_vals[k, pairs[i, k, 0], pairs[i, k, 1], ind - 1] == 1
189)
190
191total_cnt[t] += 1
192
193accuracy = np.ones(9)
194for i in range(9):
195if total_cnt[i] > 0:
196accuracy[i] = correct_cnt[i] / total_cnt[i]
197
198return accuracy, correct_cnt, total_cnt
199
200
201def get_sampled_pairs(input_seq, pred_logits, state, p_eval_step, config, key):
202"""Computes the sampled pairs at config.start_index + 1 location and return them as pairs."""
203pairs_set = []
204for _ in range(input_seq.shape[0]):
205pairs_set.append(set())
206
207pairs = np.zeros(
208(config.set_accuracy_top_k, input_seq.shape[0], 2), dtype=np.int32
209)
210flag = True ## Denotes if we want to sample next time or not.
211
212while flag:
213pred_logits_row = pred_logits[:, :, 3 * config.start_index - 1, :].reshape(
214(-1, pred_logits.shape[-1])
215)
216rkey, key = jax.random.split(key, 2)
217
218pair_row = jax.random.categorical(rkey, pred_logits_row)
219
220assert len(pair_row) == input_seq.shape[0] and pair_row.ndim == 1
221
222cur_input_seq = np.hstack(
223(input_seq[:, : (config.start_index * 3)], pair_row.reshape(-1, 1))
224)
225
226padding = np.zeros(
227(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
228dtype=np.int32,
229)
230
231concat_batch = np.hstack((cur_input_seq, padding))
232
233concat_batch = common_utils.shard(
234jax.tree_util.tree_map(np.asarray, concat_batch)
235)
236
237pred_logits = p_eval_step(state, concat_batch)
238pred_logits_col = pred_logits[:, :, 3 * config.start_index, :].reshape(
239(-1, pred_logits.shape[-1])
240)
241
242rkey, key = jax.random.split(key, 2)
243pair_col = jax.random.categorical(rkey, pred_logits_col)
244
245assert len(pair_col) == input_seq.shape[0] and pair_col.ndim == 1
246
247flag = False
248for i in range(input_seq.shape[0]):
249if pair_row[i] < 1 or pair_row[i] > 9:
250continue
251
252if pair_col[i] < 1 or pair_col[i] > 9:
253continue
254pairs_set[i].add(tuple((int(pair_row[i]), int(pair_col[i]))))
255if len(pairs_set[i]) < config.set_accuracy_top_k:
256flag = True
257
258for i in range(input_seq.shape[0]):
259j = 0
260for a_pair in pairs_set[i]:
261pairs[j, i, 0] = int(a_pair[0] - 1)
262pairs[j, i, 1] = int(a_pair[1] - 1)
263j += 1
264
265if j == config.set_accuracy_top_k:
266break
267
268return pairs
269
270
271def get_topk_probability_pairs(
272input_seq, pred_logits, state, p_eval_step, config
273):
274"""This function computes the top k most probable pairs at config.start_index + 1 location and return them as pairs."""
275
276min_start_index = 31
277pairs = np.zeros(
278(config.set_accuracy_top_k, input_seq.shape[0], 2), dtype=np.int32
279)
280pred_logits_row = pred_logits[:, :, 3 * min_start_index - 1, :].reshape(
281(-1, pred_logits.shape[-1])
282)
283
284# Row log probability
285row_log_prob = jax.nn.log_softmax(pred_logits_row[:, 1:10])
286
287pairs_log_prob = np.zeros((input_seq.shape[0], 81))
288
289for i in range(9):
290row_num = np.ones((input_seq.shape[0], 1), dtype=np.int32) * (i + 1)
291cur_input_seq = np.hstack((input_seq[:, : (min_start_index * 3)], row_num))
292
293padding = np.zeros(
294(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
295dtype=np.int32,
296)
297
298concat_batch = np.hstack((cur_input_seq, padding))
299
300concat_batch = common_utils.shard(
301jax.tree_util.tree_map(np.asarray, concat_batch)
302)
303
304pred_logits_col = p_eval_step(state, concat_batch)
305pred_logits_col = pred_logits_col[:, :, 3 * min_start_index, :].reshape(
306(-1, pred_logits.shape[-1])
307)
308
309# Column log probability
310col_log_prob = jax.nn.log_softmax(pred_logits_col[:, 1:10])
311
312# Calculates log probability for each cell by combining log probability for
313# each row and each column
314for j in range(input_seq.shape[0]):
315for k in range(9):
316pairs_log_prob[j, i * 9 + k] = col_log_prob[j, k] + row_log_prob[j, i]
317
318for i in range(input_seq.shape[0]):
319# Selects top k most probable cells
320topk_indices = np.argsort(pairs_log_prob[i, :])[::-1][
321: config.set_accuracy_top_k
322]
323for j, ind in enumerate(topk_indices):
324pairs[j, i, 0] = ind // 9
325pairs[j, i, 1] = ind % 9
326
327return pairs
328
329
330def get_set_accuracies(state, p_eval_step, input_seq, config):
331"""This function computes set accuracies for empty cells in the puzzle at config.start_index + 1 location."""
332
333possible_vals = np.ones((input_seq.shape[0], 9, 9, 9))
334given_vals = np.zeros((input_seq.shape[0], 9, 9))
335
336min_start_index = 31
337for i in range(input_seq.shape[0]):
338for j in range(min_start_index):
339row_num = input_seq[i, 3 * j] - 1
340col_num = input_seq[i, 3 * j + 1] - 1
341val = input_seq[i, 3 * j + 2] - 1
342
343possible_vals[i, row_num, :, val] = 0
344possible_vals[i, :, col_num, val] = 0
345
346given_vals[i, row_num, col_num] = 1
347
348if config.set_accuracy == "top-k":
349# Computes the set accuracy for top k most probable positions
350# at config.start_index + 1 location
351cur_input_seq = input_seq[:, : (min_start_index * 3)]
352padding = np.zeros(
353(input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
354dtype=np.int32,
355)
356
357concat_batch = np.hstack((cur_input_seq, padding))
358
359concat_batch = common_utils.shard(
360jax.tree_util.tree_map(np.asarray, concat_batch)
361)
362
363key = jax.random.PRNGKey(98)
364pred_logits = p_eval_step(state, concat_batch)
365
366# pairs = get_sampled_pairs(input_seq, pred_logits, state,
367# p_eval_step, config, key)
368
369print("get_topk_probability_pairs", flush=True)
370pairs = get_topk_probability_pairs(
371input_seq, pred_logits, state, p_eval_step, config, key
372)
373print("got_topk_probability_pairs", flush=True)
374return get_set_accuracy_for_pairs(
375pairs,
376state,
377p_eval_step,
378input_seq,
379possible_vals,
380given_vals,
381config,
382)
383
384elif config.set_accuracy == "all":
385# Computes the set accuracy for all the pairs at config.start_index + 1
386# location
387
388pairs = np.zeros((81, input_seq.shape[0], 2), dtype=np.int32)
389for i in range(81):
390pairs[i, :, 0] = np.ones(input_seq.shape[0], dtype=np.int32) * (i // 9)
391pairs[i, :, 1] = np.ones(input_seq.shape[0], dtype=np.int32) * (i % 9)
392
393# After computing pairs for which we want set accuracy
394# (config.set_accuracy == "all" => pairs contain all position)
395# (config.set_accuracy == "top-k" => pairs containing top-k most probable)
396return get_set_accuracy_for_pairs(
397pairs,
398state,
399p_eval_step,
400input_seq,
401possible_vals,
402given_vals,
403config,
404)
405
406
407def get_sudoku_eval_metrics(state, eval_data_iter, p_eval_step, config):
408"""This function computes given evaluation metrics (e.g, accuracy) in eval metrics for each batch and appends the metric in the list of eval_metrics.
409
410Args:
411state: contains model parameters, optimizer, etc.
412eval_data_iter: data iterator for evaluation dataset
413p_eval_step: pmap function for forward pass of model for evaluation
414config: general experiment config file
415
416Returns:
417eval_metrics: contains list of evaluation metrics for each batch
418"""
419
420eval_metrics = {
421"acc": [],
422"acc_complete_puzzle": [],
423"edit_distance": [],
424"set_acc1": [],
425"set_acc2": [],
426"set_acc3": [],
427"correct_cnt1": [],
428"correct_cnt2": [],
429"correct_cnt3": [],
430}
431
432for eval_epoch in range(config.eval_epochs):
433with jax.profiler.StepTraceAnnotation("eval", step_num=eval_epoch):
434
435batch_tuple = next(eval_data_iter)
436
437# Input seq is of the shape (batchsize, 3*81) and 3*81 because row, column
438# and value for each cell. Row, column and value all are in {1, ..., 9}
439input_seq = np.array(batch_tuple[0])
440
441# Puzzle solution is of the shape (batchsize, 81). Each pos in {0,.., 80}
442# for each puzzle contains value at cell (pos//9+1, pos%9 + 1)
443puzzle_sol = np.array(batch_tuple[1])
444starting_index = np.array(batch_tuple[2])
445total_pred, sucess_pred = 0, 0
446
447min_start_index = 31
448starting_index = (
449np.ones_like(starting_index, dtype=np.int32) * min_start_index
450)
451cur_input_seq = input_seq[:, :(config.start_index*3)]
452
453# Computes set accuracy for empty cells in the puzzle
454set_acc, correct_cnt, _ = get_set_accuracies(
455state, p_eval_step, input_seq, config
456)
457
458eval_metrics["set_acc1"].append(set_acc[0])
459eval_metrics["set_acc2"].append(set_acc[1])
460eval_metrics["set_acc3"].append(set_acc[2])
461
462eval_metrics["correct_cnt1"].append(correct_cnt[0])
463eval_metrics["correct_cnt2"].append(correct_cnt[1])
464eval_metrics["correct_cnt3"].append(correct_cnt[2])
465
466for i in range(min_start_index * 3, config.seq_len):
467### In i^th iteration, i^th number in sequence will predict
468padding = np.zeros((input_seq.shape[0],
469config.seq_len - len(cur_input_seq[0])),
470dtype=np.int32)
471concat_batch = np.hstack((cur_input_seq, padding))
472concat_batch = common_utils.shard(
473jax.tree_util.tree_map(np.asarray, concat_batch)
474)
475
476pred_logits = p_eval_step(state, concat_batch)
477
478if i%3 == 2:
479# Model predicts the value at the cell (cur_input_seq[j][i-2],
480# cur_input_seq[j][i-1])
481max_number = pred_logits[:, :, i-1, :].argmax(axis=-1).flatten()
482mask_arr = np.array(i >= (3 * starting_index))
483
484next_number = max_number * mask_arr + (1 - mask_arr) * input_seq[:, i]
485
486cur_input_seq = np.hstack(
487(cur_input_seq, jnp.reshape(next_number, newshape=(-1, 1)))
488)
489
490# Iterate through all examples in batch and calculate successful
491# predictions of numbers
492for j in range(len(cur_input_seq)):
493if not mask_arr[j]:
494continue
495
496total_pred += 1
497try:
498sudoku.SudokuBoardStateUpdate(puzzle_sol[j],
499cur_input_seq[j][i-2],
500cur_input_seq[j][i-1],
501cur_input_seq[j][i])
502except AssertionError:
503### Wrong update
504# if cur_input_seq[j, i-2] * 9 + cur_input_seq[j, i-1] <= 80:
505# print(puzzle_sol[j][cur_input_seq[j, i-2] * 9 +
506# cur_input_seq[j,i-1]],
507# cur_input_seq[j][i])
508pass
509else:
510sucess_pred += 1
511else:
512# Model predicts either a row number or column number
513max_pos = pred_logits[:, :, i-1, :].argmax(axis=-1).flatten()
514mask = i >= (3 * starting_index)
515next_pos = max_pos * mask + (1 - mask) * input_seq[:, i]
516cur_input_seq = np.hstack(
517(cur_input_seq, jnp.reshape(next_pos, newshape=(-1, 1)))
518)
519
520eval_metrics["acc"].append(sucess_pred * 1.0/ total_pred)
521
522correct_eval_sudoku_puzzle = 0
523solution_edit_distance = 0.0
524
525for i in range(len(cur_input_seq)):
526
527# increase correct_eval_sudoku_puzzle when the model output solution
528# for a given puzzle is correct
529correct_eval_sudoku_puzzle += valid_solution(cur_input_seq[i])
530
531# edit distance = distance between model's output order and solver's
532# output order
533solution_edit_distance += get_edit_distance(
534config, cur_input_seq[i], input_seq[i]
535)
536
537eval_metrics["acc_complete_puzzle"].append(
538correct_eval_sudoku_puzzle * 1.0 / len(cur_input_seq)
539)
540
541eval_metrics["edit_distance"].append(
542solution_edit_distance * 1.0 / len(cur_input_seq)
543)
544
545return eval_metrics
546
547
548def get_eval_metrics(state, eval_data_iter,
549p_eval_step, config):
550if config.dataset == "othello":
551return get_othello_eval_metrics(state, eval_data_iter, p_eval_step,
552config)
553elif "sudoku" in config.dataset:
554return get_sudoku_eval_metrics(state, eval_data_iter, p_eval_step,
555config)
556