google-research
378 строк · 11.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Regularizers defined in Unified GSL paper."""
17
18import abc19from typing import Callable, Optional20from ml_collections import config_dict21import tensorflow as tf22import tensorflow_gnn as tfgnn23
24
25class BaseRegularizer(abc.ABC):26"""Base class for calculating regularization on model and label GraphTensors.27
28Some regularizers only accept model GraphTensor (and ignore label).
29"""
30
31@abc.abstractmethod32def call(33self,34*,35model_graph,36label_graph = None,37edge_set_name = tfgnn.EDGES,38weights_feature_name = 'weights'39):40pass41
42def __call__(43self,44*,45model_graph,46label_graph = None,47edge_set_name = tfgnn.EDGES,48weights_feature_name = 'weights'49):50return self.call(51model_graph=model_graph,52label_graph=label_graph,53edge_set_name=edge_set_name,54weights_feature_name=weights_feature_name,55)56
57
58class ClosenessRegularizer(BaseRegularizer):59"""Call Returns ||A_model - A_label||_F^2."""60
61def call(62self,63*,64model_graph,65label_graph = None,66edge_set_name = tfgnn.EDGES,67weights_feature_name = 'weights'68):69assert label_graph is not None70# If A and B where vectors (e.g., rasterized adjacency matrices):71# ||A - B||_F^2 = ||A - B||^2_2 == (A-B)^T (A-B) = A^T A + B^T B - 2 A^T B72# The first two terms of the RHS are easy to compute: sum-of-squares.73# The last entry, however, require us to know the *common* edges in the two74# graph tensors. For this, we sort the edges of one and use tf.searchsorted.75# "EX:" stands for "Running Example".76# EX: == [w6, w3, w24]77model_weight = model_graph.edge_sets[edge_set_name][weights_feature_name]78if weights_feature_name in label_graph.edge_sets[edge_set_name].features:79label_weight = label_graph.edge_sets[edge_set_name][weights_feature_name]80else:81label_weight = tf.ones(82label_graph.edge_sets[edge_set_name].sizes, dtype=tf.float3283)84
85assert (86model_graph.edge_sets[edge_set_name].adjacency.source_name87== label_graph.edge_sets[edge_set_name].adjacency.source_name88)89assert (90model_graph.edge_sets[edge_set_name].adjacency.target_name91== label_graph.edge_sets[edge_set_name].adjacency.target_name92)93
94tgt_name = model_graph.edge_sets[edge_set_name].adjacency.target_name95src_name = model_graph.edge_sets[edge_set_name].adjacency.source_name96
97# EX: == 5 (i.e., 5 nodes in each graph).98size_target = tf.reduce_sum(model_graph.node_sets[tgt_name].sizes)99# TODO(baharef): add an assert checking if the two graphs have the same100# number of nodes.101if tgt_name == src_name:102size_source = size_target103else:104size_source = tf.reduce_sum(model_graph.node_sets[src_name].sizes)105tf.assert_equal(106size_source,107tf.reduce_sum(label_graph.node_sets[src_name].sizes),108'model_graph and label_graph have different number of source nodes.',109)110
111label_adj = label_graph.edge_sets[edge_set_name].adjacency112
113# tf can sort vectors. We combine pairs of ints (source & target vectors) to114# int vector by finding a suitable "base", multiplying the source by the115# "base" and adding target.116combined_label_indices = ( # EX:=[4, 0, 2, 0]*5+[4, 0, 1, 3]=[24, 0, 11, 3]117# EX: source=[4, 0, 2, 0] target=[4, 0, 1, 3]118tf.cast(label_adj.source, tf.int64) * tf.cast(size_target, tf.int64)119+ tf.cast(label_adj.target, tf.int64)120)121model_adj = model_graph.edge_sets[edge_set_name].adjacency122combined_model_indices = ( # EX: = [1, 0, 4]*5 + [1, 3, 4] = [6, 3, 24]123# EX: source=[0, 1, 4] target=[3, 1, 4].124tf.cast(model_adj.source, tf.int64) * tf.cast(size_target, tf.int64)125+ tf.cast(model_adj.target, tf.int64)126)127
128# Add phantom node (to prevent gather on empty array). Excluded from "EX:".129combined_label_indices = tf.concat(130[131combined_label_indices,132tf.cast(tf.expand_dims(size_source * size_target, 0), tf.int64),133],134axis=0,135)136label_weight = tf.concat(137[label_weight, tf.zeros(1, dtype=label_weight.dtype)], 0138)139
140# EX: [1, 3, 2, 0]141argsort = tf.argsort(combined_label_indices)142# EX: [0, 3, 11, 24]143sorted_combined_label_indices = tf.gather(combined_label_indices, argsort)144# EX: [2, 1, 3]145positions = tf.searchsorted(146sorted_combined_label_indices, combined_model_indices147)148
149# Boolean array. Entry is set to True if edge in model `GraphTensor` is also150# present in label `GraphTensor`.151correct_positions = ( # EX: [False, True, True]152# EX: [11, 3, 24]153tf.gather(sorted_combined_label_indices, positions)154# EX: [6, 3, 24]155== combined_model_indices156)157
158# Order label weights, in an order matching edge order of model.159label_weight_reordered = tf.gather( # EX: [W11, W3, W24]160tf.gather( # EX: = [W0, W3, W11, W24]161# EX: = [W24, W0, W11, W3]162label_weight,163argsort,164),165positions,166)167if not model_weight.dtype.is_floating:168model_weight = tf.cast(model_weight, tf.float32)169if not label_weight_reordered.dtype.is_floating:170label_weight_reordered = tf.cast(label_weight_reordered, tf.float32)171a_times_b = ( # EX: 0*0 + w3*W3 + w24*W24172# EX: [False, True, True] * [w6, w3, w24] == [0, w3, w24]173tf.where(correct_positions, model_weight, tf.zeros_like(model_weight))174* tf.where( # EX: [False, True, True] * [W11, W3, W24] = [0, W3, W24]175correct_positions,176label_weight_reordered,177tf.zeros_like(label_weight_reordered),178)179)180
181regularizer = (182tf.reduce_sum(model_weight**2)183+ tf.reduce_sum(label_weight**2)184- 2 * tf.reduce_sum(a_times_b)185)186return regularizer187
188
189def euclidean_distance_squared(v1, v2):190displacement = v1 - v2191return tf.reduce_sum(displacement**2, axis=-1)192
193
194class SmoothnessRegularizer(BaseRegularizer):195r"""Call Returns \sum_{ij} A_{ij} dist(v_i, v_j)."""196
197def __init__(198self,199source_feature_name = tfgnn.HIDDEN_STATE,200distance_fn = euclidean_distance_squared,201target_feature_name = None,202differentiable_wrt_features = False,203):204self._distance_fn = distance_fn205self._source_feature_name = source_feature_name206self._target_feature_name = target_feature_name or source_feature_name207self._differentiable_wrt_features = differentiable_wrt_features208
209def call(210self,211*,212model_graph,213label_graph = None,214edge_set_name = tfgnn.EDGES,215weights_feature_name = 'weights'216):217del label_graph218edge_set = model_graph.edge_sets[edge_set_name]219source_ns = edge_set.adjacency.source_name220target_ns = edge_set.adjacency.target_name221source_features = tf.gather(222model_graph.node_sets[source_ns][self._source_feature_name],223edge_set.adjacency.source,224)225target_features = tf.gather(226model_graph.node_sets[target_ns][self._target_feature_name],227edge_set.adjacency.target,228)229distance = self._distance_fn(source_features, target_features)230if not self._differentiable_wrt_features:231distance = tf.stop_gradient(distance)232return tf.reduce_sum(edge_set[weights_feature_name] * distance)233
234
235class SparseConnectRegularizer(BaseRegularizer):236"""Call Returns ||A||_F^2."""237
238def call(239self,240*,241model_graph,242label_graph = None,243edge_set_name = tfgnn.EDGES,244weights_feature_name = 'weights'245):246del label_graph247edge_set = model_graph.edge_sets[edge_set_name]248return tf.reduce_sum(edge_set[weights_feature_name] ** 2)249
250
251class LogBarrier(BaseRegularizer):252"""Call returns -1^T . log (A . 1) == -log(A.sum(1)).sum(0)."""253
254def call(255self,256*,257model_graph,258label_graph = None,259edge_set_name = tfgnn.EDGES,260weights_feature_name = 'weights'261):262del label_graph263weights = model_graph.edge_sets[edge_set_name][weights_feature_name]264adj = model_graph.edge_sets[edge_set_name].adjacency265src_name = model_graph.edge_sets[edge_set_name].adjacency.source_name266num_src_nodes = tf.reduce_sum(model_graph.node_sets[src_name].sizes)267column_sum = tf.math.unsorted_segment_sum(268weights, adj.source, num_src_nodes269)270column_sum += 1e-5 # avoid infinity values.271return -tf.reduce_sum(tf.math.log(column_sum))272
273
274class InformationRegularizer(BaseRegularizer):275"""Call returns A[i][j] * log (A[i][j]/r) + (1 - A[i][j]) * log ((1 - A[i][j])/(1 - r))."""276
277def __init__(self, r, do_sigmoid):278self._r = r279self._do_sigmoid = do_sigmoid280
281def call(282self,283*,284model_graph,285label_graph = None,286edge_set_name = tfgnn.EDGES,287weights_feature_name = 'weights',288):289del label_graph290weights = model_graph.edge_sets[edge_set_name][weights_feature_name]291# If the weights are coming from a soft Bernoulli, a sigmoid has already292# been applied on the weights.293if self._do_sigmoid:294weights = tf.sigmoid(weights)295# Checking numerical stability296close_to_0 = weights < 0.0000001297close_to_1 = weights > 0.9999999298pos_term = weights * tf.math.log(weights / self._r)299neg_term = (1 - weights) * tf.math.log((1 - weights) / (1 - self._r))300
301return tf.reduce_sum(302tf.where(303close_to_0,304neg_term,305tf.where(306close_to_1,307pos_term,308pos_term + neg_term,309),310)311)312
313
314def add_loss_regularizers(315model,316model_graph,317label_graph,318cfg,319):320"""Adding corresponding regularizers to the model.321
322Args:
323model: the keras model to add the regularizer for.
324model_graph: the graph generated at thi stage.
325label_graph: the input graph provided in the data.
326cfg: the regularizer config values.
327
328Returns:
329A keras model with the regularizers added in the loss.
330"""
331if cfg.smoothness_enable:332smoothness_regularizer = SmoothnessRegularizer()333model.add_loss(334cfg.smoothness_w335* smoothness_regularizer(336model_graph=model_graph,337label_graph=None,338)339)340if cfg.sparseconnect_enable:341sparseconnect_regularizer = SparseConnectRegularizer()342model.add_loss(343cfg.sparseconnect_w344* sparseconnect_regularizer(345model_graph=model_graph,346label_graph=None,347)348)349if cfg.closeness_enable:350closeness_regularizer = ClosenessRegularizer()351model.add_loss(352cfg.closeness_w353* closeness_regularizer(354model_graph=model_graph,355label_graph=label_graph,356)357)358if cfg.logbarrier_enable:359log_barrier_regularizer = LogBarrier()360model.add_loss(361cfg.logbarrier_w362* log_barrier_regularizer(363model_graph=model_graph,364label_graph=label_graph,365)366)367if cfg.information_enable:368information_regularizer = InformationRegularizer(369cfg.information_r, cfg.information_do_sigmoid370)371model.add_loss(372cfg.information_w373* information_regularizer(374model_graph=model_graph,375label_graph=label_graph,376)377)378return model379