google-research
60 строк · 2.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implements helper functions to compute alignment measure."""
17
18import tensorflow as tf
19
20
21def plot_class_alignment(batch,
22labels,
23num_labels,
24step,
25tf_summary_key='alignment'):
26"""Plots class level alignment as a summary scalar."""
27for label in range(num_labels):
28indices = tf.squeeze(
29tf.where(tf.equal(labels, tf.cast(label, tf.int64))), axis=1)
30class_batch = tf.gather(batch, indices)
31class_list = tf.unstack(class_batch)
32if len(class_list) > 1:
33alignment = compute_alignment(class_list)
34tf.summary.scalar(
35'%s/%s' % (tf_summary_key, str(label)), alignment, step=step)
36
37
38def compute_alignment(input_list):
39"""Computes alignment measure given a list of vectors in O(n) time."""
40if len(input_list) < 2:
41return None
42
43# Compute mean norm.
44norms = [tf.norm(v) for v in input_list]
45norms_mean = tf.math.reduce_mean(norms)
46
47# Normalize each vector by mean norm.
48normalized_input_list = [
49v / (tf.where(tf.math.greater(norms_mean, 0), norms_mean, 1.0))
50for v in input_list
51]
52n = len(normalized_input_list)
53
54# O(n) implementation of alignment.
55sum_norm_square = tf.math.square(
56tf.norm(tf.math.reduce_sum(normalized_input_list, axis=0)))
57norm_squares = [tf.math.square(tf.norm(v)) for v in normalized_input_list]
58norm_squares_sum = tf.math.reduce_sum(norm_squares)
59alignment = (sum_norm_square - norm_squares_sum) / (n * (n - 1))
60return alignment
61