google-research
392 строки · 12.7 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "mutator.h"16
17#include <memory>18#include <vector>19
20#include "definitions.h"21#include "random_generator.h"22#include "absl/memory/memory.h"23
24namespace automl_zero {25
26using ::absl::make_unique; // NOLINT27using ::std::endl; // NOLINT28using ::std::make_shared; // NOLINT29using ::std::mt19937; // NOLINT30using ::std::shared_ptr; // NOLINT31using ::std::vector; // NOLINT32
33Mutator::Mutator(34const MutationTypeList& allowed_actions,35const double mutate_prob,36const vector<Op>& allowed_setup_ops,37const vector<Op>& allowed_predict_ops,38const vector<Op>& allowed_learn_ops,39const IntegerT setup_size_min,40const IntegerT setup_size_max,41const IntegerT predict_size_min,42const IntegerT predict_size_max,43const IntegerT learn_size_min,44const IntegerT learn_size_max,45mt19937* bit_gen,46RandomGenerator* rand_gen)47: allowed_actions_(allowed_actions),48mutate_prob_(mutate_prob),49allowed_setup_ops_(allowed_setup_ops),50allowed_predict_ops_(allowed_predict_ops),51allowed_learn_ops_(allowed_learn_ops),52mutate_setup_(!allowed_setup_ops_.empty()),53mutate_predict_(!allowed_predict_ops_.empty()),54mutate_learn_(!allowed_learn_ops_.empty()),55setup_size_min_(setup_size_min),56setup_size_max_(setup_size_max),57predict_size_min_(predict_size_min),58predict_size_max_(predict_size_max),59learn_size_min_(learn_size_min),60learn_size_max_(learn_size_max),61bit_gen_(bit_gen),62rand_gen_(rand_gen),63randomizer_(64allowed_setup_ops_,65allowed_predict_ops_,66allowed_learn_ops_,67bit_gen_,68rand_gen_) {}69
70vector<MutationType> ConvertToMutationType(71const vector<IntegerT>& mutation_actions_as_ints) {72vector<MutationType> mutation_actions;73mutation_actions.reserve(mutation_actions_as_ints.size());74for (const IntegerT action_as_int : mutation_actions_as_ints) {75mutation_actions.push_back(static_cast<MutationType>(action_as_int));76}77return mutation_actions;78}
79
80void Mutator::Mutate(shared_ptr<const Algorithm>* algorithm) {81if (mutate_prob_ >= 1.0 || rand_gen_->UniformProbability() < mutate_prob_) {82auto mutated = make_unique<Algorithm>(**algorithm);83MutateImpl(mutated.get());84algorithm->reset(mutated.release());85}86}
87
88void Mutator::Mutate(const IntegerT num_mutations,89shared_ptr<const Algorithm>* algorithm) {90if (mutate_prob_ >= 1.0 || rand_gen_->UniformProbability() < mutate_prob_) {91auto mutated = make_unique<Algorithm>(**algorithm);92for (IntegerT i = 0; i < num_mutations; ++i) {93MutateImpl(mutated.get());94}95algorithm->reset(mutated.release());96}97}
98
99Mutator::Mutator()100: allowed_actions_(ParseTextFormat<MutationTypeList>(101"mutation_types: [ "102" ALTER_PARAM_MUTATION_TYPE, "103" RANDOMIZE_INSTRUCTION_MUTATION_TYPE, "104" RANDOMIZE_COMPONENT_FUNCTION_MUTATION_TYPE "105"]")),106mutate_prob_(0.5),107allowed_setup_ops_(108{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP}),109allowed_predict_ops_(110{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP}),111allowed_learn_ops_(112{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP}),113mutate_setup_(!allowed_setup_ops_.empty()),114mutate_predict_(!allowed_predict_ops_.empty()),115mutate_learn_(!allowed_learn_ops_.empty()),116setup_size_min_(2),117setup_size_max_(4),118predict_size_min_(3),119predict_size_max_(5),120learn_size_min_(4),121learn_size_max_(6),122bit_gen_owned_(make_unique<mt19937>(GenerateRandomSeed())),123bit_gen_(bit_gen_owned_.get()),124rand_gen_owned_(make_unique<RandomGenerator>(bit_gen_)),125rand_gen_(rand_gen_owned_.get()),126randomizer_(127allowed_setup_ops_,128allowed_predict_ops_,129allowed_learn_ops_,130bit_gen_,131rand_gen_) {}132
133void Mutator::MutateImpl(Algorithm* algorithm) {134CHECK(!allowed_actions_.mutation_types().empty());135const size_t action_index =136absl::Uniform<size_t>(*bit_gen_, 0,137allowed_actions_.mutation_types_size());138const MutationType action = allowed_actions_.mutation_types(action_index);139switch (action) {140case ALTER_PARAM_MUTATION_TYPE:141AlterParam(algorithm);142return;143case RANDOMIZE_INSTRUCTION_MUTATION_TYPE:144RandomizeInstruction(algorithm);145return;146case RANDOMIZE_COMPONENT_FUNCTION_MUTATION_TYPE:147RandomizeComponentFunction(algorithm);148return;149case IDENTITY_MUTATION_TYPE:150return;151case INSERT_INSTRUCTION_MUTATION_TYPE:152InsertInstruction(algorithm);153return;154case REMOVE_INSTRUCTION_MUTATION_TYPE:155RemoveInstruction(algorithm);156return;157case TRADE_INSTRUCTION_MUTATION_TYPE:158TradeInstruction(algorithm);159return;160case RANDOMIZE_ALGORITHM_MUTATION_TYPE:161RandomizeAlgorithm(algorithm);162return;163// Do not add a default clause here. All actions should be supported.164}165}
166
167void Mutator::AlterParam(Algorithm* algorithm) {168switch (ComponentFunction()) {169case kSetupComponentFunction: {170if (!algorithm->setup_.empty()) {171InstructionIndexT index = InstructionIndex(algorithm->setup_.size());172algorithm->setup_[index] =173make_shared<const Instruction>(174*algorithm->setup_[index], rand_gen_);175}176return;177}178case kPredictComponentFunction: {179if (!algorithm->predict_.empty()) {180InstructionIndexT index = InstructionIndex(algorithm->predict_.size());181algorithm->predict_[index] =182make_shared<const Instruction>(183*algorithm->predict_[index], rand_gen_);184}185return;186}187case kLearnComponentFunction: {188if (!algorithm->learn_.empty()) {189InstructionIndexT index = InstructionIndex(algorithm->learn_.size());190algorithm->learn_[index] =191make_shared<const Instruction>(192*algorithm->learn_[index], rand_gen_);193}194return;195}196}197LOG(FATAL) << "Control flow should not reach here.";198}
199
200void Mutator::RandomizeInstruction(Algorithm* algorithm) {201switch (ComponentFunction()) {202case kSetupComponentFunction: {203if (!algorithm->setup_.empty()) {204InstructionIndexT index = InstructionIndex(algorithm->setup_.size());205algorithm->setup_[index] =206make_shared<const Instruction>(SetupOp(), rand_gen_);207}208return;209}210case kPredictComponentFunction: {211if (!algorithm->predict_.empty()) {212InstructionIndexT index = InstructionIndex(algorithm->predict_.size());213algorithm->predict_[index] =214make_shared<const Instruction>(PredictOp(), rand_gen_);215}216return;217}218case kLearnComponentFunction: {219if (!algorithm->learn_.empty()) {220InstructionIndexT index = InstructionIndex(algorithm->learn_.size());221algorithm->learn_[index] =222make_shared<const Instruction>(LearnOp(), rand_gen_);223}224return;225}226}227LOG(FATAL) << "Control flow should not reach here.";228}
229
230void Mutator::RandomizeComponentFunction(Algorithm* algorithm) {231switch (ComponentFunction()) {232case kSetupComponentFunction: {233randomizer_.RandomizeSetup(algorithm);234return;235}236case kPredictComponentFunction: {237randomizer_.RandomizePredict(algorithm);238return;239}240case kLearnComponentFunction: {241randomizer_.RandomizeLearn(algorithm);242return;243}244}245LOG(FATAL) << "Control flow should not reach here.";246}
247
248void Mutator::InsertInstruction(Algorithm* algorithm) {249Op op; // Operation for the new instruction.250vector<shared_ptr<const Instruction>>* component_function; // To modify.251switch (ComponentFunction()) {252case kSetupComponentFunction: {253if (algorithm->setup_.size() >= setup_size_max_ - 1) return;254op = SetupOp();255component_function = &algorithm->setup_;256break;257}258case kPredictComponentFunction: {259if (algorithm->predict_.size() >= predict_size_max_ - 1) return;260op = PredictOp();261component_function = &algorithm->predict_;262break;263}264case kLearnComponentFunction: {265if (algorithm->learn_.size() >= learn_size_max_ - 1) return;266op = LearnOp();267component_function = &algorithm->learn_;268break;269}270}271InsertInstructionUnconditionally(op, component_function);272}
273
274void Mutator::RemoveInstruction(Algorithm* algorithm) {275vector<shared_ptr<const Instruction>>* component_function; // To modify.276switch (ComponentFunction()) {277case kSetupComponentFunction: {278if (algorithm->setup_.size() <= setup_size_min_) return;279component_function = &algorithm->setup_;280break;281}282case kPredictComponentFunction: {283if (algorithm->predict_.size() <= predict_size_min_) return;284component_function = &algorithm->predict_;285break;286}287case kLearnComponentFunction: {288if (algorithm->learn_.size() <= learn_size_min_) return;289component_function = &algorithm->learn_;290break;291}292}293RemoveInstructionUnconditionally(component_function);294}
295
296void Mutator::TradeInstruction(Algorithm* algorithm) {297Op op; // Operation for the new instruction.298vector<shared_ptr<const Instruction>>* component_function; // To modify.299switch (ComponentFunction()) {300case kSetupComponentFunction: {301op = SetupOp();302component_function = &algorithm->setup_;303break;304}305case kPredictComponentFunction: {306op = PredictOp();307component_function = &algorithm->predict_;308break;309}310case kLearnComponentFunction: {311op = LearnOp();312component_function = &algorithm->learn_;313break;314}315}316InsertInstructionUnconditionally(op, component_function);317RemoveInstructionUnconditionally(component_function);318}
319
320void Mutator::RandomizeAlgorithm(Algorithm* algorithm) {321if (mutate_setup_) {322randomizer_.RandomizeSetup(algorithm);323}324if (mutate_predict_) {325randomizer_.RandomizePredict(algorithm);326}327if (mutate_learn_) {328randomizer_.RandomizeLearn(algorithm);329}330}
331
332void Mutator::InsertInstructionUnconditionally(333const Op op, vector<shared_ptr<const Instruction>>* component_function) {334const InstructionIndexT position =335InstructionIndex(component_function->size() + 1);336component_function->insert(337component_function->begin() + position,338make_shared<const Instruction>(op, rand_gen_));339}
340
341void Mutator::RemoveInstructionUnconditionally(342vector<shared_ptr<const Instruction>>* component_function) {343CHECK_GT(component_function->size(), 0);344const InstructionIndexT position =345InstructionIndex(component_function->size());346component_function->erase(component_function->begin() + position);347}
348
349Op Mutator::SetupOp() {350IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(351*bit_gen_, 0, allowed_setup_ops_.size());352return allowed_setup_ops_[op_index];353}
354
355Op Mutator::PredictOp() {356IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(357*bit_gen_, 0, allowed_predict_ops_.size());358return allowed_predict_ops_[op_index];359}
360
361Op Mutator::LearnOp() {362IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(363*bit_gen_, 0, allowed_learn_ops_.size());364return allowed_learn_ops_[op_index];365}
366
367InstructionIndexT Mutator::InstructionIndex(368const InstructionIndexT component_function_size) {369return absl::Uniform<InstructionIndexT>(370*bit_gen_, 0, component_function_size);371}
372
373ComponentFunctionT Mutator::ComponentFunction() {374vector<ComponentFunctionT> allowed_component_functions;375allowed_component_functions.reserve(4);376if (mutate_setup_) {377allowed_component_functions.push_back(kSetupComponentFunction);378}379if (mutate_predict_) {380allowed_component_functions.push_back(kPredictComponentFunction);381}382if (mutate_learn_) {383allowed_component_functions.push_back(kLearnComponentFunction);384}385CHECK(!allowed_component_functions.empty())386<< "Must mutate at least one component function." << endl;387const IntegerT index =388absl::Uniform<IntegerT>(*bit_gen_, 0, allowed_component_functions.size());389return allowed_component_functions[index];390}
391
392} // namespace automl_zero393