google-research
373 строки · 12.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for protenn.per_residue_sparse."""
17
18from absl.testing import absltest
19from absl.testing import parameterized
20import numpy as np
21import scipy.sparse
22from protenn import per_residue_sparse
23
24
25class PerResidueSparseTest(parameterized.TestCase):
26
27def test_true_label_to_coo(self):
28input_ground_truth = [(4, 987), (13, 987), (2, 1234)]
29actual = per_residue_sparse.true_label_to_coo(input_ground_truth)[:10]
30expected = [(4, 987, 1.0), (13, 987, 1.0), (2, 1234, 1.0)]
31self.assertListEqual(actual, expected)
32
33@parameterized.named_parameters(
34dict(
35testcase_name=' empty inputs',
36input_ijv_tuples=[],
37input_vocab=np.array([]),
38input_applicable_label_dict={},
39expected=[],
40),
41dict(
42testcase_name=' one input, nothing implied',
43input_ijv_tuples=[
44(0, 0, 1),
45],
46input_vocab=np.array(['PF00001']),
47input_applicable_label_dict={},
48expected=[(0, 0, 1)],
49),
50dict(
51testcase_name=' one input, something implied',
52input_ijv_tuples=[
53(0, 0, 1),
54],
55input_vocab=np.array(['PF00001', 'CL0192']),
56input_applicable_label_dict={'PF00001': 'CL0192'},
57# Second tuple gets added because it's implied by the first.
58expected=[(0, 0, 1), (0, 1, 1)],
59),
60dict(
61testcase_name=' clan already has prediction, clan prediction weaker',
62input_ijv_tuples=[(0, 0, 1), (0, 1, 0.5)],
63input_vocab=np.array(['PF00001', 'CL0192']),
64input_applicable_label_dict={'PF00001': 'CL0192'},
65# Expect that, because the family label is larger than the clan label,
66# the second tuple's last entry is 1, not .5.
67expected=[(0, 0, 1), (0, 1, 1)],
68),
69dict(
70testcase_name=(
71' clan already has prediction, clan prediction stronger'
72),
73input_ijv_tuples=[(0, 0, 0.5), (0, 1, 1.0)],
74input_vocab=np.array(['PF00001', 'CL0192']),
75input_applicable_label_dict={'PF00001': 'CL0192'},
76# Expect that, because the clan label is larger than the family label,
77# the second tuple's last entry is .5, not 1.
78expected=[(0, 0, 0.5), (0, 1, 1)],
79),
80dict(
81testcase_name=' two inputs, clan label implied by both',
82input_ijv_tuples=[(0, 0, 0.5), (0, 1, 1.0)],
83input_vocab=np.array(['PF00001', 'PF00002', 'CL0192']),
84input_applicable_label_dict={
85'PF00001': 'CL0192',
86'PF00002': 'CL0192',
87},
88# Expect that the clan gets the maximum of either labels.
89expected=[(0, 0, 0.5), (0, 1, 1.0), (0, 2, 1.0)],
90),
91dict(
92testcase_name=(
93' two inputs at different indexes, clan label implied by both'
94),
95input_ijv_tuples=[(0, 0, 0.5), (1, 0, 1.0)],
96input_vocab=np.array(['PF00001', 'CL0192']),
97input_applicable_label_dict={'PF00001': 'CL0192'},
98# Expect that the clan label is applied to both indexes.
99expected=[(0, 0, 0.5), (1, 0, 1.0), (0, 1, 0.5), (1, 1, 1.0)],
100),
101)
102def test_normalize_ijv_tuples(
103self, input_ijv_tuples, input_vocab, input_applicable_label_dict, expected
104):
105actual = per_residue_sparse.normalize_ijv_tuples(
106input_ijv_tuples, input_vocab, input_applicable_label_dict
107)
108self.assertCountEqual(actual, expected)
109
110def test_dense_to_sparse_coo_list_of_tuples(self):
111input_dense = np.arange(9).reshape(3, 3)
112
113actual = per_residue_sparse.dense_to_sparse_coo_list_of_tuples(input_dense)
114expected = [
115(0, 1, 1),
116(0, 2, 2),
117(1, 0, 3),
118(1, 1, 4),
119(1, 2, 5),
120(2, 0, 6),
121(2, 1, 7),
122(2, 2, 8),
123]
124self.assertListEqual(actual, expected)
125
126def test_np_matrix_to_array(self):
127input_array = np.arange(9).reshape(3, 3)
128input_matrix = scipy.sparse.coo_matrix(input_array).todense()
129actual = per_residue_sparse.np_matrix_to_array(input_matrix)
130
131expected = input_array
132np.testing.assert_allclose(actual, expected)
133
134def test_ijv_tuples_to_sparse_coo(self):
135# This is np.arange(9).reshape(3, 3).
136input_ijv_list = [
137(0, 1, 1),
138(0, 2, 2),
139(1, 0, 3),
140(1, 1, 4),
141(1, 2, 5),
142(2, 0, 6),
143(2, 1, 7),
144(2, 2, 8),
145]
146input_sequence_length = 3
147input_num_classes = 3
148
149actual = per_residue_sparse.ijv_tuples_to_sparse_coo(
150input_ijv_list, input_sequence_length, input_num_classes
151)
152expected_num_nonzero = len(input_ijv_list)
153
154self.assertEqual(actual.count_nonzero(), expected_num_nonzero)
155self.assertEqual(actual.todense()[0, 1], 1)
156self.assertEqual(actual.todense()[0, 2], 2)
157
158def test_ijv_tuples_to_sparse_coo_empty_input(self):
159input_ijv_list = []
160input_sequence_length = 3
161input_num_classes = 3
162
163actual = per_residue_sparse.ijv_tuples_to_sparse_coo(
164input_ijv_list, input_sequence_length, input_num_classes
165)
166expected_num_nonzero = len(input_ijv_list)
167
168self.assertEqual(actual.count_nonzero(), expected_num_nonzero)
169
170def test_ijv_tuples_to_dense(self):
171# Identity matrix with 0, 1, 2 along diagonal instead of ones.
172input_ijv_list = [
173(0, 1, 1),
174(0, 2, 2),
175(1, 0, 3),
176(1, 1, 4),
177(1, 2, 5),
178(2, 0, 6),
179(2, 1, 7),
180(2, 2, 8),
181]
182input_sequence_length = 3
183input_num_classes = 3
184
185actual = per_residue_sparse.ijv_tuples_to_dense(
186input_ijv_list, input_sequence_length, input_num_classes
187)
188
189expected = np.arange(9).reshape(3, 3)
190np.testing.assert_equal(actual, expected)
191
192def test_ijv_tuples_to_dense_empty_input(self):
193input_ijv_list = []
194input_sequence_length = 3
195input_num_classes = 3
196
197actual = per_residue_sparse.ijv_tuples_to_dense(
198input_ijv_list, input_sequence_length, input_num_classes
199)
200
201expected = np.zeros(shape=(3, 3))
202np.testing.assert_equal(actual, expected)
203
204@parameterized.named_parameters(
205dict(
206testcase_name=' all false',
207input_boolean_condition=np.array([False, False]),
208expected=np.empty(shape=(0, 2)),
209),
210dict(
211testcase_name=' all true',
212input_boolean_condition=np.array([True, True, True]),
213expected=np.array([[0, 3]]),
214),
215dict(
216testcase_name=' one true',
217input_boolean_condition=np.array([False, True, False]),
218expected=np.array([[1, 2]]),
219),
220dict(
221testcase_name=' one true region',
222input_boolean_condition=np.array([False, True, True, False]),
223expected=np.array([[1, 3]]),
224),
225dict(
226testcase_name=' two true regions',
227input_boolean_condition=np.array(
228[False, True, True, False, True, True]
229),
230expected=np.array([[1, 3], [4, 6]]),
231),
232)
233def test_contiguous_regions_1d(self, input_boolean_condition, expected):
234actual = per_residue_sparse.contiguous_regions_1d(input_boolean_condition)
235np.testing.assert_equal(actual, expected)
236
237@parameterized.named_parameters(
238dict(
239testcase_name=' no activations',
240input_activations=[],
241input_sequence_length=3,
242input_vocab=np.array(['PF00001', 'PF00002', 'CL0192']),
243input_reporting_threshold=0.8,
244expected={},
245),
246dict(
247testcase_name=' one activation, below threshold',
248# .3 is below reporting threshold.
249input_activations=[(0, 0, 0.3)],
250input_sequence_length=3,
251input_vocab=np.array(['PF00001', 'PF00002', 'CL0192']),
252input_reporting_threshold=0.8,
253expected={},
254),
255dict(
256testcase_name=' one activation, above threshold',
257# .99 is above reporting threshold.
258input_activations=[(0, 0, 0.99)],
259input_sequence_length=3,
260input_vocab=np.array(['PF00001', 'PF00002', 'CL0192']),
261input_reporting_threshold=0.8,
262expected={
263'PF00001': [(1, 1)],
264},
265),
266dict(
267testcase_name=' two contiguous regions DO GET merged',
268# The two residues should get merged into one region for PF00001.
269input_activations=[(0, 0, 0.99), (1, 0, 0.99)],
270input_sequence_length=3,
271input_vocab=np.array(['PF00001']),
272input_reporting_threshold=0.8,
273expected={
274'PF00001': [(1, 2)],
275},
276),
277dict(
278testcase_name=' two NONcontiguous regions DO NOT get merged',
279input_activations=[(0, 0, 0.99), (3, 0, 0.99)],
280input_sequence_length=5,
281input_vocab=np.array(['PF00001']),
282input_reporting_threshold=0.8,
283expected={
284'PF00001': [(1, 1), (4, 4)],
285},
286),
287dict(
288testcase_name=(
289' two contiguous regions belonging to different families DO NOT'
290' GET merged'
291),
292input_activations=[(0, 0, 0.99), (1, 1, 0.99)],
293input_sequence_length=3,
294input_vocab=np.array(['PF00001', 'PF00002']),
295input_reporting_threshold=0.8,
296expected={
297'PF00001': [(1, 1)],
298'PF00002': [(2, 2)],
299},
300),
301)
302def test_contiguous_regions_2d(
303self,
304input_activations,
305input_sequence_length,
306input_vocab,
307input_reporting_threshold,
308expected,
309):
310actual = per_residue_sparse.contiguous_regions_2d(
311activations=input_activations,
312sequence_length=input_sequence_length,
313vocab=input_vocab,
314reporting_threshold=input_reporting_threshold,
315)
316self.assertDictEqual(actual, expected)
317
318def test_filter_domain_calls_by_length(self):
319input_domain_calls = {'CL0036': [(1, 2), (3, 300)], 'PF00001': [(20, 500)]}
320input_min_length = 42
321actual = per_residue_sparse.filter_domain_calls_by_length(
322input_domain_calls, input_min_length
323)
324expected = {'CL0036': [(3, 300)], 'PF00001': [(20, 500)]}
325self.assertDictEqual(actual, expected)
326
327def test_activations_to_domain_calls(self):
328input_activations_class_0 = [(i, 0, 0.4) for i in range(50)]
329input_activations_class_1 = [(i, 1, 1.0) for i in range(3)]
330input_activations = input_activations_class_0 + input_activations_class_1
331input_sequence_length = 200
332input_vocab = np.array(['CLASS_0', 'CLASS_1'])
333input_reporting_threshold = 0.3
334input_min_domain_call_length = 50
335
336actual = per_residue_sparse.activations_to_domain_calls(
337input_activations,
338input_sequence_length,
339input_vocab,
340input_reporting_threshold,
341input_min_domain_call_length,
342)
343expected = {'CLASS_0': [(1, 50)]}
344self.assertEqual(actual, expected)
345
346def test_num_labels_in_dense_label_dict(self):
347input_dense_label_dict = {
348'CL1234': [(1, 2), (3, 4)],
349'PF00001': [(100, 200)],
350}
351actual = per_residue_sparse.num_labels_in_dense_label_dict(
352input_dense_label_dict
353)
354expected = 3
355
356self.assertEqual(actual, expected)
357
358def test_flatten_dict_of_domain_calls(self):
359input_dict_of_calls = {'CL0036': [(29, 252), (253, 254), (256, 257)]}
360expected = [
361('CL0036', (29, 252)),
362('CL0036', (253, 254)),
363('CL0036', (256, 257)),
364]
365actual = per_residue_sparse.flatten_dict_of_domain_calls(
366input_dict_of_calls
367)
368
369self.assertListEqual(actual, expected)
370
371
372if __name__ == '__main__':
373absltest.main()
374