google-research
113 строк · 4.0 Кб
1// Copyright 2023 The Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "better_greedy_algorithm.h"
16
17#include <algorithm>
18#include <cassert>
19#include <iterator>
20#include <memory>
21#include <string>
22#include <vector>
23
24#include "fairness_constraint.h"
25#include "matroid.h"
26#include "matroid_intersection.h"
27#include "submodular_function.h"
28
29BetterGreedyAlgorithm::BetterGreedyAlgorithm(bool minimal)
30: minimal_(minimal) {}
31
32void BetterGreedyAlgorithm::Init(const SubmodularFunction& sub_func_f,
33const FairnessConstraint& fairness,
34const Matroid& matroid) {
35GreedyAlgorithm::Init(sub_func_f, fairness, matroid);
36function_by_color_.clear();
37for (int i = 0; i < fairness_->GetColorNum(); ++i) {
38function_by_color_.push_back(sub_func_f_->Clone());
39}
40}
41
42void BetterGreedyAlgorithm::Insert(int element) {
43const int color = fairness_->GetColor(element);
44Matroid* matroid = matroid_by_color_[color].get();
45SubmodularFunction* sub_func = function_by_color_[color].get();
46if (matroid->CanAdd(element)) {
47matroid->Add(element);
48sub_func->Add(element);
49} else {
50const std::vector<int> all_swaps = matroid->GetAllSwaps(element);
51if (all_swaps.empty()) {
52return;
53}
54const int best_swap = *std::min_element(
55all_swaps.begin(), all_swaps.end(), [&sub_func](int lhs, int rhs) {
56return sub_func->RemovalDeltaAndIncreaseOracleCall(lhs) <
57sub_func->RemovalDeltaAndIncreaseOracleCall(rhs);
58});
59if (sub_func->RemovalDeltaAndIncreaseOracleCall(best_swap) <
60sub_func->DeltaAndIncreaseOracleCall(element)) {
61matroid->Swap(element, best_swap);
62sub_func->Swap(element, best_swap);
63}
64}
65}
66
67double BetterGreedyAlgorithm::GetSolutionValue() {
68// Get feasible solution.
69std::vector<int> all_elements;
70for (int i = 0; i < matroid_by_color_.size(); ++i) {
71const Matroid* matroid = matroid_by_color_[i].get();
72const std::vector<int> elements = matroid->GetCurrent();
73all_elements.insert(all_elements.end(), elements.begin(), elements.end());
74}
75MaxIntersection(matroid_.get(), fairness_matroid_.get(), all_elements);
76std::vector<int> solution = matroid_->GetCurrent();
77assert(fairness_->IsFeasible(solution));
78
79if (!minimal_) {
80// Populate fairness_ and sub_func_f_.
81for (int element : solution) {
82fairness_->Add(element);
83sub_func_f_->Add(element);
84}
85
86// Add more elements greedily.
87std::vector<int> elements_left;
88std::sort(all_elements.begin(), all_elements.end());
89std::sort(solution.begin(), solution.end());
90std::set_difference(all_elements.begin(), all_elements.end(),
91solution.begin(), solution.end(),
92std::inserter(elements_left, elements_left.begin()));
93std::sort(elements_left.begin(), elements_left.end(),
94[this](int lhs, int rhs) {
95return sub_func_f_->DeltaAndIncreaseOracleCall(lhs) >
96sub_func_f_->DeltaAndIncreaseOracleCall(rhs);
97});
98for (int element : elements_left) {
99if (fairness_->CanAdd(element) && matroid_->CanAdd(element)) {
100fairness_->Add(element);
101matroid_->Add(element);
102solution.push_back(element);
103}
104}
105}
106
107solution_ = solution;
108return sub_func_f_->ObjectiveAndIncreaseOracleCall(solution_);
109}
110
111std::string BetterGreedyAlgorithm::GetAlgorithmName() const {
112return "Slightly better greedy algorithm";
113}
114