google-research
69 строк · 2.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""preprocessing data helper functions."""
17
18
19import os20import pickle21import numpy as np22import tensorflow.compat.v2 as tf23
24
25def process_dataset(to_skip_dict, random=False):26"""Splits (user, item) dataset to train ,valid and test.27
28Args:
29to_skip_dict: Dict of sorted examples.
30random: Bool whether to extract valid, test by random. If false, valid, test
31are the last two items per user.
32
33Returns:
34examples: Dictionary mapping splits 'train','valid','test' to Numpy array
35containing corresponding CF pairs, and 'to_skip' to a Dictionary
36containing filters for each user.
37"""
38examples = {}39examples['to_skip'] = to_skip_dict40examples['train'] = []41examples['valid'] = []42examples['test'] = []43for uid in examples['to_skip']:44if random:45np.random.shuffle(examples['to_skip'][uid])46examples['test'].append([uid, examples['to_skip'][uid][-1]])47examples['valid'].append([48uid, examples['to_skip'][uid][-2]49])50for iid in examples['to_skip'][uid][0:-2]:51examples['train'].append([uid, iid])52for split in ['train', 'valid', 'test']:53examples[split] = np.array(examples[split]).astype('int64')54return examples55
56
57def save_as_pickle(dataset_path, examples_dict):58"""Saves data to train, valid, test and to_skip pickle files.59
60Args:
61dataset_path: String path to dataset directory.
62examples_dict: Dictionary mapping splits 'train','valid','test'
63to Numpy array containing corresponding CF pairs, and 'to_skip' to
64a Dictionary containing filters for each user .
65"""
66for dataset_split in ['train', 'valid', 'test', 'to_skip']:67save_path = os.path.join(dataset_path, dataset_split + '.pickle')68with tf.gfile.Open(save_path, 'wb') as save_file:69pickle.dump(examples_dict[dataset_split], save_file)70