google-research

Форк
0
/
run_individually_fair_clustering.py 
187 строк · 6.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main file for the fair clustering algorithm."""
17

18
import datetime
19
import json
20
import random
21
from typing import Sequence
22

23
from absl import app
24
from absl import flags
25
from absl import logging
26
import numpy as np
27

28
from individually_fair_clustering import fair_clustering_algorithms
29
from individually_fair_clustering import fair_clustering_utils
30

31
_INPUT = flags.DEFINE_string(
32
    "input",
33
    "",
34
    "Path to the input file. The input file must contain one line per point,"
35
    " with each line being a tab separated vector represented as the entries,"
36
    " each entry as a floating point number.",
37
)
38
_OUTPUT = flags.DEFINE_string(
39
    "output",
40
    "",
41
    "Path to the output file. The output file is a json encoded in text format"
42
    " containing the results of one run of the algorithm.",
43
)
44
_K = flags.DEFINE_integer("k", 5, "Number of centers for the clustering.")
45

46
# The following parameters affect how the fairness constraints are defined
47
# (for all algorithms).
48
# In order to allow the efficient definition of a the fairness distance bound
49
# for each point we do not do all pairs distance computations.
50
# Instead we use a sample of points. For each point p we define the threshold
51
# d(p) of the maximum distance that is allowed for a center near p to be the
52
# distance of the sample_rank_for_threshold-th point closest to be among
53
# sample_size_for_threshold sampled points times multiplier_for_threshold.
54
_SAMPLE_SIZE_FOR_THRESHOLD = flags.DEFINE_integer(
55
    "sample_size_for_threshold",
56
    500,
57
    "Size for the sample used to determine the fairness threshold.",
58
)
59
_SAMPLE_RANK_FOR_THRESHOLD = flags.DEFINE_integer(
60
    "sample_rank_for_threshold", 50, "Rank of the distance for the threshold."
61
)
62
_MULTIPLIER_FOR_THRESHOLD = flags.DEFINE_float(
63
    "multiplier_for_threshold", 1.0, "Multiplier for the distance threshold."
64
)
65

66
# The next parameters affect the runs of the slower algorithms used as
67
# baselines. For all algorithms except LSPP, Greedy, VanillaKMeans,
68
# if the size of the dataset is larger than large_scale_dataset_size, we sample
69
# sample_size_for_slow_algorithms points and run the algorithm on them to find
70
# a solution. Then, of course, this solution is evaluated on the whole dataset.
71
_SAMPLE_SIZE_FOR_SLOW_ALGORITHMS = flags.DEFINE_integer(
72
    "sample_size_for_slow_algorithms",
73
    4000,
74
    "Number of elements used in the input for the slow algorithms",
75
)
76
_LARGE_SCALE_DATASET_SIZE = flags.DEFINE_integer(
77
    "large_scale_dataset_size",
78
    10000,
79
    "Number of elements of a dataset that require using sampling for slow "
80
    "algorithms",
81
)
82

83
_ALGORITHM = flags.DEFINE_string(
84
    "algorithm",
85
    "LSPP",
86
    "name of the algorithm among: LSPP, IMCL20, Greedy, VanillaKMeans",
87
)
88

89

90
def main(argv):
91
  del argv
92
  assert _INPUT.value
93
  assert _OUTPUT.value
94
  assert _K.value > 0
95

96
  dataset = fair_clustering_utils.ReadData(_INPUT.value)
97

98
  logging.info("Computing the thresholds")
99
  dist_threshold_vec = fair_clustering_utils.ComputeDistanceThreshold(
100
      dataset,
101
      _SAMPLE_SIZE_FOR_THRESHOLD.value,
102
      _SAMPLE_RANK_FOR_THRESHOLD.value,
103
      _MULTIPLIER_FOR_THRESHOLD.value,
104
  )
105
  logging.info("[Done] computing the thresholds")
106

107
  # since ICML20 is not scalable in case we use a large scale dataset
108
  # the algorithm is restricted to using subset of the elements.
109
  # But the evaluation is done in the whole dataset.
110
  slow_algos = set(["ICML20"])
111

112
  if (_ALGORITHM.value in slow_algos) and dataset.shape[
113
      0
114
  ] >= _LARGE_SCALE_DATASET_SIZE.value:
115
    logging.info("Using a sample as the algorithm is not scalable")
116
    input_positions = random.sample(
117
        list(range(dataset.shape[0])), _SAMPLE_SIZE_FOR_SLOW_ALGORITHMS.value
118
    )
119
    input_positions.sort()
120
    dataset_input = np.array([dataset[i] for i in input_positions])
121
    dist_threshold_vec_input = np.array(
122
        [dist_threshold_vec[i] for i in input_positions]
123
    )
124
  else:  # Using the full dataset and threshold
125
    logging.info("Using the full dataset")
126
    dataset_input = dataset
127
    dist_threshold_vec_input = dist_threshold_vec
128

129
  logging.info(
130
      "Algorithm starts, running on dataset of size %d", dataset_input.shape[0]
131
  )
132
  start = datetime.datetime.now()
133
  if _ALGORITHM.value == "LSPP":
134
    centers = fair_clustering_algorithms.LocalSearchPlusPlus(
135
        dataset=dataset_input,
136
        k=_K.value,
137
        dist_threshold_vec=dist_threshold_vec_input,
138
        coeff_anchor=3.0,
139
        coeff_search=1.0,
140
        number_of_iterations=5000,
141
        use_lloyd=True,
142
    )
143
  elif _ALGORITHM.value == "ICML20":
144
    centers = fair_clustering_algorithms.LocalSearchICML2020(
145
        dataset=dataset_input,
146
        k=_K.value,
147
        dist_threshold_vec=dist_threshold_vec_input,
148
        coeff_anchor=3.0,
149
        coeff_search=1.0,
150
        epsilon=0.01,
151
        use_lloyd=False,
152
    )
153
  elif _ALGORITHM.value == "Greedy":
154
    centers = fair_clustering_algorithms.Greedy(
155
        dataset=dataset_input,
156
        k=_K.value,
157
        dist_threshold_vec=dist_threshold_vec_input,
158
        coeff_anchor=3.0,
159
    )
160
  elif _ALGORITHM.value == "VanillaKMeans":
161
    centers = fair_clustering_algorithms.VanillaKMeans(
162
        dataset=dataset_input, k=_K.value
163
    )
164
  else:
165
    raise RuntimeError("Algorithm not supported")
166
  end = datetime.datetime.now()
167
  duration = end - start
168
  logging.info("Algorithm completes.")
169

170
  # notice that in any case the evaluation is done on the whole dataset.
171
  results = {
172
      "k": _K.value,
173
      "time": duration.total_seconds(),
174
      "k-means-cost": fair_clustering_utils.KMeansCost(dataset, centers),
175
      "max-bound-ratio": fair_clustering_utils.MaxFairnessCost(
176
          dataset, centers, dist_threshold_vec
177
      ),
178
      "algorithm": _ALGORITHM.value,
179
  }
180
  logging.info(results)
181

182
  with open(_OUTPUT.value, "w") as outfile:
183
    json.dump(results, outfile)
184

185

186
if __name__ == "__main__":
187
  app.run(main)
188

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.