google-research

Форк
0
220 строк · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Custom metrics."""
17

18
import json
19
from typing import Mapping, Optional, Sequence
20

21
import gin
22
import tensorflow as tf
23

24

25
@gin.configurable
26
class PearsonCorrelation(tf.metrics.Metric):
27
  """Implements Pearson correlation as tf.metrics.Metric class."""
28

29
  def __init__(self, *args, **kwargs):
30
    super().__init__(*args, **kwargs)
31
    self._y_true_mean = tf.metrics.Mean()
32
    self._y_pred_mean = tf.metrics.Mean()
33
    self._y_true_sq_mean = tf.metrics.Mean()
34
    self._y_pred_sq_mean = tf.metrics.Mean()
35
    self._y_true_dot_y_pred_mean = tf.metrics.Mean()
36

37
  def update_state(self, y_true, y_pred, sample_weight=None):
38
    self._y_true_mean.update_state(y_true, sample_weight)
39
    self._y_pred_mean.update_state(y_pred, sample_weight)
40
    self._y_true_sq_mean.update_state(y_true ** 2, sample_weight)
41
    self._y_pred_sq_mean.update_state(y_pred ** 2, sample_weight)
42
    self._y_true_dot_y_pred_mean.update_state(y_true * y_pred, sample_weight)
43

44
  def result(self):
45
    y_true_var = self._y_true_sq_mean.result() - self._y_true_mean.result() ** 2
46
    y_pred_var = self._y_pred_sq_mean.result() - self._y_pred_mean.result() ** 2
47
    cov = (self._y_true_dot_y_pred_mean.result()
48
           - self._y_true_mean.result() * self._y_pred_mean.result())
49
    return cov / tf.sqrt(y_true_var) / tf.sqrt(y_pred_var)
50

51
  def reset_states(self):
52
    self._y_true_mean.reset_states()
53
    self._y_pred_mean.reset_states()
54
    self._y_true_sq_mean.reset_states()
55
    self._y_pred_sq_mean.reset_states()
56
    self._y_true_dot_y_pred_mean.reset_states()
57

58

59
@gin.configurable
60
class Perplexity(tf.metrics.SparseCategoricalCrossentropy):
61
  """Implements perplexity as tf.metrics.Metric class."""
62

63
  def __init__(self, from_logits=True, name='perplexity', **kwargs):
64
    super().__init__(from_logits=from_logits, name=name, **kwargs)
65

66
  def result(self):
67
    return tf.exp(super().result())
68

69

70
class DoubleMean(tf.keras.metrics.Metric):
71
  """The means of predictions and ground truth for a given metrics."""
72

73
  def __init__(self, mean_metric_cls, **kwargs):
74
    self._predicted = mean_metric_cls()
75
    self._expected = mean_metric_cls()
76
    super().__init__(name=self._expected.name)
77

78
  def update_state(self, y_true, y_pred, sample_weight=None):
79
    self._predicted.update_state(y_pred, sample_weight)
80
    self._expected.update_state(y_true, sample_weight)
81

82
  def reset_states(self):
83
    self._predicted.reset_states()
84
    self._expected.reset_states()
85

86
  def result(self):
87
    return {
88
        f'{self.name}/true': self._expected.result(),
89
        f'{self.name}/pred': self._predicted.result()
90
    }
91

92

93
@gin.configurable
94
class SparseLiftedClanAccuracy(tf.metrics.Accuracy):
95
  """Evaluates SparseCategoricalAccuracy at the lifted clan level."""
96

97
  def __init__(
98
      self, filename, name = 'lifted_clan_accuracy', **kwargs):
99
    super().__init__(name=name, **kwargs)
100
    # Precomputes a 1D Tensor cla_from_fam such that cla_from_fam[fam_key]
101
    # contains the label cla_key of the clan to which the family indexed by
102
    # fam_key belongs.
103
    self._filename = filename  # A json file.
104
    cla_key_from_fam_key = self._load_mapping()
105
    keys = list(cla_key_from_fam_key.keys())
106
    values = list(cla_key_from_fam_key.values())
107
    indices = sorted(range(len(keys)), key=lambda i: keys[i])
108
    self._cla_from_fam = tf.convert_to_tensor(
109
        [values[i] for i in indices], tf.int64)
110

111
  def _load_mapping(self):
112
    """Prepares family to clan key mapping from JSON file."""
113
    with tf.io.gfile.GFile(self._filename, 'r') as f:
114
      cla_id_from_fam_id = json.load(f)
115
    # "Translates" the mapping between IDs to a mapping between integer keys.
116
    idx_from_fam, idx_from_cla = {}, {}
117
    for fam, cla in cla_id_from_fam_id.items():
118
      if fam not in idx_from_fam:
119
        idx_from_fam[fam] = len(idx_from_fam)
120
      if cla not in idx_from_cla:
121
        idx_from_cla[cla] = len(idx_from_cla)
122
    return {idx_from_fam[k]: idx_from_cla[v]
123
            for k, v in cla_id_from_fam_id.items()}
124

125
  def update_state(self,
126
                   y_true,
127
                   y_pred,
128
                   sample_weight = None,
129
                   metadata = ()):
130
    # Ignores family labels, assumes metadata always contains clan labels.
131
    y_true = metadata[0]
132
    # Computes predicted family labels from probabilities / logits. Then, maps
133
    # these to clan labels.
134
    y_pred = tf.gather(self._cla_from_fam, tf.math.argmax(y_pred, axis=-1))
135
    super().update_state(y_true, y_pred, sample_weight=sample_weight)
136

137

138
@gin.configurable
139
class ContactPrecisionRecallFixedK(tf.metrics.Metric):
140
  """Implements basic PR metrics for residue-residue contact prediction."""
141

142
  def __init__(self,
143
               name = 'contact_pr',
144
               range_low = 12,
145
               range_high = 23,
146
               at_k = 50,
147
               **kwargs):
148
    super().__init__(name=name, **kwargs)
149
    self._range_low = range_low
150
    self._range_high = range_high
151
    self._at_k = at_k  # TODO(qberthet): allow at_k to be list to clean up gin.
152

153
    self._precision = tf.metrics.Mean()
154
    self._recall = tf.metrics.Mean()
155
    self._f1score = tf.metrics.Mean()
156
    self._auprc = tf.metrics.AUC(
157
        num_thresholds=1000, curve='PR', from_logits=False)
158

159
  def update_state(self,
160
                   y_true,
161
                   y_pred,
162
                   sample_weight):
163
    batch = tf.shape(y_pred)[0]
164
    num_embs = tf.shape(y_pred)[1]
165
    proba_pred = tf.nn.sigmoid(y_pred)
166
    if sample_weight is not None:
167
      proba_pred *= sample_weight[Ellipsis, None]
168

169
    weights_range = tf.range(num_embs, dtype=tf.float32)
170
    weights_range_square = tf.abs(weights_range[:, tf.newaxis] -
171
                                  weights_range[tf.newaxis, :])
172
    indic_range_fun = lambda x, a, b: tf.logical_and(x >= a, x <= b)
173
    weights_square = indic_range_fun(weights_range_square,
174
                                     self._range_low,
175
                                     self._range_high)
176
    weights_square = tf.cast(weights_square[None, Ellipsis, None], dtype=tf.float32)
177

178
    proba_pred_filter = proba_pred * weights_square
179
    flat_proba_pred_filter = tf.reshape(proba_pred_filter, (batch, -1))
180
    y_true_filter = y_true * weights_square
181
    flat_y_true_filter = tf.reshape(y_true_filter, (batch, -1))
182

183
    _, indices = tf.math.top_k(flat_proba_pred_filter, k=self._at_k)
184
    flat_y_pred_filter = tf.cast(flat_proba_pred_filter > 0.5, tf.float32)
185

186
    true_in_top = tf.gather(flat_y_true_filter, indices, batch_dims=-1)
187
    pred_in_top = tf.gather(flat_y_pred_filter, indices, batch_dims=-1)
188
    true_pred_in_top = tf.gather(
189
        flat_y_true_filter * flat_y_pred_filter, indices, batch_dims=-1)
190

191
    number_true = tf.maximum(tf.reduce_sum(true_in_top, axis=-1), 1e-6)
192
    number_preds = tf.maximum(tf.reduce_sum(pred_in_top, axis=-1), 1e-6)
193
    number_true_preds = tf.reduce_sum(true_pred_in_top, axis=-1)
194

195
    precision = tf.maximum(number_true_preds / number_preds, 1e-6)
196
    recall = tf.maximum(number_true_preds / number_true, 1e-6)
197

198
    self._precision.update_state(precision)
199
    self._recall.update_state(recall)
200
    self._f1score.update_state(
201
        2 * (precision * recall) / (precision + recall))
202

203
    # TODO(qberthet): double-check.
204
    auprc_weight = sample_weight * weights_square[Ellipsis, 0]
205
    self._auprc.update_state(
206
        y_true_filter, proba_pred_filter, sample_weight=auprc_weight)
207

208
  def result(self):
209
    return {
210
        f'{self.name}/precision': self._precision.result(),
211
        f'{self.name}/recall': self._recall.result(),
212
        f'{self.name}/f1': self._f1score.result(),
213
        f'{self.name}/auprc': self._auprc.result(),
214
    }
215

216
  def reset_states(self):
217
    self._precision.reset_states()
218
    self._recall.reset_states()
219
    self._f1score.reset_states()
220
    self._auprc.reset_states()
221

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.