google-research

Форк
0
/
evaluation.py 
124 строки · 4.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Baselines and evaluation metrics for Jax language models."""
17
import itertools
18
from flax.training import common_utils
19
import jax
20
import jax.numpy as jnp
21
import numpy as np
22

23
from protein_lm import utils
24

25

26
class EmpiricalBaseline():
27
  """Empirical baseline as described in the ProGen paper.
28

29
  References:
30
    [ProGen](https://www.biorxiv.org/content/10.1101/2020.03.07.982272v1)
31
  """
32

33
  def __init__(self, domain, train_ds, alpha=1.):
34
    """Creates an instance of this class.
35

36
    # TODO(gandreea): It's unclear how to handle the length (EOS token). The
37
    #   fact that the uniform baseline is reported as (perplexity=25,
38
    #   accuracy=0.04) suggests that the EOS prediction step is not included.
39

40
    Args:
41
      domain: An instance of domains.Domain.
42
      train_ds: A tf.data.Dataset containing the data to be used for computing
43
        the empirical distribution.
44
      alpha: A float indicating the Laplace smoothing constant.
45
    """
46
    self._vocab_size = domain.vocab_size
47
    self._token_indices = [
48
        idx for idx in range(len(domain.vocab.tokens))
49
        if idx != domain.vocab.bos and idx != domain.vocab.eos]
50
    self._mask_token = domain.vocab.bos
51

52
    self._empirical_dist = np.zeros((len(self._token_indices),))
53
    for batch in train_ds:
54
      batch = np.atleast_2d(batch)
55
      batch_one_hot = np.eye(self._vocab_size)[batch]
56
      batch_one_hot = np.take(batch_one_hot, self._token_indices, axis=-1)
57
      self._empirical_dist += np.sum(np.sum(batch_one_hot, axis=0), axis=0)
58

59
    self._empirical_dist += alpha  # Laplace smoothing.
60
    self._empirical_dist /= np.sum(self._empirical_dist)
61

62
  def evaluate_batch(self, batch):
63
    """Computes all metrics on the given batch."""
64
    labels = np.atleast_2d(batch)
65
    logits = np.log(self._empirical_dist)
66
    logits = np.tile(logits, list(labels.shape) + [1])
67
    weights = np.where(labels != self._mask_token, 1, 0)
68
    metrics = utils.compute_metrics(logits, labels, weights)
69
    for key, value in metrics.items():
70
      metrics[key] = jnp.atleast_1d(value)
71
    return metrics
72

73

74
def combine_metrics(step_metrics):
75
  """Given a list of metric dicts, combine to a single summary metrics dict.
76

77
  Args:
78
    step_metrics: A dict with (metric name, metric value) items. Contains summed
79
      metrics and the corresponding denominator (the number of next-token
80
      prediction instances). Each metric value have at least one dimension.
81

82
  Returns:
83
    A dict with (metric name, metric value) items containing combined metrics.
84
  """
85
  metrics_all = common_utils.get_metrics(step_metrics)
86
  lr = None
87
  if 'learning_rate' in metrics_all:
88
    lr = metrics_all.pop('learning_rate').mean()
89
  metrics_sums = jax.tree_map(jnp.sum, metrics_all)
90
  denominator = metrics_sums.pop('denominator')
91
  summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
92
  if lr is not None:
93
    summary['learning_rate'] = lr
94

95
  # Calculate (clipped) perplexity after averaging log-perplexities:
96
  if 'loss' in summary:
97
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
98
  return summary
99

100

101
def evaluate(model, eval_ds, num_eval_steps=None):
102
  """Evaluates model on eval_ds for num_eval_steps.
103

104
  Args:
105
    model: A model to use for evaluation. Must have an evaluate_batch() method.
106
    eval_ds: A tensorflow dataset containing the data to be used for evaluation.
107
    num_eval_steps: If given, evaluate for this many steps, otherwise use the
108
      entire dataset.
109

110
  Returns:
111
    A dictionary with (metric name, metric value) items.
112
  """
113
  eval_metrics = []
114
  eval_iter = iter(eval_ds)
115
  if num_eval_steps is None:
116
    num_iter = itertools.repeat(1)
117
  else:
118
    num_iter = range(num_eval_steps)
119
  for _, eval_batch in zip(num_iter, eval_iter):
120
    eval_batch = np.asarray(eval_batch)
121
    metrics = model.evaluate_batch(eval_batch)
122
    eval_metrics.append(metrics)
123
  eval_summary = combine_metrics(eval_metrics)
124
  return eval_summary
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.