google-research

Форк
0
/
inference_eval_utils.py 
930 строк · 28.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""This file contains functions for evaluation of the trained model."""
17

18
import io
19

20
from flax.training import common_utils
21
import jax
22
from jax import numpy as jnp
23
import matplotlib
24
import matplotlib.pyplot as plt
25
import matplotlib.ticker as mticker
26
import numpy as np
27
import tensorflow as tf
28

29
from sudoku_gpt import model
30
from sudoku_gpt import othello
31
from sudoku_gpt import sudoku
32

33

34
def valid_solution(output_seq):
35
  """Checks if the output sequence is valid."""
36
  ## returns 1 if correct solution, otherwise returns 0
37
  rows = np.zeros((9, 9))
38
  cols = np.zeros((9, 9))
39
  boxes = np.zeros((9, 9))
40

41
  for j in range(81):
42
    if int(output_seq[3 * j + 2] - 1) > 8:
43
      return False
44
    if int(output_seq[3 * j] - 1) > 8:
45
      return False
46
    if int(output_seq[3 * j + 1] - 1) > 8:
47
      return False
48
    row_num = int(output_seq[3 * j] - 1)
49
    col_num = int(output_seq[3 * j + 1] - 1)
50
    rows[row_num, int(output_seq[3 * j + 2] - 1)] += 1
51
    cols[col_num, int(output_seq[3 * j + 2] - 1)] += 1
52
    boxes[
53
        int(3 * (row_num // 3) + (col_num // 3)), int(output_seq[3 * j + 2] - 1)
54
    ] += 1
55

56
  if np.all(rows) and np.all(cols) and np.all(boxes):
57
    return True
58
  else:
59
    return False
60

61

62
def eval_step(state, batch, config):
63
  pred_logits = model.TransformerLMHeadModel(config).apply(
64
      {"params": state.params}, batch
65
  )
66
  return pred_logits
67

68

69
def get_othello_eval_metrics(
70
    state, eval_data_iter, p_eval_step, config
71
    ):
72
  """Get eval metrics for Othello dataset."""
73
  eval_metrics = {"acc": []}
74
  for eval_epoch in range(config.eval_epochs):
75
    with jax.profiler.StepTraceAnnotation("eval", step_num=eval_epoch):
76

77
      batch = np.array(next(eval_data_iter))
78
      total_pred, sucess_pred = 0, 0
79

80
      for i in range(config.seq_len):
81
        padding = np.zeros(
82
            (batch.shape[0], config.seq_len - (i + 1)), dtype=np.int32
83
        )
84
        concat_batch = np.hstack((batch[:, : (i + 1)], padding))
85
        concat_batch = common_utils.shard(
86
            jax.tree_util.tree_map(np.asarray, concat_batch)
87
        )
88
        pred_logits = p_eval_step(state, concat_batch)
89

90
        max_action = pred_logits[:, :, i, :].argmax(axis=-1)
91
        pred_seq = np.hstack(
92
            (batch[:, : (i + 1)], jnp.reshape(max_action, newshape=(-1, 1)))
93
        )
94

95
        for j in range(pred_seq.shape[0]):
96
          ## When length of the game is small, then the model can simply keep
97
          ## predicting the next token which will increase the accuracy
98
          total_pred += 1
99
          try:
100
            othello.OthelloBoardState().update(pred_seq[j], prt=False)
101
          except AssertionError:
102
            ### Wrong prediction
103
            pass
104
          else:
105
            sucess_pred += 1
106

107
      eval_metrics["acc"].append(sucess_pred * 1.0 / total_pred)
108

109
  return eval_metrics
110

111

112
def get_edit_distance(config, generated_input_seq, original_input_seq):
113
  """Get the edit distance between model's output and solver's output."""
114
  total_distance = 0
115

116
  for i in range(config.start_index, config.block_size):
117
    # Iterate through model's output
118
    flg = False
119
    for j in range(config.start_index, config.block_size):
120

121
      # Iterate through solver's output to find the location of model's output
122
      same_row = generated_input_seq[3 * i] == original_input_seq[3 * j]
123
      same_col = (
124
          generated_input_seq[3 * i + 1] == original_input_seq[3 * j + 1]
125
      )
126
      if same_row and same_col:
127

128
        # When model's output cell location and solver's output cell location
129
        # matches, then calculate edit distance.
130
        total_distance += abs(j - i)
131
        flg = True
132
        break
133

134
    if not flg:
135
      total_distance += abs(config.block_size - i)
136

137
  return total_distance
138

139

140
def get_set_accuracy_for_pairs(
141
    pairs,
142
    state,
143
    p_eval_step,
144
    input_seq,
145
    possible_vals,
146
    given_vals,
147
    config,
148
    ):
149
  """Get the accuracy of the set of possible values for different cell positions."""
150
  correct_cnt = np.zeros(9)
151
  total_cnt = np.zeros(9)
152

153
  for i in range(len(pairs)):
154
    cur_input_seq = np.hstack(
155
        (input_seq[:, : (config.start_index * 3)], pairs[i] + 1)
156
    )
157

158
    padding = np.zeros(
159
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
160
        dtype=np.int32,
161
    )
162
    concat_batch = np.hstack((cur_input_seq, padding))
163

164
    concat_batch = common_utils.shard(
165
        jax.tree_util.tree_map(np.asarray, concat_batch)
166
    )
167

168
    pred_logits = p_eval_step(state, concat_batch)
169

170
    cur_pred_logits = pred_logits[:, :, 3 * config.start_index + 1, :].reshape(
171
        (-1, pred_logits.shape[-1])
172
    )
173

174
    for k in range(input_seq.shape[0]):
175
      if given_vals[k, pairs[i, k, 0], pairs[i, k, 1]] == 1:
176
        continue
177
      total_possible_vals = np.int32(
178
          np.sum(possible_vals[k, pairs[i, k, 0], pairs[i, k, 1], :])
179
      )
180
      ordering_ind = np.argsort(cur_pred_logits[np.int32(k), :])[::-1][
181
          :total_possible_vals
182
      ]
183

184
      assert len(ordering_ind) <= 9
185

186
      for t, ind in enumerate(ordering_ind):
187
        if ind <= 9 and ind >= 1:
188
          correct_cnt[t] += (
189
              possible_vals[k, pairs[i, k, 0], pairs[i, k, 1], ind - 1] == 1
190
          )
191

192
        total_cnt[t] += 1
193

194
  accuracy = np.ones(9)
195
  for i in range(9):
196
    if total_cnt[i] > 0:
197
      accuracy[i] = correct_cnt[i] / total_cnt[i]
198

199
  return accuracy, correct_cnt, total_cnt
200

201

202
def get_sampled_pairs(input_seq, pred_logits, state, p_eval_step, config, key):
203
  """Get sampled pairs in a sequence."""
204
  pairs_set = []
205
  for _ in range(input_seq.shape[0]):
206
    pairs_set.append(set())
207

208
  pairs = np.zeros(
209
      (config.set_accuracy_top_k, input_seq.shape[0], 2), dtype=np.int32
210
  )
211
  flag = True  # Denotes if we want to sample next time or not.
212

213
  while flag:
214
    pred_logits_row = pred_logits[:, :, 3 * config.start_index - 1, :].reshape(
215
        (-1, pred_logits.shape[-1])
216
    )
217
    rkey, key = jax.random.split(key, 2)
218

219
    pair_row = jax.random.categorical(rkey, pred_logits_row)
220

221
    assert len(pair_row) == input_seq.shape[0] and pair_row.ndim == 1
222

223
    cur_input_seq = np.hstack(
224
        (input_seq[:, : (config.start_index * 3)], pair_row.reshape(-1, 1))
225
    )
226

227
    padding = np.zeros(
228
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
229
        dtype=np.int32,
230
    )
231

232
    concat_batch = np.hstack((cur_input_seq, padding))
233

234
    concat_batch = common_utils.shard(
235
        jax.tree_util.tree_map(np.asarray, concat_batch)
236
    )
237

238
    pred_logits = p_eval_step(state, concat_batch)
239
    pred_logits_col = pred_logits[:, :, 3 * config.start_index, :].reshape(
240
        (-1, pred_logits.shape[-1])
241
    )
242

243
    rkey, key = jax.random.split(key, 2)
244
    pair_col = jax.random.categorical(rkey, pred_logits_col)
245

246
    assert len(pair_col) == input_seq.shape[0] and pair_col.ndim == 1
247

248
    flag = False
249
    for i in range(input_seq.shape[0]):
250
      if pair_row[i] < 1 or pair_row[i] > 9:
251
        continue
252

253
      if pair_col[i] < 1 or pair_col[i] > 9:
254
        continue
255
      pairs_set[i].add(tuple((int(pair_row[i]), int(pair_col[i]))))
256
      if len(pairs_set[i]) < config.set_accuracy_top_k:
257
        flag = True
258

259
  for i in range(input_seq.shape[0]):
260
    j = 0
261
    for a_pair in pairs_set[i]:
262
      pairs[j, i, 0] = int(a_pair[0] - 1)
263
      pairs[j, i, 1] = int(a_pair[1] - 1)
264
      j += 1
265

266
      if j == config.set_accuracy_top_k:
267
        break
268

269
  return pairs
270

271

272
def get_topk_probability_pairs(
273
    input_seq, pred_logits, state, p_eval_step, config
274
    ):
275
  """Get topk probability pairs in a sequence."""
276
  pairs = np.zeros(
277
      (config.set_accuracy_top_k, input_seq.shape[0], 2), dtype=np.int32
278
  )
279
  pred_logits_row = pred_logits[:, :, 3 * config.start_index - 1, :].reshape(
280
      (-1, pred_logits.shape[-1])
281
  )
282
  row_log_prob = jax.nn.log_softmax(pred_logits_row[:, 1:10])
283

284
  pairs_log_prob = np.zeros((input_seq.shape[0], 81))
285

286
  for i in range(9):
287
    row_num = np.ones((input_seq.shape[0], 1), dtype=np.int32) * (i + 1)
288
    cur_input_seq = np.hstack(
289
        (input_seq[:, : (config.start_index * 3)], row_num)
290
    )
291

292
    padding = np.zeros(
293
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
294
        dtype=np.int32,
295
    )
296

297
    concat_batch = np.hstack((cur_input_seq, padding))
298

299
    concat_batch = common_utils.shard(
300
        jax.tree_util.tree_map(np.asarray, concat_batch)
301
    )
302

303
    pred_logits_col = p_eval_step(state, concat_batch)
304
    pred_logits_col = pred_logits_col[:, :, 3 * config.start_index, :].reshape(
305
        (-1, pred_logits.shape[-1])
306
    )
307

308
    col_log_prob = jax.nn.log_softmax(pred_logits_col[:, 1:10])
309

310
    for j in range(input_seq.shape[0]):
311
      for k in range(9):
312
        pairs_log_prob[j, i * 9 + k] = col_log_prob[j, k] + row_log_prob[j, i]
313

314
  for i in range(input_seq.shape[0]):
315
    topk_indices = np.argsort(pairs_log_prob[i, :])[::-1][
316
        : config.set_accuracy_top_k
317
    ]
318
    for j, ind in enumerate(topk_indices):
319
      pairs[j, i, 0] = ind // 9
320
      pairs[j, i, 1] = ind % 9
321

322
  return pairs
323

324

325
def get_set_accuracies(state, p_eval_step, input_seq, config):
326
  """Get set accuracies in a sequence."""
327
  possible_vals = np.ones((input_seq.shape[0], 9, 9, 9))
328
  given_vals = np.zeros((input_seq.shape[0], 9, 9))
329

330
  for i in range(input_seq.shape[0]):
331
    for j in range(config.start_index):
332
      row_num = input_seq[i, 3 * j] - 1
333
      col_num = input_seq[i, 3 * j + 1] - 1
334
      val = input_seq[i, 3 * j + 2] - 1
335

336
      possible_vals[i, row_num, :, val] = 0
337
      possible_vals[i, :, col_num, val] = 0
338

339
      given_vals[i, row_num, col_num] = 1
340

341
  if config.set_accuracy == "top-k":
342
    cur_input_seq = input_seq[:, : (config.start_index * 3)]
343
    padding = np.zeros(
344
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
345
        dtype=np.int32,
346
    )
347

348
    concat_batch = np.hstack((cur_input_seq, padding))
349

350
    concat_batch = common_utils.shard(
351
        jax.tree_util.tree_map(np.asarray, concat_batch)
352
    )
353

354
    _ = jax.random.PRNGKey(98)
355
    pred_logits = p_eval_step(state, concat_batch)
356

357
    print("get_topk_probability_pairs", flush=True)
358
    pairs = get_topk_probability_pairs(
359
        input_seq, pred_logits, state, p_eval_step, config
360
    )
361
    print("got_topk_probability_pairs", flush=True)
362
    return get_set_accuracy_for_pairs(
363
        pairs,
364
        state,
365
        p_eval_step,
366
        input_seq,
367
        possible_vals,
368
        given_vals,
369
        config,
370
    )
371

372
  elif config.set_accuracy == "all":
373
    pairs = np.zeros((81, input_seq.shape[0], 2), dtype=np.int32)
374
    for i in range(81):
375
      pairs[i, :, 0] = np.ones(input_seq.shape[0], dtype=np.int32) * (i // 9)
376
      pairs[i, :, 1] = np.ones(input_seq.shape[0], dtype=np.int32) * (i % 9)
377

378
    return get_set_accuracy_for_pairs(
379
        pairs,
380
        state,
381
        p_eval_step,
382
        input_seq,
383
        possible_vals,
384
        given_vals,
385
        config,
386
    )
387

388

389
def get_pred_logits(cur_input_seq, input_seq, state, p_eval_step, config):
390
  padding = np.zeros(
391
      (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
392
      dtype=np.int32,
393
  )
394
  concat_batch = np.hstack((cur_input_seq, padding))
395
  concat_batch = common_utils.shard(
396
      jax.tree_util.tree_map(np.asarray, concat_batch)
397
  )
398

399
  pred_logits = p_eval_step(state, concat_batch)
400
  return pred_logits
401

402

403
def get_beam_search_candidates(
404
    input_seq, beam_search_candidates, state, p_eval_step, pos, config
405
    ):
406
  """Get beam search candidates for decoding."""
407
  new_beam_candidate_list = []
408
  new_beam_candidate_likelihood_list = []
409
  for i in range(len(beam_search_candidates)):
410
    ### Iterate through all the beam search candidates
411

412
    # predict the logits for row/column/value
413
    pred_logits = get_pred_logits(
414
        beam_search_candidates[i][0], input_seq, state, p_eval_step, config
415
    )
416

417
    # Choose top beam_search_n most probable predictions
418
    max_pos = (
419
        pred_logits[:, :, pos, :]
420
        .argpartition(-config.beam_search_n, axis=-1)[
421
            :, :, -config.beam_search_n :
422
        ]
423
        .reshape((-1, config.beam_search_n))
424
    )
425
    log_likelihood = jax.nn.log_softmax(pred_logits[:, :, pos, :]).reshape(
426
        (-1, pred_logits.shape[-1])
427
    )
428
    log_likelihood = np.take_along_axis(log_likelihood, max_pos, 1)
429

430
    # Append all of the candidates in new_beam_candidate_list
431
    for j in range(config.beam_search_n):
432
      cur_candidate = beam_search_candidates[i]
433
      new_beam_candidate = np.hstack(
434
          (cur_candidate[0], jnp.reshape(max_pos[:, j], newshape=(-1, 1)))
435
      )
436
      new_beam_candidate_likelihood = cur_candidate[1] + log_likelihood[:, j]
437
      new_beam_candidate_likelihood_list.append(new_beam_candidate_likelihood)
438
      new_beam_candidate_list.append(
439
          (new_beam_candidate, new_beam_candidate_likelihood, cur_candidate[2])
440
      )
441

442
  # Likelihood list for new candidates
443
  new_beam_candidate_likelihood_list = np.stack(
444
      new_beam_candidate_likelihood_list, axis=0
445
  )
446
  assert new_beam_candidate_likelihood_list.shape == (
447
      len(beam_search_candidates) * config.beam_search_n,
448
      config.minibatch_size,
449
  ), new_beam_candidate_likelihood_list.shape
450

451
  # Find index of top beam_search_n in new candidates
452
  new_beam_candidate_ind = new_beam_candidate_likelihood_list.argpartition(
453
      -config.beam_search_n, axis=0
454
  )[-config.beam_search_n :, :]
455
  assert new_beam_candidate_ind.shape == (
456
      config.beam_search_n,
457
      config.minibatch_size,
458
  ), new_beam_candidate_ind.shape
459

460
  # Create the new list by truncating to top beam_search_n candidate
461
  truncated_candidate_list = []
462
  for i in range(config.beam_search_n):
463
    new_candidate = np.zeros_like(new_beam_candidate_list[0][0])
464
    new_candidate_likelihood = np.zeros_like(new_beam_candidate_list[0][1])
465
    new_candidate_success_pred = np.zeros_like(new_beam_candidate_list[0][2])
466

467
    for j in range(config.minibatch_size):
468
      index = new_beam_candidate_ind[i, j]
469

470
      new_candidate[j] = new_beam_candidate_list[index][0][j]
471
      new_candidate_likelihood[j] = new_beam_candidate_list[index][1][j]
472
      new_candidate_success_pred[j] = new_beam_candidate_list[index][2][j]
473

474
    truncated_candidate_list.append(
475
        (new_candidate, new_candidate_likelihood, new_candidate_success_pred)
476
    )
477

478
  return truncated_candidate_list
479

480

481
def get_greedy_row_col(
482
    beam_search_candidates, pos, input_seq, state, p_eval_step, config
483
    ):
484
  """Perform greedy row and column decoding using beam search candidates."""
485

486
  ### Get beam search candidates for row
487
  beam_search_candidates = get_beam_search_candidates(
488
      input_seq, beam_search_candidates, state, p_eval_step, pos - 3, config
489
  )
490

491
  ### Get beam search candidates for column
492
  beam_search_candidates = get_beam_search_candidates(
493
      input_seq, beam_search_candidates, state, p_eval_step, pos - 2, config
494
  )
495
  ### Predict most confident column according to row
496
  # pred_logits = get_pred_logits(cur_input_seq, input_seq,
497
  #                                 state, p_eval_step, config)
498

499
  # max_pos = pred_logits[:, :, pos-2, :].argmax(axis=-1).flatten()
500
  # cur_input_seq = np.hstack((
501
      # cur_input_seq, jnp.reshape(max_pos, newshape=(-1, 1))))
502
  return beam_search_candidates
503

504

505
def get_greedy_pair(cur_input_seq, pos, input_seq, state, p_eval_step, config):
506
  """Get greedy pair decoding."""
507
  pred_logits = get_pred_logits(
508
      cur_input_seq, input_seq, state, p_eval_step, config
509
  )
510

511
  row_pred_logits = pred_logits[:, :, pos - 3, :].reshape(
512
      (-1, pred_logits.shape[-1])
513
  )
514
  row_log_prob = jax.nn.log_softmax(row_pred_logits[:, 1:10])
515

516
  pairs_log_prob = np.zeros((input_seq.shape[0], 81))
517

518
  for i in range(9):
519
    row_num = np.ones((input_seq.shape[0], 1), dtype=np.int32) * (i + 1)
520
    cur_input_seq = np.hstack((cur_input_seq, row_num))
521

522
    pred_logits_col = get_pred_logits(
523
        cur_input_seq, input_seq, state, p_eval_step, config
524
    )
525
    pred_logits_col = pred_logits_col[:, :, pos - 2, :].reshape(
526
        (-1, pred_logits.shape[-1])
527
    )
528

529
    col_log_prob = jax.nn.log_softmax(pred_logits_col[:, 1:10])
530

531
    for j in range(input_seq.shape[0]):
532
      for k in range(9):
533
        pairs_log_prob[j, i * 9 + k] = col_log_prob[j, k] + row_log_prob[j, i]
534

535
  pair = np.hstack((
536
      pairs_log_prob.argmax(axis=-1, keepdims=True) // 9,
537
      pairs_log_prob.argmax(axis=-1, keepdims=True) % 9,
538
  ))
539
  return np.hstack((cur_input_seq, pair))
540

541

542
def get_accuracy(
543
    cur_input_seq,
544
    state,
545
    p_eval_step,
546
    input_seq,
547
    puzzle_sol,
548
    config,
549
    eval_metrics,
550
    mistakes_metrics,
551
    ):
552
  """Get accuracy of a decoding sequence."""
553
  total_pred, _ = 0, 0
554

555
  ### Keeps tuple of best n sequences, log probability and correct pred for it
556
  beam_search_candidates = [(
557
      cur_input_seq,
558
      np.zeros(len(cur_input_seq)),
559
      np.zeros(len(cur_input_seq)),
560
  )]
561

562
  for i in range(config.start_index * 3 + 2, config.seq_len, 3):
563
    if config.sampling_method == "greedy-row-col":
564
      # greedy-row-col: selects first max probability row and
565
      #                 then max probability column.
566
      beam_search_candidates = get_greedy_row_col(
567
          beam_search_candidates, i, input_seq, state, p_eval_step, config
568
      )
569

570
    elif config.sampling_method == "greedy-pair":
571
      # greedy-pair: selects max probability (row, column) pair
572
      cur_input_seq = get_greedy_pair(
573
          cur_input_seq, i, input_seq, state, p_eval_step, config
574
      )
575

576
    beam_search_candidates = get_beam_search_candidates(
577
        input_seq, beam_search_candidates, state, p_eval_step, i - 1, config
578
    )
579

580
    total_pred += len(beam_search_candidates[0][0])
581
    for candidate in beam_search_candidates:
582
      for j in range(
583
          len(candidate[0])
584
      ):  ## Iterate through all examples in batch
585
        try:
586
          sudoku.SudokuBoardStateUpdate(
587
              puzzle_sol[j],
588
              candidate[0][j][i - 2],
589
              candidate[0][j][i - 1],
590
              candidate[0][j][i],
591
          )
592

593
          # row_num = cur_input_seq[j, i-2] - 1
594
          # col_num = cur_input_seq[j, i-1] - 1
595

596
          # strategy_id = input_seq_strategies[ j, row_num * 9 + col_num ]
597
          # mistakes_metrics['total_strategies'][ strategy_id ] += 1
598

599
        except AssertionError:
600
          # mistakes_metrics['mistakes'].append((
601
              # concat_batch[j], puzzle_sol[j]))
602
          # # if i < 81:
603
          # mistakes_metrics['mistake_pos'][i // 3] += 1
604
          # if first_mistake_ind[j] == 0:
605
          #   mistakes_metrics['first_mistake_pos'][i // 3] += 1
606

607
          #   row_num = cur_input_seq[j, i-2] - 1
608
          #   col_num = cur_input_seq[j, i-1] - 1
609
          #   strategy_id = input_seq_strategies[j, row_num * 9 + col_num ]
610
          #   mistakes_metrics['first_mistake_strategies'][ strategy_id ] += 1
611

612
          # g3pdb.set_trace()
613
          # if strategy_id == 0:
614
          #   g3pdb.set_trace()
615

616
          # first_mistake_ind[j] = 1
617
          pass
618
        else:
619
          candidate[2][j] += 1
620

621
    # cur_input_seq = input_seq[:, :(i+1)]
622

623
  max_prob_seq = np.zeros_like(beam_search_candidates[0][0])
624
  max_prob = np.zeros(
625
      (len(beam_search_candidates), beam_search_candidates[0][1].shape[0])
626
  )
627

628
  for j, candidate in enumerate(beam_search_candidates):
629
    max_prob[j, :] = candidate[1]
630

631
  max_prob_seq_ind = max_prob.argmax(axis=0)
632
  sucess_pred = np.zeros(len(max_prob_seq))
633

634
  for i in range(len(max_prob_seq)):
635
    max_prob_seq[i] = beam_search_candidates[max_prob_seq_ind[i]][0][i]
636
    sucess_pred[i] = beam_search_candidates[max_prob_seq_ind[i]][2][i]
637

638
  eval_metrics["acc"].append(sucess_pred.sum() * 1.0 / total_pred)
639
  return eval_metrics, mistakes_metrics, max_prob_seq
640

641

642
def set_set_accuracies(eval_metrics, set_acc, correct_cnt):
643
  eval_metrics["set_acc1"].append(set_acc[0])
644
  eval_metrics["set_acc2"].append(set_acc[1])
645
  eval_metrics["set_acc3"].append(set_acc[2])
646

647
  eval_metrics["correct_cnt1"].append(correct_cnt[0])
648
  eval_metrics["correct_cnt2"].append(correct_cnt[1])
649
  eval_metrics["correct_cnt3"].append(correct_cnt[2])
650

651
  return eval_metrics
652

653

654
def get_position_hinted_eval_acc(
655
    input_seq, puzzle_sol, state, p_eval_step, eval_metrics, config
656
    ):
657
  """This function computes the accuracy of the position hinted decoding model."""
658

659
  total_pred, sucess_pred = 0, 0
660

661
  cur_input_seq = input_seq[:, : (config.start_index * 3)]
662
  for i in range(config.start_index, config.block_size):
663
    ### i^th cell in sequence will predict
664

665
    # Append the row number from the ground truth sequence
666
    cur_input_seq = np.hstack(
667
        (cur_input_seq, jnp.reshape(input_seq[:, 3 * i], newshape=(-1, 1)))
668
    )
669

670
    # Append the column number from the ground truth sequence
671
    cur_input_seq = np.hstack(
672
        (cur_input_seq, jnp.reshape(input_seq[:, 3 * i + 1], newshape=(-1, 1)))
673
    )
674

675
    padding = np.zeros(
676
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
677
        dtype=np.int32,
678
    )
679
    concat_batch = np.hstack((cur_input_seq, padding))
680
    concat_batch = common_utils.shard(
681
        jax.tree_util.tree_map(np.asarray, concat_batch)
682
    )
683

684
    # Predict and append value at the pos chosen by the ground truth sequence
685
    pred_logits = p_eval_step(state, concat_batch)
686
    max_number = pred_logits[:, :, (3 * i + 1), :].argmax(axis=-1).flatten()
687
    cur_input_seq = np.hstack(
688
        (cur_input_seq, jnp.reshape(max_number, newshape=(-1, 1)))
689
    )
690
    for j in range(
691
        len(cur_input_seq)
692
    ):  ## Iterate through all examples in batch
693
      total_pred += 1
694
      try:
695
        sudoku.SudokuBoardStateUpdate(
696
            puzzle_sol[j],
697
            cur_input_seq[j, -3],
698
            cur_input_seq[j, -2],
699
            cur_input_seq[j, -1],
700
        )
701
      except AssertionError:
702
        pass
703
      else:
704
        sucess_pred += 1
705

706
  eval_metrics["hinted_acc"].append(sucess_pred * 1.0 / total_pred)
707
  return eval_metrics
708

709

710
def get_internal_model_stats(
711
    cur_input_seq,
712
    state,
713
    p_eval_step,
714
    input_seq,
715
    candidate_list,
716
    config,
717
    eval_metrics,
718
    ):
719
  """This function computes the internal model stats."""
720

721
  for i in range(10):  ### Checks internal model stats at [35, 40, 45,..., 80]
722
    ## Find already filled cell upto 35th position
723
    filled_cells = np.zeros((len(cur_input_seq), 81), dtype=np.int8)
724

725
    for i1 in range(len(cur_input_seq)):
726
      for j1 in range(5 * i + 35):
727
        cell_pos = int(
728
            (cur_input_seq[i1, 3 * j1] - 1) * 9
729
            + (cur_input_seq[i1, 3 * j1 + 1] - 1)
730
        )
731
        filled_cells[i1, cell_pos] = 1
732

733
    cur_board_state = cur_input_seq[:, : (3 * (5 * i + 35))]
734
    correct_pred = 0
735
    total_pred = 0
736

737
    for j in range(81):
738
      row = (j // 9) + 1
739
      col = (j % 9) + 1
740
      test_board_state = np.hstack((
741
          cur_board_state,
742
          np.ones((len(cur_input_seq), 1), dtype=np.int8) * row,
743
      ))
744
      test_board_state = np.hstack((
745
          test_board_state,
746
          np.ones((len(cur_input_seq), 1), dtype=np.int8) * col,
747
      ))
748

749
      pred_logits = get_pred_logits(
750
          test_board_state, input_seq, state, p_eval_step, config
751
      )
752

753
      pos = 3 * (5 * i + 35) + 1
754
      pred_logits = pred_logits[:, :, pos, :].reshape(
755
          (len(cur_input_seq), pred_logits.shape[-1])
756
      )
757

758
      for k in range(len(cur_input_seq)):
759

760
        num_candidates = np.sum(candidate_list[k, i, j])
761
        if filled_cells[k, j] == 1 or num_candidates == 0:
762
          continue
763

764
        model_candidates = pred_logits[k].argpartition(
765
            -num_candidates, axis=-1
766
        )[-num_candidates:]
767
        correct_pred += np.sum(candidate_list[k, i, j][model_candidates - 1])
768
        total_pred += num_candidates
769

770
    eval_metrics["intermediate_calc_acc" + str(5 * i + 35)].append(
771
        correct_pred * 1.0 / total_pred
772
    )
773
  return eval_metrics
774

775

776
def get_sudoku_eval_metrics(
777
    state, eval_data_iter, p_eval_step, config
778
    ):
779
  """This function computes given evaluation metrics (e.g, accuracy).
780

781
  Args:
782
    state: contains model parameters, optimizer, etc.
783
    eval_data_iter: data iterator for evaluation dataset
784
    p_eval_step: pmap function for forward pass of model for evaluation
785
    config: general config file
786

787
  Returns:
788
    eval_metrics: contains list of evaluation metrics for each batch
789
  """
790

791
  eval_metrics = {
792
      "acc": [],
793
      "hinted_acc": [],
794
      "acc_complete_puzzle": [],
795
      "edit_distance": [],
796
      "set_acc1": [],
797
      "set_acc2": [],
798
      "set_acc3": [],
799
      "correct_cnt1": [],
800
      "correct_cnt2": [],
801
      "correct_cnt3": [],
802
  }
803

804
  eval_metrics.update(
805
      {"intermediate_calc_acc" + str(5 * i + 35): [] for i in range(10)}
806
  )
807

808
  mistakes = []
809
  mistake_pos = np.zeros(81, dtype=np.int32)
810
  first_mistake_pos = np.zeros(81, dtype=np.int32)
811
  first_mistake_strategies = np.zeros(8, dtype=np.int32)
812
  total_strategies = np.zeros(8, dtype=np.int32)
813
  mistakes_metrics = {
814
      "mistakes": mistakes,
815
      "mistake_pos": mistake_pos,
816
      "first_mistake_pos": first_mistake_pos,
817
      "first_mistake_strategies": first_mistake_strategies,
818
      "total_strategies": total_strategies,
819
  }
820

821
  for eval_epoch in range(config.eval_epochs):
822
    with jax.profiler.StepTraceAnnotation("eval", step_num=eval_epoch):
823

824
      batch_tuple = next(eval_data_iter)
825

826
      # Input seq is of the shape (batchsize, 3*81) and 3*81 because row, column
827
      # and value for each cell. Row, column and value all are in {1, ..., 9}
828
      input_seq = np.array(batch_tuple[0])
829

830
      # Puzzle solution is of the shape (batchsize, 81). Each pos in {0,.., 80}
831
      # for each puzzle contains value at cell (pos//9+1, pos%9 + 1)
832
      puzzle_sol = np.array(batch_tuple[1])
833

834
      cur_input_seq = input_seq[:, : (config.start_index * 3)]
835

836
      set_acc, correct_cnt, _ = get_set_accuracies(
837
          state, p_eval_step, input_seq, config
838
      )
839

840
      eval_metrics = set_set_accuracies(eval_metrics, set_acc, correct_cnt)
841

842
      eval_metrics, mistakes_metrics, cur_input_seq = get_accuracy(
843
          cur_input_seq,
844
          state,
845
          p_eval_step,
846
          input_seq,
847
          puzzle_sol,
848
          config,
849
          eval_metrics,
850
          mistakes_metrics,
851
      )
852

853
      eval_metrics = get_position_hinted_eval_acc(
854
          input_seq, puzzle_sol, state, p_eval_step, eval_metrics, config
855
      )
856

857
      correct_eval_sudoku_puzzle = 0
858
      solution_edit_distance = 0.0
859

860
      for i, _ in enumerate(cur_input_seq):
861
        correct_eval_sudoku_puzzle += valid_solution(cur_input_seq[i])
862
        solution_edit_distance += get_edit_distance(
863
            config, cur_input_seq[i], input_seq[i]
864
        )
865

866
      eval_metrics["acc_complete_puzzle"].append(
867
          correct_eval_sudoku_puzzle * 1.0 / len(cur_input_seq)
868
      )
869

870
      eval_metrics["edit_distance"].append(
871
          solution_edit_distance * 1.0 / len(cur_input_seq)
872
      )
873
  return eval_metrics, mistakes_metrics
874

875

876
def get_eval_metrics(
877
    step, state, eval_data_iter, p_eval_step, config
878
):
879
  if config.dataset == "othello":
880
    return get_othello_eval_metrics(
881
        state, eval_data_iter, p_eval_step, config
882
    )
883
  elif "sudoku" in config.dataset:
884
    return get_sudoku_eval_metrics(
885
        step, state, eval_data_iter, p_eval_step, config
886
    )
887

888

889
def plot_to_image(figure):
890
  """Converts the matplotlib plot specified by 'figure' to a PNG image and returns it."""
891
  # The supplied figure is closed and inaccessible after this call.
892
  buf = io.BytesIO()
893
  plt.savefig(buf, format="png")
894
  plt.close(figure)
895
  buf.seek(0)
896

897
  image = tf.image.decode_png(buf.getvalue(), channels=4)
898
  image = tf.expand_dims(image, 0)
899
  return image
900

901

902
def plot_ax(ax, num, wr, wc):
903
  """Plots the given axis with the given number of values."""
904
  for i in range(9):
905
    for j in range(9):
906
      if num[i, j] == 0:
907
        continue
908
      ax.text(
909
          i + 0.5, (8 - j) + 0.5, str(int(num[i, j])), ha="center", va="center"
910
      )
911

912
  ax.axis([0, 9, 0, 9])
913

914
  rect = matplotlib.patches.Rectangle((wr, 8 - wc), 1, 1, color="red")
915
  ax.add_patch(rect)
916

917
  for axis in [ax.xaxis, ax.yaxis]:
918
    axis.set_minor_locator(mticker.MultipleLocator(1))
919
    axis.set_major_locator(mticker.MultipleLocator(3))
920
  #     axis.set_ticks(np.arange(maxnum) + 0.5)
921
  #     axis.set_ticklabels(range(maxnum))
922

923
  ax.grid(which="minor")
924
  # ax.axis('off')
925
  ax.xaxis.set_ticks_position("top")
926

927
  ax.hlines(y=3, xmin=0, xmax=10, color="0")
928
  ax.hlines(y=6, xmin=0, xmax=10, color="0")
929
  ax.vlines(x=6, ymin=0, ymax=10, color="0")
930
  ax.vlines(x=3, ymin=0, ymax=10, color="0")
931

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.