google-research
62 строки · 2.1 Кб
1// Copyright 2023 The Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "greedy_algorithm.h"16
17#include <memory>18#include <string>19#include <vector>20
21#include "fairness_constraint.h"22#include "matroid.h"23#include "matroid_intersection.h"24#include "submodular_function.h"25
26void GreedyAlgorithm::Init(const SubmodularFunction& sub_func_f,27const FairnessConstraint& fairness,28const Matroid& matroid) {29Algorithm::Init(sub_func_f, fairness, matroid);30matroid_by_color_.clear();31solution_.clear();32for (int i = 0; i < fairness_->GetColorNum(); ++i) {33matroid_by_color_.push_back(matroid_->Clone());34}35fairness_matroid_ = fairness_->LowerBoundsToMatroid();36}
37
38void GreedyAlgorithm::Insert(int element) {39const int color = fairness_->GetColor(element);40Matroid* matroid = matroid_by_color_[color].get();41if (matroid->CanAdd(element)) {42matroid->Add(element);43}44}
45
46double GreedyAlgorithm::GetSolutionValue() {47std::vector<int> all_elements;48for (int i = 0; i < matroid_by_color_.size(); ++i) {49const Matroid* matroid = matroid_by_color_[i].get();50const std::vector<int> elements = matroid->GetCurrent();51all_elements.insert(all_elements.end(), elements.begin(), elements.end());52}53MaxIntersection(matroid_.get(), fairness_matroid_.get(), all_elements);54solution_ = matroid_->GetCurrent();55return sub_func_f_->ObjectiveAndIncreaseOracleCall(solution_);56}
57
58std::vector<int> GreedyAlgorithm::GetSolutionVector() { return solution_; }59
60std::string GreedyAlgorithm::GetAlgorithmName() const {61return "Basic greedy algorithm";62}
63