google-research
530 строк · 15.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utility functions used in fair clustering algorithms."""
17
18import collections19import math20import random21from typing import List, Sequence, Set22import numpy as np23import sklearn.cluster24
25
26def ReadData(file_path):27"""Read the data from the file.28
29Args:
30file_path: path to the input file in tsv format.
31
32Returns:
33The dataset as a np.array of points each as np.array vector.
34"""
35with open(file_path, "r") as f:36dataset = []37for line in f:38x = [float(x) for x in line.split("\t")]39dataset.append(x)40return np.array(dataset)41
42
43def DistanceToCenters(44x, centers, p45):46"""Distance of a point to nearest center elevanted to p-th power.47
48Args:
49x: the point.
50centers: the centers.
51p: power.
52
53Returns:
54The distance of the point to the nearest center to the p-th power.
55"""
56min_cost = math.inf57for c in centers:58assert len(c) == len(x)59cost_p = np.linalg.norm(x - c) ** p60if cost_p < min_cost:61min_cost = cost_p62return min_cost63
64
65def FurthestPointPosition(66dataset, centers67):68"""Returns the position of the furthest point in the dataset from the centers.69
70Args:
71dataset: the dataset.
72centers: the centers.
73
74Returns:
75The furthest point position.
76"""
77
78max_cost_position = -179max_cost = -180for pos, x in enumerate(dataset):81d = DistanceToCenters(x, centers, 1)82if d > max_cost:83max_cost = d84max_cost_position = pos85assert max_cost_position >= 086return max_cost_position87
88
89def KMeansCost(dataset, centers):90"""Returns the k-means cost a solution.91
92Args:
93dataset: the dataset.
94centers: the centers.
95
96Returns:
97The kmeans cost of the solution.
98"""
99tot = 0.0100for x in dataset:101tot += DistanceToCenters(x, centers, 2)102return tot103
104
105def MaxFairnessCost(106dataset,107centers,108dist_threshold_vec,109):110"""Computes the max bound ratio on the dataset for a given solution.111
112Args:
113dataset: the dataset.
114centers: the centers.
115dist_threshold_vec: the individual fairness distance thresholds of the
116points.
117
118Returns:
119The max ratio of the distance of a point to the closest center over the
120threshold.
121"""
122tot = 0.0123for i, x in enumerate(dataset):124d = 1.0 * DistanceToCenters(x, centers, 1) / dist_threshold_vec[i]125if d > tot:126tot = d127return tot128
129
130def ComputeDistanceThreshold(131dataset,132sampled_points,133rank_sampled,134multiplier,135):136"""Computes a target distance for the individual fairness requirement.137
138In order to allow the efficient definition of a the fairness distance bound
139for each point we do not compute all pairs distances of points.
140Instead we use a sample of points. For each point p we define the threshold
141d(p) of the maximum distance that is allowed for a center near p to be the
142distance of the rank_sampled-th point closest to be among
143sampled_points sampled points times multiplier.
144
145Args:
146dataset: the dataset.
147sampled_points: number of points sampled.
148rank_sampled: rank of the distance to the sampled points used in the
149definition of the threshold.
150multiplier: multiplier used.
151
152Returns:
153The max ratio of the distance of a point to the closest center over the
154threshold.
155"""
156ret = np.zeros(len(dataset))157# Set the seeds to ensure multiple runs use the same thresholds158random.seed(100)159sample = random.sample(list(dataset), sampled_points)160# reset the seed to time161random.seed(None)162for i, x in enumerate(dataset):163distances = [np.linalg.norm(x - s) for s in sample]164distances.sort()165ret[i] = multiplier * distances[rank_sampled - 1]166return ret167
168
169# Lloyds improvement algorithm
170def IsFeasibleSolution(171dataset,172anchor_points_pos,173candidate_centers_vec,174dist_threshold_vec,175):176"""Check if candidate centers set is feasible.177
178Args:
179dataset: the dataset.
180anchor_points_pos: position of the archor points.
181candidate_centers_vec: vector of candidate centers.
182dist_threshold_vec: distance thresholds.
183
184Returns:
185If the solution is feasible.
186"""
187for s in anchor_points_pos:188if (189DistanceToCenters(dataset[s], candidate_centers_vec, 1)190> dist_threshold_vec[s]191):192return False193return True194
195
196def Mean(dataset, positions):197"""Average the points in 'positions' in the dataset.198
199Args:
200dataset: the dataset.
201positions: position in dataset of the points to average.
202
203Returns:
204Average of the points.
205"""
206assert positions207mean = np.zeros(len(dataset[0]))208for i in positions:209mean += dataset[i]210mean /= len(positions)211return mean212
213
214def LloydImprovementStepOneCluster(215dataset,216anchor_points_pos,217curr_centers_vec,218dist_threshold_vec,219cluster_position,220cluster_points_pos,221approx_error = 0.01,222):223"""Improve the current center respecting feasibility of the solution.224
225Given a cluster of points and a center centers_vec[cluster_position] is the
226center that will be updated. The current centers must be a list of np.array.
227
228Args:
229dataset: the set of points.
230anchor_points_pos: the positions in dataset for the anchor points.
231curr_centers_vec: the current centers as a list of np.array vectors.
232dist_threshold_vec: the individual fairness distance thresholds of the
233points.
234cluster_position: the cluster being improved.
235cluster_points_pos: the points in the cluster.
236approx_error: approximation error tollerated in the binary search.
237
238Returns:
239An improved center.
240"""
241
242def _IsValidSwap(vec_in):243new_centers_vec = curr_centers_vec[:]244new_centers_vec[cluster_position] = vec_in245return IsFeasibleSolution(246dataset, anchor_points_pos, new_centers_vec, dist_threshold_vec247)248
249def _Interpolate(curr_vec, new_vec, mult_new_vec):250return curr_vec + (new_vec - curr_vec) * mult_new_vec251
252assert IsFeasibleSolution(253dataset, anchor_points_pos, curr_centers_vec, dist_threshold_vec254)255curr_center_vec = np.array(curr_centers_vec[cluster_position])256mean = Mean(dataset, cluster_points_pos)257
258if _IsValidSwap(mean):259return mean260highest_valid_mult = 0.0261lowest_invalid_mult = 1.0262while highest_valid_mult - lowest_invalid_mult >= approx_error:263m = (lowest_invalid_mult + highest_valid_mult) / 2264if _IsValidSwap(_Interpolate(curr_center_vec, mean, m)):265highest_valid_mult = m266else:267lowest_invalid_mult = m268return _Interpolate(curr_center_vec, mean, highest_valid_mult)269
270
271def LloydImprovement(272dataset,273anchor_points_pos,274inital_centers_vec,275dist_threshold_vec,276num_iter = 20,277):278"""Runs the LloydImprovement algorithm respecting feasibility.279
280Given the current centers improves the solution respecting the feasibility.
281
282Args:
283dataset: the set of points.
284anchor_points_pos: the positions in dataset for the anchor points.
285inital_centers_vec: the current centers.
286dist_threshold_vec: the individual fairness distance thresholds of the
287points.
288num_iter: number of iterations for the algorithm.
289
290Returns:
291An improved solution.
292"""
293
294def _ClusterAssignment(pos_point, curr_centers):295pos_center = 0296min_cost = math.inf297for i, c in enumerate(curr_centers):298cost_p = np.linalg.norm(dataset[pos_point] - c)299if cost_p < min_cost:300min_cost = cost_p301pos_center = i302return pos_center303
304curr_center_vec = [np.array(x) for x in inital_centers_vec]305
306for _ in range(num_iter):307cluster_elements = collections.defaultdict(list)308for i in range(len(dataset)):309cluster_elements[_ClusterAssignment(i, curr_center_vec)].append(i)310for cluster_position in range(len(curr_center_vec)):311if not cluster_elements[cluster_position]:312continue313curr_center_vec[cluster_position] = LloydImprovementStepOneCluster(314dataset,315anchor_points_pos,316curr_center_vec,317dist_threshold_vec,318cluster_position,319cluster_elements[cluster_position],320)321return curr_center_vec322
323
324# Bookkeeping class for local search
325class TopTwoClosestToCenters:326"""Bookkeeping class used in local search.327
328The class stores and updates efficiently the 2 closest centers for each point.
329"""
330
331def __init__(self, dataset, centers_ids):332"""Constructor.333
334Args:
335dataset: the dataset.
336centers_ids: the positions of the centers.
337"""
338assert len(dataset) > 2339assert len(centers_ids) >= 2340
341self.dataset = dataset342# all these fields use the position of the center in dataset not the center.343self.centers = set(centers_ids) # id of the centers344self.center_to_min_dist_cluster = collections.defaultdict(set)345# mapping from center pos to list of pos of min distance points346self.center_to_second_dist_cluster = collections.defaultdict(set)347# mapping from center pos to list of pos of second min distance squared348# points.349self.point_to_min_dist_center_and_distance = {}350# mapping of points to min distance center pos, and distance.351self.point_to_second_dist_center_and_distance = {}352# mapping of points to second min distance center pos, and distance squared.353for point_pos, _ in enumerate(dataset):354self.InitializeDatastructureForPoint(point_pos)355
356def InitializeDatastructureForPoint(self, point_pos):357"""Initialize the datastructure for a point."""358if point_pos in self.point_to_min_dist_center_and_distance:359del self.point_to_min_dist_center_and_distance[point_pos]360if point_pos in self.point_to_second_dist_center_and_distance:361del self.point_to_second_dist_center_and_distance[point_pos]362for center_pos in self.centers:363self.ProposeAsCenter(point_pos, center_pos)364
365def ProposeAsCenter(self, pos_point, pos_center_to_add):366"""Updates the datastructure proposing a point as a new center.367
368Args:
369pos_point: the position of the point.
370pos_center_to_add: the position of the center to be added.
371"""
372d = (373np.linalg.norm(374self.dataset[pos_point] - self.dataset[pos_center_to_add]375)376** 2377)378# never initialized point379if pos_point not in self.point_to_min_dist_center_and_distance:380assert pos_point not in self.point_to_second_dist_center_and_distance381self.point_to_min_dist_center_and_distance[pos_point] = (382pos_center_to_add,383d,384)385self.center_to_min_dist_cluster[pos_center_to_add].add(pos_point)386return387if (388self.point_to_min_dist_center_and_distance[pos_point][0]389== pos_center_to_add390):391return392
393if d < self.point_to_min_dist_center_and_distance[pos_point][1]:394# New first center. Move first to second.395old_first_center = self.point_to_min_dist_center_and_distance[pos_point][3960397]398self.center_to_min_dist_cluster[old_first_center].remove(pos_point)399
400if pos_point in self.point_to_second_dist_center_and_distance:401self.center_to_second_dist_cluster[402self.point_to_second_dist_center_and_distance[pos_point][0]403].remove(pos_point)404self.point_to_second_dist_center_and_distance[pos_point] = (405self.point_to_min_dist_center_and_distance[pos_point]406)407self.center_to_second_dist_cluster[old_first_center].add(pos_point)408
409self.point_to_min_dist_center_and_distance[pos_point] = (410pos_center_to_add,411d,412)413self.center_to_min_dist_cluster[pos_center_to_add].add(pos_point)414else: # not first415# not initialized second.416if pos_point not in self.point_to_second_dist_center_and_distance:417self.point_to_second_dist_center_and_distance[pos_point] = (418pos_center_to_add,419d,420)421self.center_to_second_dist_cluster[pos_center_to_add].add(pos_point)422return423if (424self.point_to_second_dist_center_and_distance[pos_point][0]425== pos_center_to_add426):427return428
429if d < self.point_to_second_dist_center_and_distance[pos_point][1]:430self.center_to_second_dist_cluster[431self.point_to_second_dist_center_and_distance[pos_point][0]432].remove(pos_point)433self.point_to_second_dist_center_and_distance[pos_point] = (434pos_center_to_add,435d,436)437self.center_to_second_dist_cluster[pos_center_to_add].add(pos_point)438
439def CostAfterSwap(440self, pos_center_to_remove, pos_center_to_add441):442"""Computes the cost of a proposed swap.443
444This function does not change the data structure. It runs in O(n) time.
445
446Args:
447pos_center_to_remove: proposed center to be removed.
448pos_center_to_add: proposed center to be added.
449
450Returns:
451The cost after the swap.
452"""
453center_to_add = self.dataset[pos_center_to_add]454total_cost = 0455for point_pos, point in enumerate(self.dataset):456cost_point = np.linalg.norm(point - center_to_add) ** 2457if (458self.point_to_min_dist_center_and_distance[point_pos][0]459!= pos_center_to_remove460):461cost_point = min(462cost_point, self.point_to_min_dist_center_and_distance[point_pos][1]463)464else:465cost_point = min(466cost_point,467self.point_to_second_dist_center_and_distance[point_pos][1],468)469total_cost += cost_point470return total_cost471
472def SwapCenters(473self, pos_center_to_remove, pos_center_to_add474):475"""Updates the data structure swapping two centers.476
477Args:
478pos_center_to_remove: center to remove.
479pos_center_to_add: center to add.
480"""
481invalidated_points = (482self.center_to_min_dist_cluster[pos_center_to_remove]483| self.center_to_second_dist_cluster[pos_center_to_remove]484)485for point in invalidated_points:486min_c = self.point_to_min_dist_center_and_distance[point][0]487self.center_to_min_dist_cluster[min_c].remove(point)488second_c = self.point_to_second_dist_center_and_distance[point][0]489self.center_to_second_dist_cluster[second_c].remove(point)490
491self.centers.remove(pos_center_to_remove)492del self.center_to_min_dist_cluster[pos_center_to_remove]493del self.center_to_second_dist_cluster[pos_center_to_remove]494self.centers.add(pos_center_to_add)495for pos in invalidated_points:496self.InitializeDatastructureForPoint(pos)497for pos in range(len(self.dataset)):498self.ProposeAsCenter(pos, pos_center_to_add)499
500def SampleWithD2Distribution(self):501"""Sample a random point with prob. proportional to distance squared.502
503Returns:
504The sampled point.
505"""
506sum_cost = 0507for i in range(len(self.dataset)):508sum_cost += self.point_to_min_dist_center_and_distance[i][1]509sampled_random = random.random() * sum_cost510pos = 0511while True:512sampled_random -= self.point_to_min_dist_center_and_distance[pos][1]513if sampled_random <= 0:514break515pos += 1516return pos517
518
519def VanillaKMeans(dataset, k):520"""Vanilla (not fair) KMeans baseline.521
522Args:
523dataset: the set of points.
524k: the number of clusters.
525
526Returns:
527The cluster centers.
528"""
529kmeans = sklearn.cluster.KMeans(n_clusters=k).fit(dataset)530return kmeans.cluster_centers_531