google-research

Форк
0
303 строки · 10.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Convert sanitized data to tf dataset."""
17
import dataclasses
18
from typing import Optional
19
import numpy as np
20
import pandas as pd
21
import tensorflow as tf
22

23
Array = np.ndarray
24
OUTPUT_USER_KEY = "uid"
25
OUTPUT_ITEM_KEY = "sid"
26

27

28
@dataclasses.dataclass
29
class InputMatrix:
30
  """Represents a sparse input matrix. Used for batching of input data."""
31
  indices: np.ndarray
32
  values: np.ndarray
33
  num_rows: int
34
  num_cols: int
35
  # Weights of the examples in the loss function
36
  weights: Optional[np.ndarray]
37
  # Regularization weights (one per row) for frequency-regularization
38
  row_reg: Optional[np.ndarray]
39

40
  def to_sparse_tensor(self):
41
    return tf.SparseTensor(
42
        indices=self.indices,
43
        values=self.values,
44
        dense_shape=[self.num_rows, self.num_cols])
45

46
  def batch_ls(self, batch_size):
47
    """Batches the data for least squares solvers.
48

49
    Args:
50
      batch_size: the number of rows per batch.
51

52
    Returns:
53
      A dataset with these fields:
54
        shifted_indices: indices of the ratings, where the rows are shifted to
55
          start at 0.
56
        values: values of the ratings.
57
        weights: weights of the ratings.
58
        row_reg: the regularization weights, of shape [batch_size].
59
        update_indices: the original row indices, of shape [batch_size].
60
    """
61
    row_ids = self.indices[:, 0]
62
    def make_batch(start_row, end_row):
63
      mask = np.greater_equal(row_ids, start_row) & np.less(row_ids, end_row)
64
      indices = self.indices[mask]
65
      shifted_indices = indices - np.array([start_row, 0])
66
      update_indices = np.arange(start_row, end_row)
67
      batch = dict(
68
          shifted_indices=shifted_indices,
69
          update_indices=update_indices,
70
          values=self.values[mask],
71
          num_rows=end_row - start_row)
72
      if self.weights is not None:
73
        batch["weights"] = self.weights[mask]
74
      if self.row_reg is not None:
75
        batch["row_reg"] = self.row_reg[start_row:end_row]
76
      return batch
77
    if not batch_size:
78
      endpoints = [0, self.num_rows]
79
    else:
80
      endpoints = [i*batch_size for i in range(self.num_rows//batch_size + 1)]
81
      if self.num_rows % batch_size:
82
        endpoints.append(self.num_rows)
83
    intervals = list(zip(endpoints, endpoints[1:]))
84
    def gen():
85
      for start, end in intervals:
86
        yield make_batch(start, end)
87
    output_signature = dict(
88
        shifted_indices=tf.TensorSpec([None, 2], tf.int64),
89
        update_indices=tf.TensorSpec([None], tf.int64),
90
        values=tf.TensorSpec([None], tf.float32),
91
        num_rows=tf.TensorSpec([], tf.int64))
92
    if self.weights is not None:
93
      output_signature["weights"] = tf.TensorSpec([None], tf.float32)
94
    if self.row_reg is not None:
95
      output_signature["row_reg"] = tf.TensorSpec([None], tf.float32)
96
    return tf.data.Dataset.from_generator(
97
        gen, output_signature=output_signature)
98

99
  def batch_gd(
100
      self,
101
      user_axis,
102
      batch_size,
103
      random_seed = 1):
104
    """Batches the data for gradient descent solvers.
105

106
    Args:
107
      user_axis: axis of the user in the input_data. Should be set as 0 if the
108
        rows represent users, and 1 if the columns represent users.
109
      batch_size: the batch size of the dataset.
110
      random_seed: seed for random generator.
111

112
    Returns:
113
      A tf.Dataset of the form ({"uid": user_ids, "sid": item_ids}}, rating,
114
      weight).
115
    """
116
    if user_axis not in [0, 1]:
117
      raise ValueError("user_axis must be 0 or 1")
118
    uids = self.indices[:, user_axis]
119
    sids = self.indices[:, 1-user_axis]
120
    values = self.values
121
    weights = self.weights
122
    if weights is None:
123
      weights = np.ones(len(values), dtype=np.float32)
124
    data_size = len(values)
125
    if batch_size > data_size:
126
      raise ValueError(f"{batch_size=} cannot be larger than the size of the "
127
                       f"data ({data_size})")
128
    def generator():
129
      rng = np.random.default_rng(random_seed)
130
      while True:
131
        perm = np.tile(rng.permutation(data_size), 2)
132
        num_batches = data_size // batch_size
133
        if data_size % batch_size:
134
          num_batches += 1
135
        for i in range(num_batches):
136
          indices = perm[i*batch_size: (i+1)*batch_size]
137
          yield ({OUTPUT_USER_KEY: uids[indices],
138
                  OUTPUT_ITEM_KEY: sids[indices]},
139
                 values[indices],
140
                 weights[indices])
141
    return tf.data.Dataset.from_generator(
142
        generator,
143
        output_types=({OUTPUT_USER_KEY: tf.int64, OUTPUT_ITEM_KEY: tf.int64},
144
                      tf.float32, tf.float32),
145
        output_shapes=({OUTPUT_USER_KEY: (batch_size,),
146
                        OUTPUT_ITEM_KEY: (batch_size,)},
147
                       (batch_size,),
148
                       (batch_size,)))
149

150
  def batch_gd_by_user(
151
      self,
152
      user_axis,
153
      num_examples_per_user,
154
      num_users_per_batch,
155
      random_seed = 1):
156
    """Batches the data for gradient descent solvers, grouped by users.
157

158
    Suppose we have n users with data {D_1,..., D_n}. For each batch, we want to
159
    randomly sample `num_users_per_batch` users, and take randomly
160
    `num_examples_per_user` examples from each them.
161

162
    User-grouping is useful for user-level privacy.
163

164
    Args:
165
      user_axis: axis of the user in the input_data. Should be set as 0 if the
166
        rows represent users, and 1 if the columns represent users.
167
      num_examples_per_user: the number of examples taken from each user
168
        when form one batch.
169
      num_users_per_batch: the number of users in each batch.
170
      random_seed: seed for random generator.
171

172
    Returns:
173
      A tf.Dataset of the form ({"uid": user_ids, "sid": item_ids}}, rating,
174
      weight). The batch size is  `num_examples_per_user * num_users_per_batch`.
175

176
    Raises:
177
      ValueError if `num_users_per_batch` is larger than the number of users.
178
    """
179
    if user_axis not in [0, 1]:
180
      raise ValueError("user_axis must be 0 or 1")
181
    uids = self.indices[:, user_axis]
182
    sids = self.indices[:, 1-user_axis]
183
    values = self.values
184
    weights = self.weights
185
    if weights is None:
186
      weights = np.ones(len(values), dtype=np.float32)
187
    # Sort data
188
    indices = np.argsort(uids)
189
    uids = uids[indices]
190
    sids = sids[indices]
191
    values = values[indices]
192
    weights = weights[indices]
193
    # Compute sizes
194
    user_sizes = np.zeros(max(uids) + 1, dtype=np.int32)
195
    for uid in uids:
196
      user_sizes[uid] += 1
197
    offsets = np.concatenate([[0], np.cumsum(user_sizes)])
198
    (nonempty_users,) = np.where(user_sizes)
199
    nusers = len(nonempty_users)
200
    if num_users_per_batch > nusers:
201
      raise ValueError(
202
          f"num_users_per_batch ({num_users_per_batch}) should not be larger "
203
          f"than the number of users ({nusers}).")
204
    # Restrict to non-empty users, because for an empty users we cannot yield
205
    # any examples, so including them would change the batch size.
206
    def generator():
207
      """Combines data for a batch of users."""
208
      rng = np.random.default_rng(random_seed)
209
      def user_gen(uid, rng):
210
        """Yields num_examples_per_user sampled indices for a given user.
211

212
        If the user has fewer items than `num_examples_per_user`, then some
213
        items will be sampled more than once.
214

215
        Args:
216
          uid: the user id.
217
          rng: the random number generator.
218
        """
219
        user_size = user_sizes[uid]
220
        samples = (
221
            offsets[uid] +
222
            np.tile(np.arange(user_size), num_examples_per_user//user_size))
223
        remaining_to_sample = num_examples_per_user % user_size
224
        if remaining_to_sample:
225
          perm = offsets[uid] + np.tile(rng.permutation(user_size), 2)
226
          i = 0
227
          while True:
228
            yield np.concatenate([samples, perm[i:i+remaining_to_sample]])
229
            i = (i + remaining_to_sample) % user_size
230
        else:
231
          while True:
232
            yield samples
233
      user_gens = {uid: user_gen(uid, rng) for uid in nonempty_users}
234
      while True:
235
        rng.shuffle(nonempty_users)
236
        # Tile to avoid incomplete batch.
237
        # Note that we already check num_users_per_batch <= nusers.
238
        shuffled_uids = np.tile(nonempty_users, 2)
239
        i = 0
240
        while i < nusers:
241
          samples = np.concatenate([
242
              next(user_gens[uid])
243
              for uid in shuffled_uids[i : i + num_users_per_batch]
244
          ])
245
          i += num_users_per_batch
246
          yield ({OUTPUT_USER_KEY: uids[samples],
247
                  OUTPUT_ITEM_KEY: sids[samples]},
248
                 values[samples],
249
                 weights[samples])
250

251
    batch_size = num_examples_per_user * num_users_per_batch
252
    return tf.data.Dataset.from_generator(
253
        generator,
254
        output_types=({OUTPUT_USER_KEY: tf.int64, OUTPUT_ITEM_KEY: tf.int64},
255
                      tf.float32, tf.float32),
256
        output_shapes=({OUTPUT_USER_KEY: (batch_size,),
257
                        OUTPUT_ITEM_KEY: (batch_size,)},
258
                       (batch_size,),
259
                       (batch_size,)))
260

261

262
def df_to_input_matrix(
263
    df,
264
    num_rows,
265
    num_cols,
266
    row_key = "uid",
267
    col_key = "sid",
268
    value_key = None,
269
    sort = True):
270
  """Creates an InputMatrix from a pd.DataFrame.
271

272
  Args:
273
    df: the DataFrame. Must contain row_key and col_key.
274
    num_rows: the number of rows.
275
    num_cols: the number of columns.
276
    row_key: the dataframe key to use for the row ids.
277
    col_key: the dataframe key to use for the column ids.
278
    value_key: uses this field for the values. If None, fills with ones.
279
    sort: whether to sort the indices.
280

281
  Returns:
282
    An InputMatrix.
283
  """
284
  if sort:
285
    df = df.sort_values([row_key, col_key])
286
  # convert to int to handle empty DataFrames
287
  row_ids = df[row_key].values.astype(np.int64)
288
  col_ids = df[col_key].values.astype(np.int64)
289
  indices = np.stack([row_ids, col_ids], axis=1)
290
  if value_key:
291
    if value_key not in df:
292
      raise ValueError(f"key {value_key} is missing from the DataFrame")
293
    values = df[value_key].values
294
  else:
295
    values = np.ones(len(df))
296
  values = values.astype(np.float32)
297
  return InputMatrix(
298
      indices=indices,
299
      values=values,
300
      weights=None,
301
      row_reg=None,
302
      num_rows=num_rows,
303
      num_cols=num_cols)
304

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.