google-research

Форк
0
/
per_residue_sparse_test.py 
373 строки · 12.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for protenn.per_residue_sparse."""
17

18
from absl.testing import absltest
19
from absl.testing import parameterized
20
import numpy as np
21
import scipy.sparse
22
from protenn import per_residue_sparse
23

24

25
class PerResidueSparseTest(parameterized.TestCase):
26

27
  def test_true_label_to_coo(self):
28
    input_ground_truth = [(4, 987), (13, 987), (2, 1234)]
29
    actual = per_residue_sparse.true_label_to_coo(input_ground_truth)[:10]
30
    expected = [(4, 987, 1.0), (13, 987, 1.0), (2, 1234, 1.0)]
31
    self.assertListEqual(actual, expected)
32

33
  @parameterized.named_parameters(
34
      dict(
35
          testcase_name=' empty inputs',
36
          input_ijv_tuples=[],
37
          input_vocab=np.array([]),
38
          input_applicable_label_dict={},
39
          expected=[],
40
      ),
41
      dict(
42
          testcase_name=' one input, nothing implied',
43
          input_ijv_tuples=[
44
              (0, 0, 1),
45
          ],
46
          input_vocab=np.array(['PF00001']),
47
          input_applicable_label_dict={},
48
          expected=[(0, 0, 1)],
49
      ),
50
      dict(
51
          testcase_name=' one input, something implied',
52
          input_ijv_tuples=[
53
              (0, 0, 1),
54
          ],
55
          input_vocab=np.array(['PF00001', 'CL0192']),
56
          input_applicable_label_dict={'PF00001': 'CL0192'},
57
          # Second tuple gets added because it's implied by the first.
58
          expected=[(0, 0, 1), (0, 1, 1)],
59
      ),
60
      dict(
61
          testcase_name=' clan already has prediction, clan prediction weaker',
62
          input_ijv_tuples=[(0, 0, 1), (0, 1, 0.5)],
63
          input_vocab=np.array(['PF00001', 'CL0192']),
64
          input_applicable_label_dict={'PF00001': 'CL0192'},
65
          # Expect that, because the family label is larger than the clan label,
66
          # the second tuple's last entry is 1, not .5.
67
          expected=[(0, 0, 1), (0, 1, 1)],
68
      ),
69
      dict(
70
          testcase_name=(
71
              ' clan already has prediction, clan prediction stronger'
72
          ),
73
          input_ijv_tuples=[(0, 0, 0.5), (0, 1, 1.0)],
74
          input_vocab=np.array(['PF00001', 'CL0192']),
75
          input_applicable_label_dict={'PF00001': 'CL0192'},
76
          # Expect that, because the clan label is larger than the family label,
77
          # the second tuple's last entry is .5, not 1.
78
          expected=[(0, 0, 0.5), (0, 1, 1)],
79
      ),
80
      dict(
81
          testcase_name=' two inputs, clan label implied by both',
82
          input_ijv_tuples=[(0, 0, 0.5), (0, 1, 1.0)],
83
          input_vocab=np.array(['PF00001', 'PF00002', 'CL0192']),
84
          input_applicable_label_dict={
85
              'PF00001': 'CL0192',
86
              'PF00002': 'CL0192',
87
          },
88
          # Expect that the clan gets the maximum of either labels.
89
          expected=[(0, 0, 0.5), (0, 1, 1.0), (0, 2, 1.0)],
90
      ),
91
      dict(
92
          testcase_name=(
93
              ' two inputs at different indexes, clan label implied by both'
94
          ),
95
          input_ijv_tuples=[(0, 0, 0.5), (1, 0, 1.0)],
96
          input_vocab=np.array(['PF00001', 'CL0192']),
97
          input_applicable_label_dict={'PF00001': 'CL0192'},
98
          # Expect that the clan label is applied to both indexes.
99
          expected=[(0, 0, 0.5), (1, 0, 1.0), (0, 1, 0.5), (1, 1, 1.0)],
100
      ),
101
  )
102
  def test_normalize_ijv_tuples(
103
      self, input_ijv_tuples, input_vocab, input_applicable_label_dict, expected
104
  ):
105
    actual = per_residue_sparse.normalize_ijv_tuples(
106
        input_ijv_tuples, input_vocab, input_applicable_label_dict
107
    )
108
    self.assertCountEqual(actual, expected)
109

110
  def test_dense_to_sparse_coo_list_of_tuples(self):
111
    input_dense = np.arange(9).reshape(3, 3)
112

113
    actual = per_residue_sparse.dense_to_sparse_coo_list_of_tuples(input_dense)
114
    expected = [
115
        (0, 1, 1),
116
        (0, 2, 2),
117
        (1, 0, 3),
118
        (1, 1, 4),
119
        (1, 2, 5),
120
        (2, 0, 6),
121
        (2, 1, 7),
122
        (2, 2, 8),
123
    ]
124
    self.assertListEqual(actual, expected)
125

126
  def test_np_matrix_to_array(self):
127
    input_array = np.arange(9).reshape(3, 3)
128
    input_matrix = scipy.sparse.coo_matrix(input_array).todense()
129
    actual = per_residue_sparse.np_matrix_to_array(input_matrix)
130

131
    expected = input_array
132
    np.testing.assert_allclose(actual, expected)
133

134
  def test_ijv_tuples_to_sparse_coo(self):
135
    # This is np.arange(9).reshape(3, 3).
136
    input_ijv_list = [
137
        (0, 1, 1),
138
        (0, 2, 2),
139
        (1, 0, 3),
140
        (1, 1, 4),
141
        (1, 2, 5),
142
        (2, 0, 6),
143
        (2, 1, 7),
144
        (2, 2, 8),
145
    ]
146
    input_sequence_length = 3
147
    input_num_classes = 3
148

149
    actual = per_residue_sparse.ijv_tuples_to_sparse_coo(
150
        input_ijv_list, input_sequence_length, input_num_classes
151
    )
152
    expected_num_nonzero = len(input_ijv_list)
153

154
    self.assertEqual(actual.count_nonzero(), expected_num_nonzero)
155
    self.assertEqual(actual.todense()[0, 1], 1)
156
    self.assertEqual(actual.todense()[0, 2], 2)
157

158
  def test_ijv_tuples_to_sparse_coo_empty_input(self):
159
    input_ijv_list = []
160
    input_sequence_length = 3
161
    input_num_classes = 3
162

163
    actual = per_residue_sparse.ijv_tuples_to_sparse_coo(
164
        input_ijv_list, input_sequence_length, input_num_classes
165
    )
166
    expected_num_nonzero = len(input_ijv_list)
167

168
    self.assertEqual(actual.count_nonzero(), expected_num_nonzero)
169

170
  def test_ijv_tuples_to_dense(self):
171
    # Identity matrix with 0, 1, 2 along diagonal instead of ones.
172
    input_ijv_list = [
173
        (0, 1, 1),
174
        (0, 2, 2),
175
        (1, 0, 3),
176
        (1, 1, 4),
177
        (1, 2, 5),
178
        (2, 0, 6),
179
        (2, 1, 7),
180
        (2, 2, 8),
181
    ]
182
    input_sequence_length = 3
183
    input_num_classes = 3
184

185
    actual = per_residue_sparse.ijv_tuples_to_dense(
186
        input_ijv_list, input_sequence_length, input_num_classes
187
    )
188

189
    expected = np.arange(9).reshape(3, 3)
190
    np.testing.assert_equal(actual, expected)
191

192
  def test_ijv_tuples_to_dense_empty_input(self):
193
    input_ijv_list = []
194
    input_sequence_length = 3
195
    input_num_classes = 3
196

197
    actual = per_residue_sparse.ijv_tuples_to_dense(
198
        input_ijv_list, input_sequence_length, input_num_classes
199
    )
200

201
    expected = np.zeros(shape=(3, 3))
202
    np.testing.assert_equal(actual, expected)
203

204
  @parameterized.named_parameters(
205
      dict(
206
          testcase_name=' all false',
207
          input_boolean_condition=np.array([False, False]),
208
          expected=np.empty(shape=(0, 2)),
209
      ),
210
      dict(
211
          testcase_name=' all true',
212
          input_boolean_condition=np.array([True, True, True]),
213
          expected=np.array([[0, 3]]),
214
      ),
215
      dict(
216
          testcase_name=' one true',
217
          input_boolean_condition=np.array([False, True, False]),
218
          expected=np.array([[1, 2]]),
219
      ),
220
      dict(
221
          testcase_name=' one true region',
222
          input_boolean_condition=np.array([False, True, True, False]),
223
          expected=np.array([[1, 3]]),
224
      ),
225
      dict(
226
          testcase_name=' two true regions',
227
          input_boolean_condition=np.array(
228
              [False, True, True, False, True, True]
229
          ),
230
          expected=np.array([[1, 3], [4, 6]]),
231
      ),
232
  )
233
  def test_contiguous_regions_1d(self, input_boolean_condition, expected):
234
    actual = per_residue_sparse.contiguous_regions_1d(input_boolean_condition)
235
    np.testing.assert_equal(actual, expected)
236

237
  @parameterized.named_parameters(
238
      dict(
239
          testcase_name=' no activations',
240
          input_activations=[],
241
          input_sequence_length=3,
242
          input_vocab=np.array(['PF00001', 'PF00002', 'CL0192']),
243
          input_reporting_threshold=0.8,
244
          expected={},
245
      ),
246
      dict(
247
          testcase_name=' one activation, below threshold',
248
          # .3 is below reporting threshold.
249
          input_activations=[(0, 0, 0.3)],
250
          input_sequence_length=3,
251
          input_vocab=np.array(['PF00001', 'PF00002', 'CL0192']),
252
          input_reporting_threshold=0.8,
253
          expected={},
254
      ),
255
      dict(
256
          testcase_name=' one activation, above threshold',
257
          # .99 is above reporting threshold.
258
          input_activations=[(0, 0, 0.99)],
259
          input_sequence_length=3,
260
          input_vocab=np.array(['PF00001', 'PF00002', 'CL0192']),
261
          input_reporting_threshold=0.8,
262
          expected={
263
              'PF00001': [(1, 1)],
264
          },
265
      ),
266
      dict(
267
          testcase_name=' two contiguous regions DO GET merged',
268
          # The two residues should get merged into one region for PF00001.
269
          input_activations=[(0, 0, 0.99), (1, 0, 0.99)],
270
          input_sequence_length=3,
271
          input_vocab=np.array(['PF00001']),
272
          input_reporting_threshold=0.8,
273
          expected={
274
              'PF00001': [(1, 2)],
275
          },
276
      ),
277
      dict(
278
          testcase_name=' two NONcontiguous regions DO NOT get merged',
279
          input_activations=[(0, 0, 0.99), (3, 0, 0.99)],
280
          input_sequence_length=5,
281
          input_vocab=np.array(['PF00001']),
282
          input_reporting_threshold=0.8,
283
          expected={
284
              'PF00001': [(1, 1), (4, 4)],
285
          },
286
      ),
287
      dict(
288
          testcase_name=(
289
              ' two contiguous regions belonging to different families DO NOT'
290
              ' GET merged'
291
          ),
292
          input_activations=[(0, 0, 0.99), (1, 1, 0.99)],
293
          input_sequence_length=3,
294
          input_vocab=np.array(['PF00001', 'PF00002']),
295
          input_reporting_threshold=0.8,
296
          expected={
297
              'PF00001': [(1, 1)],
298
              'PF00002': [(2, 2)],
299
          },
300
      ),
301
  )
302
  def test_contiguous_regions_2d(
303
      self,
304
      input_activations,
305
      input_sequence_length,
306
      input_vocab,
307
      input_reporting_threshold,
308
      expected,
309
  ):
310
    actual = per_residue_sparse.contiguous_regions_2d(
311
        activations=input_activations,
312
        sequence_length=input_sequence_length,
313
        vocab=input_vocab,
314
        reporting_threshold=input_reporting_threshold,
315
    )
316
    self.assertDictEqual(actual, expected)
317

318
  def test_filter_domain_calls_by_length(self):
319
    input_domain_calls = {'CL0036': [(1, 2), (3, 300)], 'PF00001': [(20, 500)]}
320
    input_min_length = 42
321
    actual = per_residue_sparse.filter_domain_calls_by_length(
322
        input_domain_calls, input_min_length
323
    )
324
    expected = {'CL0036': [(3, 300)], 'PF00001': [(20, 500)]}
325
    self.assertDictEqual(actual, expected)
326

327
  def test_activations_to_domain_calls(self):
328
    input_activations_class_0 = [(i, 0, 0.4) for i in range(50)]
329
    input_activations_class_1 = [(i, 1, 1.0) for i in range(3)]
330
    input_activations = input_activations_class_0 + input_activations_class_1
331
    input_sequence_length = 200
332
    input_vocab = np.array(['CLASS_0', 'CLASS_1'])
333
    input_reporting_threshold = 0.3
334
    input_min_domain_call_length = 50
335

336
    actual = per_residue_sparse.activations_to_domain_calls(
337
        input_activations,
338
        input_sequence_length,
339
        input_vocab,
340
        input_reporting_threshold,
341
        input_min_domain_call_length,
342
    )
343
    expected = {'CLASS_0': [(1, 50)]}
344
    self.assertEqual(actual, expected)
345

346
  def test_num_labels_in_dense_label_dict(self):
347
    input_dense_label_dict = {
348
        'CL1234': [(1, 2), (3, 4)],
349
        'PF00001': [(100, 200)],
350
    }
351
    actual = per_residue_sparse.num_labels_in_dense_label_dict(
352
        input_dense_label_dict
353
    )
354
    expected = 3
355

356
    self.assertEqual(actual, expected)
357

358
  def test_flatten_dict_of_domain_calls(self):
359
    input_dict_of_calls = {'CL0036': [(29, 252), (253, 254), (256, 257)]}
360
    expected = [
361
        ('CL0036', (29, 252)),
362
        ('CL0036', (253, 254)),
363
        ('CL0036', (256, 257)),
364
    ]
365
    actual = per_residue_sparse.flatten_dict_of_domain_calls(
366
        input_dict_of_calls
367
    )
368

369
    self.assertListEqual(actual, expected)
370

371

372
if __name__ == '__main__':
373
  absltest.main()
374

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.