google-research
61 строка · 2.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for data_utils."""
17
18import numpy as np
19import pandas as pd
20import tensorflow.compat.v2 as tf
21
22from graph_compression.contrastive_learning.learning_latents import data_utils
23
24
25# create some fake datasets with the same structure as the real ones
26NUM_EXAMPLES = 5
27FAKE_DSPRITES_DF = pd.DataFrame(
28np.concatenate([
29np.random.randint(0, 2, (NUM_EXAMPLES, 7)),
30np.random.randn(NUM_EXAMPLES, 4)
31],
32axis=1),
33columns=data_utils.DSPRITES_SHAPE_NAMES + data_utils.DSPRITES_LABEL_NAMES +
34data_utils.DSPRITES_VALUE_NAMES)
35FAKE_THREEDIDENT_DF = pd.DataFrame(
36np.concatenate([
37np.arange(NUM_EXAMPLES).reshape(-1, 1),
38np.random.randn(NUM_EXAMPLES, 10)
39],
40axis=1),
41columns=['id'] + data_utils.THREEDIDENT_VALUE_NAMES)
42
43
44class DataTest(tf.test.TestCase):
45
46def test_dsprites_simple_noise_fn(self):
47df = FAKE_DSPRITES_DF
48result = data_utils.dsprites_simple_noise_fn(df.iloc[0], df)
49for latent_name in data_utils.DSPRITES_SHAPE_NAMES + data_utils.DSPRITES_LABEL_NAMES:
50self.assertIn(latent_name, result.keys())
51
52def test_threedident_simple_noise_fn(self):
53df = FAKE_THREEDIDENT_DF
54result = data_utils.threedident_simple_noise_fn(df.iloc[0], df)
55for latent_name in data_utils.THREEDIDENT_VALUE_NAMES:
56self.assertIn(latent_name, result.keys())
57
58
59if __name__ == '__main__':
60tf.compat.v1.enable_v2_behavior()
61tf.test.main()
62