google-research

Форк
0
/
process_movielens.py 
94 строки · 2.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Copyright 2019 Google LLC
17
#
18
# Licensed under the Apache License, Version 2.0 (the "License");
19
# you may not use this file except in compliance with the License.
20
# You may obtain a copy of the License at
21
#
22
#     https://www.apache.org/licenses/LICENSE-2.0
23
#
24
# Unless required by applicable law or agreed to in writing, software
25
# distributed under the License is distributed on an "AS IS" BASIS,
26
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27
# See the License for the specific language governing permissions and
28
# limitations under the License.
29
"""Collaborative Filtering MovieLens dataset pre-processing."""
30

31
from __future__ import absolute_import
32
from __future__ import division
33
from __future__ import print_function
34

35

36
from absl import app
37
from absl import flags
38

39
import numpy as np
40
import tensorflow.compat.v2 as tf
41
from hyperbolic.utils.preprocess import process_dataset
42
from hyperbolic.utils.preprocess import save_as_pickle
43

44

45
FLAGS = flags.FLAGS
46
flags.DEFINE_string(
47
    'dataset_path',
48
    default='data/ml-1m/ratings.dat',
49
    help='Path to raw dataset')
50
flags.DEFINE_string(
51
    'save_dir_path',
52
    default='data/ml-1m/',
53
    help='Path to saving directory')
54

55

56
def movielens_to_dict(dataset_file):
57
  """Maps raw dataset file to a Dictonary.
58

59
  Args:
60
    dataset_file: Path to file containing interactions in a format
61
      uid::iid::rate::time.
62

63
  Returns:
64
    Dictionary containing users as keys, and a numpy array of items the user
65
    interacted with, sorted by the time of interaction.
66
  """
67
  all_examples = {}
68
  with tf.gfile.Open(dataset_file, 'r') as lines:
69
    for line in lines:
70
      line = line.strip('\n').split('::')
71
      uid = int(line[0])-1
72
      iid = int(line[1])-1
73
      timestamp = int(line[3])
74
      if uid in all_examples:
75
        all_examples[uid].append((iid, timestamp))
76
      else:
77
        all_examples[uid] = [(iid, timestamp)]
78
  for uid in all_examples:
79
    sorted_items = sorted(all_examples[uid], key=lambda p: p[1])
80
    all_examples[uid] = np.array([pair[0] for pair in sorted_items
81
                                 ]).astype('int64')
82
  return all_examples
83

84

85
def main(_):
86
  dataset_path = FLAGS.dataset_path
87
  save_path = FLAGS.save_dir_path
88
  sorted_dict = movielens_to_dict(dataset_path)
89
  dataset_examples = process_dataset(sorted_dict)
90
  save_as_pickle(save_path, dataset_examples)
91

92

93
if __name__ == '__main__':
94
  app.run(main)
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.