google-research
225 строк · 8.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generates data for the VAE benchmarks.
17
18Branched from https://github.com/dawenl/vae_cf with modification.
19The generated datasets and splits are identical with the datasets and splits
20from the paper Liang et al., "Variational Autoencoders for Collaborative
21Filtering", WWW '18.
22"""
23
24import argparse25import os26import sys27import urllib.request28import zipfile29import numpy as np30import pandas as pd31
32
33def get_count(tp, idx):34playcount_groupbyid = tp[[idx]].groupby(idx, as_index=True)35count = playcount_groupbyid.size()36return count37
38
39def filter_triplets(tp, min_uc, min_sc):40"""Filters a DataFrame.41
42Args:
43tp: a DataFrame of (movieId, userId, rating) triplets.
44min_uc: filter out users with fewer than min_uc ratings.
45min_sc: filter out items with fewer than min_sc ratings.
46Returns:
47A DataFrame tuple of the filtered data, the user counts and the item counts.
48"""
49# Only keep the triplets for items which were clicked on by at least min_sc50# users.51if min_sc > 0:52itemcount = get_count(tp, 'movieId')53tp = tp[tp['movieId'].isin(itemcount.index[itemcount >= min_sc])]54
55# Only keep the triplets for users who clicked on at least min_uc items56# After doing this, some of the items will have less than min_uc users, but57# should only be a small proportion58if min_uc > 0:59usercount = get_count(tp, 'userId')60tp = tp[tp['userId'].isin(usercount.index[usercount >= min_uc])]61
62# Update both usercount and itemcount after filtering63usercount, itemcount = get_count(tp, 'userId'), get_count(tp, 'movieId')64return tp, usercount, itemcount65
66
67def split_train_test_proportion(data, test_prop=0.2):68"""Splits a DataFrame into train and test sets.69
70Args:
71data: a DataFrame of (userId, itemId, rating).
72test_prop: the proportion of test ratings.
73Returns:
74Two DataFrames of the train and test sets. The data is grouped by user, then
75each user (with 5 ratings or more) is randomly split into train and test
76ratings.
77"""
78data_grouped_by_user = data.groupby('userId')79tr_list, te_list = list(), list()80
81np.random.seed(98765)82
83for i, (_, group) in enumerate(data_grouped_by_user):84n_items_u = len(group)85
86if n_items_u >= 5:87idx = np.zeros(n_items_u, dtype='bool')88idx[np.random.choice(89n_items_u, size=int(test_prop * n_items_u), replace=False)90.astype('int64')] = True91
92tr_list.append(group[np.logical_not(idx)])93te_list.append(group[idx])94else:95tr_list.append(group)96
97if i % 1000 == 0:98print('%d users sampled' % i)99sys.stdout.flush()100
101data_tr = pd.concat(tr_list)102data_te = pd.concat(te_list)103
104return data_tr, data_te105
106
107def generate_data(raw_data, output_dir, n_heldout_users, min_uc, min_sc):108"""Generates and writes train, validation and test data.109
110The raw_data is first split into train, validation and test by user. For the
111validation set, each user's ratings are randomly partitioned into two subsets
112following a (80, 20) split (see split_train_test_proportion), and written to
113validation_tr.csv and validation_te.csv. A similar split is applied to the
114test set.
115
116Args:
117raw_data: a DataFrame of (userId, movieId, rating).
118output_dir: path to the output directory.
119n_heldout_users: this many users are held out for each of the validation and
120test sets.
121min_uc: filter out users with fewer than min_uc ratings.
122min_sc: filter out items with fewer than min_sc ratings.
123"""
124raw_data, user_activity, item_popularity = filter_triplets(125raw_data, min_uc, min_sc)126sparsity = 1. * raw_data.shape[0] / (127user_activity.shape[0] * item_popularity.shape[0])128print('After filtering, there are %d watching events from %d users and %d '129'movies (sparsity: %.3f%%)' %130(raw_data.shape[0], user_activity.shape[0], item_popularity.shape[0],131sparsity * 100))132unique_uid = user_activity.index133np.random.seed(98765)134idx_perm = np.random.permutation(unique_uid.size)135unique_uid = unique_uid[idx_perm]136n_users = unique_uid.size137tr_users = unique_uid[:(n_users - n_heldout_users * 2)]138vd_users = unique_uid[(n_users - n_heldout_users * 2):139(n_users - n_heldout_users)]140te_users = unique_uid[(n_users - n_heldout_users):]141train_plays = raw_data.loc[raw_data['userId'].isin(tr_users)]142unique_sid = pd.unique(train_plays['movieId'])143show2id = dict((sid, i) for (i, sid) in enumerate(unique_sid))144profile2id = dict((pid, i) for (i, pid) in enumerate(unique_uid))145def numerize(tp):146uid = [profile2id[x] for x in tp['userId']]147sid = [show2id[x] for x in tp['movieId']]148return pd.DataFrame(data={'uid': uid, 'sid': sid}, columns=['uid', 'sid'])149
150pro_dir = output_dir151if not os.path.exists(pro_dir):152os.makedirs(pro_dir)153with open(os.path.join(pro_dir, 'unique_sid.txt'), 'w') as f:154for sid in unique_sid:155f.write('%s\n' % sid)156vad_plays = raw_data.loc[raw_data['userId'].isin(vd_users)]157vad_plays = vad_plays.loc[vad_plays['movieId'].isin(unique_sid)]158vad_plays_tr, vad_plays_te = split_train_test_proportion(vad_plays)159test_plays = raw_data.loc[raw_data['userId'].isin(te_users)]160test_plays = test_plays.loc[test_plays['movieId'].isin(unique_sid)]161test_plays_tr, test_plays_te = split_train_test_proportion(test_plays)162
163train_data = numerize(train_plays)164train_data.to_csv(os.path.join(pro_dir, 'train.csv'), index=False)165
166vad_data_tr = numerize(vad_plays_tr)167vad_data_tr.to_csv(os.path.join(pro_dir, 'validation_tr.csv'), index=False)168
169vad_data_te = numerize(vad_plays_te)170vad_data_te.to_csv(os.path.join(pro_dir, 'validation_te.csv'), index=False)171
172test_data_tr = numerize(test_plays_tr)173test_data_tr.to_csv(os.path.join(pro_dir, 'test_tr.csv'), index=False)174
175test_data_te = numerize(test_plays_te)176test_data_te.to_csv(os.path.join(pro_dir, 'test_te.csv'), index=False)177
178
179def main():180parser = argparse.ArgumentParser()181parser.add_argument('--output_dir', type=str, default='',182help='Path where to save the datasets.')183args = parser.parse_args()184
185# MovieLens 20M186ml20m_zip = os.path.join(args.output_dir, 'ml20m.zip')187ml20m_dir = os.path.join(args.output_dir, 'ml-20m/')188ml20m_file = os.path.join(args.output_dir, 'ml-20m/ratings.csv')189print('Downloading and extracting Movielens 20M data')190urllib.request.urlretrieve(191'http://files.grouplens.org/datasets/movielens/ml-20m.zip',192ml20m_zip)193with zipfile.ZipFile(ml20m_zip, 'r') as zipref:194zipref.extract('ml-20m/ratings.csv', args.output_dir)195os.remove(ml20m_zip)196raw_data = pd.read_csv(ml20m_file, header=0)197os.remove(ml20m_file)198# binarize the data (only keep ratings >= 4)199raw_data = raw_data[raw_data['rating'] > 3.5]200generate_data(201raw_data, output_dir=ml20m_dir, n_heldout_users=10000, min_uc=5, min_sc=0)202print('Done processing Movielens 20M.')203
204# Million Song Data205print('Downloading and extracting Million Song Data')206msd_zip = os.path.join(args.output_dir, 'msd.zip')207msd_dir = os.path.join(args.output_dir, 'msd/')208msd_file = os.path.join(args.output_dir, 'msd/train_triplets.txt')209urllib.request.urlretrieve(210'http://millionsongdataset.com/sites/default/files/challenge/train_triplets.txt.zip',211msd_zip)212with zipfile.ZipFile(msd_zip, 'r') as zipref:213zipref.extractall(msd_dir)214os.remove(msd_zip)215raw_data = pd.read_csv(216msd_file, sep='\t', header=None, names=['userId', 'movieId', 'count'])217os.remove(msd_file)218generate_data(219raw_data, output_dir=msd_dir, n_heldout_users=50000, min_uc=20,220min_sc=200)221print('Done processing Million Song Data.')222
223
224if __name__ == '__main__':225main()226