google-research

Форк
0
/
better_greedy_algorithm.cc 
113 строк · 4.0 Кб
1
// Copyright 2023 The Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "better_greedy_algorithm.h"
16

17
#include <algorithm>
18
#include <cassert>
19
#include <iterator>
20
#include <memory>
21
#include <string>
22
#include <vector>
23

24
#include "fairness_constraint.h"
25
#include "matroid.h"
26
#include "matroid_intersection.h"
27
#include "submodular_function.h"
28

29
BetterGreedyAlgorithm::BetterGreedyAlgorithm(bool minimal)
30
    : minimal_(minimal) {}
31

32
void BetterGreedyAlgorithm::Init(const SubmodularFunction& sub_func_f,
33
                                 const FairnessConstraint& fairness,
34
                                 const Matroid& matroid) {
35
  GreedyAlgorithm::Init(sub_func_f, fairness, matroid);
36
  function_by_color_.clear();
37
  for (int i = 0; i < fairness_->GetColorNum(); ++i) {
38
    function_by_color_.push_back(sub_func_f_->Clone());
39
  }
40
}
41

42
void BetterGreedyAlgorithm::Insert(int element) {
43
  const int color = fairness_->GetColor(element);
44
  Matroid* matroid = matroid_by_color_[color].get();
45
  SubmodularFunction* sub_func = function_by_color_[color].get();
46
  if (matroid->CanAdd(element)) {
47
    matroid->Add(element);
48
    sub_func->Add(element);
49
  } else {
50
    const std::vector<int> all_swaps = matroid->GetAllSwaps(element);
51
    if (all_swaps.empty()) {
52
      return;
53
    }
54
    const int best_swap = *std::min_element(
55
        all_swaps.begin(), all_swaps.end(), [&sub_func](int lhs, int rhs) {
56
          return sub_func->RemovalDeltaAndIncreaseOracleCall(lhs) <
57
                 sub_func->RemovalDeltaAndIncreaseOracleCall(rhs);
58
        });
59
    if (sub_func->RemovalDeltaAndIncreaseOracleCall(best_swap) <
60
        sub_func->DeltaAndIncreaseOracleCall(element)) {
61
      matroid->Swap(element, best_swap);
62
      sub_func->Swap(element, best_swap);
63
    }
64
  }
65
}
66

67
double BetterGreedyAlgorithm::GetSolutionValue() {
68
  // Get feasible solution.
69
  std::vector<int> all_elements;
70
  for (int i = 0; i < matroid_by_color_.size(); ++i) {
71
    const Matroid* matroid = matroid_by_color_[i].get();
72
    const std::vector<int> elements = matroid->GetCurrent();
73
    all_elements.insert(all_elements.end(), elements.begin(), elements.end());
74
  }
75
  MaxIntersection(matroid_.get(), fairness_matroid_.get(), all_elements);
76
  std::vector<int> solution = matroid_->GetCurrent();
77
  assert(fairness_->IsFeasible(solution));
78

79
  if (!minimal_) {
80
    // Populate fairness_ and sub_func_f_.
81
    for (int element : solution) {
82
      fairness_->Add(element);
83
      sub_func_f_->Add(element);
84
    }
85

86
    // Add more elements greedily.
87
    std::vector<int> elements_left;
88
    std::sort(all_elements.begin(), all_elements.end());
89
    std::sort(solution.begin(), solution.end());
90
    std::set_difference(all_elements.begin(), all_elements.end(),
91
                        solution.begin(), solution.end(),
92
                        std::inserter(elements_left, elements_left.begin()));
93
    std::sort(elements_left.begin(), elements_left.end(),
94
              [this](int lhs, int rhs) {
95
                return sub_func_f_->DeltaAndIncreaseOracleCall(lhs) >
96
                       sub_func_f_->DeltaAndIncreaseOracleCall(rhs);
97
              });
98
    for (int element : elements_left) {
99
      if (fairness_->CanAdd(element) && matroid_->CanAdd(element)) {
100
        fairness_->Add(element);
101
        matroid_->Add(element);
102
        solution.push_back(element);
103
      }
104
    }
105
  }
106

107
  solution_ = solution;
108
  return sub_func_f_->ObjectiveAndIncreaseOracleCall(solution_);
109
}
110

111
std::string BetterGreedyAlgorithm::GetAlgorithmName() const {
112
  return "Slightly better greedy algorithm";
113
}
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.