google-research

Форк
0
530 строк · 15.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utility functions used in fair clustering algorithms."""
17

18
import collections
19
import math
20
import random
21
from typing import List, Sequence, Set
22
import numpy as np
23
import sklearn.cluster
24

25

26
def ReadData(file_path):
27
  """Read the data from the file.
28

29
  Args:
30
    file_path: path to the input file in tsv format.
31

32
  Returns:
33
    The dataset as a np.array of points each as np.array vector.
34
  """
35
  with open(file_path, "r") as f:
36
    dataset = []
37
    for line in f:
38
      x = [float(x) for x in line.split("\t")]
39
      dataset.append(x)
40
  return np.array(dataset)
41

42

43
def DistanceToCenters(
44
    x, centers, p
45
):
46
  """Distance of a point to nearest center elevanted to p-th power.
47

48
  Args:
49
    x: the point.
50
    centers: the centers.
51
    p: power.
52

53
  Returns:
54
    The distance of the point to the nearest center to the p-th power.
55
  """
56
  min_cost = math.inf
57
  for c in centers:
58
    assert len(c) == len(x)
59
    cost_p = np.linalg.norm(x - c) ** p
60
    if cost_p < min_cost:
61
      min_cost = cost_p
62
  return min_cost
63

64

65
def FurthestPointPosition(
66
    dataset, centers
67
):
68
  """Returns the position of the furthest point in the dataset from the centers.
69

70
  Args:
71
    dataset: the dataset.
72
    centers: the centers.
73

74
  Returns:
75
    The furthest point position.
76
  """
77

78
  max_cost_position = -1
79
  max_cost = -1
80
  for pos, x in enumerate(dataset):
81
    d = DistanceToCenters(x, centers, 1)
82
    if d > max_cost:
83
      max_cost = d
84
      max_cost_position = pos
85
  assert max_cost_position >= 0
86
  return max_cost_position
87

88

89
def KMeansCost(dataset, centers):
90
  """Returns the k-means cost a solution.
91

92
  Args:
93
    dataset: the dataset.
94
    centers: the centers.
95

96
  Returns:
97
    The kmeans cost of the solution.
98
  """
99
  tot = 0.0
100
  for x in dataset:
101
    tot += DistanceToCenters(x, centers, 2)
102
  return tot
103

104

105
def MaxFairnessCost(
106
    dataset,
107
    centers,
108
    dist_threshold_vec,
109
):
110
  """Computes the max bound ratio on the dataset for a given solution.
111

112
  Args:
113
    dataset: the dataset.
114
    centers: the centers.
115
    dist_threshold_vec: the individual fairness distance thresholds of the
116
      points.
117

118
  Returns:
119
    The max ratio of the distance of a point to the closest center over the
120
    threshold.
121
  """
122
  tot = 0.0
123
  for i, x in enumerate(dataset):
124
    d = 1.0 * DistanceToCenters(x, centers, 1) / dist_threshold_vec[i]
125
    if d > tot:
126
      tot = d
127
  return tot
128

129

130
def ComputeDistanceThreshold(
131
    dataset,
132
    sampled_points,
133
    rank_sampled,
134
    multiplier,
135
):
136
  """Computes a target distance for the individual fairness requirement.
137

138
  In order to allow the efficient definition of a the fairness distance bound
139
  for each point we do not compute all pairs distances of points.
140
  Instead we use a sample of points. For each point p we define the threshold
141
  d(p) of the maximum distance that is allowed for a center near p to be the
142
  distance of the rank_sampled-th point closest to be among
143
  sampled_points sampled points times multiplier.
144

145
  Args:
146
    dataset: the dataset.
147
    sampled_points: number of points sampled.
148
    rank_sampled: rank of the distance to the sampled points used in the
149
      definition of the threshold.
150
    multiplier: multiplier used.
151

152
  Returns:
153
    The max ratio of the distance of a point to the closest center over the
154
    threshold.
155
  """
156
  ret = np.zeros(len(dataset))
157
  # Set the seeds to ensure multiple runs use the same thresholds
158
  random.seed(100)
159
  sample = random.sample(list(dataset), sampled_points)
160
  # reset the seed to time
161
  random.seed(None)
162
  for i, x in enumerate(dataset):
163
    distances = [np.linalg.norm(x - s) for s in sample]
164
    distances.sort()
165
    ret[i] = multiplier * distances[rank_sampled - 1]
166
  return ret
167

168

169
# Lloyds improvement algorithm
170
def IsFeasibleSolution(
171
    dataset,
172
    anchor_points_pos,
173
    candidate_centers_vec,
174
    dist_threshold_vec,
175
):
176
  """Check if candidate centers set is feasible.
177

178
  Args:
179
    dataset: the dataset.
180
    anchor_points_pos: position of the archor points.
181
    candidate_centers_vec: vector of candidate centers.
182
    dist_threshold_vec: distance thresholds.
183

184
  Returns:
185
    If the solution is feasible.
186
  """
187
  for s in anchor_points_pos:
188
    if (
189
        DistanceToCenters(dataset[s], candidate_centers_vec, 1)
190
        > dist_threshold_vec[s]
191
    ):
192
      return False
193
  return True
194

195

196
def Mean(dataset, positions):
197
  """Average the points in 'positions' in the dataset.
198

199
  Args:
200
    dataset: the dataset.
201
    positions: position in dataset of the points to average.
202

203
  Returns:
204
    Average of the points.
205
  """
206
  assert positions
207
  mean = np.zeros(len(dataset[0]))
208
  for i in positions:
209
    mean += dataset[i]
210
  mean /= len(positions)
211
  return mean
212

213

214
def LloydImprovementStepOneCluster(
215
    dataset,
216
    anchor_points_pos,
217
    curr_centers_vec,
218
    dist_threshold_vec,
219
    cluster_position,
220
    cluster_points_pos,
221
    approx_error = 0.01,
222
):
223
  """Improve the current center respecting feasibility of the solution.
224

225
  Given a cluster of points and a center centers_vec[cluster_position] is the
226
  center that will be updated.  The current centers must be a list of np.array.
227

228
  Args:
229
    dataset: the set of points.
230
    anchor_points_pos: the positions in dataset for the anchor points.
231
    curr_centers_vec: the current centers as a list of np.array vectors.
232
    dist_threshold_vec: the individual fairness distance thresholds of the
233
      points.
234
    cluster_position: the cluster being improved.
235
    cluster_points_pos: the points in the cluster.
236
    approx_error: approximation error tollerated in the binary search.
237

238
  Returns:
239
    An improved center.
240
  """
241

242
  def _IsValidSwap(vec_in):
243
    new_centers_vec = curr_centers_vec[:]
244
    new_centers_vec[cluster_position] = vec_in
245
    return IsFeasibleSolution(
246
        dataset, anchor_points_pos, new_centers_vec, dist_threshold_vec
247
    )
248

249
  def _Interpolate(curr_vec, new_vec, mult_new_vec):
250
    return curr_vec + (new_vec - curr_vec) * mult_new_vec
251

252
  assert IsFeasibleSolution(
253
      dataset, anchor_points_pos, curr_centers_vec, dist_threshold_vec
254
  )
255
  curr_center_vec = np.array(curr_centers_vec[cluster_position])
256
  mean = Mean(dataset, cluster_points_pos)
257

258
  if _IsValidSwap(mean):
259
    return mean
260
  highest_valid_mult = 0.0
261
  lowest_invalid_mult = 1.0
262
  while highest_valid_mult - lowest_invalid_mult >= approx_error:
263
    m = (lowest_invalid_mult + highest_valid_mult) / 2
264
    if _IsValidSwap(_Interpolate(curr_center_vec, mean, m)):
265
      highest_valid_mult = m
266
    else:
267
      lowest_invalid_mult = m
268
  return _Interpolate(curr_center_vec, mean, highest_valid_mult)
269

270

271
def LloydImprovement(
272
    dataset,
273
    anchor_points_pos,
274
    inital_centers_vec,
275
    dist_threshold_vec,
276
    num_iter = 20,
277
):
278
  """Runs the LloydImprovement algorithm respecting feasibility.
279

280
    Given the current centers improves the solution respecting the feasibility.
281

282
  Args:
283
    dataset: the set of points.
284
    anchor_points_pos: the positions in dataset for the anchor points.
285
    inital_centers_vec: the current centers.
286
    dist_threshold_vec: the individual fairness distance thresholds of the
287
      points.
288
    num_iter: number of iterations for the algorithm.
289

290
  Returns:
291
    An improved solution.
292
  """
293

294
  def _ClusterAssignment(pos_point, curr_centers):
295
    pos_center = 0
296
    min_cost = math.inf
297
    for i, c in enumerate(curr_centers):
298
      cost_p = np.linalg.norm(dataset[pos_point] - c)
299
      if cost_p < min_cost:
300
        min_cost = cost_p
301
        pos_center = i
302
    return pos_center
303

304
  curr_center_vec = [np.array(x) for x in inital_centers_vec]
305

306
  for _ in range(num_iter):
307
    cluster_elements = collections.defaultdict(list)
308
    for i in range(len(dataset)):
309
      cluster_elements[_ClusterAssignment(i, curr_center_vec)].append(i)
310
    for cluster_position in range(len(curr_center_vec)):
311
      if not cluster_elements[cluster_position]:
312
        continue
313
      curr_center_vec[cluster_position] = LloydImprovementStepOneCluster(
314
          dataset,
315
          anchor_points_pos,
316
          curr_center_vec,
317
          dist_threshold_vec,
318
          cluster_position,
319
          cluster_elements[cluster_position],
320
      )
321
  return curr_center_vec
322

323

324
# Bookkeeping class for local search
325
class TopTwoClosestToCenters:
326
  """Bookkeeping class used in local search.
327

328
  The class stores and updates efficiently the 2 closest centers for each point.
329
  """
330

331
  def __init__(self, dataset, centers_ids):
332
    """Constructor.
333

334
    Args:
335
      dataset: the dataset.
336
      centers_ids: the positions of the centers.
337
    """
338
    assert len(dataset) > 2
339
    assert len(centers_ids) >= 2
340

341
    self.dataset = dataset
342
    # all these fields use the position of the center in dataset not the center.
343
    self.centers = set(centers_ids)  # id of the centers
344
    self.center_to_min_dist_cluster = collections.defaultdict(set)
345
    # mapping from center pos to list of pos of min distance points
346
    self.center_to_second_dist_cluster = collections.defaultdict(set)
347
    # mapping from center pos to list of pos of second min distance squared
348
    # points.
349
    self.point_to_min_dist_center_and_distance = {}
350
    # mapping of points to min distance center pos, and distance.
351
    self.point_to_second_dist_center_and_distance = {}
352
    # mapping of points to second min distance center pos, and distance squared.
353
    for point_pos, _ in enumerate(dataset):
354
      self.InitializeDatastructureForPoint(point_pos)
355

356
  def InitializeDatastructureForPoint(self, point_pos):
357
    """Initialize the datastructure for a point."""
358
    if point_pos in self.point_to_min_dist_center_and_distance:
359
      del self.point_to_min_dist_center_and_distance[point_pos]
360
    if point_pos in self.point_to_second_dist_center_and_distance:
361
      del self.point_to_second_dist_center_and_distance[point_pos]
362
    for center_pos in self.centers:
363
      self.ProposeAsCenter(point_pos, center_pos)
364

365
  def ProposeAsCenter(self, pos_point, pos_center_to_add):
366
    """Updates the datastructure proposing a point as a new center.
367

368
    Args:
369
      pos_point: the position of the point.
370
      pos_center_to_add: the position of the center to be added.
371
    """
372
    d = (
373
        np.linalg.norm(
374
            self.dataset[pos_point] - self.dataset[pos_center_to_add]
375
        )
376
        ** 2
377
    )
378
    # never initialized point
379
    if pos_point not in self.point_to_min_dist_center_and_distance:
380
      assert pos_point not in self.point_to_second_dist_center_and_distance
381
      self.point_to_min_dist_center_and_distance[pos_point] = (
382
          pos_center_to_add,
383
          d,
384
      )
385
      self.center_to_min_dist_cluster[pos_center_to_add].add(pos_point)
386
      return
387
    if (
388
        self.point_to_min_dist_center_and_distance[pos_point][0]
389
        == pos_center_to_add
390
    ):
391
      return
392

393
    if d < self.point_to_min_dist_center_and_distance[pos_point][1]:
394
      # New first center. Move first to second.
395
      old_first_center = self.point_to_min_dist_center_and_distance[pos_point][
396
          0
397
      ]
398
      self.center_to_min_dist_cluster[old_first_center].remove(pos_point)
399

400
      if pos_point in self.point_to_second_dist_center_and_distance:
401
        self.center_to_second_dist_cluster[
402
            self.point_to_second_dist_center_and_distance[pos_point][0]
403
        ].remove(pos_point)
404
      self.point_to_second_dist_center_and_distance[pos_point] = (
405
          self.point_to_min_dist_center_and_distance[pos_point]
406
      )
407
      self.center_to_second_dist_cluster[old_first_center].add(pos_point)
408

409
      self.point_to_min_dist_center_and_distance[pos_point] = (
410
          pos_center_to_add,
411
          d,
412
      )
413
      self.center_to_min_dist_cluster[pos_center_to_add].add(pos_point)
414
    else:  # not first
415
      # not initialized second.
416
      if pos_point not in self.point_to_second_dist_center_and_distance:
417
        self.point_to_second_dist_center_and_distance[pos_point] = (
418
            pos_center_to_add,
419
            d,
420
        )
421
        self.center_to_second_dist_cluster[pos_center_to_add].add(pos_point)
422
        return
423
      if (
424
          self.point_to_second_dist_center_and_distance[pos_point][0]
425
          == pos_center_to_add
426
      ):
427
        return
428

429
      if d < self.point_to_second_dist_center_and_distance[pos_point][1]:
430
        self.center_to_second_dist_cluster[
431
            self.point_to_second_dist_center_and_distance[pos_point][0]
432
        ].remove(pos_point)
433
        self.point_to_second_dist_center_and_distance[pos_point] = (
434
            pos_center_to_add,
435
            d,
436
        )
437
        self.center_to_second_dist_cluster[pos_center_to_add].add(pos_point)
438

439
  def CostAfterSwap(
440
      self, pos_center_to_remove, pos_center_to_add
441
  ):
442
    """Computes the cost of a proposed swap.
443

444
    This function does not change the data structure. It runs in O(n) time.
445

446
    Args:
447
      pos_center_to_remove: proposed center to be removed.
448
      pos_center_to_add: proposed center to be added.
449

450
    Returns:
451
      The cost after the swap.
452
    """
453
    center_to_add = self.dataset[pos_center_to_add]
454
    total_cost = 0
455
    for point_pos, point in enumerate(self.dataset):
456
      cost_point = np.linalg.norm(point - center_to_add) ** 2
457
      if (
458
          self.point_to_min_dist_center_and_distance[point_pos][0]
459
          != pos_center_to_remove
460
      ):
461
        cost_point = min(
462
            cost_point, self.point_to_min_dist_center_and_distance[point_pos][1]
463
        )
464
      else:
465
        cost_point = min(
466
            cost_point,
467
            self.point_to_second_dist_center_and_distance[point_pos][1],
468
        )
469
      total_cost += cost_point
470
    return total_cost
471

472
  def SwapCenters(
473
      self, pos_center_to_remove, pos_center_to_add
474
  ):
475
    """Updates the data structure swapping two centers.
476

477
    Args:
478
      pos_center_to_remove: center to remove.
479
      pos_center_to_add: center to add.
480
    """
481
    invalidated_points = (
482
        self.center_to_min_dist_cluster[pos_center_to_remove]
483
        | self.center_to_second_dist_cluster[pos_center_to_remove]
484
    )
485
    for point in invalidated_points:
486
      min_c = self.point_to_min_dist_center_and_distance[point][0]
487
      self.center_to_min_dist_cluster[min_c].remove(point)
488
      second_c = self.point_to_second_dist_center_and_distance[point][0]
489
      self.center_to_second_dist_cluster[second_c].remove(point)
490

491
    self.centers.remove(pos_center_to_remove)
492
    del self.center_to_min_dist_cluster[pos_center_to_remove]
493
    del self.center_to_second_dist_cluster[pos_center_to_remove]
494
    self.centers.add(pos_center_to_add)
495
    for pos in invalidated_points:
496
      self.InitializeDatastructureForPoint(pos)
497
    for pos in range(len(self.dataset)):
498
      self.ProposeAsCenter(pos, pos_center_to_add)
499

500
  def SampleWithD2Distribution(self):
501
    """Sample a random point with prob. proportional to distance squared.
502

503
    Returns:
504
      The sampled point.
505
    """
506
    sum_cost = 0
507
    for i in range(len(self.dataset)):
508
      sum_cost += self.point_to_min_dist_center_and_distance[i][1]
509
    sampled_random = random.random() * sum_cost
510
    pos = 0
511
    while True:
512
      sampled_random -= self.point_to_min_dist_center_and_distance[pos][1]
513
      if sampled_random <= 0:
514
        break
515
      pos += 1
516
    return pos
517

518

519
def VanillaKMeans(dataset, k):
520
  """Vanilla (not fair) KMeans baseline.
521

522
  Args:
523
    dataset: the set of points.
524
    k: the number of clusters.
525

526
  Returns:
527
    The cluster centers.
528
  """
529
  kmeans = sklearn.cluster.KMeans(n_clusters=k).fit(dataset)
530
  return kmeans.cluster_centers_
531

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.