google-research

Форк
0
/
tfgnn_datasets.py 
1093 строки · 39.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
17
#
18
# Licensed under the Apache License, Version 2.0 (the "License");
19
# you may not use this file except in compliance with the License.
20
# You may obtain a copy of the License at
21
#
22
#     http://www.apache.org/licenses/LICENSE-2.0
23
#
24
# Unless required by applicable law or agreed to in writing, software
25
# distributed under the License is distributed on an "AS IS" BASIS,
26
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27
# See the License for the specific language governing permissions and
28
# limitations under the License.
29
# ==============================================================================
30
"""Infrastructure and implementation of in-memory graph data.
31

32
Instantiating an object will download a dataset, and cache it locally. The
33
datasets will be cached on ~/data/ogb (for "ogbn-" and "ogbl-" datasets), which
34
can be overridden by setting environment variable `OGB_CACHE_DIR`; and on
35
~/data/planetoid (for "cora", "citeseer", "pubmed"), which can be overridden by
36
environment variable `PLANETOID_CACHE_DIR`.
37

38
High-level Abstract Classes:
39

40
  * `InMemoryGraphData`: provides nodes, edges, and features, for a
41
     homogeneous or a heteregenous graph.
42
  * `NodeClassificationGraphData`: an `InMemoryGraphData` that also provides
43
    list of {train, test, validation} nodes, as well as their labels.
44
  * `LinkPredictionGraphData`: an `InMemoryGraphData` that also provides lists
45
    of edges in {train, test, validation} partitions.
46

47

48
`InMemoryGraphData` implementations can provide
49

50
  * a single GraphTensor for training on one big graph (e.g., for node
51
    classification with `tf_trainer.py` or `keras_trainer.py`),
52
  * a big graph from which in-memory sampling (e.g., `int_arithmetic_sampler`)
53
    can create dataset of sampled subgraphs (encoded as `tfgnn.GraphTensor`).
54

55
All `InMemoryGraphData` implementations automatically inherit abilities of:
56

57
  * `as_graph_tensor()` .
58
  * These methods can be plugged-into TF-GNN models and training loops, e.g.,
59
    for node classification (see `tf_trainer.py` and `keras_trainer.py`).
60
  * In addition, they can be plugged-into in-memory sampling (see
61
    `int_arithmetic_sampler.py`, and example trainer script,
62
    `keras_minibatch_trainer.py`).
63

64

65
Concrete implementations:
66

67
  * Node classification (inheriting `NodeClassificationGraphData`)
68

69
    * `OgbnData`: Wraps node classification graph data from OGB, i.e., with
70
      name prefix of "ogbn-", such as, "ogbn-arxiv".
71

72
    * `PlanetoidGraphData`: wraps graph data that are popularized by GCN paper
73
      (cora, citeseer, pubmed).
74

75
  * Link prediction (inheriting `LinkPredictionGraphData`)
76

77
    * `OgblData`: Wraps link prediction graph data from OGB, i.e., with name
78
      prefix of "ogbl-", such as, "ogbl-ddi".
79

80

81
# Usage Example.
82

83
```
84
graph_data = datasets.OgbnData('ogbn-arxiv')
85

86
# Optionally, make graph undirected.
87
graph_data = graph_data.with_self_loops(True)
88

89
# add self-loops:
90
graph_data = graph_data.with_undirected_edges(True)
91

92
# To get GraphTensor and GraphSchema at any graph data:
93
graph_tensor = graph_data.as_graph_tensor()
94
graph_schema = graph_data.graph_schema()
95

96
spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)
97
# or optionally, by "relaxing" the batch dimension of `graph_tensor` (to None):
98
# spec = graph_tensor.spec.relax(num_nodes=True, num_edges=True)
99
```
100

101
The first line is equivalent to
102
`graph_data = datasets.get_in_memory_graph_data('ogbn-arxiv')`. Which is more
103
general, because it can load other data types:
104
  * ogbn-* are node-calssificiation datasets for OGB.
105
  * 'pubmed', 'cora', 'citeseer', correspond to transductive graphs used in
106
    Planetoid (Yang et al, ICML'16).
107

108

109
`graph_tensor` (type `GraphTensor`) contains all nodes, edges, and features.
110
If it is a node-classification dataset, the training labels are also populated.
111
**For nodes not in training set**, label feature will be `-1`. To also include
112
If you want to explicitly get all labels from all partitions, you may:
113

114
```
115
graph_data = graph_data.with_split(['train', 'test', 'validation'])
116
graph_tensor = graph_data.graph_tensor
117
```
118

119
Chaining `with_*` calls can reduce verbosity. For example,
120
```
121
graph_data = (
122
    datasets.OgbnData('ogbn-arxiv').with_undirected_edges(True)
123
    .with_self_loops(True))
124
graph_tensor = graph_data.as_graph_tensor()
125
```
126
"""
127
import abc
128
import copy
129
import io
130
import json
131
import os
132
import pickle
133
import sys
134
from typing import Any, List, Mapping, NamedTuple, Tuple, Union, Optional
135
import urllib.request
136

137
import numpy as np
138
import scipy
139
import tensorflow as tf
140
import tensorflow_gnn as tfgnn
141

142

143
class InMemoryGraphData(abc.ABC):
144
  """Abstract class for hold a graph data in-memory (nodes, edges, features).
145

146
  Subclasses must implement methods `node_features_dicts()`, `node_counts()`,
147
  `edge_lists()`, `node_sets()`, and optionally, `context()`. They inherit
148
  methods `graph_schema()`, `edge_sets()`, and `as_graph_tensor()` based on
149
  those.
150
  """
151

152
  def __init__(self, make_undirected = False,
153
               add_self_loops = False):
154
    self._make_undirected = make_undirected
155
    self._add_self_loops = add_self_loops
156

157
  @property
158
  def name(self):
159
    """Returns name of dataset object. Can be overridden to return data name."""
160
    return self.__class__.__name__
161

162
  def with_undirected_edges(self, make_undirected):
163
    """Returns same graph data but with undirected edges added (or removed).
164

165
    Subsequent calls to `.graph_schema()` and to `.as_graph_tensor()` will be
166
    affected. Specifically, the generated output `tfgnn.GraphTensor` (by
167
    `.as_graph_tensor()`) will reverse all homogeneous edge sets (where its
168
    source node set equals its target node set). Suppose edge `(i, j)` is
169
    included in *homogeneous* edge set "MyEdgeSet", then output `GraphTensor`
170
    will also contain edge `(j, i)` on edge set "MyEdgeSet". If edge `(j, i)`
171
    already exists, then it will be duplicated.
172

173
    If make_undirected == True:
174

175
      * output of `.as_graph_tensor()` will contain only edge-set names that are
176
        returned by `.edge_sets()`, where each homogeneous edge-set with M edges
177
        will be expanded to M*2 edges with edge `M+k` reversing edge `k`.
178
      * output of `.graph_schema()` will contain only edge-sets returned by
179
        `edge_sets`.
180

181
    If make_undirected == False:
182

183
      * output of `.as_graph_tensor()` will contain, for each edge set "EdgeSet"
184
        (returned by `.edge_sets()`) a new edge-set "rev_EdgeSet" that reverses
185
        the "EdgeSet".
186
      * output of `.graph_schema()`. will have both "EdgeSet" and "rev_EdgeSet".
187
      * `with_reverse_edge_sets()` is an equivalent and a more explicit method
188
        to add reverse edge sets to the graph tensor and its schema.
189
    Args:
190
      make_undirected: If True, subsequent calls to `.graph_schema()` and
191
        `.as_graph_tensor()` will export an undirected graph. If False, a
192
        directed graph (with additional "rev_*" edges).
193
    """
194
    modified = copy.copy(self)
195
    modified._make_undirected = make_undirected  # pylint: disable=protected-access -- same class.
196
    return modified
197

198
  def with_reverse_edge_sets(self):
199
    """Returns same graph data but with reverse edge sets added."""
200

201
    # Calling `with_undirected_edges` with `False` input automatically makes the
202
    # output of `.as_graph_tensor()` to contain, for each edge set "EdgeSet"
203
    # (returned by `.edge_sets()`) a new edge-set "rev_EdgeSet" that reverses
204
    # the "EdgeSet". Similarly, output of `.graph_schema()`. will have both
205
    # "EdgeSet" and "rev_EdgeSet".
206
    return self.with_undirected_edges(False)
207

208
  def with_self_loops(self, add_self_loops):
209
    """Returns same graph data but with self-loops added (or removed).
210

211
    If add_self_loops == True, then subsequent calls to `.as_graph_tensor()`
212
    will contain edges `[(i, i) for i in range(N_j)]`, for each homogeneous edge
213
    set j, where `N_j` is the number of nodes in node set connected by edge set
214
    `j`.
215

216
    NOTE: self-loops will be added *regardless* if they already exist or not.
217
    If the datasets already has self-loops, calling this, will double the self-
218
    loop edges.
219

220
    Args:
221
      add_self_loops: If set, self-loops will be amended on subsequent calls to
222
      `.as_graph_tensor()`. If not, no self-loops will be automatically added.
223
    """
224
    modified = copy.copy(self)
225
    modified._add_self_loops = add_self_loops  # pylint: disable=protected-access -- same class.
226
    return modified
227

228
  @abc.abstractmethod
229
  def node_counts(self):
230
    """Returns total number of graph nodes per node set."""
231
    raise NotImplementedError()
232

233
  @abc.abstractmethod
234
  def node_features_dicts(self):
235
    """Returns 2-level dict: NodeSetName->FeatureName->Feature tensor.
236

237
    For every node set (`"x"`), feature tensor must have leading dimension equal
238
    to number of nodes in node set (`.node_counts()["x"]`). Other dimensions are
239
    dataset specific.
240
    """
241
    raise NotImplementedError()
242

243
  @abc.abstractmethod
244
  def edge_lists(self):
245
    """Returns dict from "edge type tuple" to int Tensor of shape (2, num_edges).
246

247
    "edge type tuple" string three-tuple:
248
      `(source node set name, edge set name, target node set name)`.
249
    where `edge set name` must be unique.
250
    """
251
    raise NotImplementedError()
252

253
  def node_sets(self):
254
    """Returns node sets of entire graph (dict: node set name -> NodeSet)."""
255
    node_counts = self.node_counts()
256
    features_dicts = self.node_features_dicts()
257
    node_set_names = set(node_counts.keys()).union(features_dicts.keys())
258
    return (
259
        {name: tfgnn.NodeSet.from_fields(sizes=as_tensor([node_counts[name]]),
260
                                         features=features_dicts.get(name, {}))
261
         for name in node_set_names})
262

263
  def context(self):
264
    return None
265

266
  def as_graph_tensor(self):
267
    """Returns `GraphTensor` holding the entire graph."""
268
    return tfgnn.GraphTensor.from_pieces(
269
        node_sets=self.node_sets(), edge_sets=self.edge_sets(),
270
        context=self.context())
271

272
  def graph_schema(self):
273
    """`tfgnn.GraphSchema` instance corresponding to `as_graph_tensor()`."""
274
    # Populate node features specs.
275
    schema = tfgnn.GraphSchema()
276
    for node_set_name, node_set in self.node_sets().items():
277
      node_features = schema.node_sets[node_set_name]
278
      for feat_name, feature in node_set.features.items():
279
        node_features.features[feat_name].dtype = feature.dtype.as_datatype_enum
280
        for dim in feature.shape[1:]:
281
          node_features.features[feat_name].shape.dim.add().size = dim
282

283
    # Populate edge specs.
284
    for edge_type in self.edge_lists().keys():
285
      src_node_set_name, edge_set_name, dst_node_set_name = edge_type
286
      # Populate edges with adjacency and it transpose.
287
      schema.edge_sets[edge_set_name].source = src_node_set_name
288
      schema.edge_sets[edge_set_name].target = dst_node_set_name
289
      if not self._make_undirected:
290
        schema.edge_sets['rev_' + edge_set_name].source = dst_node_set_name
291
        schema.edge_sets['rev_' + edge_set_name].target = src_node_set_name
292

293
    return schema
294

295
  def edge_sets(self):
296
    """Returns edge sets of entire graph (dict: edge set name -> EdgeSet)."""
297
    edge_sets = {}
298
    node_counts = self.node_counts() if self._add_self_loops else None
299
    for edge_type, edge_list in self.edge_lists().items():
300
      (source_node_set_name, edge_set_name, target_node_set_name) = edge_type
301

302
      if self._make_undirected and source_node_set_name == target_node_set_name:
303
        edge_list = tf.concat([edge_list, edge_list[::-1]], axis=-1)
304
      if self._add_self_loops and source_node_set_name == target_node_set_name:
305
        all_nodes = tf.range(node_counts[source_node_set_name],
306
                             dtype=edge_list.dtype)
307
        self_connections = tf.stack([all_nodes, all_nodes], axis=0)
308
        edge_list = tf.concat([edge_list, self_connections], axis=-1)
309
      edge_sets[edge_set_name] = tfgnn.EdgeSet.from_fields(
310
          sizes=tf.shape(edge_list)[1:2],
311
          adjacency=tfgnn.Adjacency.from_indices(
312
              source=(source_node_set_name, edge_list[0]),
313
              target=(target_node_set_name, edge_list[1])))
314
      if not self._make_undirected:
315
        edge_sets['rev_' + edge_set_name] = tfgnn.EdgeSet.from_fields(
316
            sizes=tf.shape(edge_list)[1:2],
317
            adjacency=tfgnn.Adjacency.from_indices(
318
                source=(target_node_set_name, edge_list[1]),
319
                target=(source_node_set_name, edge_list[0])))
320
    return edge_sets
321

322
  def save(self, filename):
323
    """Superclasses can save themselves to disk."""
324
    raise NotImplementedError()
325

326

327
class NodeSplit(NamedTuple):
328
  """Contains 1D int tensors holding positions of {train, valid, test} nodes.
329

330
  This is returned by `NodeClassificationGraphData.node_split()`
331
  """
332
  train: tf.Tensor
333
  validation: tf.Tensor
334
  test: tf.Tensor
335

336

337
class EdgeSplit(NamedTuple):
338
  """Contains positive and negative edges in {train, test, valid} partitions.
339

340
  Each `tf.Tensor` will be of shape `[2, num_edges]` with dtype int64.
341
  """
342
  # Only need positive edges for training. The (entire) graph compliment can be
343
  # used for negative edges.
344
  train_edges: tf.Tensor
345
  validation_edges: tf.Tensor
346
  test_edges: tf.Tensor
347
  negative_validation_edges: tf.Tensor
348
  negative_test_edges: tf.Tensor
349

350

351
class NodeClassificationGraphData(InMemoryGraphData):
352
  """Adapts `InMemoryGraphData` for node classification settings.
353

354
  Subclasses should information for node classification: (node labels, name of
355
  node set, and partitions train:validation:test nodes).
356
  """
357

358
  def __init__(self, split = 'train', use_labels_as_features=False):
359
    super().__init__()
360
    self._splits = [split]
361
    self._use_labels_as_features = use_labels_as_features
362

363
  def with_split(self, split = 'train'
364
                 ):
365
    """Returns same graph data but with specific partition.
366

367
    Args:
368
      split: must be one of {"train", "validation", "test"}.
369
    """
370
    splits = split if isinstance(split, (tuple, list)) else [split]
371
    for split in splits:
372
      if split not in ('train', 'validation', 'test'):
373
        raise ValueError(
374
            'split must be one of {"train", "validation", "test"}.')
375
    modified = copy.copy(self)
376
    modified._splits = splits  # pylint: disable=protected-access -- same class.
377
    return modified
378

379
  def with_labels_as_features(
380
      self, use_labels_as_features):
381
    """Returns same graph data with labels as an additional feature on nodes.
382

383
    The feature will be added to the node-set with name `self.labeled_nodeset`.
384

385
    Args:
386
      use_labels_as_features: Label feature will be added iff set to True.
387
    """
388
    modified = copy.copy(self)
389
    modified._use_labels_as_features = use_labels_as_features  # pylint: disable=protected-access -- same class.
390
    return modified
391

392
  @property
393
  def splits(self):
394
    return copy.copy(self._splits)
395

396
  @abc.abstractmethod
397
  def num_classes(self):
398
    """Number of node classes. Max of `labels` should be `< num_classes`."""
399
    raise NotImplementedError('num_classes')
400

401
  @abc.abstractmethod
402
  def node_split(self):
403
    """`NodeSplit` with attributes `train`, `validation`, `test` set.
404

405
    The attributes are set to indices of the `labeled_nodeset`. Specifically,
406
    they correspond to leading dimension of features of the node set.
407
    """
408
    raise NotImplementedError()
409

410
  @abc.abstractmethod
411
  def labels(self):
412
    """int vector containing labels for train & validation nodes.
413

414
    Size of vector is number of nodes in the labeled node set. In particular:
415
    `self.labels().shape[0] == self.node_counts()[self.labeled_nodeset]`.
416
    Specifically, the vector has as many entries as there are nodes belonging to
417
    the node set that this task aims to predict labels for.
418

419
    Entry `labels()[i]` will be -1 iff `i in self.node_split().test`. Otherwise,
420
    `labels()[i]` will be int in range [`0`, `self.num_classes() - 1`].
421
    """
422
    raise NotImplementedError()
423

424
  @abc.abstractmethod
425
  def test_labels(self):
426
    """Like the above but contains no -1's.
427

428
    Every {train, valid, test} node will have its class label.
429
    """
430
    raise NotImplementedError()
431

432
  @property
433
  @abc.abstractmethod
434
  def labeled_nodeset(self):
435
    """Name of node set which `labels` and `node_splits` reference."""
436
    raise NotImplementedError()
437

438
  @abc.abstractmethod
439
  def node_features_dicts_without_labels(self):
440
    raise NotImplementedError()
441

442
  def node_features_dicts(self):
443
    """Implements a method required by the base class.
444

445
    This method combines the data from `labels()` or `test_labels()` with the
446
    data from `node_features_dicts_without_labels()` into a single features
447
    dict.
448

449
    Subclasses need to implement aforementioned methods and may inherit this.
450

451
    Returns:
452
      NodeSetName -> FeatureName -> Feature Tensor.
453
    """
454
    node_features_dicts = self.node_features_dicts_without_labels()
455
    node_features_dicts = {ns: dict(features)  # Shallow copy.
456
                           for ns, features in node_features_dicts.items()}
457
    if self._use_labels_as_features:
458
      if 'test' in self._splits:
459
        node_features_dicts[self.labeled_nodeset]['label'] = self.test_labels()
460
      else:
461
        node_features_dicts[self.labeled_nodeset]['label'] = self.labels()
462

463
    return node_features_dicts
464

465
  def context(self):
466
    node_split = self.node_split()
467
    seed_nodes = tf.concat(
468
        [getattr(node_split, split) for split in self._splits], axis=0)
469
    seed_nodes = tf.expand_dims(seed_nodes, axis=0)
470
    seed_feature_name = 'seed_nodes.' + self.labeled_nodeset
471

472
    return tfgnn.Context.from_fields(features={seed_feature_name: seed_nodes})
473

474
  def graph_schema(self):
475
    graph_schema = super().graph_schema()
476
    context_features = graph_schema.context.features
477
    context_features['seed_nodes.' + self.labeled_nodeset].dtype = (
478
        tf.int64.as_datatype_enum)
479
    return graph_schema
480

481
  def save(self, filename):
482
    """Saves the dataset on numpy compressed (.npz) file.
483

484
    The file runs once the functions,
485
    (labeled_nodeset, test_labels, labels, node_split, edge_lists, node_counts,
486
     node_features, num_classes),
487
    composes a flat dict (keys are json-encoded arrays), then writes as numpy
488
    file. Flat dict is needed as numpy only saves named arrays, not nested
489
    structures.
490

491
    Args:
492
      filename: file path to save onto. ".npz" extension is recommended. Parent
493
        directory must exist.
494
    """
495
    features_without_labels = self.node_features_dicts_without_labels()
496
    node_split = self.node_split()
497

498
    attribute_dict = {
499
        ('num_classes',): self.num_classes(),
500
        ('node_split', 'train'): node_split.train.numpy(),
501
        ('node_split', 'test'): node_split.test.numpy(),
502
        ('node_split', 'validation'): node_split.validation.numpy(),
503
        ('labels',): self.labels().numpy(),
504
        ('test_labels',): self.test_labels().numpy(),
505
        ('labeled_nodeset',): self.labeled_nodeset,
506
    }
507

508
    # Edge sets.
509
    for (src_name, es_name, tgt_name), es_indices in self.edge_lists().items():
510
      key = ('e', '#', src_name, es_name, tgt_name)
511
      attribute_dict[key] = es_indices.numpy()
512

513
    for ns_name, features in features_without_labels.items():
514
      for feature_name, feature_tensor in features.items():
515
        attribute_dict[('n', ns_name, feature_name)] = feature_tensor.numpy()
516

517
    for node_set_name, node_count in self.node_counts().items():
518
      attribute_dict[('nc', node_set_name)] = node_count
519

520
    bytes_io = io.BytesIO()
521
    attribute_dict = {json.dumps(k): v for k, v in attribute_dict.items()}
522
    np.savez_compressed(bytes_io, **attribute_dict)
523
    with tf.io.gfile.GFile(filename, 'wb') as f:
524
      f.write(bytes_io.getvalue())
525

526
  @staticmethod
527
  def load(filename):
528
    """Loads from disk `NodeClassificationGraphData` that was `save()`ed."""
529
    dataset_dict = dict(np.load(tf.io.gfile.GFile(filename, 'rb')))
530
    dataset_dict = {tuple(json.loads(k)): v for k, v in dataset_dict.items()}
531
    edge_lists = {}
532
    node_features = {}
533
    node_counts = {}
534
    for key, array in dataset_dict.items():
535
      # edge lists.
536
      if key[0] == 'e':
537
        if key[1] != '#':
538
          raise ValueError('Expecting ("e", "#", ...) but got %s' % str(key))
539
        src_name = key[2]
540
        es_name = key[3]
541
        tgt_name = key[4]
542
        indices = as_tensor(array)
543
        edge_lists[(src_name, es_name, tgt_name)] = indices
544
      # node features.
545
      if key[0] == 'n':
546
        node_set_name = key[1]
547
        feature_name = key[2]
548
        if node_set_name not in node_features:
549
          node_features[node_set_name] = {}
550
        node_features[node_set_name][feature_name] = as_tensor(array)
551
      if key[0] == 'nc':
552
        node_counts[key[1]] = int(array)
553

554
    return _PreloadedNodeClassificationGraphData(
555
        num_classes=dataset_dict[('num_classes',)],
556
        node_features_dicts_without_labels=node_features,
557
        node_counts=node_counts,
558
        edge_lists=edge_lists,
559
        node_split=NodeSplit(
560
            train=as_tensor(dataset_dict[('node_split', 'train')]),
561
            validation=as_tensor(dataset_dict[('node_split', 'validation')]),
562
            test=as_tensor(dataset_dict[('node_split', 'test')])),
563
        labels=as_tensor(dataset_dict[('labels',)]),
564
        test_labels=as_tensor(dataset_dict[('test_labels',)]),
565
        labeled_nodeset=str(dataset_dict[('labeled_nodeset',)]))
566

567

568
class _PreloadedNodeClassificationGraphData(NodeClassificationGraphData):
569
  """Dataset from pre-computed attributes."""
570

571
  def __init__(
572
      self, num_classes,
573
      node_features_dicts_without_labels,
574
      node_counts,
575
      edge_lists,
576
      node_split, labels, test_labels,
577
      labeled_nodeset):
578
    super().__init__()
579
    self._num_classes = num_classes
580
    self._node_features_dicts_without_labels = (
581
        node_features_dicts_without_labels)
582
    self._node_counts = node_counts
583
    self._edge_lists = edge_lists
584
    self._node_split = node_split
585
    self._labels = labels
586
    self._test_labels = test_labels
587
    self._labeled_nodeset = labeled_nodeset
588

589
  def num_classes(self):
590
    return self._num_classes
591

592
  def node_features_dicts_without_labels(self):
593
    return self._node_features_dicts_without_labels
594

595
  def node_counts(self):
596
    return self._node_counts
597

598
  def edge_lists(self):
599
    return self._edge_lists
600

601
  def node_split(self):
602
    return self._node_split
603

604
  def labels(self):
605
    return self._labels
606

607
  def test_labels(self):
608
    return self._test_labels
609

610
  @property
611
  def labeled_nodeset(self):
612
    return self._labeled_nodeset
613

614

615
class _OgbGraph:
616
  """Wraps data exposed by OGB graph objects, while enforcing heterogeneity.
617

618
  Attributes offered by this class are consistent with the APIs of GraphData.
619
  """
620

621
  def __init__(self, graph):
622
    """Reads dict OGB `graph` and into the attributes defined below.
623

624
    Args:
625
      graph: Dict, described in
626
        https://github.com/snap-stanford/ogb/blob/master/ogb/io/README.md#2-saving-graph-list
627
    """
628
    if 'edge_index_dict' in graph:  # Heterogeneous graph
629
      assert 'num_nodes_dict' in graph
630
      assert 'node_feat_dict' in graph
631

632
      # node set name -> feature name -> feature matrix (numNodes x featDim).
633
      node_set = {node_set_name: {'feat': as_tensor(feat)}
634
                  for node_set_name, feat in graph['node_feat_dict'].items()
635
                  if feat is not None}
636
      # Populate remaining features
637
      for key, node_set_name_to_feat in graph.items():
638
        if key.startswith('node_') and key != 'node_feat_dict':
639
          feat_name = key.split('node_', 1)[-1]
640
          for node_set_name, feat in node_set_name_to_feat.items():
641
            node_set[node_set_name][feat_name] = as_tensor(feat)
642
      self._num_nodes_dict = graph['num_nodes_dict']
643
      self._node_feat_dict = node_set
644
      self._edge_index_dict = tf.nest.map_structure(
645
          as_tensor, graph['edge_index_dict'])
646
    else:  # Homogenous graph. Make heterogeneous.
647
      if graph.get('node_feat', None) is not None:
648
        node_features = {
649
            tfgnn.NODES: {'feat': as_tensor(graph['node_feat'])}
650
        }
651
      else:
652
        node_features = {
653
            tfgnn.NODES: {
654
                'feat': tf.zeros([graph['num_nodes'], 0], dtype=tf.float32)
655
            }
656
        }
657

658
      self._edge_index_dict = {
659
          (tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): as_tensor(
660
              graph['edge_index']),
661
      }
662
      self._num_nodes_dict = {tfgnn.NODES: graph['num_nodes']}
663
      self._node_feat_dict = node_features
664

665
  @property
666
  def num_nodes_dict(self):
667
    """Maps "node set name" -> number of nodes."""
668
    return self._num_nodes_dict
669

670
  @property
671
  def node_feat_dict(self):
672
    """Maps "node set name" to dict of "feature name"->tf.Tensor."""
673
    return self._node_feat_dict
674

675
  @property
676
  def edge_index_dict(self):
677
    """Adjacency lists for all edge sets.
678

679
    Returns:
680
      Dict (source node set name, edge set name, target node set name) -> edges.
681
      Where `edges` is tf.Tensor of shape (2, num edges), with `edges[0]` and
682
      `edges[1]`, respectively, containing source and target node IDs (as 1D int
683
      tf.Tensor).
684
    """
685
    return self._edge_index_dict
686

687

688
def _get_ogbn_dataset(dataset_name, cache_dir = None):
689
  """Imports ogb and returns `NodePropPredDataset`."""
690
  # This is done on purpose: we only import ogb if an ogb dataset is requested.
691
  import ogb.nodeproppred  # pylint: disable=g-import-not-at-top
692
  return ogb.nodeproppred.NodePropPredDataset(dataset_name, root=cache_dir)
693

694

695
def _get_ogbl_dataset(dataset_name, cache_dir = None):
696
  """Imports ogb and returns `LinkPropPredDataset`."""
697
  # This is done on purpose: we only import ogb if an ogb dataset is requested.
698
  import ogb.linkproppred  # pylint: disable=g-import-not-at-top
699
  return ogb.linkproppred.LinkPropPredDataset(dataset_name, root=cache_dir)
700

701

702
class OgbnData(NodeClassificationGraphData):
703
  """Wraps node classification graph data of ogbn-* for in-memory learning."""
704

705
  def __init__(self, dataset_name, cache_dir=None):
706
    super().__init__()
707
    self._dataset_name = dataset_name
708
    if cache_dir is None:
709
      cache_dir = os.environ.get(
710
          'OGB_CACHE_DIR', os.path.expanduser(os.path.join('~', 'data', 'ogb')))
711

712
    self._ogb_dataset = _get_ogbn_dataset(dataset_name, cache_dir)
713
    self._graph, self._node_labels, self._node_split, self._labeled_nodeset = (
714
        OgbnData._to_heterogeneous(self._ogb_dataset))
715

716
    # rehape from [N, 1] to [N].
717
    self._node_labels = self._node_labels[:, 0]
718

719
    # train labels (test set to -1).
720
    self._train_labels = np.copy(self._node_labels)
721
    self._train_labels[self._node_split.test] = -1
722

723
    self._train_labels = as_tensor(self._train_labels)
724
    self._node_labels = as_tensor(self._node_labels)
725

726
  @property
727
  def name(self):
728
    return self._dataset_name
729

730
  @staticmethod
731
  def _to_heterogeneous(
732
      ogb_dataset):
733
    """Returns heterogeneous dicts from homogeneous or heterogeneous OGB dataset.
734

735
    Args:
736
      ogb_dataset: OGBN dataset. It can be homogeneous (single node set type,
737
        single edge set type), or heterogeneous (various node/edge set types),
738
        and returns data structure as-if the dataset is heterogeneous (i.e.,
739
        names each node/edge set). If input is a homogeneous graph, then the
740
        node set will be named "nodes" and the edge set will be named "edges".
741

742
    Returns:
743
      tuple: `(ogb_graph, node_labels, idx_split, labeled_nodeset)`, where:
744
        `ogb_graph` is instance of _OgbGraph.
745
        `node_labels`: np.array of labels, with .shape[0] equals number of nodes
746
          in node set with name `labeled_nodeset`.
747
        `idx_split`: instance of NodeSplit. Members `train`, `test` and `valid`,
748
          respectively, contain indices of nodes in node set with name
749
          `labeled_nodeset`.
750
        `labeled_nodeset`: name of node set that the node-classification task is
751
          designed over.
752
    """
753
    graph, node_labels = ogb_dataset[0]
754
    ogb_graph = _OgbGraph(graph)
755
    if 'edge_index_dict' in graph:  # Graph is heterogeneous
756
      assert 'num_nodes_dict' in graph
757
      assert 'node_feat_dict' in graph
758
      labeled_nodeset = list(node_labels.keys())
759
      if len(labeled_nodeset) != 1:
760
        raise ValueError('Expecting OGB dataset with *one* node set with '
761
                         'labels. Found: ' + ', '.join(labeled_nodeset))
762
      labeled_nodeset = labeled_nodeset[0]
763

764
      node_labels = node_labels[labeled_nodeset]
765
      # idx_split is dict: {'train': {labeled_nodeset: np.array}, 'test': ...}.
766
      idx_split = ogb_dataset.get_idx_split()
767
      # Change to {'train': Tensor, 'test': Tensor, 'valid': Tensor}
768
      idx_split = {split_name: as_tensor(split_dict[labeled_nodeset])
769
                   for split_name, split_dict in idx_split.items()}
770
      # third-party OGB class returns dict with key 'valid'. Make consistent
771
      # with TF nomenclature by renaming.
772
      idx_split['validation'] = idx_split.pop('valid')  # Rename
773
      idx_split = NodeSplit(**idx_split)
774

775
      return ogb_graph, node_labels, idx_split, labeled_nodeset
776

777
    # Copy other node information.
778
    for key, value in graph.items():
779
      if key != 'node_feat' and key.startswith('node_'):
780
        key = key.split('node_', 1)[-1]
781
        ogb_graph.node_feat_dict[tfgnn.NODES][key] = as_tensor(value)  # pytype: disable=unsupported-operands  # always-use-property-annotation
782
    idx_split = ogb_dataset.get_idx_split()
783
    idx_split['validation'] = idx_split.pop('valid')  # Rename
784
    idx_split = NodeSplit(**tf.nest.map_structure(
785
        tf.convert_to_tensor, idx_split))
786
    return ogb_graph, node_labels, idx_split, tfgnn.NODES
787

788
  def num_classes(self):
789
    return self._ogb_dataset.num_classes
790

791
  def node_features_dicts_without_labels(self):
792
    # Deep-copy dict (*but* without copying tf.Tensor objects).
793
    node_sets = self._graph.node_feat_dict
794
    node_sets = {node_set_name: dict(node_set.items())
795
                 for node_set_name, node_set in node_sets.items()}
796
    node_counts = self.node_counts()
797
    for node_set_name, count in node_counts.items():
798
      if node_set_name not in node_sets:
799
        node_sets[node_set_name] = {}
800
      feat_dict = node_sets[node_set_name]
801
      feat_dict['#id'] = tf.range(count, dtype=tf.int32)
802
    return node_sets
803

804
  @property
805
  def labeled_nodeset(self):
806
    return self._labeled_nodeset
807

808
  def node_counts(self):
809
    return self._graph.num_nodes_dict
810

811
  def edge_lists(self):
812
    return self._graph.edge_index_dict
813

814
  def node_split(self):
815
    return self._node_split
816

817
  def labels(self):
818
    return self._train_labels
819

820
  def test_labels(self):
821
    """int numpy array of length num_nodes containing train and test labels."""
822
    return self._node_labels
823

824

825
def _maybe_download_file(source_url, destination_path, make_dirs=True):
826
  """Downloads URL `source_url` onto file `destination_path` if not present."""
827
  if not tf.io.gfile.exists(destination_path):
828
    dir_name = os.path.dirname(destination_path)
829
    if make_dirs:
830
      try:
831
        tf.io.gfile.makedirs(dir_name)
832
      except FileExistsError:
833
        pass
834

835
    with urllib.request.urlopen(source_url) as fin:
836
      with tf.io.gfile.GFile(destination_path, 'wb') as fout:
837
        fout.write(fin.read())
838

839

840
class PlanetoidGraphData(NodeClassificationGraphData):
841
  """Wraps Planetoid node-classificaiton datasets.
842

843
  These datasets first appeared in the Planetoid [1] paper and popularized by
844
  the GCN paper [2].
845

846
  [1] Yang et al, ICML'16
847
  [2] Kipf & Welling, ICLR'17.
848
  """
849

850
  def __init__(self, dataset_name, cache_dir=None):
851
    super().__init__()
852
    self._dataset_name = dataset_name
853
    allowed_names = ('pubmed', 'citeseer', 'cora')
854

855
    url_template = (
856
        'https://github.com/kimiyoung/planetoid/blob/master/data/'
857
        'ind.%s.%s?raw=true')
858
    file_parts = ['ally', 'allx', 'graph', 'ty', 'tx', 'test.index']
859
    if dataset_name not in allowed_names:
860
      raise ValueError('Dataset must be one of: ' + ', '.join(allowed_names))
861
    if cache_dir is None:
862
      cache_dir = os.environ.get(
863
          'PLANETOID_CACHE_DIR', os.path.expanduser(
864
              os.path.join('~', 'data', 'planetoid')))
865
    base_path = os.path.join(cache_dir, 'ind.%s' % dataset_name)
866
    # Download all files.
867
    for file_part in file_parts:
868
      source_url = url_template % (dataset_name, file_part)
869
      destination_path = os.path.join(
870
          cache_dir, 'ind.%s.%s' % (dataset_name, file_part))
871
      _maybe_download_file(source_url, destination_path)
872

873
    # Load data files.
874
    edge_lists = pickle.load(tf.io.gfile.GFile(base_path + '.graph', 'rb'))
875
    allx = PlanetoidGraphData.load_x(base_path + '.allx')
876
    ally = np.load(tf.io.gfile.GFile(base_path + '.ally', 'rb'),
877
                   allow_pickle=True)
878

879
    testx = PlanetoidGraphData.load_x(base_path + '.tx')
880

881
    # Add test
882
    test_idx = list(
883
        map(int, tf.io.gfile.GFile(
884
            base_path + '.test.index').read().split('\n')[:-1]))
885

886
    num_test_examples = max(test_idx) - min(test_idx) + 1
887
    sparse_zeros = scipy.sparse.csr_matrix((num_test_examples, allx.shape[1]),
888
                                           dtype='float32')
889

890
    allx = scipy.sparse.vstack((allx, sparse_zeros))
891
    llallx = allx.tolil()
892
    llallx[test_idx] = testx
893
    self._allx = as_tensor(np.array(llallx.todense()))
894

895
    testy = np.load(tf.io.gfile.GFile(base_path + '.ty', 'rb'),
896
                    allow_pickle=True)
897
    ally = np.pad(ally, [(0, num_test_examples), (0, 0)], mode='constant')
898
    ally[test_idx] = testy
899

900
    self._num_nodes = len(edge_lists)
901
    self._num_classes = ally.shape[1]
902
    self._node_labels = np.argmax(ally, axis=1)
903
    self._train_labels = self._node_labels + 0  # Copy.
904
    self._train_labels[test_idx] = -1
905
    self._node_labels = as_tensor(self._node_labels)
906
    self._train_labels = as_tensor(self._train_labels)
907
    self._test_idx = as_tensor(np.array(test_idx, dtype='int32'))
908
    self._node_split = None  # Populated on `node_split()`
909

910
    # Will be used to construct (sparse) adjacency matrix.
911
    adj_src = []
912
    adj_target = []
913
    for node, neighbors in edge_lists.items():
914
      adj_src.extend([node] * len(neighbors))
915
      adj_target.extend(neighbors)
916

917
    self._edge_list = as_tensor(np.stack([adj_src, adj_target], axis=0))
918

919
  @property
920
  def name(self):
921
    return self._dataset_name
922

923
  @staticmethod
924
  def load_x(filename):
925
    if sys.version_info > (3, 0):
926
      return pickle.load(tf.io.gfile.GFile(filename, 'rb'), encoding='latin1')
927
    else:
928
      return np.load(tf.io.gfile.GFile(filename))
929

930
  def num_classes(self):
931
    return self._num_classes
932

933
  def node_features_dicts_without_labels(self):
934
    features = {'feat': self._allx}
935
    features['#id'] = tf.range(self._num_nodes, dtype=tf.int32)
936
    return {tfgnn.NODES: features}
937

938
  def node_counts(self):
939
    return {tfgnn.NODES: self._num_nodes}
940

941
  def edge_lists(self):
942
    return {(tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): self._edge_list}
943

944
  def node_split(self):
945
    if self._node_split is None:
946
      # By default, we mimic Planetoid & GCN setup -- i.e., 20 labels per class.
947
      labels_per_class = int(os.environ.get('PLANETOID_LABELS_PER_CLASS', '20'))
948
      num_train_nodes = labels_per_class * self.num_classes()
949
      num_validation_nodes = 500
950
      train_ids = tf.range(num_train_nodes, dtype=tf.int32)
951
      validation_ids = tf.range(
952
          num_train_nodes,
953
          num_train_nodes + num_validation_nodes, dtype=tf.int32)
954
      self._node_split = NodeSplit(train=train_ids, validation=validation_ids,
955
                                   test=self._test_idx)
956
    return self._node_split
957

958
  @property
959
  def labeled_nodeset(self):
960
    return tfgnn.NODES
961

962
  def labels(self):
963
    return self._train_labels
964

965
  def test_labels(self):
966
    """int numpy array of length num_nodes containing train and test labels."""
967
    return self._node_labels
968

969

970
class LinkPredictionGraphData(InMemoryGraphData):
971
  """Superclasses must wrap dataset of graph(s) for link-prediction tasks."""
972

973
  @abc.abstractmethod
974
  def edge_split(self):
975
    """Returns edge endpoints for {train, test, valid} partitions."""
976
    raise NotImplementedError()
977

978
  @property
979
  @abc.abstractmethod
980
  def target_edgeset(self):
981
    """Name of edge set over which link prediction is defined."""
982
    raise NotImplementedError()
983

984
  @property
985
  def source_node_set_name(self):
986
    """Node set name of source node of (task) target_edgeset."""
987
    return self.graph_schema().edge_sets[self.target_edgeset].source
988

989
  @property
990
  def target_node_set_name(self):
991
    """Node set name of target node of (task) target_edgeset."""
992
    return self.graph_schema().edge_sets[self.target_edgeset].target
993

994
  @property
995
  def num_source_nodes(self):
996
    """Number of nodes in the source endpoint of (task) target_edgeset."""
997
    return self.node_counts()[self.source_node_set_name]
998

999
  @property
1000
  def num_target_nodes(self):
1001
    """Number of nodes in the target endpoint of (task) target_edgeset."""
1002
    return self.node_counts()[self.target_node_set_name]
1003

1004

1005
class OgblData(LinkPredictionGraphData):
1006
  """Wraps link prediction datasets of ogbl-* for in-memory learning."""
1007

1008
  def __init__(self, dataset_name, cache_dir = None):
1009
    super().__init__()
1010
    self._dataset_name = dataset_name
1011
    if cache_dir is None:
1012
      cache_dir = os.environ.get(
1013
          'OGB_CACHE_DIR', os.path.expanduser(os.path.join('~', 'data', 'ogb')))
1014

1015
    self._ogb_dataset = _get_ogbl_dataset(dataset_name, cache_dir)
1016

1017
    ogb_edge_dict = self._ogb_dataset.get_edge_split()
1018
    self._edge_split = EdgeSplit(
1019
        train_edges=as_tensor(ogb_edge_dict['train']['edge']),
1020
        validation_edges=as_tensor(ogb_edge_dict['train']['edge']),
1021
        test_edges=as_tensor(ogb_edge_dict['test']['edge']),
1022
        negative_validation_edges=as_tensor(ogb_edge_dict['valid']['edge_neg']),
1023
        negative_test_edges=as_tensor(ogb_edge_dict['test']['edge_neg']))
1024

1025
    self._ogb_graph = _OgbGraph(self._ogb_dataset.graph)
1026

1027
  @property
1028
  def name(self):
1029
    return self._dataset_name
1030

1031
  def node_features_dicts(self, add_id = True):
1032
    features = self._ogb_graph.node_feat_dict
1033
    # 2-level dict shallow copy. Inner value stores reference to tf.Tensor,
1034
    features = {node_set_name: copy.copy(features)
1035
                for node_set_name, features in features.items()}
1036
    if add_id:
1037
      counts = self.node_counts()
1038
      for node_set_name, feats in features.items():
1039
        feats['#id'] = tf.range(counts[node_set_name], dtype=tf.int32)  # pytype: disable=unsupported-operands  # always-use-property-annotation
1040
    return features
1041

1042
  def node_counts(self):
1043
    return dict(self._ogb_graph.num_nodes_dict)  # Return copy.
1044

1045
  def edge_lists(self):
1046
    return dict(self._ogb_graph.edge_index_dict)  # Return shallow copy.
1047

1048
  def edge_split(self):
1049
    return self._edge_split
1050

1051
  @property
1052
  def target_edgeset(self):
1053
    return tfgnn.EDGES
1054

1055

1056
def get_in_memory_graph_data(dataset_name):
1057
  if dataset_name.startswith('ogbn-'):
1058
    return OgbnData(dataset_name)
1059
  elif dataset_name.startswith('ogbl-'):
1060
    return OgblData(dataset_name)
1061
  elif dataset_name in ('cora', 'citeseer', 'pubmed'):
1062
    return PlanetoidGraphData(dataset_name)
1063
  else:
1064
    raise ValueError('Unknown Dataset name: ' + dataset_name)
1065

1066

1067
# Shorthand. Can be replaced with: `as_tensor = tf.convert_to_tensor`.
1068
def as_tensor(obj):
1069
  """short-hand for tf.convert_to_tensor."""
1070
  return tf.convert_to_tensor(obj)
1071

1072

1073
def load_ogbn_graph_tensor(
1074
    dataset_path, *, add_reverse_edge_sets = False
1075
):
1076
  """Load OGBN graph data as a graph tensor from numpy compressed (.npz) files.
1077

1078
  To generate the .npz files from the original OGB dataset, please refer to
1079
  tensorflow_gnn/converters/ogb/convert_ogb_to_npz.py
1080

1081
  Args:
1082
    dataset_path: Path to the saved OGBN numpy compressed (.npz) files.
1083
    add_reverse_edge_sets: Flag to determine whether to add reversed edge sets.
1084

1085
  Returns:
1086
    A tfgnn.GraphTensor comprising of the full OGBN graph loaded in-memory.
1087
  """
1088
  graph_data = NodeClassificationGraphData.load(dataset_path)
1089
  graph_data = graph_data.with_labels_as_features(True)
1090
  if add_reverse_edge_sets:
1091
    graph_data = graph_data.with_reverse_edge_sets()
1092
  graph_tensor = graph_data.as_graph_tensor()
1093
  return graph_tensor
1094

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.