google-research

Форк
0
/
datasets.py 
591 строка · 19.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Dataset definitions for in-memory graph structure learning."""
17
import copy
18
import io
19
import math
20
import os
21
import random
22
from typing import List, Mapping, MutableMapping, Tuple
23

24
import numpy as np
25
import scipy.sparse
26
import tensorflow as tf
27
import tensorflow_gnn as tfgnn
28
import tensorflow_hub as tfhub
29

30
from ugsl import tfgnn_datasets
31

32

33
class GSLGraphData:
34
  """Wraps graph datasets to be used for graph structure learning.
35

36
  GSLGraphData can take a given tensor as a generated adjacency and incorporate
37
  it in the graph tensow.
38
  """
39

40
  def __init__(
41
      self,
42
      remove_noise_ratio = 0.0,
43
      add_noise_ratio = 0.0,
44
  ):
45
    super().__init__()
46
    # Saving the generated noisy adjacency to reuse.
47
    self._cached_noisy_adjacency = None
48
    self._input_gt = self.as_graph_tensor_noisy_adjacency(
49
        remove_noise_ratio=remove_noise_ratio, add_noise_ratio=add_noise_ratio
50
    )
51

52
  def node_sets(self):
53
    raise NotImplementedError
54

55
  def splits(self):
56
    return copy.copy(self._splits)
57

58
  def num_classes(self):
59
    raise NotImplementedError('num_classes')
60

61
  def node_split(self):
62
    raise NotImplementedError()
63

64
  def labels(self):
65
    raise NotImplementedError()
66

67
  def test_labels(self):
68
    raise NotImplementedError()
69

70
  @property
71
  def labeled_nodeset(self):
72
    raise NotImplementedError()
73

74
  def node_features_dicts_without_labels(
75
      self,
76
  ):
77
    raise NotImplementedError()
78

79
  def edge_lists(
80
      self,
81
  ):
82
    raise NotImplementedError()
83

84
  def as_graph_tensor(self):
85
    raise NotImplementedError()
86

87
  def node_features_dicts(
88
      self,
89
  ):
90
    raise NotImplementedError()
91

92
  def get_input_graph_tensor(self):
93
    return self._input_gt
94

95
  def as_graph_tensor_given_adjacency(
96
      self,
97
      adjacency_tensor,
98
      edge_weights,
99
      node_features,
100
      make_undirected = False,
101
      add_self_loops = False,
102
  ):
103
    """Returns `GraphTensor` holding the entire graph."""
104
    return tfgnn.GraphTensor.from_pieces(
105
        node_sets=self.node_sets_given_features(node_features),
106
        edge_sets=self.edge_sets_given_adjacency(
107
            adjacency_tensor,
108
            edge_weights,
109
            make_undirected,
110
            add_self_loops,
111
        ),
112
        context=self.context(),
113
    )
114

115
  def node_sets_given_features(
116
      self, node_features
117
  ):
118
    """Returns node sets of entire graph (dict: node set name -> NodeSet)."""
119
    node_counts = self.node_counts()
120
    features_dicts = self.node_features_dicts()
121
    node_set_names = set(node_counts.keys()).union(features_dicts.keys())
122
    return {
123
        name: tfgnn.NodeSet.from_fields(
124
            sizes=tf.convert_to_tensor([node_counts[name]]),
125
            features={'feat': node_features})
126
        for name in node_set_names
127
    }
128

129
  def edge_sets_given_adjacency(
130
      self,
131
      edge_list,
132
      edge_weights,
133
      make_undirected = False,
134
      add_self_loops = False,
135
  ):
136
    """Returns edge sets of entire graph (dict: edge set name -> EdgeSet)."""
137
    if make_undirected:
138
      edge_list = tf.concat([edge_list, edge_list[::-1]], axis=-1)
139
      edge_weights = tf.concat([edge_weights, edge_weights[::-1]], axis=-1)
140
    if add_self_loops:
141
      node_counts = self.node_counts()
142
      all_nodes = tf.range(node_counts[tfgnn.NODES], dtype=edge_list.dtype)
143
      self_connections = tf.stack([all_nodes, all_nodes], axis=0)
144
      # The following line adds self_connections to the existing edges.
145
      # It is possible for an edge to be both available in the edge_list and
146
      # also in the self_connections.
147
      edge_list = tf.concat([edge_list, self_connections], axis=-1)
148
      edge_weights = tf.concat(
149
          [edge_weights, tf.ones(node_counts[tfgnn.NODES])], axis=-1
150
      )
151
    return {
152
        tfgnn.EDGES: tfgnn.EdgeSet.from_fields(
153
            sizes=tf.shape(edge_list)[1:2],
154
            adjacency=tfgnn.Adjacency.from_indices(
155
                source=(tfgnn.NODES, edge_list[0]),
156
                target=(tfgnn.NODES, edge_list[1]),
157
            ),
158
            features={'weights': edge_weights},
159
        )
160
    }
161

162
  def as_graph_tensor_noisy_adjacency(
163
      self,
164
      remove_noise_ratio,
165
      add_noise_ratio,
166
      make_undirected = False,
167
      add_self_loops = False,
168
  ):
169
    """Returns `GraphTensor` holding the entire graph."""
170
    return tfgnn.GraphTensor.from_pieces(
171
        node_sets=self.node_sets(),
172
        edge_sets=self.edge_sets_noisy_adjacency(
173
            add_noise_ratio=add_noise_ratio,
174
            remove_noise_ratio=remove_noise_ratio,
175
            make_undirected=make_undirected,
176
            add_self_loops=add_self_loops,
177
        ),
178
        context=self.context(),
179
    )
180

181
  def edge_sets_noisy_adjacency(
182
      self,
183
      add_noise_ratio,
184
      remove_noise_ratio,
185
      make_undirected = False,
186
      add_self_loops = False,
187
  ):
188
    """Returns noisy edge sets of entire graph (dict: edge set name -> EdgeSet)."""
189
    if self._cached_noisy_adjacency:
190
      return self._cached_noisy_adjacency
191
    edge_sets = {}
192
    node_counts = self.node_counts()
193
    for edge_type, edge_list in self.edge_lists().items():
194
      (source_node_set_name, edge_set_name, target_node_set_name) = edge_type
195
      number_of_nodes = node_counts[source_node_set_name]
196
      sources = edge_list[0].numpy()
197
      targets = edge_list[1].numpy()
198
      number_of_edges = len(sources)
199
      if add_noise_ratio:
200
        number_of_edges_to_add = math.floor(
201
            ((number_of_nodes * number_of_nodes) / 2 - number_of_edges)
202
            * add_noise_ratio
203
        )
204
        sources_to_add = np.array(
205
            random.choices(range(number_of_nodes), k=number_of_edges_to_add)
206
        )
207
        targets_to_add = np.array(
208
            random.choices(range(number_of_nodes), k=number_of_edges_to_add)
209
        )
210
      else:
211
        sources_to_add, targets_to_add = np.array([]), np.array([])
212
      if remove_noise_ratio:
213
        number_of_edges_to_remove = math.floor(
214
            number_of_edges * remove_noise_ratio
215
        )
216
        edge_indices_to_remove = random.sample(
217
            range(0, number_of_edges), number_of_edges_to_remove
218
        )
219
        noisy_sources = np.delete(sources, edge_indices_to_remove)
220
        noisy_targets = np.delete(targets, edge_indices_to_remove)
221
      else:
222
        noisy_sources, noisy_targets = sources, targets
223
      noisy_sources = tf.constant(
224
          np.concatenate((noisy_sources, sources_to_add)), dtype=tf.int32
225
      )
226
      noisy_targets = tf.constant(
227
          np.concatenate((noisy_targets, targets_to_add)), dtype=tf.int32
228
      )
229
      edge_list = tf.stack([noisy_sources, noisy_targets])
230
      if make_undirected:
231
        edge_list = tf.concat([edge_list, edge_list[::-1]], axis=-1)
232
      if add_self_loops:
233
        all_nodes = tf.range(number_of_nodes, dtype=edge_list.dtype)
234
        self_connections = tf.stack([all_nodes, all_nodes], axis=0)
235
        edge_list = tf.concat([edge_list, self_connections], axis=-1)
236
      edge_sets[edge_set_name] = tfgnn.EdgeSet.from_fields(
237
          sizes=tf.shape(edge_list)[1:2],
238
          adjacency=tfgnn.Adjacency.from_indices(
239
              source=(source_node_set_name, edge_list[0]),
240
              target=(target_node_set_name, edge_list[1]),
241
          ),
242
      )
243
    self._cached_noisy_adjacency = edge_sets
244
    return edge_sets
245

246

247
class GSLPlanetoidGraphData(tfgnn_datasets.PlanetoidGraphData, GSLGraphData):
248
  """Wraps Planetoid graph datasets to be used for graph structure learning.
249

250
  Besides the initial input adjacency matrix, GSLGraphData can take a given
251
  tensor as a generated adjacency and incorporate it in the graph tensow.
252
  """
253

254
  def __init__(
255
      self,
256
      dataset_name,
257
      remove_noise_ratio,
258
      add_noise_ratio,
259
  ):
260
    tfgnn_datasets.PlanetoidGraphData.__init__(self, dataset_name)
261
    GSLGraphData.__init__(
262
        self,
263
        remove_noise_ratio=remove_noise_ratio,
264
        add_noise_ratio=add_noise_ratio,
265
    )
266

267

268
class GcnBenchmarkFileGraphData(tfgnn_datasets.NodeClassificationGraphData):
269
  """Adapt npz with format of github.com/shchur/gnn-benchmark into TF-GNN.
270

271
  NOTE: This can be moved to TF-GNN (tfgnn/experimental/in_memory/datasets.py).
272
  """
273

274
  def __init__(self, dataset_path):
275
    """Loads .npz file following shchur's format."""
276
    if not tf.io.gfile.exists(dataset_path):
277
      raise ValueError('Dataset file not found: ' + dataset_path)
278

279
    adj_matrix, attr_matrix, labels, label_mask = _load_npz_to_sparse_graph(
280
        dataset_path)
281
    del label_mask
282

283
    edge_indices = tf.convert_to_tensor(adj_matrix.nonzero())
284
    self._edge_lists = {(tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): edge_indices}
285

286
    num_nodes = attr_matrix.shape[0]
287
    self._node_features_dicts = {
288
        tfgnn.NODES: {
289
            'feat': tf.convert_to_tensor(attr_matrix),
290
            '#id': tf.range(num_nodes),
291
        }
292
    }
293
    self._node_counts = {tfgnn.NODES: num_nodes}
294
    self._num_classes = labels.max() + 1
295
    self._test_labels = tf.convert_to_tensor(labels)
296

297
    permutation = np.random.default_rng(seed=1234).permutation(num_nodes)
298
    num_train_examples = num_nodes // 10
299
    num_validate_examples = num_nodes // 10
300
    train_indices = permutation[:num_train_examples]
301
    num_validate_plus_train = num_validate_examples + num_train_examples
302
    validate_indices = permutation[num_train_examples:num_validate_plus_train]
303
    test_indices = permutation[num_validate_plus_train:]
304

305
    self._node_split = tfgnn_datasets.NodeSplit(
306
        tf.convert_to_tensor(train_indices),
307
        tf.convert_to_tensor(validate_indices),
308
        tf.convert_to_tensor(test_indices))
309

310
    self._train_labels = labels + 0  # Make a copy.
311
    self._train_labels[test_indices] = -1
312
    self._train_labels = tf.convert_to_tensor(self._train_labels)
313
    super().__init__()
314

315
  def node_counts(self):
316
    return self._node_counts
317

318
  def edge_lists(self):
319
    return self._edge_lists
320

321
  def num_classes(self):
322
    return self._num_classes
323

324
  def node_split(self):
325
    return self._node_split
326

327
  def labels(self):
328
    return self._train_labels
329

330
  def test_labels(self):
331
    return self._test_labels
332

333
  @property
334
  def labeled_nodeset(self):
335
    return tfgnn.NODES
336

337
  def node_features_dicts_without_labels(self):
338
    return self._node_features_dicts
339

340

341
_maybe_download_file = tfgnn_datasets._maybe_download_file  # pylint: disable=protected-access
342

343

344
class GcnBenchmarkUrlGraphData(GcnBenchmarkFileGraphData):
345

346
  def __init__(
347
      self, npz_url,
348
      cache_dir = os.path.expanduser(
349
          os.path.join('~', 'data', 'gnn-benchmark'))):
350
    destination_url = os.path.join(cache_dir, os.path.basename(npz_url))
351
    _maybe_download_file(npz_url, destination_url)
352
    super().__init__(destination_url)
353

354

355
def _load_npz_to_sparse_graph(file_name):
356
  """Copied from experimental/users/tsitsulin/gcns/cgcn/utilities/graph.py."""
357
  file_bytes = tf.io.gfile.GFile(file_name, 'rb').read()
358
  bytes_io = io.BytesIO(file_bytes)
359
  with np.load(bytes_io, allow_pickle=True) as fin:
360
    loader = dict(fin)
361
    adj_matrix = scipy.sparse.csr_matrix(
362
        (loader['adj_data'], loader['adj_indices'], loader['adj_indptr']),
363
        shape=loader['adj_shape'])
364

365
    if 'attr_data' in loader:
366
      # Attributes are stored as a sparse CSR matrix
367
      attr_matrix = scipy.sparse.csr_matrix(
368
          (loader['attr_data'], loader['attr_indices'],
369
           loader['attr_indptr']),
370
          shape=loader['attr_shape']).todense()
371
    elif 'attr_matrix' in loader:
372
      # Attributes are stored as a (dense) np.ndarray
373
      attr_matrix = loader['attr_matrix']
374
    else:
375
      raise ValueError('No attributes in the data file: ' + file_name)
376

377
    if 'labels_data' in loader:
378
      # Labels are stored as a CSR matrix
379
      labels = scipy.sparse.csr_matrix(
380
          (loader['labels_data'], loader['labels_indices'],
381
           loader['labels_indptr']),
382
          shape=loader['labels_shape'])
383
      label_mask = labels.nonzero()[0]
384
      labels = labels.nonzero()[1]
385
    elif 'labels' in loader:
386
      # Labels are stored as a numpy array
387
      labels = loader['labels']
388
      label_mask = np.ones(labels.shape, dtype=np.bool_)
389
    else:
390
      raise ValueError('No labels in the data file: ' + file_name)
391

392
  return adj_matrix, attr_matrix, labels, label_mask
393

394

395
class GSLAmazonPhotosGraphData(GcnBenchmarkUrlGraphData, GSLGraphData):
396
  """Wraps GCN Benchmark datasets to be used for graph structure learning."""
397

398
  def __init__(
399
      self,
400
      dataset_name,
401
      remove_noise_ratio,
402
      add_noise_ratio,
403
  ):
404
    GcnBenchmarkUrlGraphData.__init__(
405
        self,
406
        'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'
407
        'amazon_electronics_photo.npz')
408
    GSLGraphData.__init__(
409
        self,
410
        remove_noise_ratio=remove_noise_ratio,
411
        add_noise_ratio=add_noise_ratio,
412
    )
413

414

415
class StackOverflowGraphlessData(tfgnn_datasets.NodeClassificationGraphData):
416
  """Stackoverflow dataset contains node features and labels (but no edges)."""
417

418
  def __init__(
419
      self, cache_dir = os.path.expanduser(
420
          os.path.join('~', 'data', 'stackoverflow-bert'))):
421
    labels_path = os.path.join(cache_dir, 'labels.npy')
422
    embeddings_path = os.path.join(cache_dir, 'embeddings.npy')
423

424
    if (not tf.io.gfile.exists(labels_path) or
425
        not tf.io.gfile.exists(embeddings_path)):
426
      if not tf.io.gfile.exists(cache_dir):
427
        tf.io.gfile.makedirs(cache_dir)
428
      # Download.
429
      self._download_dataset_extract_features(labels_path, embeddings_path)
430

431
    node_features = np.load(tf.io.gfile.GFile(embeddings_path, 'rb'))
432
    node_labels = np.load(tf.io.gfile.GFile(labels_path, 'rb'))
433
    num_nodes = node_features.shape[0]
434
    self._node_counts = {tfgnn.NODES: num_nodes}
435
    self._num_classes = node_labels.max() + 1
436
    self._test_labels = tf.convert_to_tensor(node_labels)
437
    self._edge_lists = {
438
        (tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): (
439
            tf.zeros(shape=[2, 0], dtype=tf.int32))}
440
    self._node_features_dicts = {
441
        tfgnn.NODES: {
442
            'feat': tf.convert_to_tensor(node_features, dtype=tf.float32),
443
            '#id': tf.range(num_nodes),
444
        }
445
    }
446
    permutation = np.random.default_rng(seed=1234).permutation(num_nodes)
447
    num_train_examples = num_nodes // 10
448
    num_validate_examples = num_nodes // 10
449
    train_indices = permutation[:num_train_examples]
450
    num_validate_plus_train = num_validate_examples + num_train_examples
451
    validate_indices = permutation[num_train_examples:num_validate_plus_train]
452
    test_indices = permutation[num_validate_plus_train:]
453

454
    self._node_split = tfgnn_datasets.NodeSplit(
455
        tf.convert_to_tensor(train_indices),
456
        tf.convert_to_tensor(validate_indices),
457
        tf.convert_to_tensor(test_indices))
458

459
    self._train_labels = node_labels + 0  # Make a copy.
460
    self._train_labels[test_indices] = -1
461
    self._train_labels = tf.convert_to_tensor(self._train_labels)
462
    super().__init__()
463

464
  def node_counts(self):
465
    return self._node_counts
466

467
  def edge_lists(self):
468
    return self._edge_lists
469

470
  def num_classes(self):
471
    return self._num_classes
472

473
  def node_split(self):
474
    return self._node_split
475

476
  def labels(self):
477
    return self._train_labels
478

479
  def test_labels(self):
480
    return self._test_labels
481

482
  @property
483
  def labeled_nodeset(self):
484
    return tfgnn.NODES
485

486
  def node_features_dicts_without_labels(self):
487
    return self._node_features_dicts
488

489
  def _download_dataset_extract_features(
490
      self, labels_path, embeddings_path):
491
    cache_dir = os.path.dirname(labels_path)
492
    url = ('https://raw.githubusercontent.com/rashadulrakib/'
493
           'short-text-clustering-enhancement/master/data/stackoverflow/'
494
           'traintest')
495
    tab_separated_filepath = os.path.join(cache_dir, 'traintest.tsv')
496
    _maybe_download_file(url, tab_separated_filepath)
497

498
    data_cluster = {}
499
    with tf.io.gfile.GFile(tab_separated_filepath, 'r') as f:
500
      for line in f:
501
        l1, l2, text = line.strip().split('\t')
502
        data_cluster[text] = (int(l1), int(l2))
503

504
    def remove_cls_sep(masks):
505
      last_1s = np.sum(masks, axis=1) - 1
506
      for i in range(masks.shape[0]):
507
        masks[i][0] = 0
508
        masks[i][last_1s[i]] = 0
509
      return masks
510

511
    def bert_embs(texts):
512
      text_preprocessed = bert_preprocess_model(texts)
513
      bert_results = bert_model(text_preprocessed)
514
      masks = np.expand_dims(
515
          remove_cls_sep(text_preprocessed['input_mask'].numpy()), axis=2)
516
      emb = (np.sum(bert_results['sequence_output'].numpy() * masks, axis=1)
517
             / np.sum(masks, axis=1))
518
      return emb
519

520
    # Instantiate BERT model.
521
    bert_preprocess_model = tfhub.KerasLayer(
522
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')
523
    bert_model = tfhub.KerasLayer(
524
        'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3')
525

526
    # Map keys of `cluster` through `bert_model``
527
    data_cluster_keys = list(data_cluster.keys())
528
    embeddings = []
529
    for i in range(0, len(data_cluster_keys), 100):
530
      embeddings.append(bert_embs(data_cluster_keys[i:i+100]))
531
    embeddings = np.vstack(embeddings)
532
    labels = np.array([data_cluster[t][1] for t in data_cluster_keys])
533

534
    with tf.io.gfile.GFile(labels_path, 'wb') as fout:
535
      np.save(fout, labels)
536

537
    with tf.io.gfile.GFile(embeddings_path, 'wb') as fout:
538
      np.save(fout, embeddings)
539

540

541
class GSLStackOverflowGraphlessData(StackOverflowGraphlessData, GSLGraphData):
542
  """Wraps Stackoverflow datasets to be used for graph structure learning."""
543

544
  def __init__(
545
      self,
546
      remove_noise_ratio,
547
      add_noise_ratio,
548
      cache_dir = os.path.expanduser(
549
          os.path.join('~', 'data', 'stackoverflow-bert'))):
550
    StackOverflowGraphlessData.__init__(self, cache_dir=cache_dir)
551
    GSLGraphData.__init__(
552
        self, remove_noise_ratio=remove_noise_ratio,
553
        add_noise_ratio=add_noise_ratio)
554

555

556
def get_in_memory_graph_data(
557
    dataset_name,
558
    remove_noise_ratio,
559
    add_noise_ratio,
560
):
561
  """Getting the dataset based on the name.
562

563
  Args:
564
    dataset_name: the name of the dataset to prepare.
565
    remove_noise_ratio: ratio of the existing edge to remove.
566
    add_noise_ratio: ratio of the non-existing edges to add.
567

568
  Returns:
569
    The graph data to be used in training.
570
  Raises:
571
    ValueError: if the name of the dataset is not defined.
572
  """
573
  if dataset_name in ('cora', 'citeseer', 'pubmed'):
574
    return GSLPlanetoidGraphData(
575
        dataset_name,
576
        remove_noise_ratio=remove_noise_ratio,
577
        add_noise_ratio=add_noise_ratio,
578
    )
579
  elif dataset_name == 'amazon_photos':
580
    return GSLAmazonPhotosGraphData(
581
        dataset_name,
582
        remove_noise_ratio=remove_noise_ratio,
583
        add_noise_ratio=add_noise_ratio,
584
    )
585
  elif dataset_name == 'stackoverflow':
586
    return GSLStackOverflowGraphlessData(
587
        remove_noise_ratio=remove_noise_ratio,
588
        add_noise_ratio=add_noise_ratio,
589
    )
590
  else:
591
    raise ValueError('Unknown Dataset name: ' + dataset_name)
592

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.