google-research
215 строк · 7.6 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "evaluator.h"16
17#include <algorithm>18#include <iomanip>19#include <ios>20#include <limits>21#include <memory>22#include <string>23
24#include "task.h"25#include "task_util.h"26#include "task.pb.h"27#include "definitions.h"28#include "executor.h"29#include "random_generator.h"30#include "train_budget.h"31#include "google/protobuf/text_format.h"32#include "absl/algorithm/container.h"33#include "absl/flags/flag.h"34#include "absl/memory/memory.h"35#include "absl/strings/str_cat.h"36
37namespace automl_zero {38
39using ::absl::c_linear_search; // NOLINT40using ::absl::GetFlag; // NOLINT41using ::absl::make_unique; // NOLINT42using ::std::cout; // NOLINT43using ::std::endl; // NOLINT44using ::std::fixed; // NOLINT45using ::std::make_shared; // NOLINT46using ::std::min; // NOLINT47using ::std::mt19937; // NOLINT48using ::std::nth_element; // NOLINT49using ::std::pair; // NOLINT50using ::std::setprecision; // NOLINT51using ::std::vector; // NOLINT52using ::std::unique_ptr; // NOLINT53using internal::CombineFitnesses;54
55constexpr IntegerT kMinNumTrainExamples = 10;56constexpr RandomSeedT kFunctionalCacheRandomSeed = 235732282;57
58Evaluator::Evaluator(const FitnessCombinationMode fitness_combination_mode,59const TaskCollection& task_collection,60RandomGenerator* rand_gen,61FECCache* functional_cache,62TrainBudget* train_budget,63const double max_abs_error)64: fitness_combination_mode_(fitness_combination_mode),65task_collection_(task_collection),66train_budget_(train_budget),67rand_gen_(rand_gen),68functional_cache_(functional_cache),69functional_cache_bit_gen_owned_(70make_unique<mt19937>(kFunctionalCacheRandomSeed)),71functional_cache_rand_gen_owned_(72make_unique<RandomGenerator>(functional_cache_bit_gen_owned_.get())),73functional_cache_rand_gen_(functional_cache_rand_gen_owned_.get()),74best_fitness_(-1.0),75max_abs_error_(max_abs_error),76num_train_steps_completed_(0) {77FillTasks(task_collection_, &tasks_);78CHECK_GT(tasks_.size(), 0);79}
80
81double Evaluator::Evaluate(const Algorithm& algorithm) {82// Compute the mean fitness across all tasks.83vector<double> task_fitnesses;84task_fitnesses.reserve(tasks_.size());85vector<double> debug_fitnesses;86vector<IntegerT> debug_num_train_examples;87vector<IntegerT> task_indexes; // Tasks to use.88// Use all the tasks.89for (IntegerT i = 0; i < tasks_.size(); ++i) {90task_indexes.push_back(i);91}92for (IntegerT task_index : task_indexes) {93const unique_ptr<TaskInterface>& task = tasks_[task_index];94CHECK_GE(task->MaxTrainExamples(), kMinNumTrainExamples);95const IntegerT num_train_examples =96train_budget_ == nullptr ?97task->MaxTrainExamples() :98train_budget_->TrainExamples(algorithm, task->MaxTrainExamples());99double curr_fitness = -1.0;100curr_fitness = Execute(*task, num_train_examples, algorithm);101task_fitnesses.push_back(curr_fitness);102}103double combined_fitness =104CombineFitnesses(task_fitnesses, fitness_combination_mode_);105
106CHECK_GE(combined_fitness, kMinFitness);107CHECK_LE(combined_fitness, kMaxFitness);108
109return combined_fitness;110}
111
112double Evaluator::Execute(const TaskInterface& task,113const IntegerT num_train_examples,114const Algorithm& algorithm) {115switch (task.FeaturesSize()) {116case 2: {117const Task<2>& downcasted_task = *SafeDowncast<2>(&task);118return ExecuteImpl<2>(downcasted_task, num_train_examples, algorithm);119}120case 4: {121const Task<4>& downcasted_task = *SafeDowncast<4>(&task);122return ExecuteImpl<4>(downcasted_task, num_train_examples, algorithm);123}124case 8: {125const Task<8>& downcasted_task = *SafeDowncast<8>(&task);126return ExecuteImpl<8>(downcasted_task, num_train_examples, algorithm);127}128case 16: {129const Task<16>& downcasted_task = *SafeDowncast<16>(&task);130return ExecuteImpl<16>(downcasted_task, num_train_examples, algorithm);131}132case 32: {133const Task<32>& downcasted_task = *SafeDowncast<32>(&task);134return ExecuteImpl<32>(downcasted_task, num_train_examples, algorithm);135}136default:137LOG(FATAL) << "Unsupported features size." << endl;138}139}
140
141IntegerT Evaluator::GetNumTrainStepsCompleted() const {142return num_train_steps_completed_;143}
144
145template <FeatureIndexT F>146double Evaluator::ExecuteImpl(const Task<F>& task,147const IntegerT num_train_examples,148const Algorithm& algorithm) {149if (functional_cache_ != nullptr) {150CHECK_LE(functional_cache_->NumTrainExamples(), task.MaxTrainExamples());151CHECK_LE(functional_cache_->NumValidExamples(), task.ValidSteps());152functional_cache_bit_gen_owned_->seed(kFunctionalCacheRandomSeed);153Executor<F> functional_cache_executor(154algorithm, task, functional_cache_->NumTrainExamples(),155functional_cache_->NumValidExamples(), functional_cache_rand_gen_,156max_abs_error_);157vector<double> train_errors;158vector<double> valid_errors;159functional_cache_executor.Execute(&train_errors, &valid_errors);160num_train_steps_completed_ +=161functional_cache_executor.GetNumTrainStepsCompleted();162const size_t hash = functional_cache_->Hash(163train_errors, valid_errors, task.index_, num_train_examples);164pair<double, bool> fitness_and_found = functional_cache_->Find(hash);165if (fitness_and_found.second) {166// Cache hit.167functional_cache_->UpdateOrDie(hash, fitness_and_found.first);168return fitness_and_found.first;169} else {170// Cache miss.171Executor<F> executor(algorithm, task, num_train_examples,172task.ValidSteps(), rand_gen_, max_abs_error_);173double fitness = executor.Execute();174num_train_steps_completed_ += executor.GetNumTrainStepsCompleted();175functional_cache_->InsertOrDie(hash, fitness);176return fitness;177}178} else {179Executor<F> executor(180algorithm, task, num_train_examples, task.ValidSteps(),181rand_gen_, max_abs_error_);182const double fitness = executor.Execute();183num_train_steps_completed_ += executor.GetNumTrainStepsCompleted();184return fitness;185}186}
187
188namespace internal {189
190double Median(vector<double> values) { // Intentional copy.191const size_t half_num_values = values.size() / 2;192nth_element(values.begin(), values.begin() + half_num_values, values.end());193return values[half_num_values];194}
195
196double CombineFitnesses(197const vector<double>& task_fitnesses,198const FitnessCombinationMode mode) {199if (mode == MEAN_FITNESS_COMBINATION) {200double combined_fitness = 0.0;201for (const double fitness : task_fitnesses) {202combined_fitness += fitness;203}204combined_fitness /= static_cast<double>(task_fitnesses.size());205return combined_fitness;206} else if (mode == MEDIAN_FITNESS_COMBINATION) {207return Median(task_fitnesses);208} else {209LOG(FATAL) << "Unsupported fitness combination." << endl;210}211}
212
213} // namespace internal214
215} // namespace automl_zero216