google-research

Форк
0
/
merger.py 
218 строк · 6.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Merger layer of the GSL Layer.
17

18
This step merges an input graph with a generated graph and returns a
19
GraphTensor as the final output.
20
"""
21
import tensorflow as tf
22

23
from ugsl import datasets
24

25

26
@tf.keras.utils.register_keras_serializable(package="GSL")
27
class Merger(tf.keras.layers.Layer):
28

29
  def __init__(self, graph_data):
30
    super().__init__()
31
    self._graph_data = graph_data
32

33
  def get_config(self):
34
    return dict(graph_data=self._graph_data, **super().get_config())
35

36

37
class WeightedSum(Merger):
38
  """Sums a generated adjacency with a given adjacency into a GraphTensor."""
39

40
  def __init__(
41
      self,
42
      graph_data,
43
      dropout_rate,
44
      given_adjacency_weight = 1.0,
45
  ):
46
    super().__init__(graph_data)
47
    self._dropout_rate = dropout_rate
48
    self._dropout_layer = tf.keras.layers.Dropout(dropout_rate)
49
    self._given_adjacency_weight = given_adjacency_weight
50

51
  def call(self, inputs):
52
    graph_structure = inputs[0]
53
    node_embeddings = inputs[1]
54
    noisy_gt = self._graph_data.get_input_graph_tensor()
55
    given_noisy_sources = noisy_gt.edge_sets["edges"].adjacency.source
56
    given_noisy_targets = noisy_gt.edge_sets["edges"].adjacency.target
57
    noisy_sources = tf.concat(
58
        (graph_structure.sources, given_noisy_sources), axis=0
59
    )
60
    noisy_targets = tf.concat(
61
        (graph_structure.targets, given_noisy_targets), axis=0
62
    )
63
    noisy_weights = tf.concat(
64
        (
65
            graph_structure.weights,
66
            self._given_adjacency_weight * tf.ones(given_noisy_sources.shape),
67
        ),
68
        axis=0,
69
    )
70
    graph_tensor = self._graph_data.as_graph_tensor_given_adjacency(
71
        [noisy_sources, noisy_targets],
72
        edge_weights=self._dropout_layer(noisy_weights),
73
        node_features=node_embeddings,
74
    )
75
    return graph_tensor
76

77
  def get_config(self):
78
    return dict(
79
        dropout_rate=self._dropout_rate,
80
        given_adjacency_weight=self._given_adjacency_weight,
81
        **super().get_config(),
82
    )
83

84

85
class ToGraphTensor(Merger):
86
  """ToGraphTensor converts an adjacency in the form of rows, columns, and weights into a GraphTensor."""
87

88
  def __init__(
89
      self,
90
      graph_data,
91
      dropout_rate,
92
      **kwargs,
93
  ):
94
    super().__init__(graph_data)
95
    self._dropout_rate = dropout_rate
96
    self._dropout_layer = tf.keras.layers.Dropout(dropout_rate)
97

98
  def call(self, inputs):
99
    graph_structure = inputs[0]
100
    node_embeddings = inputs[1]
101
    graph_tensor = self._graph_data.as_graph_tensor_given_adjacency(
102
        [graph_structure.sources, graph_structure.targets],
103
        edge_weights=self._dropout_layer(graph_structure.weights),
104
        node_features=node_embeddings,
105
    )
106
    return graph_tensor
107

108
  def get_config(self):
109
    return dict(dropout_rate=self._dropout_rate, **super().get_config())
110

111

112
class RandomGraphTensor(Merger):
113
  """Generates a random graph tensor to be tested as baseline in the framework."""
114

115
  def __init__(
116
      self,
117
      graph_data,
118
      dropout_rate,
119
      **kwargs,
120
  ):
121
    super().__init__(graph_data)
122
    self._graph_data = graph_data
123
    self._dropout_rate = dropout_rate
124
    self._dropout_layer = tf.keras.layers.Dropout(dropout_rate)
125
    input_gt = self._graph_data.get_input_graph_tensor()
126
    number_of_edges = input_gt.edge_sets["edges"].adjacency.source.shape[0]
127
    number_of_nodes = input_gt.node_sets["nodes"].features["feat"].shape[0]
128
    self._random_sources = tf.random.uniform(
129
        shape=(number_of_edges,),
130
        minval=0,
131
        maxval=number_of_nodes,
132
        dtype=tf.int32,
133
    )
134
    self._random_targets = tf.random.uniform(
135
        shape=(number_of_edges,),
136
        minval=0,
137
        maxval=number_of_nodes,
138
        dtype=tf.int32,
139
    )
140
    self._random_weights = tf.random.uniform(
141
        shape=(number_of_edges,), minval=0, maxval=1.0, dtype=tf.float32
142
    )
143

144
  def call(self, inputs):
145
    node_embeddings = inputs[1]
146
    graph_tensor = self._graph_data.as_graph_tensor_given_adjacency(
147
        tf.stack([self._random_sources, self._random_targets], axis=0),
148
        edge_weights=self._dropout_layer(self._random_weights),
149
        node_features=node_embeddings,
150
    )
151
    return graph_tensor
152

153
  def get_config(self):
154
    return dict(
155
        dropout_rate=self._dropout_rate,
156
        **super().get_config(),
157
    )
158

159

160
class InputGraphTensor(Merger):
161
  """Sums a generated adjacency with a given adjacency into a GraphTensor."""
162

163
  def __init__(
164
      self,
165
      graph_data,
166
      dropout_rate,
167
      **kwargs,
168
  ):
169
    super().__init__(graph_data)
170
    self._dropout_rate = dropout_rate
171
    self._dropout_layer = tf.keras.layers.Dropout(dropout_rate)
172

173
  def call(self, inputs):
174
    node_embeddings = inputs[1]
175
    noisy_gt = self._graph_data.get_input_graph_tensor()
176
    noisy_sources = noisy_gt.edge_sets["edges"].adjacency.source
177
    noisy_targets = noisy_gt.edge_sets["edges"].adjacency.target
178
    noisy_weights = tf.ones(noisy_sources.shape)
179

180
    graph_tensor = self._graph_data.as_graph_tensor_given_adjacency(
181
        [noisy_sources, noisy_targets],
182
        edge_weights=self._dropout_layer(noisy_weights),
183
        node_features=node_embeddings,
184
    )
185
    return graph_tensor
186

187
  def get_config(self):
188
    return dict(
189
        dropout_rate=self._dropout_rate,
190
        **super().get_config(),
191
    )
192

193

194
def get_merger(
195
    graph_data, name, **kwargs
196
):
197
  """Return the corresponding merger based on the name provided.
198

199
  Args:
200
    graph_data: the GSL graph data.
201
    name: name of the merger to use in the gsl framework.
202
    **kwargs:
203

204
  Returns:
205
    Merger associated to the provided name.
206
  Raises:
207
    ValueError: if the merger name is not defined.
208
  """
209
  if name == "none":
210
    return ToGraphTensor(graph_data, **kwargs)
211
  elif name == "weighted-sum":
212
    return WeightedSum(graph_data, **kwargs)
213
  elif name == "random":
214
    return RandomGraphTensor(graph_data, **kwargs)
215
  elif name == "input":
216
    return InputGraphTensor(graph_data, **kwargs)
217
  else:
218
    raise ValueError(f"Merger {name} is not defined.")
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.