google-research

Форк
0
312 строк · 10.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Interfaces for reading raw graph datasets."""
17

18
import abc
19
import json
20
import os
21
from typing import Set
22

23
from absl import logging
24
import networkx as nx
25
import numpy as np
26
import pandas as pd
27
import scipy.sparse as sp
28
import sklearn.preprocessing
29
import tensorflow as tf
30

31

32

33
class Dataset(abc.ABC):
34
  """Abstract base class for datasets."""
35

36
  senders: np.ndarray
37
  receivers: np.ndarray
38
  node_features: np.ndarray
39
  node_labels: np.ndarray
40
  train_nodes: np.ndarray
41
  validation_nodes: np.ndarray
42
  test_nodes: np.ndarray
43

44
  def num_nodes(self):
45
    """Returns the number of nodes in the dataset."""
46
    return len(self.node_labels)
47

48
  def num_edges(self):
49
    """Returns the number of edges in the dataset."""
50
    return len(self.senders)
51

52

53
class DummyDataset(Dataset):
54
  """A dummy dataset for testing."""
55

56
  NUM_DUMMY_TRAINING_SAMPLES: int = 3
57
  NUM_DUMMY_VALIDATION_SAMPLES: int = 3
58
  NUM_DUMMY_TEST_SAMPLES: int = 3
59
  NUM_DUMMY_FEATURES: int = 5
60
  NUM_DUMMY_CLASSES: int = 3
61

62
  def __init__(self):
63
    num_samples = (
64
        DummyDataset.NUM_DUMMY_TRAINING_SAMPLES +
65
        DummyDataset.NUM_DUMMY_VALIDATION_SAMPLES +
66
        DummyDataset.NUM_DUMMY_TEST_SAMPLES)
67
    self.senders = np.arange(num_samples)
68
    self.receivers = np.roll(np.arange(num_samples), -1)
69
    self.node_features = np.repeat(
70
        np.arange(num_samples), DummyDataset.NUM_DUMMY_FEATURES)
71
    self.node_features = self.node_features.reshape(
72
        (num_samples, DummyDataset.NUM_DUMMY_FEATURES))
73
    self.node_labels = np.zeros(num_samples)
74
    self.train_nodes = np.arange(
75
        DummyDataset.NUM_DUMMY_TRAINING_SAMPLES)
76
    self.validation_nodes = np.arange(
77
        DummyDataset.NUM_DUMMY_TRAINING_SAMPLES,
78
        DummyDataset.NUM_DUMMY_TRAINING_SAMPLES +
79
        DummyDataset.NUM_DUMMY_VALIDATION_SAMPLES)
80
    self.test_nodes = np.arange(
81
        DummyDataset.NUM_DUMMY_TRAINING_SAMPLES +
82
        DummyDataset.NUM_DUMMY_VALIDATION_SAMPLES, num_samples)
83

84

85
class OGBTransductiveDataset(Dataset):
86
  """Reads Open Graph Benchmark (OGB) datasets."""
87

88
  def __init__(self, dataset_name, dataset_path):
89
    super(OGBTransductiveDataset, self).__init__()
90
    self.name = dataset_name.replace('-disjoint', '').replace('-', '_')
91
    base_path = os.path.join(dataset_path, self.name)
92

93
    if self.name == 'ogbn_arxiv':
94
      split_property = 'split/time/'
95
    elif self.name == 'ogbn_mag':
96
      split_property = 'split/time/paper/'
97
    elif self.name == 'ogbn_products':
98
      split_property = 'split/sales_ranking/'
99
    elif self.name == 'ogbn_proteins':
100
      split_property = 'split/species/'
101
    else:
102
      raise ValueError('Unsupported dataset.')
103

104
    train_split_file = os.path.join(
105
        base_path, split_property, 'train.csv.gz')
106
    validation_split_file = os.path.join(
107
        base_path, split_property, 'valid.csv.gz')
108
    test_split_file = os.path.join(
109
        base_path, split_property, 'test.csv.gz')
110

111
    if self.name == 'ogbn_mag':
112
      node_feature_file = os.path.join(base_path,
113
                                       'raw/node-feat/paper/node-feat.csv.gz')
114
      node_label_file = os.path.join(base_path,
115
                                     'raw/node-label/paper/node-label.csv.gz')
116
    else:
117
      node_feature_file = os.path.join(base_path, 'raw/node-feat.csv.gz')
118
      node_label_file = os.path.join(base_path, 'raw/node-label.csv.gz')
119

120
    logging.info('Reading node features...')
121
    self.node_features = pd.read_csv(
122
        node_feature_file, header=None).values.astype(np.float32)
123
    logging.info('Node features loaded.')
124

125
    logging.info('Reading node labels...')
126
    self.node_labels = pd.read_csv(
127
        node_label_file, header=None).values.astype(np.int64).squeeze()
128
    logging.info('Node labels loaded.')
129

130
    if self.name == 'ogbn_mag':
131
      edge_file = os.path.join(
132
          base_path, 'raw/relations/paper___cites___paper/edge.csv.gz')
133
    else:
134
      edge_file = os.path.join(base_path, 'raw/edge.csv.gz')
135

136
    logging.info('Reading edges...')
137
    senders_receivers = pd.read_csv(
138
        edge_file, header=None).values.T.astype(np.int64)
139
    self.senders, self.receivers = senders_receivers
140
    logging.info('Edges loaded.')
141

142
    logging.info('Reading train, validation and test splits...')
143
    self.train_nodes = pd.read_csv(
144
        train_split_file, header=None).values.T.astype(np.int64).squeeze()
145
    self.validation_nodes = pd.read_csv(
146
        validation_split_file, header=None).values.T.astype(np.int64).squeeze()
147
    self.test_nodes = pd.read_csv(
148
        test_split_file, header=None).values.T.astype(np.int64).squeeze()
149
    logging.info('Loaded train, test and validation splits.')
150

151

152
class OGBDisjointDataset(OGBTransductiveDataset):
153
  """A disjoint version of a OGB dataset, with no inter-split edges."""
154

155
  def __init__(self, dataset_name, dataset_path):
156
    super(OGBDisjointDataset, self).__init__(dataset_name, dataset_path)
157
    self.name = dataset_name
158

159
    train_split = set(self.train_nodes.flat)
160
    validation_split = set(self.validation_nodes.flat)
161
    test_split = set(self.test_nodes.flat)
162
    splits = [train_split, validation_split, test_split]
163

164
    def _compute_split_index(elem):
165
      elem_index = None
166
      for index, split in enumerate(splits):
167
        if elem in split:
168
          if elem_index is not None:
169
            raise ValueError(f'Node {elem} present in multiple splits.')
170
          elem_index = index
171
      if elem_index is None:
172
        raise ValueError(f'Node {elem} present in none of the splits.')
173
      return elem_index
174

175
    senders_split_indices = np.vectorize(_compute_split_index)(self.senders)
176
    receivers_split_indices = np.vectorize(_compute_split_index)(self.receivers)
177
    in_same_split = (senders_split_indices == receivers_split_indices)
178

179
    self.senders = self.senders[in_same_split]
180
    self.receivers = self.receivers[in_same_split]
181

182

183

184

185
class GraphSAINTTransductiveDataset(Dataset):
186
  """Reads a GraphSAINT-format transductive dataset."""
187

188
  def __init__(self, dataset_name, dataset_path):
189
    super(GraphSAINTTransductiveDataset, self).__init__()
190

191
    self.name = dataset_name
192
    base_name = dataset_name.replace('-disjoint', '')
193
    base_name = base_name.replace('-transductive', '')
194

195

196
    self.base_name = base_name
197
    base_path = os.path.join(dataset_path, base_name)
198

199
    logging.info('Reading graph data...')
200
    self.adj_full = sp.load_npz(
201
        tf.io.gfile.GFile(os.path.join(base_path, 'adj_full.npz'), 'rb'))
202
    graph = nx.from_scipy_sparse_matrix(self.adj_full)
203
    graph_data = nx.readwrite.node_link_data(graph)
204
    logging.info('Graph data loaded.')
205

206
    self.senders = [e[0] for e in graph.edges]
207
    self.receivers = [e[1] for e in graph.edges]
208

209
    train_nodes = []
210
    validation_nodes = []
211
    test_nodes = []
212

213
    splits = json.load(
214
        tf.io.gfile.GFile(os.path.join(base_path, 'role.json'), 'r'))
215
    train_split = set(splits['tr'])
216
    validation_split = set(splits['va'])
217
    test_split = set(splits['te'])
218

219
    for node in graph_data['nodes']:
220
      node_id = node['id']
221
      if node_id in validation_split:
222
        validation_nodes.append(node_id)
223
      elif node_id in test_split:
224
        test_nodes.append(node_id)
225
      elif node_id in train_split:
226
        train_nodes.append(node_id)
227
      else:
228
        raise ValueError(f'Node {node_id} not present in any split.')
229

230
    self.train_nodes = np.asarray(train_nodes)
231
    self.validation_nodes = np.asarray(validation_nodes)
232
    self.test_nodes = np.asarray(test_nodes)
233

234
    logging.info('Reading node features...')
235
    node_features = np.load(
236
        tf.io.gfile.GFile(os.path.join(base_path, 'feats.npy'), 'rb'))
237
    logging.info('Node features loaded.')
238

239
    logging.info('Preprocessing node features...')
240
    train_node_features = node_features[self.train_nodes]
241
    scaler = sklearn.preprocessing.StandardScaler()
242
    scaler.fit(train_node_features)
243
    self.node_features = scaler.transform(node_features)
244
    logging.info('Node features preprocessed.')
245

246
    logging.info('Reading node labels...')
247
    class_map = json.load(
248
        tf.io.gfile.GFile(os.path.join(base_path, 'class_map.json'), 'r'))
249
    labels = [class_map[node_id] for node_id in sorted(class_map)]
250
    self.node_labels = np.asarray(labels).squeeze()
251
    logging.info('Node labels loaded.')
252

253

254
class GraphSAINTDisjointDataset(GraphSAINTTransductiveDataset):
255
  """Reads a GraphSAINT-format disjoint dataset."""
256

257
  def __init__(self, dataset_name, dataset_path):
258
    super(GraphSAINTDisjointDataset, self).__init__(dataset_name, dataset_path)
259

260
    self.name = dataset_name
261

262
    train_split = set(self.train_nodes)
263
    validation_split = set(self.validation_nodes)
264
    test_split = set(self.test_nodes)
265

266
    graph_train = _get_graph_for_split(self.adj_full, train_split)
267
    graph_validation = _get_graph_for_split(self.adj_full, validation_split)
268
    graph_test = _get_graph_for_split(self.adj_full, test_split)
269
    graph = nx.union_all((graph_train, graph_validation, graph_test))
270

271
    self.senders = [e[0] for e in graph.edges]
272
    self.receivers = [e[1] for e in graph.edges]
273

274

275
def _get_graph_for_split(adj_full,
276
                         split_set):
277
  """Returns the induced subgraph for the required split."""
278
  def edge_generator():
279
    senders, receivers = adj_full.nonzero()
280
    for sender, receiver in zip(senders, receivers):
281
      if sender in split_set and receiver in split_set:
282
        yield sender, receiver
283

284
  graph_split = nx.Graph()
285
  graph_split.add_nodes_from(split_set)
286
  graph_split.add_edges_from(edge_generator())
287
  return graph_split
288

289

290
def get_dataset(dataset_name, dataset_path):
291
  """Returns a graph dataset."""
292
  special_dataset_fns = {
293
      'dummy': DummyDataset,
294
  }
295
  if dataset_name in special_dataset_fns:
296
    return special_dataset_fns[dataset_name]()
297

298
  if dataset_name.startswith('ogb'):
299
    if dataset_name.endswith('disjoint'):
300
      return OGBDisjointDataset(dataset_name, dataset_path)
301
    return OGBTransductiveDataset(dataset_name, dataset_path)
302

303
  graphsaint_datasets = ['reddit', 'yelp', 'flickr']
304
  if any(dataset_name.startswith(name) for name in graphsaint_datasets):
305
    if dataset_name.endswith('disjoint'):
306
      return GraphSAINTDisjointDataset(dataset_name, dataset_path)
307
    if dataset_name.endswith('transductive'):
308
      return GraphSAINTTransductiveDataset(dataset_name, dataset_path)
309
    raise ValueError(
310
        'Please prefix dataset_name with `transductive` or `disjoint`.')
311

312
  raise ValueError(f'Unsupported dataset: {dataset_name}.')
313

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.