google-research
187 строк · 6.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main file for the fair clustering algorithm."""
17
18import datetime19import json20import random21from typing import Sequence22
23from absl import app24from absl import flags25from absl import logging26import numpy as np27
28from individually_fair_clustering import fair_clustering_algorithms29from individually_fair_clustering import fair_clustering_utils30
31_INPUT = flags.DEFINE_string(32"input",33"",34"Path to the input file. The input file must contain one line per point,"35" with each line being a tab separated vector represented as the entries,"36" each entry as a floating point number.",37)
38_OUTPUT = flags.DEFINE_string(39"output",40"",41"Path to the output file. The output file is a json encoded in text format"42" containing the results of one run of the algorithm.",43)
44_K = flags.DEFINE_integer("k", 5, "Number of centers for the clustering.")45
46# The following parameters affect how the fairness constraints are defined
47# (for all algorithms).
48# In order to allow the efficient definition of a the fairness distance bound
49# for each point we do not do all pairs distance computations.
50# Instead we use a sample of points. For each point p we define the threshold
51# d(p) of the maximum distance that is allowed for a center near p to be the
52# distance of the sample_rank_for_threshold-th point closest to be among
53# sample_size_for_threshold sampled points times multiplier_for_threshold.
54_SAMPLE_SIZE_FOR_THRESHOLD = flags.DEFINE_integer(55"sample_size_for_threshold",56500,57"Size for the sample used to determine the fairness threshold.",58)
59_SAMPLE_RANK_FOR_THRESHOLD = flags.DEFINE_integer(60"sample_rank_for_threshold", 50, "Rank of the distance for the threshold."61)
62_MULTIPLIER_FOR_THRESHOLD = flags.DEFINE_float(63"multiplier_for_threshold", 1.0, "Multiplier for the distance threshold."64)
65
66# The next parameters affect the runs of the slower algorithms used as
67# baselines. For all algorithms except LSPP, Greedy, VanillaKMeans,
68# if the size of the dataset is larger than large_scale_dataset_size, we sample
69# sample_size_for_slow_algorithms points and run the algorithm on them to find
70# a solution. Then, of course, this solution is evaluated on the whole dataset.
71_SAMPLE_SIZE_FOR_SLOW_ALGORITHMS = flags.DEFINE_integer(72"sample_size_for_slow_algorithms",734000,74"Number of elements used in the input for the slow algorithms",75)
76_LARGE_SCALE_DATASET_SIZE = flags.DEFINE_integer(77"large_scale_dataset_size",7810000,79"Number of elements of a dataset that require using sampling for slow "80"algorithms",81)
82
83_ALGORITHM = flags.DEFINE_string(84"algorithm",85"LSPP",86"name of the algorithm among: LSPP, IMCL20, Greedy, VanillaKMeans",87)
88
89
90def main(argv):91del argv92assert _INPUT.value93assert _OUTPUT.value94assert _K.value > 095
96dataset = fair_clustering_utils.ReadData(_INPUT.value)97
98logging.info("Computing the thresholds")99dist_threshold_vec = fair_clustering_utils.ComputeDistanceThreshold(100dataset,101_SAMPLE_SIZE_FOR_THRESHOLD.value,102_SAMPLE_RANK_FOR_THRESHOLD.value,103_MULTIPLIER_FOR_THRESHOLD.value,104)105logging.info("[Done] computing the thresholds")106
107# since ICML20 is not scalable in case we use a large scale dataset108# the algorithm is restricted to using subset of the elements.109# But the evaluation is done in the whole dataset.110slow_algos = set(["ICML20"])111
112if (_ALGORITHM.value in slow_algos) and dataset.shape[1130114] >= _LARGE_SCALE_DATASET_SIZE.value:115logging.info("Using a sample as the algorithm is not scalable")116input_positions = random.sample(117list(range(dataset.shape[0])), _SAMPLE_SIZE_FOR_SLOW_ALGORITHMS.value118)119input_positions.sort()120dataset_input = np.array([dataset[i] for i in input_positions])121dist_threshold_vec_input = np.array(122[dist_threshold_vec[i] for i in input_positions]123)124else: # Using the full dataset and threshold125logging.info("Using the full dataset")126dataset_input = dataset127dist_threshold_vec_input = dist_threshold_vec128
129logging.info(130"Algorithm starts, running on dataset of size %d", dataset_input.shape[0]131)132start = datetime.datetime.now()133if _ALGORITHM.value == "LSPP":134centers = fair_clustering_algorithms.LocalSearchPlusPlus(135dataset=dataset_input,136k=_K.value,137dist_threshold_vec=dist_threshold_vec_input,138coeff_anchor=3.0,139coeff_search=1.0,140number_of_iterations=5000,141use_lloyd=True,142)143elif _ALGORITHM.value == "ICML20":144centers = fair_clustering_algorithms.LocalSearchICML2020(145dataset=dataset_input,146k=_K.value,147dist_threshold_vec=dist_threshold_vec_input,148coeff_anchor=3.0,149coeff_search=1.0,150epsilon=0.01,151use_lloyd=False,152)153elif _ALGORITHM.value == "Greedy":154centers = fair_clustering_algorithms.Greedy(155dataset=dataset_input,156k=_K.value,157dist_threshold_vec=dist_threshold_vec_input,158coeff_anchor=3.0,159)160elif _ALGORITHM.value == "VanillaKMeans":161centers = fair_clustering_algorithms.VanillaKMeans(162dataset=dataset_input, k=_K.value163)164else:165raise RuntimeError("Algorithm not supported")166end = datetime.datetime.now()167duration = end - start168logging.info("Algorithm completes.")169
170# notice that in any case the evaluation is done on the whole dataset.171results = {172"k": _K.value,173"time": duration.total_seconds(),174"k-means-cost": fair_clustering_utils.KMeansCost(dataset, centers),175"max-bound-ratio": fair_clustering_utils.MaxFairnessCost(176dataset, centers, dist_threshold_vec177),178"algorithm": _ALGORITHM.value,179}180logging.info(results)181
182with open(_OUTPUT.value, "w") as outfile:183json.dump(results, outfile)184
185
186if __name__ == "__main__":187app.run(main)188