google-research
86 строк · 3.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for supcon.classification_head."""
17from absl.testing import parameterized
18
19import numpy as np
20import tensorflow.compat.v1 as tf
21
22from supcon import classification_head
23
24
25class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase):
26
27@parameterized.named_parameters(
28('rank_1', 1),
29('rank_4', 4),
30('rank_8', 8),
31)
32def testIncorrectRank(self, rank):
33inputs = tf.compat.v1.placeholder(tf.float32, shape=[10] * rank)
34with self.assertRaisesRegex(ValueError, 'is expected to have rank 2'):
35classifier = classification_head.ClassificationHead(num_classes=10)
36classifier(inputs)
37
38@parameterized.named_parameters(
39('float32', tf.float32),
40('float64', tf.float64),
41('float16', tf.float16),
42)
43def testConstructClassificationHead(self, dtype):
44batch_size = 3
45num_classes = 10
46input_shape = [batch_size, 4]
47expected_output_shape = [batch_size, num_classes]
48inputs = tf.random.uniform(input_shape, seed=1, dtype=dtype)
49classifier = classification_head.ClassificationHead(num_classes=num_classes)
50output = classifier(inputs)
51self.assertListEqual(expected_output_shape, output.shape.as_list())
52self.assertEqual(inputs.dtype, output.dtype)
53
54def testGradient(self):
55inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1)
56classifier = classification_head.ClassificationHead(num_classes=10)
57output = classifier(inputs)
58gradient = tf.gradients(output, inputs)
59self.assertIsNotNone(gradient)
60
61def testCreateVariables(self):
62inputs = tf.random.uniform((3, 4), dtype=tf.float64, seed=1)
63classifier = classification_head.ClassificationHead(num_classes=10)
64classifier(inputs)
65self.assertLen(
66[var for var in tf.trainable_variables() if 'kernel' in var.name], 1)
67self.assertLen(
68[var for var in tf.trainable_variables() if 'bias' in var.name], 1)
69
70def testInputOutput(self):
71batch_size = 3
72num_classes = 10
73expected_output_shape = (batch_size, num_classes)
74inputs = tf.random.uniform((batch_size, 4), dtype=tf.float64, seed=1)
75classifier = classification_head.ClassificationHead(num_classes=num_classes)
76output_tensor = classifier(inputs)
77with self.cached_session() as sess:
78sess.run(tf.compat.v1.global_variables_initializer())
79outputs = sess.run(output_tensor)
80# Make sure that there are no NaNs
81self.assertFalse(np.isnan(outputs).any())
82self.assertEqual(outputs.shape, expected_output_shape)
83
84
85if __name__ == '__main__':
86tf.test.main()
87