google-research

Форк
0
131 строка · 3.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utils for the fair correlation clustering algorithm.
17
"""
18

19
import collections
20
import math
21
import numpy as np
22
import sklearn.metrics
23

24

25
def BooleanVectorsFromGraph(graph):
26
  """Create a boolean encoding for the nodes in the graph.
27

28
  Starting from the graph, it creates a set of boolean vectors where u,v,
29
  has an entry 1 for each positive edge (0 for negative edge). Selfloops
30
  are assumed positive.
31

32
  Args:
33
    graph: graph in nx.Graph format.
34
  Returns:
35
    the nxn bolean matrix with the encoding.
36
  """
37
  n = graph.number_of_nodes()
38
  vectors = np.identity(n)
39
  for u, v, d in graph.edges(data=True):
40
    if d['weight'] > 0:
41
      vectors[u][v] = 1
42
      vectors[v][u] = 1
43
  return vectors
44

45

46
def PairwiseFairletCosts(graph):
47
  """Create a matrix with the fairlet cost.
48

49
  Args:
50
    graph: graph in nx.Graph format.
51
  Returns:
52
    the nxn matrix with the fairlet cost for each pair of nodes.
53
  """
54
  assert max(list(graph.nodes())) == graph.number_of_nodes() - 1
55
  assert min(list(graph.nodes())) == 0
56

57
  bool_vectors = BooleanVectorsFromGraph(graph)
58
  distance_matrix = sklearn.metrics.pairwise_distances(
59
      bool_vectors, metric='l1')
60
  # This counts twice the negative edge inside each u,v fairlet, so we deduct
61
  # one for each such pair.
62
  for u, v, d in graph.edges(data=True):
63
    if d['weight'] < 0:
64
      distance_matrix[u][v] -= 1
65
      distance_matrix[v][u] -= 1
66
  return distance_matrix
67

68

69
def ClusterIdMap(solution):
70
  """Create a map from node to cluster id.
71

72
  Args:
73
    solution: list of clusters.
74
  Returns:
75
    the map from node id to cluster id.
76
  """
77
  clust_assignment = {}
78
  for i, clust in enumerate(solution):
79
    for elem in clust:
80
      clust_assignment[elem] = i
81
  return clust_assignment
82

83

84
def FractionalColorImbalance(graph, solution, alpha):
85
  """Evaluates the color imbalance of solution.
86

87
  Computes the fraction of nodes that are above the threshold for color
88
  representation.
89

90
  Args:
91
    graph: in nx.Graph format.
92
    solution: list of clusters.
93
    alpha: representation constraint.
94
  Returns:
95
    the fraction of nodes that are above the threshold for color.
96
  """
97
  total_violation = 0
98
  nodes = 0
99
  for cluster in solution:
100
    color_count = collections.defaultdict(int)
101
    for elem in cluster:
102
      color_count[graph.nodes[elem]['color']] += 1
103
    for count in color_count.values():
104
      imbalance = max(0, count - math.floor(float(len(cluster)) * alpha))
105
      total_violation += imbalance
106
    nodes += len(cluster)
107
  return 1.0 * total_violation / nodes
108

109

110
def CorrelationClusteringError(graph, solution):
111
  """Evaluates  the correlation clustering error of solution.
112

113
  Computes the fraction of edges that are misclassified by the algorithm.
114

115
  Args:
116
    graph: in nx.Graph format.
117
    solution: list of clusters.
118
  Returns:
119
    the fraction of edges that are incorrectly classified.
120
  """
121
  clust_assignment = ClusterIdMap(solution)
122
  errors = 0
123
  corrects = 0
124
  for u, v, d in graph.edges(data=True):
125
    if (d['weight'] > 0 and clust_assignment[u] != clust_assignment[v]) or \
126
        (d['weight'] < 0 and clust_assignment[u] == clust_assignment[v]):
127
      errors += 1
128
    elif (d['weight'] > 0 and clust_assignment[u] == clust_assignment[v]) or \
129
        (d['weight'] < 0 and clust_assignment[u] != clust_assignment[v]):
130
      corrects += 1
131
  return float(errors) / (errors + corrects)
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.