google-research

Форк
0
/
regularizers.py 
378 строк · 11.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Regularizers defined in Unified GSL paper."""
17

18
import abc
19
from typing import Callable, Optional
20
from ml_collections import config_dict
21
import tensorflow as tf
22
import tensorflow_gnn as tfgnn
23

24

25
class BaseRegularizer(abc.ABC):
26
  """Base class for calculating regularization on model and label GraphTensors.
27

28
  Some regularizers only accept model GraphTensor (and ignore label).
29
  """
30

31
  @abc.abstractmethod
32
  def call(
33
      self,
34
      *,
35
      model_graph,
36
      label_graph = None,
37
      edge_set_name = tfgnn.EDGES,
38
      weights_feature_name = 'weights'
39
  ):
40
    pass
41

42
  def __call__(
43
      self,
44
      *,
45
      model_graph,
46
      label_graph = None,
47
      edge_set_name = tfgnn.EDGES,
48
      weights_feature_name = 'weights'
49
  ):
50
    return self.call(
51
        model_graph=model_graph,
52
        label_graph=label_graph,
53
        edge_set_name=edge_set_name,
54
        weights_feature_name=weights_feature_name,
55
    )
56

57

58
class ClosenessRegularizer(BaseRegularizer):
59
  """Call Returns ||A_model - A_label||_F^2."""
60

61
  def call(
62
      self,
63
      *,
64
      model_graph,
65
      label_graph = None,
66
      edge_set_name = tfgnn.EDGES,
67
      weights_feature_name = 'weights'
68
  ):
69
    assert label_graph is not None
70
    # If A and B where vectors (e.g., rasterized adjacency matrices):
71
    # ||A - B||_F^2 = ||A - B||^2_2 == (A-B)^T (A-B) = A^T A + B^T B - 2 A^T B
72
    # The first two terms of the RHS are easy to compute: sum-of-squares.
73
    # The last entry, however, require us to know the *common* edges in the two
74
    # graph tensors. For this, we sort the edges of one and use tf.searchsorted.
75
    # "EX:" stands for "Running Example".
76
    # EX: == [w6, w3, w24]
77
    model_weight = model_graph.edge_sets[edge_set_name][weights_feature_name]
78
    if weights_feature_name in label_graph.edge_sets[edge_set_name].features:
79
      label_weight = label_graph.edge_sets[edge_set_name][weights_feature_name]
80
    else:
81
      label_weight = tf.ones(
82
          label_graph.edge_sets[edge_set_name].sizes, dtype=tf.float32
83
      )
84

85
    assert (
86
        model_graph.edge_sets[edge_set_name].adjacency.source_name
87
        == label_graph.edge_sets[edge_set_name].adjacency.source_name
88
    )
89
    assert (
90
        model_graph.edge_sets[edge_set_name].adjacency.target_name
91
        == label_graph.edge_sets[edge_set_name].adjacency.target_name
92
    )
93

94
    tgt_name = model_graph.edge_sets[edge_set_name].adjacency.target_name
95
    src_name = model_graph.edge_sets[edge_set_name].adjacency.source_name
96

97
    # EX: == 5  (i.e., 5 nodes in each graph).
98
    size_target = tf.reduce_sum(model_graph.node_sets[tgt_name].sizes)
99
    # TODO(baharef): add an assert checking if the two graphs have the same
100
    # number of nodes.
101
    if tgt_name == src_name:
102
      size_source = size_target
103
    else:
104
      size_source = tf.reduce_sum(model_graph.node_sets[src_name].sizes)
105
      tf.assert_equal(
106
          size_source,
107
          tf.reduce_sum(label_graph.node_sets[src_name].sizes),
108
          'model_graph and label_graph have different number of source nodes.',
109
      )
110

111
    label_adj = label_graph.edge_sets[edge_set_name].adjacency
112

113
    # tf can sort vectors. We combine pairs of ints (source & target vectors) to
114
    # int vector by finding a suitable "base", multiplying the source by the
115
    # "base" and adding target.
116
    combined_label_indices = (  # EX:=[4, 0, 2, 0]*5+[4, 0, 1, 3]=[24, 0, 11, 3]
117
        # EX: source=[4, 0, 2, 0]       target=[4, 0, 1, 3]
118
        tf.cast(label_adj.source, tf.int64) * tf.cast(size_target, tf.int64)
119
        + tf.cast(label_adj.target, tf.int64)
120
    )
121
    model_adj = model_graph.edge_sets[edge_set_name].adjacency
122
    combined_model_indices = (  # EX: = [1, 0, 4]*5 + [1, 3, 4] = [6, 3, 24]
123
        # EX: source=[0, 1, 4]       target=[3, 1, 4].
124
        tf.cast(model_adj.source, tf.int64) * tf.cast(size_target, tf.int64)
125
        + tf.cast(model_adj.target, tf.int64)
126
    )
127

128
    # Add phantom node (to prevent gather on empty array). Excluded from "EX:".
129
    combined_label_indices = tf.concat(
130
        [
131
            combined_label_indices,
132
            tf.cast(tf.expand_dims(size_source * size_target, 0), tf.int64),
133
        ],
134
        axis=0,
135
    )
136
    label_weight = tf.concat(
137
        [label_weight, tf.zeros(1, dtype=label_weight.dtype)], 0
138
    )
139

140
    # EX: [1, 3, 2, 0]
141
    argsort = tf.argsort(combined_label_indices)
142
    # EX: [0, 3, 11, 24]
143
    sorted_combined_label_indices = tf.gather(combined_label_indices, argsort)
144
    # EX: [2, 1, 3]
145
    positions = tf.searchsorted(
146
        sorted_combined_label_indices, combined_model_indices
147
    )
148

149
    # Boolean array. Entry is set to True if edge in model `GraphTensor` is also
150
    # present in label `GraphTensor`.
151
    correct_positions = (  # EX: [False, True, True]
152
        # EX: [11, 3, 24]
153
        tf.gather(sorted_combined_label_indices, positions)
154
        # EX: [6, 3, 24]
155
        == combined_model_indices
156
    )
157

158
    # Order label weights, in an order matching edge order of model.
159
    label_weight_reordered = tf.gather(  # EX: [W11, W3, W24]
160
        tf.gather(  # EX: = [W0, W3, W11, W24]
161
            # EX: = [W24, W0, W11, W3]
162
            label_weight,
163
            argsort,
164
        ),
165
        positions,
166
    )
167
    if not model_weight.dtype.is_floating:
168
      model_weight = tf.cast(model_weight, tf.float32)
169
    if not label_weight_reordered.dtype.is_floating:
170
      label_weight_reordered = tf.cast(label_weight_reordered, tf.float32)
171
    a_times_b = (  # EX: 0*0 + w3*W3 + w24*W24
172
        # EX: [False, True, True] * [w6, w3, w24] == [0, w3, w24]
173
        tf.where(correct_positions, model_weight, tf.zeros_like(model_weight))
174
        * tf.where(  # EX: [False, True, True] * [W11, W3, W24] = [0, W3, W24]
175
            correct_positions,
176
            label_weight_reordered,
177
            tf.zeros_like(label_weight_reordered),
178
        )
179
    )
180

181
    regularizer = (
182
        tf.reduce_sum(model_weight**2)
183
        + tf.reduce_sum(label_weight**2)
184
        - 2 * tf.reduce_sum(a_times_b)
185
    )
186
    return regularizer
187

188

189
def euclidean_distance_squared(v1, v2):
190
  displacement = v1 - v2
191
  return tf.reduce_sum(displacement**2, axis=-1)
192

193

194
class SmoothnessRegularizer(BaseRegularizer):
195
  r"""Call Returns \sum_{ij} A_{ij} dist(v_i, v_j)."""
196

197
  def __init__(
198
      self,
199
      source_feature_name = tfgnn.HIDDEN_STATE,
200
      distance_fn = euclidean_distance_squared,
201
      target_feature_name = None,
202
      differentiable_wrt_features = False,
203
  ):
204
    self._distance_fn = distance_fn
205
    self._source_feature_name = source_feature_name
206
    self._target_feature_name = target_feature_name or source_feature_name
207
    self._differentiable_wrt_features = differentiable_wrt_features
208

209
  def call(
210
      self,
211
      *,
212
      model_graph,
213
      label_graph = None,
214
      edge_set_name = tfgnn.EDGES,
215
      weights_feature_name = 'weights'
216
  ):
217
    del label_graph
218
    edge_set = model_graph.edge_sets[edge_set_name]
219
    source_ns = edge_set.adjacency.source_name
220
    target_ns = edge_set.adjacency.target_name
221
    source_features = tf.gather(
222
        model_graph.node_sets[source_ns][self._source_feature_name],
223
        edge_set.adjacency.source,
224
    )
225
    target_features = tf.gather(
226
        model_graph.node_sets[target_ns][self._target_feature_name],
227
        edge_set.adjacency.target,
228
    )
229
    distance = self._distance_fn(source_features, target_features)
230
    if not self._differentiable_wrt_features:
231
      distance = tf.stop_gradient(distance)
232
    return tf.reduce_sum(edge_set[weights_feature_name] * distance)
233

234

235
class SparseConnectRegularizer(BaseRegularizer):
236
  """Call Returns ||A||_F^2."""
237

238
  def call(
239
      self,
240
      *,
241
      model_graph,
242
      label_graph = None,
243
      edge_set_name = tfgnn.EDGES,
244
      weights_feature_name = 'weights'
245
  ):
246
    del label_graph
247
    edge_set = model_graph.edge_sets[edge_set_name]
248
    return tf.reduce_sum(edge_set[weights_feature_name] ** 2)
249

250

251
class LogBarrier(BaseRegularizer):
252
  """Call returns -1^T . log (A . 1) == -log(A.sum(1)).sum(0)."""
253

254
  def call(
255
      self,
256
      *,
257
      model_graph,
258
      label_graph = None,
259
      edge_set_name = tfgnn.EDGES,
260
      weights_feature_name = 'weights'
261
  ):
262
    del label_graph
263
    weights = model_graph.edge_sets[edge_set_name][weights_feature_name]
264
    adj = model_graph.edge_sets[edge_set_name].adjacency
265
    src_name = model_graph.edge_sets[edge_set_name].adjacency.source_name
266
    num_src_nodes = tf.reduce_sum(model_graph.node_sets[src_name].sizes)
267
    column_sum = tf.math.unsorted_segment_sum(
268
        weights, adj.source, num_src_nodes
269
    )
270
    column_sum += 1e-5  # avoid infinity values.
271
    return -tf.reduce_sum(tf.math.log(column_sum))
272

273

274
class InformationRegularizer(BaseRegularizer):
275
  """Call returns A[i][j] * log (A[i][j]/r) + (1 - A[i][j]) * log ((1 - A[i][j])/(1 - r))."""
276

277
  def __init__(self, r, do_sigmoid):
278
    self._r = r
279
    self._do_sigmoid = do_sigmoid
280

281
  def call(
282
      self,
283
      *,
284
      model_graph,
285
      label_graph = None,
286
      edge_set_name = tfgnn.EDGES,
287
      weights_feature_name = 'weights',
288
  ):
289
    del label_graph
290
    weights = model_graph.edge_sets[edge_set_name][weights_feature_name]
291
    # If the weights are coming from a soft Bernoulli, a sigmoid has already
292
    # been applied on the weights.
293
    if self._do_sigmoid:
294
      weights = tf.sigmoid(weights)
295
    # Checking numerical stability
296
    close_to_0 = weights < 0.0000001
297
    close_to_1 = weights > 0.9999999
298
    pos_term = weights * tf.math.log(weights / self._r)
299
    neg_term = (1 - weights) * tf.math.log((1 - weights) / (1 - self._r))
300

301
    return tf.reduce_sum(
302
        tf.where(
303
            close_to_0,
304
            neg_term,
305
            tf.where(
306
                close_to_1,
307
                pos_term,
308
                pos_term + neg_term,
309
            ),
310
        )
311
    )
312

313

314
def add_loss_regularizers(
315
    model,
316
    model_graph,
317
    label_graph,
318
    cfg,
319
):
320
  """Adding corresponding regularizers to the model.
321

322
  Args:
323
    model: the keras model to add the regularizer for.
324
    model_graph: the graph generated at thi stage.
325
    label_graph: the input graph provided in the data.
326
    cfg: the regularizer config values.
327

328
  Returns:
329
    A keras model with the regularizers added in the loss.
330
  """
331
  if cfg.smoothness_enable:
332
    smoothness_regularizer = SmoothnessRegularizer()
333
    model.add_loss(
334
        cfg.smoothness_w
335
        * smoothness_regularizer(
336
            model_graph=model_graph,
337
            label_graph=None,
338
        )
339
    )
340
  if cfg.sparseconnect_enable:
341
    sparseconnect_regularizer = SparseConnectRegularizer()
342
    model.add_loss(
343
        cfg.sparseconnect_w
344
        * sparseconnect_regularizer(
345
            model_graph=model_graph,
346
            label_graph=None,
347
        )
348
    )
349
  if cfg.closeness_enable:
350
    closeness_regularizer = ClosenessRegularizer()
351
    model.add_loss(
352
        cfg.closeness_w
353
        * closeness_regularizer(
354
            model_graph=model_graph,
355
            label_graph=label_graph,
356
        )
357
    )
358
  if cfg.logbarrier_enable:
359
    log_barrier_regularizer = LogBarrier()
360
    model.add_loss(
361
        cfg.logbarrier_w
362
        * log_barrier_regularizer(
363
            model_graph=model_graph,
364
            label_graph=label_graph,
365
        )
366
    )
367
  if cfg.information_enable:
368
    information_regularizer = InformationRegularizer(
369
        cfg.information_r, cfg.information_do_sigmoid
370
    )
371
    model.add_loss(
372
        cfg.information_w
373
        * information_regularizer(
374
            model_graph=model_graph,
375
            label_graph=label_graph,
376
        )
377
    )
378
  return model
379

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.