google-research

Форк
0
/
per_residue_sparse.py 
297 строк · 9.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for working with sparse arrays for per-residue models."""
17

18
import collections
19
from typing import Dict, List, Optional, Tuple
20
import numpy as np
21
import scipy.sparse
22
from protenn import utils
23

24

25
# list of triples (i index, j index, values).
26
# Compatible with scipy.sparse.coo format.
27
# The i index is the sequence position index, and the j index is the label
28
# index.
29
# This structure is 0-indexed by residue, unlike most tools in the
30
# bioinformatics world.
31
COO_ijv_list = List[Tuple[int, int, float]]  # pylint: disable=invalid-name
32

33

34
# label -> list of (start index, end index).
35
# This structure is 1-indexed by residue (like all tools in the bioinformatics
36
# world), not 0-indexed, and is left-inclusive, right-inclusive.
37
# The reason to be 1-indexed is for better interoperability with tools like
38
# HMMER and InterProScan.
39
# See `programmer_range_to_biologist_range`.
40
DenseLabelDict = Dict[str, List[Tuple[int, int]]]
41

42

43
DEFAULT_DOMAIN_CALL_MIN_LENGTH = 20
44

45

46
def true_label_to_coo(true_label_tuples):
47
  """Converts tuples (seq_idx, class_idx) into ijv COO with "v" value 1."""
48
  return [(x[0], x[1], 1.) for x in true_label_tuples]
49

50

51
def dense_to_sparse_coo_list_of_tuples(
52
    twod_nparray):
53
  """Converts dense array to list of triples (i index, j index, values).
54

55
  Compatible with scipy.sparse.coo format.
56

57
  Args:
58
    twod_nparray: array.
59

60
  Returns:
61
    List of triples i, j, v.
62
  """
63
  to_return = []
64
  for nonzero_i, nonzero_j in np.array(twod_nparray.nonzero()).T:  # pylint: disable=not-an-iterable
65
    to_return.append((nonzero_i, nonzero_j, twod_nparray[nonzero_i, nonzero_j]))
66
  return to_return
67

68

69
def np_matrix_to_array(a):
70
  """Converts scipy.sparse.coo_matrix.todense() to array."""
71
  return np.squeeze(np.asarray(a))
72

73

74
def ijv_tuples_to_sparse_coo(ijv_list, sequence_length,
75
                             num_classes):
76
  """Converts list of triples (i index, j index, values) to coo_matrix.
77

78
  Args:
79
    ijv_list: see COO_ijv_list above.
80
    sequence_length: int.
81
    num_classes: int.
82

83
  Returns:
84
    coo_matrix of shape (sequence_length, num_classes)
85
  """
86
  if len(ijv_list) == 0:  # pylint: disable=g-explicit-length-test
87
    return scipy.sparse.coo_matrix((sequence_length, num_classes), np.float_)
88

89
  ijv_np = np.array(ijv_list)
90

91
  try:
92
    i = ijv_np[:, 0]
93
    j = ijv_np[:, 1]
94
    v = ijv_np[:, 2]
95
  except IndexError as e:
96
    # If there is an error, reraise it and include contents of ijv_np in the
97
    # stack trace to aid debugging.
98
    raise ValueError(ijv_np) from e
99
  return scipy.sparse.coo_matrix((v, (i, j)),
100
                                 shape=(sequence_length, num_classes))
101

102

103
def ijv_tuples_to_dense(ijv_list, sequence_length,
104
                        num_classes):
105
  """Converts list of triples (i index, j index, values) to dense np.array.
106

107
  Args:
108
    ijv_list: see COO_ijv_list above.
109
    sequence_length: int.
110
    num_classes: int.
111

112
  Returns:
113
    np.ndarray of shape (sequence_length, num_classes)
114
  """
115
  coo = ijv_tuples_to_sparse_coo(
116
      ijv_list, sequence_length=sequence_length, num_classes=num_classes)
117
  return np_matrix_to_array(coo.todense())
118

119

120
# https://stackoverflow.com/questions/4494404/find-large-number-of-consecutive-values-fulfilling-condition-in-a-numpy-array
121
def contiguous_regions_1d(boolean_condition):
122
  """Finds contiguous True regions of the boolean array "boolean_condition".
123

124
  Args:
125
    boolean_condition: boolean array of shape (sequence_length,).
126

127
  Returns:
128
    a 2D array where the first column is the start index of the region and the
129
    second column is the end index.
130
    The output is 0-indexed, both by family and residue, and is left-inclusive,
131
    right exclusive.
132
  """
133

134
  # Find the indices of changes in "boolean_condition".
135
  d = np.diff(boolean_condition)
136
  (idx,) = d.nonzero()
137

138
  # We need to start things after the change in "boolean_condition". Therefore,
139
  # we'll shift the index by 1 to the right.
140
  idx += 1
141

142
  if boolean_condition[0]:
143
    # If the start of boolean_condition is True prepend a 0.
144
    idx = np.r_[0, idx]
145

146
  if boolean_condition[-1]:
147
    # If the end of boolean_condition is True, append the length of the array.
148
    idx = np.r_[idx, boolean_condition.size]
149

150
  # Reshape the result into two columns.
151
  idx.shape = (-1, 2)
152
  return idx
153

154

155
def normalize_ijv_tuples(
156
    ijv_list,
157
    vocab,
158
    applicable_label_dict,
159
    label_to_idx = None,
160
):
161
  """Gives, for example, clan labels for each family label.
162

163
  For each ijv, if there is an associated label that is implied by that label,
164
  then also return that ijv. If a clan label is implied by
165
  more than one other label, ties are broken by taking the max.
166

167
  Args:
168
    ijv_list: see COO_ijv_list above.
169
    vocab: 1d array of string values corresponding to label indexes.
170
    applicable_label_dict: Mapping from labels to their parents (including
171
      indirect parents). E.g. utils.family_to_clan_mapping. Note that this is
172
      different from proteinfer-style applicable label dicts, where more than
173
      one label may be implied.
174
    label_to_idx: optional inverted lookup of vocab. Often, this function is
175
      called many times, and inverting the vocabulary for each call can cause
176
      performance problems. In this case, one can provide a precomputed lookup.
177
      If not provided, vocab will be manually inverted.
178

179
  Returns:
180
    ijv list as described above.
181
  """
182
  if label_to_idx is None:
183
    label_to_idx = {v: i for i, v in enumerate(vocab)}
184

185
  seq_and_label_to_v = {}
186

187
  for ijv in ijv_list:
188
    seq_idx, label_idx, activation_confidence = ijv
189

190
    value_key = (seq_idx, label_idx)
191
    if value_key not in seq_and_label_to_v:
192
      seq_and_label_to_v[value_key] = activation_confidence
193
    elif seq_and_label_to_v[value_key] < activation_confidence:
194
      seq_and_label_to_v[value_key] = activation_confidence
195

196
    label = vocab[label_idx]
197
    if label in applicable_label_dict:
198
      implied_label = applicable_label_dict[label]
199
      implied_label_idx = label_to_idx[implied_label]
200
      value_key = (seq_idx, implied_label_idx)
201
      if value_key not in seq_and_label_to_v:
202
        seq_and_label_to_v[value_key] = activation_confidence
203
      elif seq_and_label_to_v[value_key] < activation_confidence:
204
        seq_and_label_to_v[value_key] = activation_confidence
205

206
  return [(i, j, v) for (i, j), v in seq_and_label_to_v.items()]
207

208

209
def contiguous_regions_2d(activations,
210
                          sequence_length,
211
                          vocab,
212
                          reporting_threshold = .5):
213
  """For a list of tuple activations ijv, compute contiguous domain calls.
214

215
  For each entry, consider it a call if the v in ijv > reporting_threshold.
216
  Then, coalesce contiguous entries (along the sequence dimension, i.e.
217
  fixing a particular label) for each label.
218

219
  No handling of label propagation is done in this function.
220

221
  Args:
222
    activations: see COO_ijv_list above.
223
    sequence_length: int.
224
    vocab: 1d array of string values corresponding to label indexes.
225
    reporting_threshold: float.
226

227
  Returns:
228
    label -> list of (start index, end index).
229
  """
230
  calls = [(x[0], x[1]) for x in activations if x[2] > reporting_threshold]
231
  residue_calls_by_family = collections.defaultdict(list)
232
  for residue_idx, label_idx in calls:
233
    residue_calls_by_family[vocab[label_idx]].append(residue_idx)
234

235
  domain_calls_by_family = {}
236
  for family, residues in residue_calls_by_family.items():
237
    dense = np.zeros((sequence_length,), dtype=np.bool_)
238
    dense[residues] = True
239

240
    # DenseLabelDict is a biologist-index-scheme based data structure.
241
    domain_calls_by_family[family] = [
242
        utils.programmer_range_to_biologist_range(x[0], x[1])
243
        for x in contiguous_regions_1d(dense)
244
    ]
245

246
  return domain_calls_by_family
247

248

249
def filter_domain_calls_by_length(
250
    calls_dict,
251
    min_length = DEFAULT_DOMAIN_CALL_MIN_LENGTH,
252
):
253
  """Filters out short calls from calls_dict."""
254
  calls_filtered = collections.defaultdict(list)
255
  for label, domain_ranges in calls_dict.items():
256
    for domain_start, domain_end in domain_ranges:
257

258
      # Add 1 because the input to this function is a DenseLabelDict,
259
      # which is 1-indexed, right inclusive.
260
      if domain_end - domain_start + 1 >= min_length:
261
        calls_filtered[label].append((domain_start, domain_end))
262

263
  # Convert from a defaultdict to a dict so as not to confuse
264
  # downstream users with weird behavior for new keys.
265
  calls_filtered = dict(calls_filtered.items())
266

267
  return calls_filtered
268

269

270
def activations_to_domain_calls(
271
    activations,
272
    sequence_length,
273
    vocab,
274
    reporting_threshold = 0.025,
275
    min_domain_call_length = DEFAULT_DOMAIN_CALL_MIN_LENGTH,
276
):
277
  """Convert activations to dict of domain calls."""
278
  domain_calls = contiguous_regions_2d(activations, sequence_length, vocab,
279
                                       reporting_threshold)
280
  return filter_domain_calls_by_length(domain_calls, min_domain_call_length)
281

282

283
def num_labels_in_dense_label_dict(d):
284
  count = 0
285
  for ranges in d.values():
286
    count += len(ranges)
287
  return count
288

289

290
def flatten_dict_of_domain_calls(
291
    calls_dict):
292
  """Flattens label -> list[(start, end)] to list[(label, (start, end))]."""
293
  to_return = []
294
  for family, ranges in calls_dict.items():
295
    for r in ranges:
296
      to_return.append((family, tuple(r)))
297
  return to_return
298

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.