google-research
94 строки · 2.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# Copyright 2019 Google LLC
17#
18# Licensed under the Apache License, Version 2.0 (the "License");
19# you may not use this file except in compliance with the License.
20# You may obtain a copy of the License at
21#
22# https://www.apache.org/licenses/LICENSE-2.0
23#
24# Unless required by applicable law or agreed to in writing, software
25# distributed under the License is distributed on an "AS IS" BASIS,
26# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27# See the License for the specific language governing permissions and
28# limitations under the License.
29"""Collaborative Filtering MovieLens dataset pre-processing."""
30
31from __future__ import absolute_import
32from __future__ import division
33from __future__ import print_function
34
35
36from absl import app
37from absl import flags
38
39import numpy as np
40import tensorflow.compat.v2 as tf
41from hyperbolic.utils.preprocess import process_dataset
42from hyperbolic.utils.preprocess import save_as_pickle
43
44
45FLAGS = flags.FLAGS
46flags.DEFINE_string(
47'dataset_path',
48default='data/ml-1m/ratings.dat',
49help='Path to raw dataset')
50flags.DEFINE_string(
51'save_dir_path',
52default='data/ml-1m/',
53help='Path to saving directory')
54
55
56def movielens_to_dict(dataset_file):
57"""Maps raw dataset file to a Dictonary.
58
59Args:
60dataset_file: Path to file containing interactions in a format
61uid::iid::rate::time.
62
63Returns:
64Dictionary containing users as keys, and a numpy array of items the user
65interacted with, sorted by the time of interaction.
66"""
67all_examples = {}
68with tf.gfile.Open(dataset_file, 'r') as lines:
69for line in lines:
70line = line.strip('\n').split('::')
71uid = int(line[0])-1
72iid = int(line[1])-1
73timestamp = int(line[3])
74if uid in all_examples:
75all_examples[uid].append((iid, timestamp))
76else:
77all_examples[uid] = [(iid, timestamp)]
78for uid in all_examples:
79sorted_items = sorted(all_examples[uid], key=lambda p: p[1])
80all_examples[uid] = np.array([pair[0] for pair in sorted_items
81]).astype('int64')
82return all_examples
83
84
85def main(_):
86dataset_path = FLAGS.dataset_path
87save_path = FLAGS.save_dir_path
88sorted_dict = movielens_to_dict(dataset_path)
89dataset_examples = process_dataset(sorted_dict)
90save_as_pickle(save_path, dataset_examples)
91
92
93if __name__ == '__main__':
94app.run(main)
95