google-research
121 строка · 4.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for dataset."""
17import itertools18from absl.testing import parameterized19import numpy as np20import tensorflow as tf21
22from dp_alternating_minimization import dataset23
24
25def _construct_input_matrix(indices0, indices1, values):26indices0 = np.array(list(indices0))27indices1 = np.array(list(indices1))28values = np.array(list(values))29return dataset.InputMatrix(30indices=np.array([[x, y] for x, y in zip(indices0, indices1)]),31values=values,32weights=None,33row_reg=None,34num_rows=indices0.max() + 1,35num_cols=indices1.max() + 1)36
37
38class DatasetTest(parameterized.TestCase, tf.test.TestCase):39
40@parameterized.parameters(41(100, 20),42(100, 1),43(100, 100),44(100, 27),45)46def test_batch_gd(self, n, batch_size):47input_data = _construct_input_matrix(np.arange(n), np.arange(n) * 10,48np.arange(n) / 10)49d = input_data.batch_gd(user_axis=0, batch_size=batch_size)50d = iter(d)51for _ in range(n):52ids, values, weights = next(d)53self.assertLen(ids[dataset.OUTPUT_USER_KEY], batch_size)54self.assertLen(ids[dataset.OUTPUT_ITEM_KEY], batch_size)55self.assertLen(values, batch_size)56self.assertLen(weights, batch_size)57
58@parameterized.parameters(59(3, 2),60(2, 3),61(1, 10),62(10, 1),63)64def test_batch_gd_by_user(self, num_examples_per_user, num_users_per_batch):65data = itertools.chain(*[[(i, i, i)] * (i + 1) for i in range(10)])66input_matrix = _construct_input_matrix(*zip(*data))67d = input_matrix.batch_gd_by_user(680, num_examples_per_user, num_users_per_batch)69
70def _check_batch(batch):71"""Checks if one user's data is consecutive in the batch."""72batch_size = num_examples_per_user * num_users_per_batch73ids, values, weights = batch74self.assertSetEqual(75set(ids.keys()),76set([dataset.OUTPUT_USER_KEY, dataset.OUTPUT_ITEM_KEY]))77users = ids[dataset.OUTPUT_USER_KEY].numpy()78items = ids[dataset.OUTPUT_ITEM_KEY].numpy()79ratings = values.numpy()80weights = weights.numpy()81self.assertLen(users, batch_size)82self.assertLen(items, batch_size)83self.assertLen(ratings, batch_size)84self.assertLen(weights, batch_size)85# As we have set user = item = rating.86self.assertAllClose(users, items)87self.assertAllClose(users, ratings)88self.assertAllClose(weights, np.ones(batch_size))89user_ids = []90for i in range(0, batch_size, num_examples_per_user):91self.assertAllEqual(users[i:i+num_examples_per_user],92users[i] * np.ones(num_examples_per_user))93user_ids.append(users[i])94self.assertLen(set(user_ids), num_users_per_batch)95
96d = iter(d)97batch1 = next(d)98batch2 = next(d)99batch3 = next(d)100batch4 = next(d)101batch5 = next(d)102
103_check_batch(batch1)104_check_batch(batch2)105_check_batch(batch3)106_check_batch(batch4)107_check_batch(batch5)108
109def test_batch_gd_by_user_users_per_batch_large(self):110data = itertools.chain(*[[(i, i, i)] * (i + 1) for i in range(10)])111input_matrix = _construct_input_matrix(*zip(*data))112self.assertRaises(ValueError,113input_matrix.batch_gd_by_user,114user_axis=0,115num_examples_per_user=100,116# num_users_per_batch exceeds total users117num_users_per_batch=11)118
119
120if __name__ == "__main__":121tf.test.main()122