google-research

Форк
0
61 строка · 2.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for data_utils."""
17

18
import numpy as np
19
import pandas as pd
20
import tensorflow.compat.v2 as tf
21

22
from graph_compression.contrastive_learning.learning_latents import data_utils
23

24

25
# create some fake datasets with the same structure as the real ones
26
NUM_EXAMPLES = 5
27
FAKE_DSPRITES_DF = pd.DataFrame(
28
    np.concatenate([
29
        np.random.randint(0, 2, (NUM_EXAMPLES, 7)),
30
        np.random.randn(NUM_EXAMPLES, 4)
31
    ],
32
                   axis=1),
33
    columns=data_utils.DSPRITES_SHAPE_NAMES + data_utils.DSPRITES_LABEL_NAMES +
34
    data_utils.DSPRITES_VALUE_NAMES)
35
FAKE_THREEDIDENT_DF = pd.DataFrame(
36
    np.concatenate([
37
        np.arange(NUM_EXAMPLES).reshape(-1, 1),
38
        np.random.randn(NUM_EXAMPLES, 10)
39
    ],
40
                   axis=1),
41
    columns=['id'] + data_utils.THREEDIDENT_VALUE_NAMES)
42

43

44
class DataTest(tf.test.TestCase):
45

46
  def test_dsprites_simple_noise_fn(self):
47
    df = FAKE_DSPRITES_DF
48
    result = data_utils.dsprites_simple_noise_fn(df.iloc[0], df)
49
    for latent_name in data_utils.DSPRITES_SHAPE_NAMES + data_utils.DSPRITES_LABEL_NAMES:
50
      self.assertIn(latent_name, result.keys())
51

52
  def test_threedident_simple_noise_fn(self):
53
    df = FAKE_THREEDIDENT_DF
54
    result = data_utils.threedident_simple_noise_fn(df.iloc[0], df)
55
    for latent_name in data_utils.THREEDIDENT_VALUE_NAMES:
56
      self.assertIn(latent_name, result.keys())
57

58

59
if __name__ == '__main__':
60
  tf.compat.v1.enable_v2_behavior()
61
  tf.test.main()
62

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.