google-research

Форк
0
60 строк · 2.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implements helper functions to compute alignment measure."""
17

18
import tensorflow as tf
19

20

21
def plot_class_alignment(batch,
22
                         labels,
23
                         num_labels,
24
                         step,
25
                         tf_summary_key='alignment'):
26
  """Plots class level alignment as a summary scalar."""
27
  for label in range(num_labels):
28
    indices = tf.squeeze(
29
        tf.where(tf.equal(labels, tf.cast(label, tf.int64))), axis=1)
30
    class_batch = tf.gather(batch, indices)
31
    class_list = tf.unstack(class_batch)
32
    if len(class_list) > 1:
33
      alignment = compute_alignment(class_list)
34
      tf.summary.scalar(
35
          '%s/%s' % (tf_summary_key, str(label)), alignment, step=step)
36

37

38
def compute_alignment(input_list):
39
  """Computes alignment measure given a list of vectors in O(n) time."""
40
  if len(input_list) < 2:
41
    return None
42

43
  # Compute mean norm.
44
  norms = [tf.norm(v) for v in input_list]
45
  norms_mean = tf.math.reduce_mean(norms)
46

47
  # Normalize each vector by mean norm.
48
  normalized_input_list = [
49
      v / (tf.where(tf.math.greater(norms_mean, 0), norms_mean, 1.0))
50
      for v in input_list
51
  ]
52
  n = len(normalized_input_list)
53

54
  # O(n) implementation of alignment.
55
  sum_norm_square = tf.math.square(
56
      tf.norm(tf.math.reduce_sum(normalized_input_list, axis=0)))
57
  norm_squares = [tf.math.square(tf.norm(v)) for v in normalized_input_list]
58
  norm_squares_sum = tf.math.reduce_sum(norm_squares)
59
  alignment = (sum_norm_square - norm_squares_sum) / (n * (n - 1))
60
  return alignment
61

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.