google-research
312 строк · 10.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Interfaces for reading raw graph datasets."""
17
18import abc19import json20import os21from typing import Set22
23from absl import logging24import networkx as nx25import numpy as np26import pandas as pd27import scipy.sparse as sp28import sklearn.preprocessing29import tensorflow as tf30
31
32
33class Dataset(abc.ABC):34"""Abstract base class for datasets."""35
36senders: np.ndarray37receivers: np.ndarray38node_features: np.ndarray39node_labels: np.ndarray40train_nodes: np.ndarray41validation_nodes: np.ndarray42test_nodes: np.ndarray43
44def num_nodes(self):45"""Returns the number of nodes in the dataset."""46return len(self.node_labels)47
48def num_edges(self):49"""Returns the number of edges in the dataset."""50return len(self.senders)51
52
53class DummyDataset(Dataset):54"""A dummy dataset for testing."""55
56NUM_DUMMY_TRAINING_SAMPLES: int = 357NUM_DUMMY_VALIDATION_SAMPLES: int = 358NUM_DUMMY_TEST_SAMPLES: int = 359NUM_DUMMY_FEATURES: int = 560NUM_DUMMY_CLASSES: int = 361
62def __init__(self):63num_samples = (64DummyDataset.NUM_DUMMY_TRAINING_SAMPLES +65DummyDataset.NUM_DUMMY_VALIDATION_SAMPLES +66DummyDataset.NUM_DUMMY_TEST_SAMPLES)67self.senders = np.arange(num_samples)68self.receivers = np.roll(np.arange(num_samples), -1)69self.node_features = np.repeat(70np.arange(num_samples), DummyDataset.NUM_DUMMY_FEATURES)71self.node_features = self.node_features.reshape(72(num_samples, DummyDataset.NUM_DUMMY_FEATURES))73self.node_labels = np.zeros(num_samples)74self.train_nodes = np.arange(75DummyDataset.NUM_DUMMY_TRAINING_SAMPLES)76self.validation_nodes = np.arange(77DummyDataset.NUM_DUMMY_TRAINING_SAMPLES,78DummyDataset.NUM_DUMMY_TRAINING_SAMPLES +79DummyDataset.NUM_DUMMY_VALIDATION_SAMPLES)80self.test_nodes = np.arange(81DummyDataset.NUM_DUMMY_TRAINING_SAMPLES +82DummyDataset.NUM_DUMMY_VALIDATION_SAMPLES, num_samples)83
84
85class OGBTransductiveDataset(Dataset):86"""Reads Open Graph Benchmark (OGB) datasets."""87
88def __init__(self, dataset_name, dataset_path):89super(OGBTransductiveDataset, self).__init__()90self.name = dataset_name.replace('-disjoint', '').replace('-', '_')91base_path = os.path.join(dataset_path, self.name)92
93if self.name == 'ogbn_arxiv':94split_property = 'split/time/'95elif self.name == 'ogbn_mag':96split_property = 'split/time/paper/'97elif self.name == 'ogbn_products':98split_property = 'split/sales_ranking/'99elif self.name == 'ogbn_proteins':100split_property = 'split/species/'101else:102raise ValueError('Unsupported dataset.')103
104train_split_file = os.path.join(105base_path, split_property, 'train.csv.gz')106validation_split_file = os.path.join(107base_path, split_property, 'valid.csv.gz')108test_split_file = os.path.join(109base_path, split_property, 'test.csv.gz')110
111if self.name == 'ogbn_mag':112node_feature_file = os.path.join(base_path,113'raw/node-feat/paper/node-feat.csv.gz')114node_label_file = os.path.join(base_path,115'raw/node-label/paper/node-label.csv.gz')116else:117node_feature_file = os.path.join(base_path, 'raw/node-feat.csv.gz')118node_label_file = os.path.join(base_path, 'raw/node-label.csv.gz')119
120logging.info('Reading node features...')121self.node_features = pd.read_csv(122node_feature_file, header=None).values.astype(np.float32)123logging.info('Node features loaded.')124
125logging.info('Reading node labels...')126self.node_labels = pd.read_csv(127node_label_file, header=None).values.astype(np.int64).squeeze()128logging.info('Node labels loaded.')129
130if self.name == 'ogbn_mag':131edge_file = os.path.join(132base_path, 'raw/relations/paper___cites___paper/edge.csv.gz')133else:134edge_file = os.path.join(base_path, 'raw/edge.csv.gz')135
136logging.info('Reading edges...')137senders_receivers = pd.read_csv(138edge_file, header=None).values.T.astype(np.int64)139self.senders, self.receivers = senders_receivers140logging.info('Edges loaded.')141
142logging.info('Reading train, validation and test splits...')143self.train_nodes = pd.read_csv(144train_split_file, header=None).values.T.astype(np.int64).squeeze()145self.validation_nodes = pd.read_csv(146validation_split_file, header=None).values.T.astype(np.int64).squeeze()147self.test_nodes = pd.read_csv(148test_split_file, header=None).values.T.astype(np.int64).squeeze()149logging.info('Loaded train, test and validation splits.')150
151
152class OGBDisjointDataset(OGBTransductiveDataset):153"""A disjoint version of a OGB dataset, with no inter-split edges."""154
155def __init__(self, dataset_name, dataset_path):156super(OGBDisjointDataset, self).__init__(dataset_name, dataset_path)157self.name = dataset_name158
159train_split = set(self.train_nodes.flat)160validation_split = set(self.validation_nodes.flat)161test_split = set(self.test_nodes.flat)162splits = [train_split, validation_split, test_split]163
164def _compute_split_index(elem):165elem_index = None166for index, split in enumerate(splits):167if elem in split:168if elem_index is not None:169raise ValueError(f'Node {elem} present in multiple splits.')170elem_index = index171if elem_index is None:172raise ValueError(f'Node {elem} present in none of the splits.')173return elem_index174
175senders_split_indices = np.vectorize(_compute_split_index)(self.senders)176receivers_split_indices = np.vectorize(_compute_split_index)(self.receivers)177in_same_split = (senders_split_indices == receivers_split_indices)178
179self.senders = self.senders[in_same_split]180self.receivers = self.receivers[in_same_split]181
182
183
184
185class GraphSAINTTransductiveDataset(Dataset):186"""Reads a GraphSAINT-format transductive dataset."""187
188def __init__(self, dataset_name, dataset_path):189super(GraphSAINTTransductiveDataset, self).__init__()190
191self.name = dataset_name192base_name = dataset_name.replace('-disjoint', '')193base_name = base_name.replace('-transductive', '')194
195
196self.base_name = base_name197base_path = os.path.join(dataset_path, base_name)198
199logging.info('Reading graph data...')200self.adj_full = sp.load_npz(201tf.io.gfile.GFile(os.path.join(base_path, 'adj_full.npz'), 'rb'))202graph = nx.from_scipy_sparse_matrix(self.adj_full)203graph_data = nx.readwrite.node_link_data(graph)204logging.info('Graph data loaded.')205
206self.senders = [e[0] for e in graph.edges]207self.receivers = [e[1] for e in graph.edges]208
209train_nodes = []210validation_nodes = []211test_nodes = []212
213splits = json.load(214tf.io.gfile.GFile(os.path.join(base_path, 'role.json'), 'r'))215train_split = set(splits['tr'])216validation_split = set(splits['va'])217test_split = set(splits['te'])218
219for node in graph_data['nodes']:220node_id = node['id']221if node_id in validation_split:222validation_nodes.append(node_id)223elif node_id in test_split:224test_nodes.append(node_id)225elif node_id in train_split:226train_nodes.append(node_id)227else:228raise ValueError(f'Node {node_id} not present in any split.')229
230self.train_nodes = np.asarray(train_nodes)231self.validation_nodes = np.asarray(validation_nodes)232self.test_nodes = np.asarray(test_nodes)233
234logging.info('Reading node features...')235node_features = np.load(236tf.io.gfile.GFile(os.path.join(base_path, 'feats.npy'), 'rb'))237logging.info('Node features loaded.')238
239logging.info('Preprocessing node features...')240train_node_features = node_features[self.train_nodes]241scaler = sklearn.preprocessing.StandardScaler()242scaler.fit(train_node_features)243self.node_features = scaler.transform(node_features)244logging.info('Node features preprocessed.')245
246logging.info('Reading node labels...')247class_map = json.load(248tf.io.gfile.GFile(os.path.join(base_path, 'class_map.json'), 'r'))249labels = [class_map[node_id] for node_id in sorted(class_map)]250self.node_labels = np.asarray(labels).squeeze()251logging.info('Node labels loaded.')252
253
254class GraphSAINTDisjointDataset(GraphSAINTTransductiveDataset):255"""Reads a GraphSAINT-format disjoint dataset."""256
257def __init__(self, dataset_name, dataset_path):258super(GraphSAINTDisjointDataset, self).__init__(dataset_name, dataset_path)259
260self.name = dataset_name261
262train_split = set(self.train_nodes)263validation_split = set(self.validation_nodes)264test_split = set(self.test_nodes)265
266graph_train = _get_graph_for_split(self.adj_full, train_split)267graph_validation = _get_graph_for_split(self.adj_full, validation_split)268graph_test = _get_graph_for_split(self.adj_full, test_split)269graph = nx.union_all((graph_train, graph_validation, graph_test))270
271self.senders = [e[0] for e in graph.edges]272self.receivers = [e[1] for e in graph.edges]273
274
275def _get_graph_for_split(adj_full,276split_set):277"""Returns the induced subgraph for the required split."""278def edge_generator():279senders, receivers = adj_full.nonzero()280for sender, receiver in zip(senders, receivers):281if sender in split_set and receiver in split_set:282yield sender, receiver283
284graph_split = nx.Graph()285graph_split.add_nodes_from(split_set)286graph_split.add_edges_from(edge_generator())287return graph_split288
289
290def get_dataset(dataset_name, dataset_path):291"""Returns a graph dataset."""292special_dataset_fns = {293'dummy': DummyDataset,294}295if dataset_name in special_dataset_fns:296return special_dataset_fns[dataset_name]()297
298if dataset_name.startswith('ogb'):299if dataset_name.endswith('disjoint'):300return OGBDisjointDataset(dataset_name, dataset_path)301return OGBTransductiveDataset(dataset_name, dataset_path)302
303graphsaint_datasets = ['reddit', 'yelp', 'flickr']304if any(dataset_name.startswith(name) for name in graphsaint_datasets):305if dataset_name.endswith('disjoint'):306return GraphSAINTDisjointDataset(dataset_name, dataset_path)307if dataset_name.endswith('transductive'):308return GraphSAINTTransductiveDataset(dataset_name, dataset_path)309raise ValueError(310'Please prefix dataset_name with `transductive` or `disjoint`.')311
312raise ValueError(f'Unsupported dataset: {dataset_name}.')313