google-research
220 строк · 8.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Custom metrics."""
17
18import json19from typing import Mapping, Optional, Sequence20
21import gin22import tensorflow as tf23
24
25@gin.configurable26class PearsonCorrelation(tf.metrics.Metric):27"""Implements Pearson correlation as tf.metrics.Metric class."""28
29def __init__(self, *args, **kwargs):30super().__init__(*args, **kwargs)31self._y_true_mean = tf.metrics.Mean()32self._y_pred_mean = tf.metrics.Mean()33self._y_true_sq_mean = tf.metrics.Mean()34self._y_pred_sq_mean = tf.metrics.Mean()35self._y_true_dot_y_pred_mean = tf.metrics.Mean()36
37def update_state(self, y_true, y_pred, sample_weight=None):38self._y_true_mean.update_state(y_true, sample_weight)39self._y_pred_mean.update_state(y_pred, sample_weight)40self._y_true_sq_mean.update_state(y_true ** 2, sample_weight)41self._y_pred_sq_mean.update_state(y_pred ** 2, sample_weight)42self._y_true_dot_y_pred_mean.update_state(y_true * y_pred, sample_weight)43
44def result(self):45y_true_var = self._y_true_sq_mean.result() - self._y_true_mean.result() ** 246y_pred_var = self._y_pred_sq_mean.result() - self._y_pred_mean.result() ** 247cov = (self._y_true_dot_y_pred_mean.result()48- self._y_true_mean.result() * self._y_pred_mean.result())49return cov / tf.sqrt(y_true_var) / tf.sqrt(y_pred_var)50
51def reset_states(self):52self._y_true_mean.reset_states()53self._y_pred_mean.reset_states()54self._y_true_sq_mean.reset_states()55self._y_pred_sq_mean.reset_states()56self._y_true_dot_y_pred_mean.reset_states()57
58
59@gin.configurable60class Perplexity(tf.metrics.SparseCategoricalCrossentropy):61"""Implements perplexity as tf.metrics.Metric class."""62
63def __init__(self, from_logits=True, name='perplexity', **kwargs):64super().__init__(from_logits=from_logits, name=name, **kwargs)65
66def result(self):67return tf.exp(super().result())68
69
70class DoubleMean(tf.keras.metrics.Metric):71"""The means of predictions and ground truth for a given metrics."""72
73def __init__(self, mean_metric_cls, **kwargs):74self._predicted = mean_metric_cls()75self._expected = mean_metric_cls()76super().__init__(name=self._expected.name)77
78def update_state(self, y_true, y_pred, sample_weight=None):79self._predicted.update_state(y_pred, sample_weight)80self._expected.update_state(y_true, sample_weight)81
82def reset_states(self):83self._predicted.reset_states()84self._expected.reset_states()85
86def result(self):87return {88f'{self.name}/true': self._expected.result(),89f'{self.name}/pred': self._predicted.result()90}91
92
93@gin.configurable94class SparseLiftedClanAccuracy(tf.metrics.Accuracy):95"""Evaluates SparseCategoricalAccuracy at the lifted clan level."""96
97def __init__(98self, filename, name = 'lifted_clan_accuracy', **kwargs):99super().__init__(name=name, **kwargs)100# Precomputes a 1D Tensor cla_from_fam such that cla_from_fam[fam_key]101# contains the label cla_key of the clan to which the family indexed by102# fam_key belongs.103self._filename = filename # A json file.104cla_key_from_fam_key = self._load_mapping()105keys = list(cla_key_from_fam_key.keys())106values = list(cla_key_from_fam_key.values())107indices = sorted(range(len(keys)), key=lambda i: keys[i])108self._cla_from_fam = tf.convert_to_tensor(109[values[i] for i in indices], tf.int64)110
111def _load_mapping(self):112"""Prepares family to clan key mapping from JSON file."""113with tf.io.gfile.GFile(self._filename, 'r') as f:114cla_id_from_fam_id = json.load(f)115# "Translates" the mapping between IDs to a mapping between integer keys.116idx_from_fam, idx_from_cla = {}, {}117for fam, cla in cla_id_from_fam_id.items():118if fam not in idx_from_fam:119idx_from_fam[fam] = len(idx_from_fam)120if cla not in idx_from_cla:121idx_from_cla[cla] = len(idx_from_cla)122return {idx_from_fam[k]: idx_from_cla[v]123for k, v in cla_id_from_fam_id.items()}124
125def update_state(self,126y_true,127y_pred,128sample_weight = None,129metadata = ()):130# Ignores family labels, assumes metadata always contains clan labels.131y_true = metadata[0]132# Computes predicted family labels from probabilities / logits. Then, maps133# these to clan labels.134y_pred = tf.gather(self._cla_from_fam, tf.math.argmax(y_pred, axis=-1))135super().update_state(y_true, y_pred, sample_weight=sample_weight)136
137
138@gin.configurable139class ContactPrecisionRecallFixedK(tf.metrics.Metric):140"""Implements basic PR metrics for residue-residue contact prediction."""141
142def __init__(self,143name = 'contact_pr',144range_low = 12,145range_high = 23,146at_k = 50,147**kwargs):148super().__init__(name=name, **kwargs)149self._range_low = range_low150self._range_high = range_high151self._at_k = at_k # TODO(qberthet): allow at_k to be list to clean up gin.152
153self._precision = tf.metrics.Mean()154self._recall = tf.metrics.Mean()155self._f1score = tf.metrics.Mean()156self._auprc = tf.metrics.AUC(157num_thresholds=1000, curve='PR', from_logits=False)158
159def update_state(self,160y_true,161y_pred,162sample_weight):163batch = tf.shape(y_pred)[0]164num_embs = tf.shape(y_pred)[1]165proba_pred = tf.nn.sigmoid(y_pred)166if sample_weight is not None:167proba_pred *= sample_weight[Ellipsis, None]168
169weights_range = tf.range(num_embs, dtype=tf.float32)170weights_range_square = tf.abs(weights_range[:, tf.newaxis] -171weights_range[tf.newaxis, :])172indic_range_fun = lambda x, a, b: tf.logical_and(x >= a, x <= b)173weights_square = indic_range_fun(weights_range_square,174self._range_low,175self._range_high)176weights_square = tf.cast(weights_square[None, Ellipsis, None], dtype=tf.float32)177
178proba_pred_filter = proba_pred * weights_square179flat_proba_pred_filter = tf.reshape(proba_pred_filter, (batch, -1))180y_true_filter = y_true * weights_square181flat_y_true_filter = tf.reshape(y_true_filter, (batch, -1))182
183_, indices = tf.math.top_k(flat_proba_pred_filter, k=self._at_k)184flat_y_pred_filter = tf.cast(flat_proba_pred_filter > 0.5, tf.float32)185
186true_in_top = tf.gather(flat_y_true_filter, indices, batch_dims=-1)187pred_in_top = tf.gather(flat_y_pred_filter, indices, batch_dims=-1)188true_pred_in_top = tf.gather(189flat_y_true_filter * flat_y_pred_filter, indices, batch_dims=-1)190
191number_true = tf.maximum(tf.reduce_sum(true_in_top, axis=-1), 1e-6)192number_preds = tf.maximum(tf.reduce_sum(pred_in_top, axis=-1), 1e-6)193number_true_preds = tf.reduce_sum(true_pred_in_top, axis=-1)194
195precision = tf.maximum(number_true_preds / number_preds, 1e-6)196recall = tf.maximum(number_true_preds / number_true, 1e-6)197
198self._precision.update_state(precision)199self._recall.update_state(recall)200self._f1score.update_state(2012 * (precision * recall) / (precision + recall))202
203# TODO(qberthet): double-check.204auprc_weight = sample_weight * weights_square[Ellipsis, 0]205self._auprc.update_state(206y_true_filter, proba_pred_filter, sample_weight=auprc_weight)207
208def result(self):209return {210f'{self.name}/precision': self._precision.result(),211f'{self.name}/recall': self._recall.result(),212f'{self.name}/f1': self._f1score.result(),213f'{self.name}/auprc': self._auprc.result(),214}215
216def reset_states(self):217self._precision.reset_states()218self._recall.reset_states()219self._f1score.reset_states()220self._auprc.reset_states()221