google-research

Форк
0
121 строка · 4.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for dataset."""
17
import itertools
18
from absl.testing import parameterized
19
import numpy as np
20
import tensorflow as tf
21

22
from dp_alternating_minimization import dataset
23

24

25
def _construct_input_matrix(indices0, indices1, values):
26
  indices0 = np.array(list(indices0))
27
  indices1 = np.array(list(indices1))
28
  values = np.array(list(values))
29
  return dataset.InputMatrix(
30
      indices=np.array([[x, y] for x, y in zip(indices0, indices1)]),
31
      values=values,
32
      weights=None,
33
      row_reg=None,
34
      num_rows=indices0.max() + 1,
35
      num_cols=indices1.max() + 1)
36

37

38
class DatasetTest(parameterized.TestCase, tf.test.TestCase):
39

40
  @parameterized.parameters(
41
      (100, 20),
42
      (100, 1),
43
      (100, 100),
44
      (100, 27),
45
      )
46
  def test_batch_gd(self, n, batch_size):
47
    input_data = _construct_input_matrix(np.arange(n), np.arange(n) * 10,
48
                                         np.arange(n) / 10)
49
    d = input_data.batch_gd(user_axis=0, batch_size=batch_size)
50
    d = iter(d)
51
    for _ in range(n):
52
      ids, values, weights = next(d)
53
      self.assertLen(ids[dataset.OUTPUT_USER_KEY], batch_size)
54
      self.assertLen(ids[dataset.OUTPUT_ITEM_KEY], batch_size)
55
      self.assertLen(values, batch_size)
56
      self.assertLen(weights, batch_size)
57

58
  @parameterized.parameters(
59
      (3, 2),
60
      (2, 3),
61
      (1, 10),
62
      (10, 1),
63
      )
64
  def test_batch_gd_by_user(self, num_examples_per_user, num_users_per_batch):
65
    data = itertools.chain(*[[(i, i, i)] * (i + 1) for i in range(10)])
66
    input_matrix = _construct_input_matrix(*zip(*data))
67
    d = input_matrix.batch_gd_by_user(
68
        0, num_examples_per_user, num_users_per_batch)
69

70
    def _check_batch(batch):
71
      """Checks if one user's data is consecutive in the batch."""
72
      batch_size = num_examples_per_user * num_users_per_batch
73
      ids, values, weights = batch
74
      self.assertSetEqual(
75
          set(ids.keys()),
76
          set([dataset.OUTPUT_USER_KEY, dataset.OUTPUT_ITEM_KEY]))
77
      users = ids[dataset.OUTPUT_USER_KEY].numpy()
78
      items = ids[dataset.OUTPUT_ITEM_KEY].numpy()
79
      ratings = values.numpy()
80
      weights = weights.numpy()
81
      self.assertLen(users, batch_size)
82
      self.assertLen(items, batch_size)
83
      self.assertLen(ratings, batch_size)
84
      self.assertLen(weights, batch_size)
85
      # As we have set user = item = rating.
86
      self.assertAllClose(users, items)
87
      self.assertAllClose(users, ratings)
88
      self.assertAllClose(weights, np.ones(batch_size))
89
      user_ids = []
90
      for i in range(0, batch_size, num_examples_per_user):
91
        self.assertAllEqual(users[i:i+num_examples_per_user],
92
                            users[i] * np.ones(num_examples_per_user))
93
        user_ids.append(users[i])
94
      self.assertLen(set(user_ids), num_users_per_batch)
95

96
    d = iter(d)
97
    batch1 = next(d)
98
    batch2 = next(d)
99
    batch3 = next(d)
100
    batch4 = next(d)
101
    batch5 = next(d)
102

103
    _check_batch(batch1)
104
    _check_batch(batch2)
105
    _check_batch(batch3)
106
    _check_batch(batch4)
107
    _check_batch(batch5)
108

109
  def test_batch_gd_by_user_users_per_batch_large(self):
110
    data = itertools.chain(*[[(i, i, i)] * (i + 1) for i in range(10)])
111
    input_matrix = _construct_input_matrix(*zip(*data))
112
    self.assertRaises(ValueError,
113
                      input_matrix.batch_gd_by_user,
114
                      user_axis=0,
115
                      num_examples_per_user=100,
116
                      # num_users_per_batch exceeds total users
117
                      num_users_per_batch=11)
118

119

120
if __name__ == "__main__":
121
  tf.test.main()
122

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.