google-research
163 строки · 5.2 Кб
1// Copyright 2023 The Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "matroid_intersection.h"
16
17#include <cassert>
18#include <iostream>
19#include <map>
20#include <queue>
21#include <set>
22#include <vector>
23
24#include "absl/container/btree_map.h"
25#include "matroid.h"
26#include "submodular_function.h"
27
28void MaxIntersection(Matroid* matroid_a, Matroid* matroid_b,
29const std::vector<int>& elements) {
30matroid_a->Reset();
31matroid_b->Reset();
32// Adjacency lists;
33std::map<int, std::vector<int>> exchange_graph;
34while (true) {
35// Greedily add elements to the solution;
36for (int element : elements) {
37if (matroid_a->InCurrent(element)) {
38continue;
39}
40if (matroid_a->CanAdd(element) && matroid_b->CanAdd(element)) {
41matroid_a->Add(element);
42matroid_b->Add(element);
43}
44}
45
46// Construct the exchange graph.
47exchange_graph.clear();
48for (int element : elements) {
49if (matroid_a->InCurrent(element)) {
50continue;
51}
52for (int a_swap : matroid_a->GetAllSwaps(element)) {
53exchange_graph[a_swap].push_back(element);
54}
55for (int b_swap : matroid_b->GetAllSwaps(element)) {
56exchange_graph[element].push_back(b_swap);
57}
58}
59
60// Find an augmenting path via BFS.
61std::map<int, int> bfs_parent;
62std::queue<int> queue;
63int aug_path_dest = -1;
64for (int element : elements) {
65if (matroid_a->InCurrent(element)) {
66continue;
67}
68if (matroid_a->CanAdd(element)) {
69bfs_parent[element] = -1;
70queue.push(element);
71}
72}
73while (!queue.empty()) {
74const int element = queue.front();
75queue.pop();
76if (!matroid_b->InCurrent(element) && matroid_b->CanAdd(element)) {
77aug_path_dest = element;
78break;
79}
80for (int neighbor : exchange_graph[element]) {
81if (!bfs_parent.count(neighbor)) {
82bfs_parent[neighbor] = element;
83queue.push(neighbor);
84}
85}
86}
87
88if (aug_path_dest == -1) {
89// No augmenting path found.
90break;
91}
92
93// Swap along the augmenting path.
94std::cerr << "we are applying an augmenting path" << std::endl;
95int out_element = aug_path_dest;
96int in_element = bfs_parent[aug_path_dest];
97while (in_element != -1) {
98matroid_a->Swap(out_element, in_element);
99matroid_b->Swap(out_element, in_element);
100out_element = bfs_parent[in_element];
101in_element = bfs_parent[out_element];
102}
103matroid_a->Add(out_element);
104matroid_b->Add(out_element);
105}
106
107assert(matroid_a->CurrentIsFeasible());
108assert(matroid_b->CurrentIsFeasible());
109}
110
111// Returns if an element is needed to be removed from `matroid_` to insert
112// `element`. Returns "-1" if no element is needed to be remove and "-2" if
113// the element cannot be swapped.
114int MinWeightElementToRemove(Matroid* matroid,
115absl::btree_map<int, double>& weight,
116const std::set<int>& const_elements,
117const int element) {
118if (matroid->CanAdd(element)) {
119return -1;
120}
121int best_element = -2;
122for (const int& swap : matroid->GetAllSwaps(element)) {
123if (const_elements.find(swap) != const_elements.end()) continue;
124if (best_element < 0 || weight[best_element] > weight[swap]) {
125best_element = swap;
126}
127}
128return best_element;
129}
130
131void SubMaxIntersection(Matroid* matroid_a, Matroid* matroid_b,
132SubmodularFunction* sub_func_f,
133const std::set<int>& const_elements,
134const std::vector<int>& universe) {
135// DO NOT reset the matroids here.
136absl::btree_map<int, double> weight;
137for (const int& element : universe) {
138if (const_elements.count(element)) continue; // don't add const_elements
139int first_swap =
140MinWeightElementToRemove(matroid_a, weight, const_elements, element);
141int second_swap =
142MinWeightElementToRemove(matroid_b, weight, const_elements, element);
143if (first_swap == -2 || second_swap == -2) continue;
144double total_decrease = weight[first_swap] + weight[second_swap];
145double cont_element = sub_func_f->DeltaAndIncreaseOracleCall(element);
146if (2 * total_decrease <= cont_element) {
147if (first_swap >= 0) {
148matroid_a->Remove(first_swap);
149matroid_b->Remove(first_swap);
150sub_func_f->Remove(first_swap);
151}
152if (second_swap >= 0 && first_swap != second_swap) {
153matroid_a->Remove(second_swap);
154matroid_b->Remove(second_swap);
155sub_func_f->Remove(second_swap);
156}
157matroid_a->Add(element);
158matroid_b->Add(element);
159sub_func_f->Add(element);
160weight[element] = cont_element;
161}
162}
163}
164