google-research

Форк
0
/
run_search_experiment.cc 
258 строк · 9.3 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
// Runs the RegularizedEvolution algorithm locally.
16

17
#include <algorithm>
18
#include <iostream>
19
#include <limits>
20
#include <memory>
21
#include <random>
22

23

24
#include "algorithm.h"
25
#include "task_util.h"
26
#include "task.pb.h"
27
#include "definitions.h"
28
#include "instruction.pb.h"
29
#include "evaluator.h"
30
#include "experiment.pb.h"
31
#include "experiment_util.h"
32
#include "fec_cache.h"
33
#include "generator.h"
34
#include "mutator.h"
35
#include "random_generator.h"
36
#include "regularized_evolution.h"
37
#include "train_budget.h"
38
#include "google/protobuf/text_format.h"
39
#include "absl/flags/flag.h"
40
#include "absl/flags/parse.h"
41
#include "absl/time/time.h"
42

43
typedef automl_zero::IntegerT IntegerT;
44
typedef automl_zero::RandomSeedT RandomSeedT;
45
typedef automl_zero::InstructionIndexT InstructionIndexT;
46

47
ABSL_FLAG(
48
    std::string, search_experiment_spec, "",
49
    "Specification for the experiment. Must be an SearchExperimentSpec "
50
    "proto in text-format. Required.");
51
ABSL_FLAG(
52
    std::string, final_tasks, "",
53
    "The tasks to use for the final evaluation. Must be a TaskCollection "
54
    "proto in text format. Required.");
55
ABSL_FLAG(
56
    IntegerT, max_experiments, 1,
57
    "Number of experiments to run. The code may end up running fewer "
58
    "if `sufficient_fitness` is set. If `0`, runs indefinitely.");
59
ABSL_FLAG(
60
    RandomSeedT, random_seed, 0,
61
    "Seed for random generator. Use `0` to not specify a seed (creates a new "
62
    "seed each time). If running multiple experiments, this seed is set at the "
63
    "beginning of the first experiment. Does not affect tasks.");
64
ABSL_FLAG(
65
    bool, randomize_task_seeds, false,
66
    "If true, the data in T_search and T_select is randomized for every "
67
    "experiment (including the first one). That is, any seeds specified in "
68
    "the search_tasks inside the search_experiment_spec or in the "
69
    "select_tasks are ignored. (Seeds in final_tasks are still "
70
    "respected, however).");
71
ABSL_FLAG(
72
    std::string, select_tasks, "",
73
    "The tasks to use in T_select. Must be a TaskCollection proto "
74
    "in text-format. Required.");
75
ABSL_FLAG(
76
    double, sufficient_fitness, std::numeric_limits<double>::max(),
77
    "Experimentation stops when any experiment reaches this select fitness. "
78
    "If not specified, keeps experimenting until max_experiments is reached.");
79

80
namespace automl_zero {
81

82
namespace {
83
using ::absl::GetCurrentTimeNanos;  // NOLINT
84
using ::absl::GetFlag;  // NOLINT
85
using ::absl::make_unique;  // NOLINT
86
using ::std::cout;  // NOLINT
87
using ::std::endl;  // NOLINT
88
using ::std::make_shared;  // NOLINT
89
using ::std::mt19937;  // NOLINT
90
using ::std::numeric_limits;  // NOLINT
91
using ::std::shared_ptr;  // NOLINT
92
using ::std::unique_ptr;  // NOLINT
93
using ::std::vector;  // NOLINT
94
}  // namespace
95

96
void run() {
97
  // Set random seed.
98
  RandomSeedT random_seed = GetFlag(FLAGS_random_seed);
99
  if (random_seed == 0) {
100
    random_seed = GenerateRandomSeed();
101
  }
102
  mt19937 bit_gen(random_seed);
103
  RandomGenerator rand_gen(&bit_gen);
104
  cout << "Random seed = " << random_seed << endl;
105

106
  // Build reusable search and select structures.
107
  CHECK(!GetFlag(FLAGS_search_experiment_spec).empty());
108
  auto experiment_spec = ParseTextFormat<SearchExperimentSpec>(
109
      GetFlag(FLAGS_search_experiment_spec));
110
  const double sufficient_fitness = GetFlag(FLAGS_sufficient_fitness);
111
  const IntegerT max_experiments = GetFlag(FLAGS_max_experiments);
112
  Generator generator(
113
      experiment_spec.initial_population(),
114
      experiment_spec.setup_size_init(),
115
      experiment_spec.predict_size_init(),
116
      experiment_spec.learn_size_init(),
117
      ExtractOps(experiment_spec.setup_ops()),
118
      ExtractOps(experiment_spec.predict_ops()),
119
      ExtractOps(experiment_spec.learn_ops()), &bit_gen,
120
      &rand_gen);
121
  unique_ptr<TrainBudget> train_budget;
122
  if (experiment_spec.has_train_budget()) {
123
    train_budget =
124
        BuildTrainBudget(experiment_spec.train_budget(), &generator);
125
  }
126
  Mutator mutator(
127
      experiment_spec.allowed_mutation_types(),
128
      experiment_spec.mutate_prob(),
129
      ExtractOps(experiment_spec.setup_ops()),
130
      ExtractOps(experiment_spec.predict_ops()),
131
      ExtractOps(experiment_spec.learn_ops()),
132
      experiment_spec.mutate_setup_size_min(),
133
      experiment_spec.mutate_setup_size_max(),
134
      experiment_spec.mutate_predict_size_min(),
135
      experiment_spec.mutate_predict_size_max(),
136
      experiment_spec.mutate_learn_size_min(),
137
      experiment_spec.mutate_learn_size_max(),
138
      &bit_gen, &rand_gen);
139
  auto select_tasks =
140
      ParseTextFormat<TaskCollection>(GetFlag(FLAGS_select_tasks));
141

142
  // Run search experiments and select best algorithm.
143
  IntegerT num_experiments = 0;
144
  double best_select_fitness = numeric_limits<double>::lowest();
145
  shared_ptr<const Algorithm> best_algorithm = make_shared<const Algorithm>();
146
  while (true) {
147
    // Randomize T_search tasks.
148
    if (GetFlag(FLAGS_randomize_task_seeds)) {
149
      RandomizeTaskSeeds(experiment_spec.mutable_search_tasks(),
150
                            rand_gen.UniformRandomSeed());
151
    }
152

153
    // Build non-reusable search structures.
154
    unique_ptr<FECCache> functional_cache =
155
        experiment_spec.has_fec() ?
156
            make_unique<FECCache>(experiment_spec.fec()) :
157
            nullptr;
158
    Evaluator evaluator(
159
        experiment_spec.fitness_combination_mode(),
160
        experiment_spec.search_tasks(),
161
        &rand_gen, functional_cache.get(), train_budget.get(),
162
        experiment_spec.max_abs_error());
163
    RegularizedEvolution regularized_evolution(
164
        &rand_gen, experiment_spec.population_size(),
165
        experiment_spec.tournament_size(),
166
        experiment_spec.progress_every(),
167
        &generator, &evaluator, &mutator);
168

169
    // Run one experiment.
170
    cout << "Running evolution experiment (on the T_search tasks)..." << endl;
171
    regularized_evolution.Init();
172
    const IntegerT remaining_train_steps =
173
        experiment_spec.max_train_steps() -
174
        regularized_evolution.NumTrainSteps();
175
    regularized_evolution.Run(remaining_train_steps, kUnlimitedTime);
176
    cout << "Experiment done. Retrieving candidate algorithm." << endl;
177

178
    // Extract best algorithm based on T_search.
179
    double unused_pop_mean, unused_pop_stdev, search_fitness;
180
    shared_ptr<const Algorithm> candidate_algorithm =
181
        make_shared<const Algorithm>();
182
    regularized_evolution.PopulationStats(
183
        &unused_pop_mean, &unused_pop_stdev,
184
        &candidate_algorithm, &search_fitness);
185
    cout << "Search fitness for candidate algorithm = "
186
         << search_fitness << endl;
187

188
    // Randomize T_select tasks.
189
    if (GetFlag(FLAGS_randomize_task_seeds)) {
190
      RandomizeTaskSeeds(&select_tasks, rand_gen.UniformRandomSeed());
191
    }
192
    mt19937 select_bit_gen(rand_gen.UniformRandomSeed());
193
    RandomGenerator select_rand_gen(&select_bit_gen);
194

195
    // Keep track of the best model on the T_select tasks.
196
    cout << "Evaluating candidate algorithm from experiment "
197
         << "(on T_select tasks)... " << endl;
198
    Evaluator select_evaluator(
199
        MEAN_FITNESS_COMBINATION,
200
        select_tasks,
201
        &select_rand_gen,
202
        nullptr,  // functional_cache
203
        nullptr,  // train_budget
204
        experiment_spec.max_abs_error());
205
    const double select_fitness =
206
        select_evaluator.Evaluate(*candidate_algorithm);
207
    cout << "Select fitness for candidate algorithm = "
208
         << select_fitness << endl;
209
    if (select_fitness >= best_select_fitness) {
210
      best_select_fitness = select_fitness;
211
      best_algorithm = candidate_algorithm;
212
      cout << "Select fitness is the best so far. " << endl;
213
    }
214

215
    // Consider stopping experiments.
216
    if (sufficient_fitness > 0.0 &&
217
        best_select_fitness > sufficient_fitness) {
218
      // Stop if we reached the specified `sufficient_fitness`.
219
      break;
220
    }
221
    ++num_experiments;
222
    if (max_experiments != 0 && num_experiments >= max_experiments) {
223
      // Stop if we reached the maximum number of experiments.
224
      break;
225
    }
226
  }
227

228
  // Do a final evaluation on unseen tasks.
229
  cout << endl;
230
  cout << "Final evaluation of best algorithm "
231
       << "(on unseen tasks)..." << endl;
232
  const auto final_tasks =
233
      ParseTextFormat<TaskCollection>(GetFlag(FLAGS_final_tasks));
234
  mt19937 final_bit_gen(rand_gen.UniformRandomSeed());
235
  RandomGenerator final_rand_gen(&final_bit_gen);
236
  Evaluator final_evaluator(
237
      MEAN_FITNESS_COMBINATION,
238
      final_tasks,
239
      &final_rand_gen,
240
      nullptr,  // functional_cache
241
      nullptr,  // train_budget
242
      experiment_spec.max_abs_error());
243
  const double final_fitness =
244
      final_evaluator.Evaluate(*best_algorithm);
245

246
  cout << "Final evaluation fitness (on unseen data) = "
247
       << final_fitness << endl;
248
  cout << "Algorithm found: " << endl
249
       << best_algorithm->ToReadable() << endl;
250
}
251

252
}  // namespace automl_zero
253

254
int main(int argc, char** argv) {
255
  absl::ParseCommandLine(argc, argv);
256
  automl_zero::run();
257
  return 0;
258
}
259

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.