google-research

Форк
0
555 строк · 17.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluation related functions."""
17

18
from flax.training import common_utils
19
import jax
20
from jax import numpy as jnp
21
import numpy as np
22

23
from sudoku_gpt import model
24
from sudoku_gpt import othello
25
from sudoku_gpt import sudoku
26

27

28
def valid_solution(output_seq):
29
  """Checks if the output sequence is a valid solution for the sudoku puzzle."""
30
  ## returns 1 if correct solution, otherwise returns 0
31
  rows = np.zeros((9, 9))
32
  cols = np.zeros((9, 9))
33
  boxes = np.zeros((9, 9))
34

35
  for j in range(81):
36
    if int(output_seq[3 * j + 2] - 1) > 8:
37
      return False
38
    if int(output_seq[3 * j] - 1) > 8:
39
      return False
40
    if int(output_seq[3 * j + 1] - 1) > 8:
41
      return False
42
    row_num = int(output_seq[3 * j] - 1)
43
    col_num = int(output_seq[3 * j + 1] - 1)
44
    rows[row_num, int(output_seq[3 * j + 2] - 1)] += 1
45
    cols[col_num, int(output_seq[3 * j + 2] - 1)] += 1
46
    boxes[
47
        int(3 * (row_num // 3) + (col_num // 3)), int(output_seq[3 * j + 2] - 1)
48
    ] += 1
49

50
  if np.all(rows) and np.all(cols) and np.all(boxes):
51
    return True
52
  else:
53
    return False
54

55

56
def eval_step(state, batch, config):
57
  pred_logits = model.TransformerLMHeadModel(config).apply(
58
      {"params": state.params}, batch)
59
  return pred_logits
60

61

62
def get_othello_eval_metrics(state, eval_data_iter, p_eval_step, config):
63
  """Get evaluation metrics for Othello game.
64

65
  Args:
66
    state: 
67
    eval_data_iter: Iterator for evaluation dataset.
68
    p_eval_step: Function to compute forward pass on a single evaluation batch.
69
    config: The config for the experiment.
70

71
  Returns:
72

73
  """
74
  eval_metrics = {"acc": []}
75
  for eval_epoch in range(config.eval_epochs):
76
    with jax.profiler.StepTraceAnnotation("eval", step_num=eval_epoch):
77

78
      batch = np.array(next(eval_data_iter))
79
      total_pred, sucess_pred = 0, 0
80

81
      for i in range(config.seq_len):
82
        padding = np.zeros((batch.shape[0],
83
                            config.seq_len - (i+1)), dtype=np.int32)
84
        concat_batch = np.hstack((batch[:, :(i + 1)], padding))
85
        concat_batch = common_utils.shard(
86
            jax.tree_util.tree_map(np.asarray, concat_batch)
87
        )
88
        pred_logits = p_eval_step(state, concat_batch)
89

90
        max_action = pred_logits[:, :, i, :].argmax(axis=-1)
91
        pred_seq = np.hstack((batch[:, :(i + 1)],
92
                              jnp.reshape(max_action, newshape=(-1, 1))))
93

94
        for j in range(pred_seq.shape[0]):
95
          ## When length of the game is small, then the model can simply keep
96
          ## predicting the next token which will increase the accuracy
97
          total_pred += 1
98
          try:
99
            othello.OthelloBoardState().update(pred_seq[j], prt=False)
100
          except AssertionError:
101
            ### Wrong prediction
102
            pass
103
          else:
104
            sucess_pred += 1
105

106
      eval_metrics["acc"].append(sucess_pred * 1.0/ total_pred)
107

108
  return eval_metrics
109

110

111
def get_edit_distance(config, generated_input_seq, original_input_seq):
112
  """Get edit distance between generated input and original input."""
113
  total_distance = 0
114
  for i in range(config.start_index, config.block_size):
115
    # Iterate through model's output
116
    flg = False
117
    for j in range(config.start_index, config.block_size):
118

119
      # Iterate through solver's output to find the location of model's output
120
      same_row = generated_input_seq[3 * i] == original_input_seq[3 * j]
121
      same_col = (
122
          generated_input_seq[3 * i + 1] == original_input_seq[3 * j + 1]
123
      )
124
      if same_row and same_col:
125

126
        # When model's output cell location and solver's output cell location
127
        # matches, then calculate edit distance.
128
        total_distance += abs(j - i)
129
        flg = True
130
        break
131

132
    if not flg:
133
      total_distance += abs(config.block_size - i)
134

135
  return total_distance
136

137

138
def get_set_accuracy_for_pairs(
139
    pairs,
140
    state,
141
    p_eval_step,
142
    input_seq,
143
    possible_vals,
144
    given_vals,
145
    config,
146
):
147
  """Computes accuracy of set of possible values."""
148
  correct_cnt = np.zeros(9)
149
  total_cnt = np.zeros(9)
150

151
  min_start_index = 31
152
  for i in range(len(pairs)):
153
    cur_input_seq = np.hstack(
154
        (input_seq[:, : (min_start_index * 3)], pairs[i] + 1)
155
    )
156

157
    padding = np.zeros(
158
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
159
        dtype=np.int32,
160
    )
161
    concat_batch = np.hstack((cur_input_seq, padding))
162

163
    concat_batch = common_utils.shard(
164
        jax.tree_util.tree_map(np.asarray, concat_batch)
165
    )
166

167
    pred_logits = p_eval_step(state, concat_batch)
168

169
    cur_pred_logits = pred_logits[:, :, 3 * min_start_index + 1, :].reshape(
170
        (-1, pred_logits.shape[-1])
171
    )
172

173
    for k in range(input_seq.shape[0]):
174
      if given_vals[k, pairs[i, k, 0], pairs[i, k, 1]] == 1:
175
        continue
176
      total_possible_vals = np.int32(
177
          np.sum(possible_vals[k, pairs[i, k, 0], pairs[i, k, 1], :])
178
      )
179
      ordering_ind = np.argsort(cur_pred_logits[np.int32(k), :])[::-1][
180
          :total_possible_vals
181
      ]
182

183
      assert len(ordering_ind) <= 9
184

185
      for t, ind in enumerate(ordering_ind):
186
        if ind <= 9 and ind >= 1:
187
          correct_cnt[t] += (
188
              possible_vals[k, pairs[i, k, 0], pairs[i, k, 1], ind - 1] == 1
189
          )
190

191
        total_cnt[t] += 1
192

193
  accuracy = np.ones(9)
194
  for i in range(9):
195
    if total_cnt[i] > 0:
196
      accuracy[i] = correct_cnt[i] / total_cnt[i]
197

198
  return accuracy, correct_cnt, total_cnt
199

200

201
def get_sampled_pairs(input_seq, pred_logits, state, p_eval_step, config, key):
202
  """Computes the sampled pairs at config.start_index + 1 location and return them as pairs."""
203
  pairs_set = []
204
  for _ in range(input_seq.shape[0]):
205
    pairs_set.append(set())
206

207
  pairs = np.zeros(
208
      (config.set_accuracy_top_k, input_seq.shape[0], 2), dtype=np.int32
209
  )
210
  flag = True  ## Denotes if we want to sample next time or not.
211

212
  while flag:
213
    pred_logits_row = pred_logits[:, :, 3 * config.start_index - 1, :].reshape(
214
        (-1, pred_logits.shape[-1])
215
    )
216
    rkey, key = jax.random.split(key, 2)
217

218
    pair_row = jax.random.categorical(rkey, pred_logits_row)
219

220
    assert len(pair_row) == input_seq.shape[0] and pair_row.ndim == 1
221

222
    cur_input_seq = np.hstack(
223
        (input_seq[:, : (config.start_index * 3)], pair_row.reshape(-1, 1))
224
    )
225

226
    padding = np.zeros(
227
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
228
        dtype=np.int32,
229
    )
230

231
    concat_batch = np.hstack((cur_input_seq, padding))
232

233
    concat_batch = common_utils.shard(
234
        jax.tree_util.tree_map(np.asarray, concat_batch)
235
    )
236

237
    pred_logits = p_eval_step(state, concat_batch)
238
    pred_logits_col = pred_logits[:, :, 3 * config.start_index, :].reshape(
239
        (-1, pred_logits.shape[-1])
240
    )
241

242
    rkey, key = jax.random.split(key, 2)
243
    pair_col = jax.random.categorical(rkey, pred_logits_col)
244

245
    assert len(pair_col) == input_seq.shape[0] and pair_col.ndim == 1
246

247
    flag = False
248
    for i in range(input_seq.shape[0]):
249
      if pair_row[i] < 1 or pair_row[i] > 9:
250
        continue
251

252
      if pair_col[i] < 1 or pair_col[i] > 9:
253
        continue
254
      pairs_set[i].add(tuple((int(pair_row[i]), int(pair_col[i]))))
255
      if len(pairs_set[i]) < config.set_accuracy_top_k:
256
        flag = True
257

258
  for i in range(input_seq.shape[0]):
259
    j = 0
260
    for a_pair in pairs_set[i]:
261
      pairs[j, i, 0] = int(a_pair[0] - 1)
262
      pairs[j, i, 1] = int(a_pair[1] - 1)
263
      j += 1
264

265
      if j == config.set_accuracy_top_k:
266
        break
267

268
  return pairs
269

270

271
def get_topk_probability_pairs(
272
    input_seq, pred_logits, state, p_eval_step, config
273
    ):
274
  """This function computes the top k most probable pairs at config.start_index + 1 location and return them as pairs."""
275

276
  min_start_index = 31
277
  pairs = np.zeros(
278
      (config.set_accuracy_top_k, input_seq.shape[0], 2), dtype=np.int32
279
  )
280
  pred_logits_row = pred_logits[:, :, 3 * min_start_index - 1, :].reshape(
281
      (-1, pred_logits.shape[-1])
282
  )
283

284
  # Row log probability
285
  row_log_prob = jax.nn.log_softmax(pred_logits_row[:, 1:10])
286

287
  pairs_log_prob = np.zeros((input_seq.shape[0], 81))
288

289
  for i in range(9):
290
    row_num = np.ones((input_seq.shape[0], 1), dtype=np.int32) * (i + 1)
291
    cur_input_seq = np.hstack((input_seq[:, : (min_start_index * 3)], row_num))
292

293
    padding = np.zeros(
294
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
295
        dtype=np.int32,
296
    )
297

298
    concat_batch = np.hstack((cur_input_seq, padding))
299

300
    concat_batch = common_utils.shard(
301
        jax.tree_util.tree_map(np.asarray, concat_batch)
302
    )
303

304
    pred_logits_col = p_eval_step(state, concat_batch)
305
    pred_logits_col = pred_logits_col[:, :, 3 * min_start_index, :].reshape(
306
        (-1, pred_logits.shape[-1])
307
    )
308

309
    # Column log probability
310
    col_log_prob = jax.nn.log_softmax(pred_logits_col[:, 1:10])
311

312
    # Calculates log probability for each cell by combining log probability for
313
    # each row and each column
314
    for j in range(input_seq.shape[0]):
315
      for k in range(9):
316
        pairs_log_prob[j, i * 9 + k] = col_log_prob[j, k] + row_log_prob[j, i]
317

318
  for i in range(input_seq.shape[0]):
319
    # Selects top k most probable cells
320
    topk_indices = np.argsort(pairs_log_prob[i, :])[::-1][
321
        : config.set_accuracy_top_k
322
    ]
323
    for j, ind in enumerate(topk_indices):
324
      pairs[j, i, 0] = ind // 9
325
      pairs[j, i, 1] = ind % 9
326

327
  return pairs
328

329

330
def get_set_accuracies(state, p_eval_step, input_seq, config):
331
  """This function computes set accuracies for empty cells in the puzzle at config.start_index + 1 location."""
332

333
  possible_vals = np.ones((input_seq.shape[0], 9, 9, 9))
334
  given_vals = np.zeros((input_seq.shape[0], 9, 9))
335

336
  min_start_index = 31
337
  for i in range(input_seq.shape[0]):
338
    for j in range(min_start_index):
339
      row_num = input_seq[i, 3 * j] - 1
340
      col_num = input_seq[i, 3 * j + 1] - 1
341
      val = input_seq[i, 3 * j + 2] - 1
342

343
      possible_vals[i, row_num, :, val] = 0
344
      possible_vals[i, :, col_num, val] = 0
345

346
      given_vals[i, row_num, col_num] = 1
347

348
  if config.set_accuracy == "top-k":
349
    # Computes the set accuracy for top k most probable positions
350
    # at config.start_index + 1 location
351
    cur_input_seq = input_seq[:, : (min_start_index * 3)]
352
    padding = np.zeros(
353
        (input_seq.shape[0], config.seq_len - len(cur_input_seq[0])),
354
        dtype=np.int32,
355
    )
356

357
    concat_batch = np.hstack((cur_input_seq, padding))
358

359
    concat_batch = common_utils.shard(
360
        jax.tree_util.tree_map(np.asarray, concat_batch)
361
    )
362

363
    key = jax.random.PRNGKey(98)
364
    pred_logits = p_eval_step(state, concat_batch)
365

366
    # pairs = get_sampled_pairs(input_seq, pred_logits, state,
367
    #                           p_eval_step, config, key)
368

369
    print("get_topk_probability_pairs", flush=True)
370
    pairs = get_topk_probability_pairs(
371
        input_seq, pred_logits, state, p_eval_step, config, key
372
    )
373
    print("got_topk_probability_pairs", flush=True)
374
    return get_set_accuracy_for_pairs(
375
        pairs,
376
        state,
377
        p_eval_step,
378
        input_seq,
379
        possible_vals,
380
        given_vals,
381
        config,
382
    )
383

384
  elif config.set_accuracy == "all":
385
    # Computes the set accuracy for all the pairs at config.start_index + 1
386
    # location
387

388
    pairs = np.zeros((81, input_seq.shape[0], 2), dtype=np.int32)
389
    for i in range(81):
390
      pairs[i, :, 0] = np.ones(input_seq.shape[0], dtype=np.int32) * (i // 9)
391
      pairs[i, :, 1] = np.ones(input_seq.shape[0], dtype=np.int32) * (i % 9)
392

393
    # After computing pairs for which we want set accuracy
394
    # (config.set_accuracy == "all" => pairs contain all position)
395
    # (config.set_accuracy == "top-k" => pairs containing top-k most probable)
396
    return get_set_accuracy_for_pairs(
397
        pairs,
398
        state,
399
        p_eval_step,
400
        input_seq,
401
        possible_vals,
402
        given_vals,
403
        config,
404
    )
405

406

407
def get_sudoku_eval_metrics(state, eval_data_iter, p_eval_step, config):
408
  """This function computes given evaluation metrics (e.g, accuracy) in eval metrics for each batch and appends the metric in the list of eval_metrics.
409

410
  Args: 
411
    state: contains model parameters, optimizer, etc.
412
    eval_data_iter: data iterator for evaluation dataset
413
    p_eval_step: pmap function for forward pass of model for evaluation
414
    config: general experiment config file
415

416
  Returns: 
417
    eval_metrics: contains list of evaluation metrics for each batch
418
  """
419

420
  eval_metrics = {
421
      "acc": [],
422
      "acc_complete_puzzle": [],
423
      "edit_distance": [],
424
      "set_acc1": [],
425
      "set_acc2": [],
426
      "set_acc3": [],
427
      "correct_cnt1": [],
428
      "correct_cnt2": [],
429
      "correct_cnt3": [],
430
  }
431

432
  for eval_epoch in range(config.eval_epochs):
433
    with jax.profiler.StepTraceAnnotation("eval", step_num=eval_epoch):
434

435
      batch_tuple = next(eval_data_iter)
436

437
      # Input seq is of the shape (batchsize, 3*81) and 3*81 because row, column
438
      # and value for each cell. Row, column and value all are in {1, ..., 9}
439
      input_seq = np.array(batch_tuple[0])
440

441
      # Puzzle solution is of the shape (batchsize, 81). Each pos in {0,.., 80}
442
      # for each puzzle contains value at cell (pos//9+1, pos%9 + 1)
443
      puzzle_sol = np.array(batch_tuple[1])
444
      starting_index = np.array(batch_tuple[2])
445
      total_pred, sucess_pred = 0, 0
446

447
      min_start_index = 31
448
      starting_index = (
449
          np.ones_like(starting_index, dtype=np.int32) * min_start_index
450
      )
451
      cur_input_seq = input_seq[:, :(config.start_index*3)]
452

453
      # Computes set accuracy for empty cells in the puzzle
454
      set_acc, correct_cnt, _ = get_set_accuracies(
455
          state, p_eval_step, input_seq, config
456
      )
457

458
      eval_metrics["set_acc1"].append(set_acc[0])
459
      eval_metrics["set_acc2"].append(set_acc[1])
460
      eval_metrics["set_acc3"].append(set_acc[2])
461

462
      eval_metrics["correct_cnt1"].append(correct_cnt[0])
463
      eval_metrics["correct_cnt2"].append(correct_cnt[1])
464
      eval_metrics["correct_cnt3"].append(correct_cnt[2])
465

466
      for i in range(min_start_index * 3, config.seq_len):
467
        ### In i^th iteration, i^th number in sequence will predict
468
        padding = np.zeros((input_seq.shape[0],
469
                            config.seq_len - len(cur_input_seq[0])),
470
                           dtype=np.int32)
471
        concat_batch = np.hstack((cur_input_seq, padding))
472
        concat_batch = common_utils.shard(
473
            jax.tree_util.tree_map(np.asarray, concat_batch)
474
        )
475

476
        pred_logits = p_eval_step(state, concat_batch)
477

478
        if i%3 == 2:
479
          # Model predicts the value at the cell (cur_input_seq[j][i-2],
480
          # cur_input_seq[j][i-1])
481
          max_number = pred_logits[:, :, i-1, :].argmax(axis=-1).flatten()
482
          mask_arr = np.array(i >= (3 * starting_index))
483

484
          next_number = max_number * mask_arr + (1 - mask_arr) * input_seq[:, i]
485

486
          cur_input_seq = np.hstack(
487
              (cur_input_seq, jnp.reshape(next_number, newshape=(-1, 1)))
488
          )
489

490
          # Iterate through all examples in batch and calculate successful
491
          # predictions of numbers
492
          for j in range(len(cur_input_seq)):
493
            if not mask_arr[j]:
494
              continue
495

496
            total_pred += 1
497
            try:
498
              sudoku.SudokuBoardStateUpdate(puzzle_sol[j],
499
                                            cur_input_seq[j][i-2],
500
                                            cur_input_seq[j][i-1],
501
                                            cur_input_seq[j][i])
502
            except AssertionError:
503
              ### Wrong update
504
              # if cur_input_seq[j, i-2] * 9 + cur_input_seq[j, i-1] <= 80:
505
              #   print(puzzle_sol[j][cur_input_seq[j, i-2] * 9 +
506
              #                       cur_input_seq[j,i-1]],
507
              #                       cur_input_seq[j][i])
508
              pass
509
            else:
510
              sucess_pred += 1
511
        else:
512
          # Model predicts either a row number or column number
513
          max_pos = pred_logits[:, :, i-1, :].argmax(axis=-1).flatten()
514
          mask = i >= (3 * starting_index)
515
          next_pos = max_pos * mask + (1 - mask) * input_seq[:, i]
516
          cur_input_seq = np.hstack(
517
              (cur_input_seq, jnp.reshape(next_pos, newshape=(-1, 1)))
518
          )
519

520
      eval_metrics["acc"].append(sucess_pred * 1.0/ total_pred)
521

522
      correct_eval_sudoku_puzzle = 0
523
      solution_edit_distance = 0.0
524

525
      for i in range(len(cur_input_seq)):
526

527
        # increase correct_eval_sudoku_puzzle when the model output solution
528
        # for a given puzzle is correct
529
        correct_eval_sudoku_puzzle += valid_solution(cur_input_seq[i])
530

531
        # edit distance = distance between model's output order and solver's
532
        #                 output order
533
        solution_edit_distance += get_edit_distance(
534
            config, cur_input_seq[i], input_seq[i]
535
        )
536

537
      eval_metrics["acc_complete_puzzle"].append(
538
          correct_eval_sudoku_puzzle * 1.0 / len(cur_input_seq)
539
      )
540

541
      eval_metrics["edit_distance"].append(
542
          solution_edit_distance * 1.0 / len(cur_input_seq)
543
      )
544

545
  return eval_metrics
546

547

548
def get_eval_metrics(state, eval_data_iter,
549
                     p_eval_step, config):
550
  if config.dataset == "othello":
551
    return get_othello_eval_metrics(state, eval_data_iter, p_eval_step,
552
                                    config)
553
  elif "sudoku" in config.dataset:
554
    return get_sudoku_eval_metrics(state, eval_data_iter, p_eval_step,
555
                                   config)
556

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.