google-research
158 строк · 5.4 Кб
1// Copyright 2023 The Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "two_pass_algorithm_with_conditioned_matroid.h"
16
17#include <algorithm>
18#include <iostream>
19#include <map>
20#include <memory>
21#include <string>
22#include <utility>
23#include <vector>
24
25#include "better_greedy_algorithm.h"
26#include "conditioned_matroid.h"
27#include "fairness_constraint.h"
28#include "matroid_intersection.h"
29#include "uniform_matroid.h"
30#include "utilities.h"
31
32void TwoPassAlgorithmWithConditionedMatroid::Init(
33const SubmodularFunction& sub_func_f, const FairnessConstraint& fairness,
34const Matroid& matroid) {
35Algorithm::Init(sub_func_f, fairness, matroid);
36bounds_ = fairness.GetBounds();
37universe_elements_.clear();
38}
39
40void TwoPassAlgorithmWithConditionedMatroid::Insert(int element) {
41universe_elements_.push_back(element);
42}
43
44void TwoPassAlgorithmWithConditionedMatroid::GreedyFirstPass() {
45matroid_->Reset();
46fairness_->Reset();
47sub_func_f_->Reset();
48BetterGreedyAlgorithm greedy(true);
49greedy.Init(*sub_func_f_, *fairness_, *matroid_);
50for (const int element : universe_elements_) {
51greedy.Insert(element);
52}
53double solution_value = greedy.GetSolutionValue();
54std::cout << "Solution value after greedy first pass: " << solution_value <<
55std::endl;
56first_round_solution_ = greedy.GetSolutionVector();
57}
58
59void TwoPassAlgorithmWithConditionedMatroid::FirstPass() {
60per_color_solutions_ =
61std::vector<std::vector<int>>(bounds_.size(), std::vector<int>());
62std::vector<std::unique_ptr<Matroid>> per_color_matroids;
63matroid_->Reset();
64per_color_matroids.reserve(bounds_.size());
65for (int i = 0; i < bounds_.size(); i++) {
66per_color_matroids.push_back(matroid_->Clone());
67}
68for (const int element : universe_elements_) {
69int color = fairness_->GetColor(element);
70if (per_color_matroids[color]->CanAdd(element)) {
71// && per_color_solutions_[color].size() < bounds_[color].first) {
72per_color_matroids[color]->Add(element);
73per_color_solutions_[color].push_back(element);
74}
75}
76}
77
78void TwoPassAlgorithmWithConditionedMatroid::FindFeasibleSolution() {
79std::vector<int> all_colors_solution;
80for (const auto& solution : per_color_solutions_) {
81all_colors_solution.insert(all_colors_solution.end(), solution.begin(),
82solution.end());
83}
84matroid_->Reset();
85fairness_->Reset();
86MaxIntersection(matroid_.get(), fairness_->LowerBoundsToMatroid().get(),
87all_colors_solution);
88first_round_solution_ = matroid_->GetCurrent();
89}
90
91void TwoPassAlgorithmWithConditionedMatroid::DivideSolution() {
92lower_bound_solutions_.clear();
93lower_bound_solutions_.push_back(std::vector<int>());
94lower_bound_solutions_.push_back(std::vector<int>());
95std::vector<int> num_picked_per_color(bounds_.size(), 0);
96for (const auto& element : first_round_solution_) {
97lower_bound_solutions_
98[(num_picked_per_color[fairness_->GetColor(element)]++) % 2]
99.push_back(element);
100}
101}
102
103std::vector<int> TwoPassAlgorithmWithConditionedMatroid::SecondPass(
104std::vector<int> start_solution) {
105matroid_->Reset();
106fairness_->Reset();
107sub_func_f_->Reset();
108weights_.clear();
109
110ConditionedMatroid condmatroid(*matroid_, start_solution);
111
112std::unique_ptr<Matroid> color_mat = fairness_->UpperBoundsToMatroid();
113color_mat->Reset();
114SubMaxIntersection(&condmatroid, color_mat.get(), sub_func_f_.get(), {},
115universe_elements_);
116
117std::vector<int> current_sol = color_mat->GetCurrent();
118std::vector<int> start_solution_not_chosen;
119for (int el : start_solution) {
120if (!color_mat->InCurrent(el)) {
121start_solution_not_chosen.push_back(el);
122}
123}
124
125// find the best subset of start_solution to add by
126// max {F(S U S_current) : S subset of S_start, S U S_current in I^C}.
127UniformMatroid dummy_mat(1'000'000'000);
128ConditionedMatroid cond_fairness(*color_mat, current_sol);
129// sub_func_f_ already has current sol.
130// dummy_mat has nothing.
131// cond_fairness also has nothing (it's reset when created).
132SubMaxIntersection(&dummy_mat, &cond_fairness, sub_func_f_.get(), {},
133start_solution_not_chosen);
134
135return append(current_sol, cond_fairness.GetCurrent());
136}
137
138double TwoPassAlgorithmWithConditionedMatroid::GetSolutionValue() {
139GreedyFirstPass();
140std::pair<std::vector<int>, double> answer[2];
141for (int i = 0; i < 2; ++i) {
142answer[i].first = SecondPass(lower_bound_solutions_[i]);
143answer[i].second =
144sub_func_f_->ObjectiveAndIncreaseOracleCall(answer[i].first);
145}
146
147final_solution_ =
148answer[0].second > answer[1].second ? answer[0].first : answer[1].first;
149return std::max(answer[0].second, answer[1].second);
150}
151
152std::vector<int> TwoPassAlgorithmWithConditionedMatroid::GetSolutionVector() {
153return final_solution_;
154}
155
156std::string TwoPassAlgorithmWithConditionedMatroid::GetAlgorithmName() const {
157return "Two pass algorithm (with CM)";
158}
159