google-research
131 строка · 3.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utils for the fair correlation clustering algorithm.
17"""
18
19import collections
20import math
21import numpy as np
22import sklearn.metrics
23
24
25def BooleanVectorsFromGraph(graph):
26"""Create a boolean encoding for the nodes in the graph.
27
28Starting from the graph, it creates a set of boolean vectors where u,v,
29has an entry 1 for each positive edge (0 for negative edge). Selfloops
30are assumed positive.
31
32Args:
33graph: graph in nx.Graph format.
34Returns:
35the nxn bolean matrix with the encoding.
36"""
37n = graph.number_of_nodes()
38vectors = np.identity(n)
39for u, v, d in graph.edges(data=True):
40if d['weight'] > 0:
41vectors[u][v] = 1
42vectors[v][u] = 1
43return vectors
44
45
46def PairwiseFairletCosts(graph):
47"""Create a matrix with the fairlet cost.
48
49Args:
50graph: graph in nx.Graph format.
51Returns:
52the nxn matrix with the fairlet cost for each pair of nodes.
53"""
54assert max(list(graph.nodes())) == graph.number_of_nodes() - 1
55assert min(list(graph.nodes())) == 0
56
57bool_vectors = BooleanVectorsFromGraph(graph)
58distance_matrix = sklearn.metrics.pairwise_distances(
59bool_vectors, metric='l1')
60# This counts twice the negative edge inside each u,v fairlet, so we deduct
61# one for each such pair.
62for u, v, d in graph.edges(data=True):
63if d['weight'] < 0:
64distance_matrix[u][v] -= 1
65distance_matrix[v][u] -= 1
66return distance_matrix
67
68
69def ClusterIdMap(solution):
70"""Create a map from node to cluster id.
71
72Args:
73solution: list of clusters.
74Returns:
75the map from node id to cluster id.
76"""
77clust_assignment = {}
78for i, clust in enumerate(solution):
79for elem in clust:
80clust_assignment[elem] = i
81return clust_assignment
82
83
84def FractionalColorImbalance(graph, solution, alpha):
85"""Evaluates the color imbalance of solution.
86
87Computes the fraction of nodes that are above the threshold for color
88representation.
89
90Args:
91graph: in nx.Graph format.
92solution: list of clusters.
93alpha: representation constraint.
94Returns:
95the fraction of nodes that are above the threshold for color.
96"""
97total_violation = 0
98nodes = 0
99for cluster in solution:
100color_count = collections.defaultdict(int)
101for elem in cluster:
102color_count[graph.nodes[elem]['color']] += 1
103for count in color_count.values():
104imbalance = max(0, count - math.floor(float(len(cluster)) * alpha))
105total_violation += imbalance
106nodes += len(cluster)
107return 1.0 * total_violation / nodes
108
109
110def CorrelationClusteringError(graph, solution):
111"""Evaluates the correlation clustering error of solution.
112
113Computes the fraction of edges that are misclassified by the algorithm.
114
115Args:
116graph: in nx.Graph format.
117solution: list of clusters.
118Returns:
119the fraction of edges that are incorrectly classified.
120"""
121clust_assignment = ClusterIdMap(solution)
122errors = 0
123corrects = 0
124for u, v, d in graph.edges(data=True):
125if (d['weight'] > 0 and clust_assignment[u] != clust_assignment[v]) or \
126(d['weight'] < 0 and clust_assignment[u] == clust_assignment[v]):
127errors += 1
128elif (d['weight'] > 0 and clust_assignment[u] == clust_assignment[v]) or \
129(d['weight'] < 0 and clust_assignment[u] != clust_assignment[v]):
130corrects += 1
131return float(errors) / (errors + corrects)
132