google-research
90 строк · 3.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for utils."""
17
18import functools
19from absl.testing import parameterized
20import numpy as np
21import tensorflow.compat.v1 as tf
22from protein_lm import domains
23from protein_lm import models
24from protein_lm import utils
25
26lm_cfg = dict(
27batch_size=1, num_layers=2, num_heads=2, emb_dim=32, mlp_dim=32, qkv_dim=32)
28lm_cls = functools.partial(models.FlaxLM, **lm_cfg)
29
30
31class UtilsTest(tf.test.TestCase, parameterized.TestCase):
32
33def test_count_params(self):
34domain = domains.FixedLengthDiscreteDomain(length=4, vocab_size=2)
35lm = lm_cls(domain=domain)
36count = utils.param_count(lm)
37self.assertEqual(13059, count)
38
39# Check these methods run.
40utils.param_pprint(lm)
41sizes = utils.param_reduce(lm, log=True)
42self.assertIsInstance(sizes, dict)
43
44@parameterized.parameters((5, 5), (5, 1), (5, 2), (5, 6), (5, 12))
45def test_batch_apply(self, batch_size, num_inputs):
46def fn(inputs):
47return np.power(inputs + 1, 2)
48
49def batch_fn(batched_inputs):
50if len(batched_inputs) != batch_size:
51raise ValueError('fn() called with a batch that is '
52'the wrong size (%d vs. %d).' % (len(batched_inputs),
53batch_size))
54return fn(batched_inputs)
55inputs = np.stack([np.arange(num_inputs), -np.arange(num_inputs)], axis=1)
56unbatched_output = fn(inputs)
57batched_output = utils.batch_apply(batch_fn, inputs, batch_size)
58np.testing.assert_array_equal(unbatched_output, batched_output)
59
60def test_get_normalized_matrix(self):
61"""Tests that the normalized matrix is computed correctly."""
62domain = domains.FixedLengthDiscreteDomain(
63vocab=domains.Vocabulary(tokens=['A', 'B', 'C']),
64length=2)
65freq_dict = {'A': {'A': 5, 'B': 3, 'C': 1},
66'B': {'A': 3, 'B': 5, 'C': 1},
67'C': {'A': 1, 'B': 1, 'C': 1}}
68matrix = utils.get_normalized_matrix(domain, freq_dict)
69expected_matrix = [[1, 0.5, 0], [0.5, 1, 0,], [0, 0, 0]]
70self.assertAllEqual(matrix, expected_matrix)
71
72def test_soft_accuracy(self):
73"""Tests that soft accuracy is computed correctly."""
74domain = domains.FixedLengthDiscreteDomain(
75vocab=domains.Vocabulary(tokens=['A', 'B', 'C']),
76length=2)
77targets = np.array([[0, 1]])
78logits = np.log([[[0.9, 0.1], [0.6, 0.4]]])
79freq_dict = {'A': {'A': 5, 'B': 3, 'C': 1},
80'B': {'A': 3, 'B': 5, 'C': 1},
81'C': {'A': 1, 'B': 1, 'C': 1}}
82accuracy, denominator = utils.compute_weighted_soft_accuracy(
83logits, targets,
84weights=None,
85matrix=utils.get_normalized_matrix(domain, freq_dict))
86self.assertEqual(accuracy / denominator, 0.75)
87
88
89if __name__ == '__main__':
90tf.test.main()
91