google-research

Форк
0
225 строк · 8.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generates data for the VAE benchmarks.
17

18
Branched from https://github.com/dawenl/vae_cf with modification.
19
The generated datasets and splits are identical with the datasets and splits
20
from the paper Liang et al., "Variational Autoencoders for Collaborative
21
Filtering", WWW '18.
22
"""
23

24
import argparse
25
import os
26
import sys
27
import urllib.request
28
import zipfile
29
import numpy as np
30
import pandas as pd
31

32

33
def get_count(tp, idx):
34
  playcount_groupbyid = tp[[idx]].groupby(idx, as_index=True)
35
  count = playcount_groupbyid.size()
36
  return count
37

38

39
def filter_triplets(tp, min_uc, min_sc):
40
  """Filters a DataFrame.
41

42
  Args:
43
    tp: a DataFrame of (movieId, userId, rating) triplets.
44
    min_uc: filter out users with fewer than min_uc ratings.
45
    min_sc: filter out items with fewer than min_sc ratings.
46
  Returns:
47
    A DataFrame tuple of the filtered data, the user counts and the item counts.
48
  """
49
  # Only keep the triplets for items which were clicked on by at least min_sc
50
  # users.
51
  if min_sc > 0:
52
    itemcount = get_count(tp, 'movieId')
53
    tp = tp[tp['movieId'].isin(itemcount.index[itemcount >= min_sc])]
54

55
  # Only keep the triplets for users who clicked on at least min_uc items
56
  # After doing this, some of the items will have less than min_uc users, but
57
  # should only be a small proportion
58
  if min_uc > 0:
59
    usercount = get_count(tp, 'userId')
60
    tp = tp[tp['userId'].isin(usercount.index[usercount >= min_uc])]
61

62
  # Update both usercount and itemcount after filtering
63
  usercount, itemcount = get_count(tp, 'userId'), get_count(tp, 'movieId')
64
  return tp, usercount, itemcount
65

66

67
def split_train_test_proportion(data, test_prop=0.2):
68
  """Splits a DataFrame into train and test sets.
69

70
  Args:
71
    data: a DataFrame of (userId, itemId, rating).
72
    test_prop: the proportion of test ratings.
73
  Returns:
74
    Two DataFrames of the train and test sets. The data is grouped by user, then
75
    each user (with 5 ratings or more) is randomly split into train and test
76
    ratings.
77
  """
78
  data_grouped_by_user = data.groupby('userId')
79
  tr_list, te_list = list(), list()
80

81
  np.random.seed(98765)
82

83
  for i, (_, group) in enumerate(data_grouped_by_user):
84
    n_items_u = len(group)
85

86
    if n_items_u >= 5:
87
      idx = np.zeros(n_items_u, dtype='bool')
88
      idx[np.random.choice(
89
          n_items_u, size=int(test_prop * n_items_u), replace=False)
90
          .astype('int64')] = True
91

92
      tr_list.append(group[np.logical_not(idx)])
93
      te_list.append(group[idx])
94
    else:
95
      tr_list.append(group)
96

97
    if i % 1000 == 0:
98
      print('%d users sampled' % i)
99
      sys.stdout.flush()
100

101
  data_tr = pd.concat(tr_list)
102
  data_te = pd.concat(te_list)
103

104
  return data_tr, data_te
105

106

107
def generate_data(raw_data, output_dir, n_heldout_users, min_uc, min_sc):
108
  """Generates and writes train, validation and test data.
109

110
  The raw_data is first split into train, validation and test by user. For the
111
  validation set, each user's ratings are randomly partitioned into two subsets
112
  following a (80, 20) split (see split_train_test_proportion), and written to
113
  validation_tr.csv and validation_te.csv. A similar split is applied to the
114
  test set.
115

116
  Args:
117
    raw_data: a DataFrame of (userId, movieId, rating).
118
    output_dir: path to the output directory.
119
    n_heldout_users: this many users are held out for each of the validation and
120
      test sets.
121
    min_uc: filter out users with fewer than min_uc ratings.
122
    min_sc: filter out items with fewer than min_sc ratings.
123
  """
124
  raw_data, user_activity, item_popularity = filter_triplets(
125
      raw_data, min_uc, min_sc)
126
  sparsity = 1. * raw_data.shape[0] / (
127
      user_activity.shape[0] * item_popularity.shape[0])
128
  print('After filtering, there are %d watching events from %d users and %d '
129
        'movies (sparsity: %.3f%%)' %
130
        (raw_data.shape[0], user_activity.shape[0], item_popularity.shape[0],
131
         sparsity * 100))
132
  unique_uid = user_activity.index
133
  np.random.seed(98765)
134
  idx_perm = np.random.permutation(unique_uid.size)
135
  unique_uid = unique_uid[idx_perm]
136
  n_users = unique_uid.size
137
  tr_users = unique_uid[:(n_users - n_heldout_users * 2)]
138
  vd_users = unique_uid[(n_users - n_heldout_users * 2):
139
                        (n_users - n_heldout_users)]
140
  te_users = unique_uid[(n_users - n_heldout_users):]
141
  train_plays = raw_data.loc[raw_data['userId'].isin(tr_users)]
142
  unique_sid = pd.unique(train_plays['movieId'])
143
  show2id = dict((sid, i) for (i, sid) in enumerate(unique_sid))
144
  profile2id = dict((pid, i) for (i, pid) in enumerate(unique_uid))
145
  def numerize(tp):
146
    uid = [profile2id[x] for x in tp['userId']]
147
    sid = [show2id[x] for x in tp['movieId']]
148
    return pd.DataFrame(data={'uid': uid, 'sid': sid}, columns=['uid', 'sid'])
149

150
  pro_dir = output_dir
151
  if not os.path.exists(pro_dir):
152
    os.makedirs(pro_dir)
153
  with open(os.path.join(pro_dir, 'unique_sid.txt'), 'w') as f:
154
    for sid in unique_sid:
155
      f.write('%s\n' % sid)
156
  vad_plays = raw_data.loc[raw_data['userId'].isin(vd_users)]
157
  vad_plays = vad_plays.loc[vad_plays['movieId'].isin(unique_sid)]
158
  vad_plays_tr, vad_plays_te = split_train_test_proportion(vad_plays)
159
  test_plays = raw_data.loc[raw_data['userId'].isin(te_users)]
160
  test_plays = test_plays.loc[test_plays['movieId'].isin(unique_sid)]
161
  test_plays_tr, test_plays_te = split_train_test_proportion(test_plays)
162

163
  train_data = numerize(train_plays)
164
  train_data.to_csv(os.path.join(pro_dir, 'train.csv'), index=False)
165

166
  vad_data_tr = numerize(vad_plays_tr)
167
  vad_data_tr.to_csv(os.path.join(pro_dir, 'validation_tr.csv'), index=False)
168

169
  vad_data_te = numerize(vad_plays_te)
170
  vad_data_te.to_csv(os.path.join(pro_dir, 'validation_te.csv'), index=False)
171

172
  test_data_tr = numerize(test_plays_tr)
173
  test_data_tr.to_csv(os.path.join(pro_dir, 'test_tr.csv'), index=False)
174

175
  test_data_te = numerize(test_plays_te)
176
  test_data_te.to_csv(os.path.join(pro_dir, 'test_te.csv'), index=False)
177

178

179
def main():
180
  parser = argparse.ArgumentParser()
181
  parser.add_argument('--output_dir', type=str, default='',
182
                      help='Path where to save the datasets.')
183
  args = parser.parse_args()
184

185
  # MovieLens 20M
186
  ml20m_zip = os.path.join(args.output_dir, 'ml20m.zip')
187
  ml20m_dir = os.path.join(args.output_dir, 'ml-20m/')
188
  ml20m_file = os.path.join(args.output_dir, 'ml-20m/ratings.csv')
189
  print('Downloading and extracting Movielens 20M data')
190
  urllib.request.urlretrieve(
191
      'http://files.grouplens.org/datasets/movielens/ml-20m.zip',
192
      ml20m_zip)
193
  with zipfile.ZipFile(ml20m_zip, 'r') as zipref:
194
    zipref.extract('ml-20m/ratings.csv', args.output_dir)
195
  os.remove(ml20m_zip)
196
  raw_data = pd.read_csv(ml20m_file, header=0)
197
  os.remove(ml20m_file)
198
  # binarize the data (only keep ratings >= 4)
199
  raw_data = raw_data[raw_data['rating'] > 3.5]
200
  generate_data(
201
      raw_data, output_dir=ml20m_dir, n_heldout_users=10000, min_uc=5, min_sc=0)
202
  print('Done processing Movielens 20M.')
203

204
  # Million Song Data
205
  print('Downloading and extracting Million Song Data')
206
  msd_zip = os.path.join(args.output_dir, 'msd.zip')
207
  msd_dir = os.path.join(args.output_dir, 'msd/')
208
  msd_file = os.path.join(args.output_dir, 'msd/train_triplets.txt')
209
  urllib.request.urlretrieve(
210
      'http://millionsongdataset.com/sites/default/files/challenge/train_triplets.txt.zip',
211
      msd_zip)
212
  with zipfile.ZipFile(msd_zip, 'r') as zipref:
213
    zipref.extractall(msd_dir)
214
  os.remove(msd_zip)
215
  raw_data = pd.read_csv(
216
      msd_file, sep='\t', header=None, names=['userId', 'movieId', 'count'])
217
  os.remove(msd_file)
218
  generate_data(
219
      raw_data, output_dir=msd_dir, n_heldout_users=50000, min_uc=20,
220
      min_sc=200)
221
  print('Done processing Million Song Data.')
222

223

224
if __name__ == '__main__':
225
  main()
226

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.