google-research
124 строки · 4.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Baselines and evaluation metrics for Jax language models."""
17import itertools
18from flax.training import common_utils
19import jax
20import jax.numpy as jnp
21import numpy as np
22
23from protein_lm import utils
24
25
26class EmpiricalBaseline():
27"""Empirical baseline as described in the ProGen paper.
28
29References:
30[ProGen](https://www.biorxiv.org/content/10.1101/2020.03.07.982272v1)
31"""
32
33def __init__(self, domain, train_ds, alpha=1.):
34"""Creates an instance of this class.
35
36# TODO(gandreea): It's unclear how to handle the length (EOS token). The
37# fact that the uniform baseline is reported as (perplexity=25,
38# accuracy=0.04) suggests that the EOS prediction step is not included.
39
40Args:
41domain: An instance of domains.Domain.
42train_ds: A tf.data.Dataset containing the data to be used for computing
43the empirical distribution.
44alpha: A float indicating the Laplace smoothing constant.
45"""
46self._vocab_size = domain.vocab_size
47self._token_indices = [
48idx for idx in range(len(domain.vocab.tokens))
49if idx != domain.vocab.bos and idx != domain.vocab.eos]
50self._mask_token = domain.vocab.bos
51
52self._empirical_dist = np.zeros((len(self._token_indices),))
53for batch in train_ds:
54batch = np.atleast_2d(batch)
55batch_one_hot = np.eye(self._vocab_size)[batch]
56batch_one_hot = np.take(batch_one_hot, self._token_indices, axis=-1)
57self._empirical_dist += np.sum(np.sum(batch_one_hot, axis=0), axis=0)
58
59self._empirical_dist += alpha # Laplace smoothing.
60self._empirical_dist /= np.sum(self._empirical_dist)
61
62def evaluate_batch(self, batch):
63"""Computes all metrics on the given batch."""
64labels = np.atleast_2d(batch)
65logits = np.log(self._empirical_dist)
66logits = np.tile(logits, list(labels.shape) + [1])
67weights = np.where(labels != self._mask_token, 1, 0)
68metrics = utils.compute_metrics(logits, labels, weights)
69for key, value in metrics.items():
70metrics[key] = jnp.atleast_1d(value)
71return metrics
72
73
74def combine_metrics(step_metrics):
75"""Given a list of metric dicts, combine to a single summary metrics dict.
76
77Args:
78step_metrics: A dict with (metric name, metric value) items. Contains summed
79metrics and the corresponding denominator (the number of next-token
80prediction instances). Each metric value have at least one dimension.
81
82Returns:
83A dict with (metric name, metric value) items containing combined metrics.
84"""
85metrics_all = common_utils.get_metrics(step_metrics)
86lr = None
87if 'learning_rate' in metrics_all:
88lr = metrics_all.pop('learning_rate').mean()
89metrics_sums = jax.tree_map(jnp.sum, metrics_all)
90denominator = metrics_sums.pop('denominator')
91summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
92if lr is not None:
93summary['learning_rate'] = lr
94
95# Calculate (clipped) perplexity after averaging log-perplexities:
96if 'loss' in summary:
97summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
98return summary
99
100
101def evaluate(model, eval_ds, num_eval_steps=None):
102"""Evaluates model on eval_ds for num_eval_steps.
103
104Args:
105model: A model to use for evaluation. Must have an evaluate_batch() method.
106eval_ds: A tensorflow dataset containing the data to be used for evaluation.
107num_eval_steps: If given, evaluate for this many steps, otherwise use the
108entire dataset.
109
110Returns:
111A dictionary with (metric name, metric value) items.
112"""
113eval_metrics = []
114eval_iter = iter(eval_ds)
115if num_eval_steps is None:
116num_iter = itertools.repeat(1)
117else:
118num_iter = range(num_eval_steps)
119for _, eval_batch in zip(num_iter, eval_iter):
120eval_batch = np.asarray(eval_batch)
121metrics = model.evaluate_batch(eval_batch)
122eval_metrics.append(metrics)
123eval_summary = combine_metrics(eval_metrics)
124return eval_summary
125