google-research
303 строки · 10.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Convert sanitized data to tf dataset."""
17import dataclasses
18from typing import Optional
19import numpy as np
20import pandas as pd
21import tensorflow as tf
22
23Array = np.ndarray
24OUTPUT_USER_KEY = "uid"
25OUTPUT_ITEM_KEY = "sid"
26
27
28@dataclasses.dataclass
29class InputMatrix:
30"""Represents a sparse input matrix. Used for batching of input data."""
31indices: np.ndarray
32values: np.ndarray
33num_rows: int
34num_cols: int
35# Weights of the examples in the loss function
36weights: Optional[np.ndarray]
37# Regularization weights (one per row) for frequency-regularization
38row_reg: Optional[np.ndarray]
39
40def to_sparse_tensor(self):
41return tf.SparseTensor(
42indices=self.indices,
43values=self.values,
44dense_shape=[self.num_rows, self.num_cols])
45
46def batch_ls(self, batch_size):
47"""Batches the data for least squares solvers.
48
49Args:
50batch_size: the number of rows per batch.
51
52Returns:
53A dataset with these fields:
54shifted_indices: indices of the ratings, where the rows are shifted to
55start at 0.
56values: values of the ratings.
57weights: weights of the ratings.
58row_reg: the regularization weights, of shape [batch_size].
59update_indices: the original row indices, of shape [batch_size].
60"""
61row_ids = self.indices[:, 0]
62def make_batch(start_row, end_row):
63mask = np.greater_equal(row_ids, start_row) & np.less(row_ids, end_row)
64indices = self.indices[mask]
65shifted_indices = indices - np.array([start_row, 0])
66update_indices = np.arange(start_row, end_row)
67batch = dict(
68shifted_indices=shifted_indices,
69update_indices=update_indices,
70values=self.values[mask],
71num_rows=end_row - start_row)
72if self.weights is not None:
73batch["weights"] = self.weights[mask]
74if self.row_reg is not None:
75batch["row_reg"] = self.row_reg[start_row:end_row]
76return batch
77if not batch_size:
78endpoints = [0, self.num_rows]
79else:
80endpoints = [i*batch_size for i in range(self.num_rows//batch_size + 1)]
81if self.num_rows % batch_size:
82endpoints.append(self.num_rows)
83intervals = list(zip(endpoints, endpoints[1:]))
84def gen():
85for start, end in intervals:
86yield make_batch(start, end)
87output_signature = dict(
88shifted_indices=tf.TensorSpec([None, 2], tf.int64),
89update_indices=tf.TensorSpec([None], tf.int64),
90values=tf.TensorSpec([None], tf.float32),
91num_rows=tf.TensorSpec([], tf.int64))
92if self.weights is not None:
93output_signature["weights"] = tf.TensorSpec([None], tf.float32)
94if self.row_reg is not None:
95output_signature["row_reg"] = tf.TensorSpec([None], tf.float32)
96return tf.data.Dataset.from_generator(
97gen, output_signature=output_signature)
98
99def batch_gd(
100self,
101user_axis,
102batch_size,
103random_seed = 1):
104"""Batches the data for gradient descent solvers.
105
106Args:
107user_axis: axis of the user in the input_data. Should be set as 0 if the
108rows represent users, and 1 if the columns represent users.
109batch_size: the batch size of the dataset.
110random_seed: seed for random generator.
111
112Returns:
113A tf.Dataset of the form ({"uid": user_ids, "sid": item_ids}}, rating,
114weight).
115"""
116if user_axis not in [0, 1]:
117raise ValueError("user_axis must be 0 or 1")
118uids = self.indices[:, user_axis]
119sids = self.indices[:, 1-user_axis]
120values = self.values
121weights = self.weights
122if weights is None:
123weights = np.ones(len(values), dtype=np.float32)
124data_size = len(values)
125if batch_size > data_size:
126raise ValueError(f"{batch_size=} cannot be larger than the size of the "
127f"data ({data_size})")
128def generator():
129rng = np.random.default_rng(random_seed)
130while True:
131perm = np.tile(rng.permutation(data_size), 2)
132num_batches = data_size // batch_size
133if data_size % batch_size:
134num_batches += 1
135for i in range(num_batches):
136indices = perm[i*batch_size: (i+1)*batch_size]
137yield ({OUTPUT_USER_KEY: uids[indices],
138OUTPUT_ITEM_KEY: sids[indices]},
139values[indices],
140weights[indices])
141return tf.data.Dataset.from_generator(
142generator,
143output_types=({OUTPUT_USER_KEY: tf.int64, OUTPUT_ITEM_KEY: tf.int64},
144tf.float32, tf.float32),
145output_shapes=({OUTPUT_USER_KEY: (batch_size,),
146OUTPUT_ITEM_KEY: (batch_size,)},
147(batch_size,),
148(batch_size,)))
149
150def batch_gd_by_user(
151self,
152user_axis,
153num_examples_per_user,
154num_users_per_batch,
155random_seed = 1):
156"""Batches the data for gradient descent solvers, grouped by users.
157
158Suppose we have n users with data {D_1,..., D_n}. For each batch, we want to
159randomly sample `num_users_per_batch` users, and take randomly
160`num_examples_per_user` examples from each them.
161
162User-grouping is useful for user-level privacy.
163
164Args:
165user_axis: axis of the user in the input_data. Should be set as 0 if the
166rows represent users, and 1 if the columns represent users.
167num_examples_per_user: the number of examples taken from each user
168when form one batch.
169num_users_per_batch: the number of users in each batch.
170random_seed: seed for random generator.
171
172Returns:
173A tf.Dataset of the form ({"uid": user_ids, "sid": item_ids}}, rating,
174weight). The batch size is `num_examples_per_user * num_users_per_batch`.
175
176Raises:
177ValueError if `num_users_per_batch` is larger than the number of users.
178"""
179if user_axis not in [0, 1]:
180raise ValueError("user_axis must be 0 or 1")
181uids = self.indices[:, user_axis]
182sids = self.indices[:, 1-user_axis]
183values = self.values
184weights = self.weights
185if weights is None:
186weights = np.ones(len(values), dtype=np.float32)
187# Sort data
188indices = np.argsort(uids)
189uids = uids[indices]
190sids = sids[indices]
191values = values[indices]
192weights = weights[indices]
193# Compute sizes
194user_sizes = np.zeros(max(uids) + 1, dtype=np.int32)
195for uid in uids:
196user_sizes[uid] += 1
197offsets = np.concatenate([[0], np.cumsum(user_sizes)])
198(nonempty_users,) = np.where(user_sizes)
199nusers = len(nonempty_users)
200if num_users_per_batch > nusers:
201raise ValueError(
202f"num_users_per_batch ({num_users_per_batch}) should not be larger "
203f"than the number of users ({nusers}).")
204# Restrict to non-empty users, because for an empty users we cannot yield
205# any examples, so including them would change the batch size.
206def generator():
207"""Combines data for a batch of users."""
208rng = np.random.default_rng(random_seed)
209def user_gen(uid, rng):
210"""Yields num_examples_per_user sampled indices for a given user.
211
212If the user has fewer items than `num_examples_per_user`, then some
213items will be sampled more than once.
214
215Args:
216uid: the user id.
217rng: the random number generator.
218"""
219user_size = user_sizes[uid]
220samples = (
221offsets[uid] +
222np.tile(np.arange(user_size), num_examples_per_user//user_size))
223remaining_to_sample = num_examples_per_user % user_size
224if remaining_to_sample:
225perm = offsets[uid] + np.tile(rng.permutation(user_size), 2)
226i = 0
227while True:
228yield np.concatenate([samples, perm[i:i+remaining_to_sample]])
229i = (i + remaining_to_sample) % user_size
230else:
231while True:
232yield samples
233user_gens = {uid: user_gen(uid, rng) for uid in nonempty_users}
234while True:
235rng.shuffle(nonempty_users)
236# Tile to avoid incomplete batch.
237# Note that we already check num_users_per_batch <= nusers.
238shuffled_uids = np.tile(nonempty_users, 2)
239i = 0
240while i < nusers:
241samples = np.concatenate([
242next(user_gens[uid])
243for uid in shuffled_uids[i : i + num_users_per_batch]
244])
245i += num_users_per_batch
246yield ({OUTPUT_USER_KEY: uids[samples],
247OUTPUT_ITEM_KEY: sids[samples]},
248values[samples],
249weights[samples])
250
251batch_size = num_examples_per_user * num_users_per_batch
252return tf.data.Dataset.from_generator(
253generator,
254output_types=({OUTPUT_USER_KEY: tf.int64, OUTPUT_ITEM_KEY: tf.int64},
255tf.float32, tf.float32),
256output_shapes=({OUTPUT_USER_KEY: (batch_size,),
257OUTPUT_ITEM_KEY: (batch_size,)},
258(batch_size,),
259(batch_size,)))
260
261
262def df_to_input_matrix(
263df,
264num_rows,
265num_cols,
266row_key = "uid",
267col_key = "sid",
268value_key = None,
269sort = True):
270"""Creates an InputMatrix from a pd.DataFrame.
271
272Args:
273df: the DataFrame. Must contain row_key and col_key.
274num_rows: the number of rows.
275num_cols: the number of columns.
276row_key: the dataframe key to use for the row ids.
277col_key: the dataframe key to use for the column ids.
278value_key: uses this field for the values. If None, fills with ones.
279sort: whether to sort the indices.
280
281Returns:
282An InputMatrix.
283"""
284if sort:
285df = df.sort_values([row_key, col_key])
286# convert to int to handle empty DataFrames
287row_ids = df[row_key].values.astype(np.int64)
288col_ids = df[col_key].values.astype(np.int64)
289indices = np.stack([row_ids, col_ids], axis=1)
290if value_key:
291if value_key not in df:
292raise ValueError(f"key {value_key} is missing from the DataFrame")
293values = df[value_key].values
294else:
295values = np.ones(len(df))
296values = values.astype(np.float32)
297return InputMatrix(
298indices=indices,
299values=values,
300weights=None,
301row_reg=None,
302num_rows=num_rows,
303num_cols=num_cols)
304