google-research
297 строк · 9.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for working with sparse arrays for per-residue models."""
17
18import collections19from typing import Dict, List, Optional, Tuple20import numpy as np21import scipy.sparse22from protenn import utils23
24
25# list of triples (i index, j index, values).
26# Compatible with scipy.sparse.coo format.
27# The i index is the sequence position index, and the j index is the label
28# index.
29# This structure is 0-indexed by residue, unlike most tools in the
30# bioinformatics world.
31COO_ijv_list = List[Tuple[int, int, float]] # pylint: disable=invalid-name32
33
34# label -> list of (start index, end index).
35# This structure is 1-indexed by residue (like all tools in the bioinformatics
36# world), not 0-indexed, and is left-inclusive, right-inclusive.
37# The reason to be 1-indexed is for better interoperability with tools like
38# HMMER and InterProScan.
39# See `programmer_range_to_biologist_range`.
40DenseLabelDict = Dict[str, List[Tuple[int, int]]]41
42
43DEFAULT_DOMAIN_CALL_MIN_LENGTH = 2044
45
46def true_label_to_coo(true_label_tuples):47"""Converts tuples (seq_idx, class_idx) into ijv COO with "v" value 1."""48return [(x[0], x[1], 1.) for x in true_label_tuples]49
50
51def dense_to_sparse_coo_list_of_tuples(52twod_nparray):53"""Converts dense array to list of triples (i index, j index, values).54
55Compatible with scipy.sparse.coo format.
56
57Args:
58twod_nparray: array.
59
60Returns:
61List of triples i, j, v.
62"""
63to_return = []64for nonzero_i, nonzero_j in np.array(twod_nparray.nonzero()).T: # pylint: disable=not-an-iterable65to_return.append((nonzero_i, nonzero_j, twod_nparray[nonzero_i, nonzero_j]))66return to_return67
68
69def np_matrix_to_array(a):70"""Converts scipy.sparse.coo_matrix.todense() to array."""71return np.squeeze(np.asarray(a))72
73
74def ijv_tuples_to_sparse_coo(ijv_list, sequence_length,75num_classes):76"""Converts list of triples (i index, j index, values) to coo_matrix.77
78Args:
79ijv_list: see COO_ijv_list above.
80sequence_length: int.
81num_classes: int.
82
83Returns:
84coo_matrix of shape (sequence_length, num_classes)
85"""
86if len(ijv_list) == 0: # pylint: disable=g-explicit-length-test87return scipy.sparse.coo_matrix((sequence_length, num_classes), np.float_)88
89ijv_np = np.array(ijv_list)90
91try:92i = ijv_np[:, 0]93j = ijv_np[:, 1]94v = ijv_np[:, 2]95except IndexError as e:96# If there is an error, reraise it and include contents of ijv_np in the97# stack trace to aid debugging.98raise ValueError(ijv_np) from e99return scipy.sparse.coo_matrix((v, (i, j)),100shape=(sequence_length, num_classes))101
102
103def ijv_tuples_to_dense(ijv_list, sequence_length,104num_classes):105"""Converts list of triples (i index, j index, values) to dense np.array.106
107Args:
108ijv_list: see COO_ijv_list above.
109sequence_length: int.
110num_classes: int.
111
112Returns:
113np.ndarray of shape (sequence_length, num_classes)
114"""
115coo = ijv_tuples_to_sparse_coo(116ijv_list, sequence_length=sequence_length, num_classes=num_classes)117return np_matrix_to_array(coo.todense())118
119
120# https://stackoverflow.com/questions/4494404/find-large-number-of-consecutive-values-fulfilling-condition-in-a-numpy-array
121def contiguous_regions_1d(boolean_condition):122"""Finds contiguous True regions of the boolean array "boolean_condition".123
124Args:
125boolean_condition: boolean array of shape (sequence_length,).
126
127Returns:
128a 2D array where the first column is the start index of the region and the
129second column is the end index.
130The output is 0-indexed, both by family and residue, and is left-inclusive,
131right exclusive.
132"""
133
134# Find the indices of changes in "boolean_condition".135d = np.diff(boolean_condition)136(idx,) = d.nonzero()137
138# We need to start things after the change in "boolean_condition". Therefore,139# we'll shift the index by 1 to the right.140idx += 1141
142if boolean_condition[0]:143# If the start of boolean_condition is True prepend a 0.144idx = np.r_[0, idx]145
146if boolean_condition[-1]:147# If the end of boolean_condition is True, append the length of the array.148idx = np.r_[idx, boolean_condition.size]149
150# Reshape the result into two columns.151idx.shape = (-1, 2)152return idx153
154
155def normalize_ijv_tuples(156ijv_list,157vocab,158applicable_label_dict,159label_to_idx = None,160):161"""Gives, for example, clan labels for each family label.162
163For each ijv, if there is an associated label that is implied by that label,
164then also return that ijv. If a clan label is implied by
165more than one other label, ties are broken by taking the max.
166
167Args:
168ijv_list: see COO_ijv_list above.
169vocab: 1d array of string values corresponding to label indexes.
170applicable_label_dict: Mapping from labels to their parents (including
171indirect parents). E.g. utils.family_to_clan_mapping. Note that this is
172different from proteinfer-style applicable label dicts, where more than
173one label may be implied.
174label_to_idx: optional inverted lookup of vocab. Often, this function is
175called many times, and inverting the vocabulary for each call can cause
176performance problems. In this case, one can provide a precomputed lookup.
177If not provided, vocab will be manually inverted.
178
179Returns:
180ijv list as described above.
181"""
182if label_to_idx is None:183label_to_idx = {v: i for i, v in enumerate(vocab)}184
185seq_and_label_to_v = {}186
187for ijv in ijv_list:188seq_idx, label_idx, activation_confidence = ijv189
190value_key = (seq_idx, label_idx)191if value_key not in seq_and_label_to_v:192seq_and_label_to_v[value_key] = activation_confidence193elif seq_and_label_to_v[value_key] < activation_confidence:194seq_and_label_to_v[value_key] = activation_confidence195
196label = vocab[label_idx]197if label in applicable_label_dict:198implied_label = applicable_label_dict[label]199implied_label_idx = label_to_idx[implied_label]200value_key = (seq_idx, implied_label_idx)201if value_key not in seq_and_label_to_v:202seq_and_label_to_v[value_key] = activation_confidence203elif seq_and_label_to_v[value_key] < activation_confidence:204seq_and_label_to_v[value_key] = activation_confidence205
206return [(i, j, v) for (i, j), v in seq_and_label_to_v.items()]207
208
209def contiguous_regions_2d(activations,210sequence_length,211vocab,212reporting_threshold = .5):213"""For a list of tuple activations ijv, compute contiguous domain calls.214
215For each entry, consider it a call if the v in ijv > reporting_threshold.
216Then, coalesce contiguous entries (along the sequence dimension, i.e.
217fixing a particular label) for each label.
218
219No handling of label propagation is done in this function.
220
221Args:
222activations: see COO_ijv_list above.
223sequence_length: int.
224vocab: 1d array of string values corresponding to label indexes.
225reporting_threshold: float.
226
227Returns:
228label -> list of (start index, end index).
229"""
230calls = [(x[0], x[1]) for x in activations if x[2] > reporting_threshold]231residue_calls_by_family = collections.defaultdict(list)232for residue_idx, label_idx in calls:233residue_calls_by_family[vocab[label_idx]].append(residue_idx)234
235domain_calls_by_family = {}236for family, residues in residue_calls_by_family.items():237dense = np.zeros((sequence_length,), dtype=np.bool_)238dense[residues] = True239
240# DenseLabelDict is a biologist-index-scheme based data structure.241domain_calls_by_family[family] = [242utils.programmer_range_to_biologist_range(x[0], x[1])243for x in contiguous_regions_1d(dense)244]245
246return domain_calls_by_family247
248
249def filter_domain_calls_by_length(250calls_dict,251min_length = DEFAULT_DOMAIN_CALL_MIN_LENGTH,252):253"""Filters out short calls from calls_dict."""254calls_filtered = collections.defaultdict(list)255for label, domain_ranges in calls_dict.items():256for domain_start, domain_end in domain_ranges:257
258# Add 1 because the input to this function is a DenseLabelDict,259# which is 1-indexed, right inclusive.260if domain_end - domain_start + 1 >= min_length:261calls_filtered[label].append((domain_start, domain_end))262
263# Convert from a defaultdict to a dict so as not to confuse264# downstream users with weird behavior for new keys.265calls_filtered = dict(calls_filtered.items())266
267return calls_filtered268
269
270def activations_to_domain_calls(271activations,272sequence_length,273vocab,274reporting_threshold = 0.025,275min_domain_call_length = DEFAULT_DOMAIN_CALL_MIN_LENGTH,276):277"""Convert activations to dict of domain calls."""278domain_calls = contiguous_regions_2d(activations, sequence_length, vocab,279reporting_threshold)280return filter_domain_calls_by_length(domain_calls, min_domain_call_length)281
282
283def num_labels_in_dense_label_dict(d):284count = 0285for ranges in d.values():286count += len(ranges)287return count288
289
290def flatten_dict_of_domain_calls(291calls_dict):292"""Flattens label -> list[(start, end)] to list[(label, (start, end))]."""293to_return = []294for family, ranges in calls_dict.items():295for r in ranges:296to_return.append((family, tuple(r)))297return to_return298