google-research
103 строки · 3.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for Flax modules."""
17
18import functools
19
20from absl.testing import parameterized
21import jax.numpy as jnp
22import jax.random as jrandom
23import numpy as np
24import tensorflow.compat.v1 as tf
25
26from protein_lm import domains
27from protein_lm import models
28from protein_lm import modules
29
30lm_cls = functools.partial(
31models.FlaxLM,
32num_layers=1,
33num_heads=1,
34emb_dim=64,
35mlp_dim=64,
36qkv_dim=64)
37
38
39class ModulesTest(tf.test.TestCase, parameterized.TestCase):
40
41@parameterized.parameters(
42(modules.AddLearnedPositionalEncodings,),
43(modules.AddSinusoidalPositionalEncodings,))
44def test_positional_encodings(self, positional_encoding_module):
45"""Tests that the model runs with both types of positional encodings."""
46domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=2)
47lm = lm_cls(domain=domain,
48positional_encoding_module=positional_encoding_module)
49lm.sample(1)
50
51def test_embeddings_for_one_hot(self):
52"""Tests that the embeddings match for int and one-hot representations."""
53vocab_size = 10
54emb_dim = 7
55x_int = jnp.array([[1, 3], [2, 8]])
56module = modules.Embed.partial(
57num_embeddings=vocab_size, num_features=emb_dim)
58_, params = module.init(jrandom.PRNGKey(0), x_int)
59emb_int = module.call(params, x_int)
60x_one_hot = jnp.eye(vocab_size)[x_int]
61emb_one_hot = module.call(params, x_one_hot)
62self.assertAllEqual(emb_int, emb_one_hot)
63
64def test_embeddings_for_dist(self):
65"""Tests that the embeddings for soft inputs contain both tokens."""
66vocab_size = 5
67emb_dim = 7
68x_int = np.array([[1], [3]])
69module = modules.Embed.partial(
70num_embeddings=vocab_size, num_features=emb_dim)
71_, params = module.init(jrandom.PRNGKey(0), x_int)
72emb_int = module.call(params, x_int)
73x_dist = np.array([[[0, 0.25, 0, 0.75, 0]], [[0, 0.5, 0, 0.5, 0]]])
74emb_dist = np.array(module.call(params, x_dist))
75emb_expected = np.array([[emb_int[0, 0] * 0.25 + emb_int[1, 0] * 0.75],
76[emb_int[0, 0] * 0.5 + emb_int[1, 0] * 0.5]])
77self.assertAllClose(emb_dist, emb_expected)
78
79@parameterized.parameters(
80('logits', False),
81(['logits'], True),
82(('logits', 'output_emb'), True),
83)
84def test_output_head(self, output_head, multiple_heads):
85domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=2)
86inputs = domain.sample_uniformly(8)
87lm = lm_cls(domain=domain, pmap=False)
88outputs = models.predict_step(
89lm.optimizer.target,
90inputs,
91preprocess_fn=lm.preprocess,
92output_head=output_head)
93if multiple_heads:
94self.assertIsInstance(outputs, dict)
95self.assertLen(outputs, len(output_head))
96else:
97# We should have gotten a single output, the logits.
98self.assertEqual(outputs.shape,
99(inputs.shape[0], inputs.shape[1], lm.vocab_size))
100
101
102if __name__ == '__main__':
103tf.test.main()
104