google-research

Форк
0
215 строк · 7.6 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "evaluator.h"
16

17
#include <algorithm>
18
#include <iomanip>
19
#include <ios>
20
#include <limits>
21
#include <memory>
22
#include <string>
23

24
#include "task.h"
25
#include "task_util.h"
26
#include "task.pb.h"
27
#include "definitions.h"
28
#include "executor.h"
29
#include "random_generator.h"
30
#include "train_budget.h"
31
#include "google/protobuf/text_format.h"
32
#include "absl/algorithm/container.h"
33
#include "absl/flags/flag.h"
34
#include "absl/memory/memory.h"
35
#include "absl/strings/str_cat.h"
36

37
namespace automl_zero {
38

39
using ::absl::c_linear_search;  // NOLINT
40
using ::absl::GetFlag;  // NOLINT
41
using ::absl::make_unique;  // NOLINT
42
using ::std::cout;  // NOLINT
43
using ::std::endl;  // NOLINT
44
using ::std::fixed;  // NOLINT
45
using ::std::make_shared;  // NOLINT
46
using ::std::min;  // NOLINT
47
using ::std::mt19937;  // NOLINT
48
using ::std::nth_element;  // NOLINT
49
using ::std::pair;  // NOLINT
50
using ::std::setprecision;  // NOLINT
51
using ::std::vector;  // NOLINT
52
using ::std::unique_ptr;  // NOLINT
53
using internal::CombineFitnesses;
54

55
constexpr IntegerT kMinNumTrainExamples = 10;
56
constexpr RandomSeedT kFunctionalCacheRandomSeed = 235732282;
57

58
Evaluator::Evaluator(const FitnessCombinationMode fitness_combination_mode,
59
                     const TaskCollection& task_collection,
60
                     RandomGenerator* rand_gen,
61
                     FECCache* functional_cache,
62
                     TrainBudget* train_budget,
63
                     const double max_abs_error)
64
    : fitness_combination_mode_(fitness_combination_mode),
65
      task_collection_(task_collection),
66
      train_budget_(train_budget),
67
      rand_gen_(rand_gen),
68
      functional_cache_(functional_cache),
69
      functional_cache_bit_gen_owned_(
70
          make_unique<mt19937>(kFunctionalCacheRandomSeed)),
71
      functional_cache_rand_gen_owned_(
72
          make_unique<RandomGenerator>(functional_cache_bit_gen_owned_.get())),
73
      functional_cache_rand_gen_(functional_cache_rand_gen_owned_.get()),
74
      best_fitness_(-1.0),
75
      max_abs_error_(max_abs_error),
76
      num_train_steps_completed_(0) {
77
  FillTasks(task_collection_, &tasks_);
78
  CHECK_GT(tasks_.size(), 0);
79
}
80

81
double Evaluator::Evaluate(const Algorithm& algorithm) {
82
  // Compute the mean fitness across all tasks.
83
  vector<double> task_fitnesses;
84
  task_fitnesses.reserve(tasks_.size());
85
  vector<double> debug_fitnesses;
86
  vector<IntegerT> debug_num_train_examples;
87
  vector<IntegerT> task_indexes;  // Tasks to use.
88
  // Use all the tasks.
89
  for (IntegerT i = 0; i < tasks_.size(); ++i) {
90
    task_indexes.push_back(i);
91
  }
92
  for (IntegerT task_index : task_indexes) {
93
    const unique_ptr<TaskInterface>& task = tasks_[task_index];
94
    CHECK_GE(task->MaxTrainExamples(), kMinNumTrainExamples);
95
    const IntegerT num_train_examples =
96
        train_budget_ == nullptr ?
97
        task->MaxTrainExamples() :
98
        train_budget_->TrainExamples(algorithm, task->MaxTrainExamples());
99
    double curr_fitness = -1.0;
100
    curr_fitness = Execute(*task, num_train_examples, algorithm);
101
    task_fitnesses.push_back(curr_fitness);
102
  }
103
  double combined_fitness =
104
      CombineFitnesses(task_fitnesses, fitness_combination_mode_);
105

106
  CHECK_GE(combined_fitness, kMinFitness);
107
  CHECK_LE(combined_fitness, kMaxFitness);
108

109
  return combined_fitness;
110
}
111

112
double Evaluator::Execute(const TaskInterface& task,
113
                          const IntegerT num_train_examples,
114
                          const Algorithm& algorithm) {
115
  switch (task.FeaturesSize()) {
116
    case 2: {
117
      const Task<2>& downcasted_task = *SafeDowncast<2>(&task);
118
      return ExecuteImpl<2>(downcasted_task, num_train_examples, algorithm);
119
    }
120
    case 4: {
121
      const Task<4>& downcasted_task = *SafeDowncast<4>(&task);
122
      return ExecuteImpl<4>(downcasted_task, num_train_examples, algorithm);
123
    }
124
    case 8: {
125
      const Task<8>& downcasted_task = *SafeDowncast<8>(&task);
126
      return ExecuteImpl<8>(downcasted_task, num_train_examples, algorithm);
127
    }
128
    case 16: {
129
      const Task<16>& downcasted_task = *SafeDowncast<16>(&task);
130
      return ExecuteImpl<16>(downcasted_task, num_train_examples, algorithm);
131
    }
132
    case 32: {
133
      const Task<32>& downcasted_task = *SafeDowncast<32>(&task);
134
      return ExecuteImpl<32>(downcasted_task, num_train_examples, algorithm);
135
    }
136
    default:
137
      LOG(FATAL) << "Unsupported features size." << endl;
138
  }
139
}
140

141
IntegerT Evaluator::GetNumTrainStepsCompleted() const {
142
  return num_train_steps_completed_;
143
}
144

145
template <FeatureIndexT F>
146
double Evaluator::ExecuteImpl(const Task<F>& task,
147
                              const IntegerT num_train_examples,
148
                              const Algorithm& algorithm) {
149
  if (functional_cache_ != nullptr) {
150
    CHECK_LE(functional_cache_->NumTrainExamples(), task.MaxTrainExamples());
151
    CHECK_LE(functional_cache_->NumValidExamples(), task.ValidSteps());
152
    functional_cache_bit_gen_owned_->seed(kFunctionalCacheRandomSeed);
153
    Executor<F> functional_cache_executor(
154
        algorithm, task, functional_cache_->NumTrainExamples(),
155
        functional_cache_->NumValidExamples(), functional_cache_rand_gen_,
156
        max_abs_error_);
157
    vector<double> train_errors;
158
    vector<double> valid_errors;
159
    functional_cache_executor.Execute(&train_errors, &valid_errors);
160
    num_train_steps_completed_ +=
161
        functional_cache_executor.GetNumTrainStepsCompleted();
162
    const size_t hash = functional_cache_->Hash(
163
        train_errors, valid_errors, task.index_, num_train_examples);
164
    pair<double, bool> fitness_and_found = functional_cache_->Find(hash);
165
    if (fitness_and_found.second) {
166
      // Cache hit.
167
      functional_cache_->UpdateOrDie(hash, fitness_and_found.first);
168
      return fitness_and_found.first;
169
    } else {
170
      // Cache miss.
171
      Executor<F> executor(algorithm, task, num_train_examples,
172
                           task.ValidSteps(), rand_gen_, max_abs_error_);
173
      double fitness = executor.Execute();
174
      num_train_steps_completed_ += executor.GetNumTrainStepsCompleted();
175
      functional_cache_->InsertOrDie(hash, fitness);
176
      return fitness;
177
    }
178
  } else {
179
    Executor<F> executor(
180
        algorithm, task, num_train_examples, task.ValidSteps(),
181
        rand_gen_, max_abs_error_);
182
    const double fitness = executor.Execute();
183
    num_train_steps_completed_ += executor.GetNumTrainStepsCompleted();
184
    return fitness;
185
  }
186
}
187

188
namespace internal {
189

190
double Median(vector<double> values) {  // Intentional copy.
191
  const size_t half_num_values = values.size() / 2;
192
  nth_element(values.begin(), values.begin() + half_num_values, values.end());
193
  return values[half_num_values];
194
}
195

196
double CombineFitnesses(
197
    const vector<double>& task_fitnesses,
198
    const FitnessCombinationMode mode) {
199
  if (mode == MEAN_FITNESS_COMBINATION) {
200
    double combined_fitness = 0.0;
201
    for (const double fitness : task_fitnesses) {
202
      combined_fitness += fitness;
203
    }
204
    combined_fitness /= static_cast<double>(task_fitnesses.size());
205
    return combined_fitness;
206
  } else if (mode == MEDIAN_FITNESS_COMBINATION) {
207
    return Median(task_fitnesses);
208
  } else {
209
    LOG(FATAL) << "Unsupported fitness combination." << endl;
210
  }
211
}
212

213
}  // namespace internal
214

215
}  // namespace automl_zero
216

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.