google-research

Форк
0
69 строк · 2.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""preprocessing data helper functions."""
17

18

19
import os
20
import pickle
21
import numpy as np
22
import tensorflow.compat.v2 as tf
23

24

25
def process_dataset(to_skip_dict, random=False):
26
  """Splits (user, item) dataset to train ,valid and test.
27

28
  Args:
29
    to_skip_dict: Dict of sorted examples.
30
    random: Bool whether to extract valid, test by random. If false, valid, test
31
    are the last two items per user.
32

33
  Returns:
34
    examples: Dictionary mapping splits 'train','valid','test' to Numpy array
35
      containing corresponding CF pairs, and 'to_skip' to a Dictionary
36
      containing filters for each user.
37
  """
38
  examples = {}
39
  examples['to_skip'] = to_skip_dict
40
  examples['train'] = []
41
  examples['valid'] = []
42
  examples['test'] = []
43
  for uid in examples['to_skip']:
44
    if random:
45
      np.random.shuffle(examples['to_skip'][uid])
46
    examples['test'].append([uid, examples['to_skip'][uid][-1]])
47
    examples['valid'].append([
48
        uid, examples['to_skip'][uid][-2]
49
    ])
50
    for iid in examples['to_skip'][uid][0:-2]:
51
      examples['train'].append([uid, iid])
52
  for split in ['train', 'valid', 'test']:
53
    examples[split] = np.array(examples[split]).astype('int64')
54
  return examples
55

56

57
def save_as_pickle(dataset_path, examples_dict):
58
  """Saves data to train, valid, test and to_skip pickle files.
59

60
  Args:
61
    dataset_path: String path to dataset directory.
62
    examples_dict: Dictionary mapping splits 'train','valid','test'
63
      to Numpy array containing corresponding CF pairs, and 'to_skip' to
64
      a Dictionary containing filters for each user .
65
  """
66
  for dataset_split in ['train', 'valid', 'test', 'to_skip']:
67
    save_path = os.path.join(dataset_path, dataset_split + '.pickle')
68
    with tf.gfile.Open(save_path, 'wb') as save_file:
69
      pickle.dump(examples_dict[dataset_split], save_file)
70

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.