google-research
258 строк · 9.3 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Runs the RegularizedEvolution algorithm locally.
16
17#include <algorithm>
18#include <iostream>
19#include <limits>
20#include <memory>
21#include <random>
22
23
24#include "algorithm.h"
25#include "task_util.h"
26#include "task.pb.h"
27#include "definitions.h"
28#include "instruction.pb.h"
29#include "evaluator.h"
30#include "experiment.pb.h"
31#include "experiment_util.h"
32#include "fec_cache.h"
33#include "generator.h"
34#include "mutator.h"
35#include "random_generator.h"
36#include "regularized_evolution.h"
37#include "train_budget.h"
38#include "google/protobuf/text_format.h"
39#include "absl/flags/flag.h"
40#include "absl/flags/parse.h"
41#include "absl/time/time.h"
42
43typedef automl_zero::IntegerT IntegerT;
44typedef automl_zero::RandomSeedT RandomSeedT;
45typedef automl_zero::InstructionIndexT InstructionIndexT;
46
47ABSL_FLAG(
48std::string, search_experiment_spec, "",
49"Specification for the experiment. Must be an SearchExperimentSpec "
50"proto in text-format. Required.");
51ABSL_FLAG(
52std::string, final_tasks, "",
53"The tasks to use for the final evaluation. Must be a TaskCollection "
54"proto in text format. Required.");
55ABSL_FLAG(
56IntegerT, max_experiments, 1,
57"Number of experiments to run. The code may end up running fewer "
58"if `sufficient_fitness` is set. If `0`, runs indefinitely.");
59ABSL_FLAG(
60RandomSeedT, random_seed, 0,
61"Seed for random generator. Use `0` to not specify a seed (creates a new "
62"seed each time). If running multiple experiments, this seed is set at the "
63"beginning of the first experiment. Does not affect tasks.");
64ABSL_FLAG(
65bool, randomize_task_seeds, false,
66"If true, the data in T_search and T_select is randomized for every "
67"experiment (including the first one). That is, any seeds specified in "
68"the search_tasks inside the search_experiment_spec or in the "
69"select_tasks are ignored. (Seeds in final_tasks are still "
70"respected, however).");
71ABSL_FLAG(
72std::string, select_tasks, "",
73"The tasks to use in T_select. Must be a TaskCollection proto "
74"in text-format. Required.");
75ABSL_FLAG(
76double, sufficient_fitness, std::numeric_limits<double>::max(),
77"Experimentation stops when any experiment reaches this select fitness. "
78"If not specified, keeps experimenting until max_experiments is reached.");
79
80namespace automl_zero {
81
82namespace {
83using ::absl::GetCurrentTimeNanos; // NOLINT
84using ::absl::GetFlag; // NOLINT
85using ::absl::make_unique; // NOLINT
86using ::std::cout; // NOLINT
87using ::std::endl; // NOLINT
88using ::std::make_shared; // NOLINT
89using ::std::mt19937; // NOLINT
90using ::std::numeric_limits; // NOLINT
91using ::std::shared_ptr; // NOLINT
92using ::std::unique_ptr; // NOLINT
93using ::std::vector; // NOLINT
94} // namespace
95
96void run() {
97// Set random seed.
98RandomSeedT random_seed = GetFlag(FLAGS_random_seed);
99if (random_seed == 0) {
100random_seed = GenerateRandomSeed();
101}
102mt19937 bit_gen(random_seed);
103RandomGenerator rand_gen(&bit_gen);
104cout << "Random seed = " << random_seed << endl;
105
106// Build reusable search and select structures.
107CHECK(!GetFlag(FLAGS_search_experiment_spec).empty());
108auto experiment_spec = ParseTextFormat<SearchExperimentSpec>(
109GetFlag(FLAGS_search_experiment_spec));
110const double sufficient_fitness = GetFlag(FLAGS_sufficient_fitness);
111const IntegerT max_experiments = GetFlag(FLAGS_max_experiments);
112Generator generator(
113experiment_spec.initial_population(),
114experiment_spec.setup_size_init(),
115experiment_spec.predict_size_init(),
116experiment_spec.learn_size_init(),
117ExtractOps(experiment_spec.setup_ops()),
118ExtractOps(experiment_spec.predict_ops()),
119ExtractOps(experiment_spec.learn_ops()), &bit_gen,
120&rand_gen);
121unique_ptr<TrainBudget> train_budget;
122if (experiment_spec.has_train_budget()) {
123train_budget =
124BuildTrainBudget(experiment_spec.train_budget(), &generator);
125}
126Mutator mutator(
127experiment_spec.allowed_mutation_types(),
128experiment_spec.mutate_prob(),
129ExtractOps(experiment_spec.setup_ops()),
130ExtractOps(experiment_spec.predict_ops()),
131ExtractOps(experiment_spec.learn_ops()),
132experiment_spec.mutate_setup_size_min(),
133experiment_spec.mutate_setup_size_max(),
134experiment_spec.mutate_predict_size_min(),
135experiment_spec.mutate_predict_size_max(),
136experiment_spec.mutate_learn_size_min(),
137experiment_spec.mutate_learn_size_max(),
138&bit_gen, &rand_gen);
139auto select_tasks =
140ParseTextFormat<TaskCollection>(GetFlag(FLAGS_select_tasks));
141
142// Run search experiments and select best algorithm.
143IntegerT num_experiments = 0;
144double best_select_fitness = numeric_limits<double>::lowest();
145shared_ptr<const Algorithm> best_algorithm = make_shared<const Algorithm>();
146while (true) {
147// Randomize T_search tasks.
148if (GetFlag(FLAGS_randomize_task_seeds)) {
149RandomizeTaskSeeds(experiment_spec.mutable_search_tasks(),
150rand_gen.UniformRandomSeed());
151}
152
153// Build non-reusable search structures.
154unique_ptr<FECCache> functional_cache =
155experiment_spec.has_fec() ?
156make_unique<FECCache>(experiment_spec.fec()) :
157nullptr;
158Evaluator evaluator(
159experiment_spec.fitness_combination_mode(),
160experiment_spec.search_tasks(),
161&rand_gen, functional_cache.get(), train_budget.get(),
162experiment_spec.max_abs_error());
163RegularizedEvolution regularized_evolution(
164&rand_gen, experiment_spec.population_size(),
165experiment_spec.tournament_size(),
166experiment_spec.progress_every(),
167&generator, &evaluator, &mutator);
168
169// Run one experiment.
170cout << "Running evolution experiment (on the T_search tasks)..." << endl;
171regularized_evolution.Init();
172const IntegerT remaining_train_steps =
173experiment_spec.max_train_steps() -
174regularized_evolution.NumTrainSteps();
175regularized_evolution.Run(remaining_train_steps, kUnlimitedTime);
176cout << "Experiment done. Retrieving candidate algorithm." << endl;
177
178// Extract best algorithm based on T_search.
179double unused_pop_mean, unused_pop_stdev, search_fitness;
180shared_ptr<const Algorithm> candidate_algorithm =
181make_shared<const Algorithm>();
182regularized_evolution.PopulationStats(
183&unused_pop_mean, &unused_pop_stdev,
184&candidate_algorithm, &search_fitness);
185cout << "Search fitness for candidate algorithm = "
186<< search_fitness << endl;
187
188// Randomize T_select tasks.
189if (GetFlag(FLAGS_randomize_task_seeds)) {
190RandomizeTaskSeeds(&select_tasks, rand_gen.UniformRandomSeed());
191}
192mt19937 select_bit_gen(rand_gen.UniformRandomSeed());
193RandomGenerator select_rand_gen(&select_bit_gen);
194
195// Keep track of the best model on the T_select tasks.
196cout << "Evaluating candidate algorithm from experiment "
197<< "(on T_select tasks)... " << endl;
198Evaluator select_evaluator(
199MEAN_FITNESS_COMBINATION,
200select_tasks,
201&select_rand_gen,
202nullptr, // functional_cache
203nullptr, // train_budget
204experiment_spec.max_abs_error());
205const double select_fitness =
206select_evaluator.Evaluate(*candidate_algorithm);
207cout << "Select fitness for candidate algorithm = "
208<< select_fitness << endl;
209if (select_fitness >= best_select_fitness) {
210best_select_fitness = select_fitness;
211best_algorithm = candidate_algorithm;
212cout << "Select fitness is the best so far. " << endl;
213}
214
215// Consider stopping experiments.
216if (sufficient_fitness > 0.0 &&
217best_select_fitness > sufficient_fitness) {
218// Stop if we reached the specified `sufficient_fitness`.
219break;
220}
221++num_experiments;
222if (max_experiments != 0 && num_experiments >= max_experiments) {
223// Stop if we reached the maximum number of experiments.
224break;
225}
226}
227
228// Do a final evaluation on unseen tasks.
229cout << endl;
230cout << "Final evaluation of best algorithm "
231<< "(on unseen tasks)..." << endl;
232const auto final_tasks =
233ParseTextFormat<TaskCollection>(GetFlag(FLAGS_final_tasks));
234mt19937 final_bit_gen(rand_gen.UniformRandomSeed());
235RandomGenerator final_rand_gen(&final_bit_gen);
236Evaluator final_evaluator(
237MEAN_FITNESS_COMBINATION,
238final_tasks,
239&final_rand_gen,
240nullptr, // functional_cache
241nullptr, // train_budget
242experiment_spec.max_abs_error());
243const double final_fitness =
244final_evaluator.Evaluate(*best_algorithm);
245
246cout << "Final evaluation fitness (on unseen data) = "
247<< final_fitness << endl;
248cout << "Algorithm found: " << endl
249<< best_algorithm->ToReadable() << endl;
250}
251
252} // namespace automl_zero
253
254int main(int argc, char** argv) {
255absl::ParseCommandLine(argc, argv);
256automl_zero::run();
257return 0;
258}
259