google-research

Форк
0
392 строки · 12.7 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "mutator.h"
16

17
#include <memory>
18
#include <vector>
19

20
#include "definitions.h"
21
#include "random_generator.h"
22
#include "absl/memory/memory.h"
23

24
namespace automl_zero {
25

26
using ::absl::make_unique;  // NOLINT
27
using ::std::endl;  // NOLINT
28
using ::std::make_shared;  // NOLINT
29
using ::std::mt19937;  // NOLINT
30
using ::std::shared_ptr;  // NOLINT
31
using ::std::vector;  // NOLINT
32

33
Mutator::Mutator(
34
    const MutationTypeList& allowed_actions,
35
    const double mutate_prob,
36
    const vector<Op>& allowed_setup_ops,
37
    const vector<Op>& allowed_predict_ops,
38
    const vector<Op>& allowed_learn_ops,
39
    const IntegerT setup_size_min,
40
    const IntegerT setup_size_max,
41
    const IntegerT predict_size_min,
42
    const IntegerT predict_size_max,
43
    const IntegerT learn_size_min,
44
    const IntegerT learn_size_max,
45
    mt19937* bit_gen,
46
    RandomGenerator* rand_gen)
47
    : allowed_actions_(allowed_actions),
48
      mutate_prob_(mutate_prob),
49
      allowed_setup_ops_(allowed_setup_ops),
50
      allowed_predict_ops_(allowed_predict_ops),
51
      allowed_learn_ops_(allowed_learn_ops),
52
      mutate_setup_(!allowed_setup_ops_.empty()),
53
      mutate_predict_(!allowed_predict_ops_.empty()),
54
      mutate_learn_(!allowed_learn_ops_.empty()),
55
      setup_size_min_(setup_size_min),
56
      setup_size_max_(setup_size_max),
57
      predict_size_min_(predict_size_min),
58
      predict_size_max_(predict_size_max),
59
      learn_size_min_(learn_size_min),
60
      learn_size_max_(learn_size_max),
61
      bit_gen_(bit_gen),
62
      rand_gen_(rand_gen),
63
      randomizer_(
64
          allowed_setup_ops_,
65
          allowed_predict_ops_,
66
          allowed_learn_ops_,
67
          bit_gen_,
68
          rand_gen_) {}
69

70
vector<MutationType> ConvertToMutationType(
71
    const vector<IntegerT>& mutation_actions_as_ints) {
72
  vector<MutationType> mutation_actions;
73
  mutation_actions.reserve(mutation_actions_as_ints.size());
74
  for (const IntegerT action_as_int : mutation_actions_as_ints) {
75
    mutation_actions.push_back(static_cast<MutationType>(action_as_int));
76
  }
77
  return mutation_actions;
78
}
79

80
void Mutator::Mutate(shared_ptr<const Algorithm>* algorithm) {
81
  if (mutate_prob_ >= 1.0 || rand_gen_->UniformProbability() < mutate_prob_) {
82
    auto mutated = make_unique<Algorithm>(**algorithm);
83
    MutateImpl(mutated.get());
84
    algorithm->reset(mutated.release());
85
  }
86
}
87

88
void Mutator::Mutate(const IntegerT num_mutations,
89
                     shared_ptr<const Algorithm>* algorithm) {
90
  if (mutate_prob_ >= 1.0 || rand_gen_->UniformProbability() < mutate_prob_) {
91
    auto mutated = make_unique<Algorithm>(**algorithm);
92
    for (IntegerT i = 0; i < num_mutations; ++i) {
93
      MutateImpl(mutated.get());
94
    }
95
    algorithm->reset(mutated.release());
96
  }
97
}
98

99
Mutator::Mutator()
100
    : allowed_actions_(ParseTextFormat<MutationTypeList>(
101
        "mutation_types: [ "
102
        "  ALTER_PARAM_MUTATION_TYPE, "
103
        "  RANDOMIZE_INSTRUCTION_MUTATION_TYPE, "
104
        "  RANDOMIZE_COMPONENT_FUNCTION_MUTATION_TYPE "
105
        "]")),
106
      mutate_prob_(0.5),
107
      allowed_setup_ops_(
108
          {NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP}),
109
      allowed_predict_ops_(
110
          {NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP}),
111
      allowed_learn_ops_(
112
          {NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP}),
113
      mutate_setup_(!allowed_setup_ops_.empty()),
114
      mutate_predict_(!allowed_predict_ops_.empty()),
115
      mutate_learn_(!allowed_learn_ops_.empty()),
116
      setup_size_min_(2),
117
      setup_size_max_(4),
118
      predict_size_min_(3),
119
      predict_size_max_(5),
120
      learn_size_min_(4),
121
      learn_size_max_(6),
122
      bit_gen_owned_(make_unique<mt19937>(GenerateRandomSeed())),
123
      bit_gen_(bit_gen_owned_.get()),
124
      rand_gen_owned_(make_unique<RandomGenerator>(bit_gen_)),
125
      rand_gen_(rand_gen_owned_.get()),
126
      randomizer_(
127
          allowed_setup_ops_,
128
          allowed_predict_ops_,
129
          allowed_learn_ops_,
130
          bit_gen_,
131
          rand_gen_) {}
132

133
void Mutator::MutateImpl(Algorithm* algorithm) {
134
  CHECK(!allowed_actions_.mutation_types().empty());
135
  const size_t action_index =
136
      absl::Uniform<size_t>(*bit_gen_, 0,
137
                            allowed_actions_.mutation_types_size());
138
  const MutationType action = allowed_actions_.mutation_types(action_index);
139
  switch (action) {
140
    case ALTER_PARAM_MUTATION_TYPE:
141
      AlterParam(algorithm);
142
      return;
143
    case RANDOMIZE_INSTRUCTION_MUTATION_TYPE:
144
      RandomizeInstruction(algorithm);
145
      return;
146
    case RANDOMIZE_COMPONENT_FUNCTION_MUTATION_TYPE:
147
      RandomizeComponentFunction(algorithm);
148
      return;
149
    case IDENTITY_MUTATION_TYPE:
150
      return;
151
    case INSERT_INSTRUCTION_MUTATION_TYPE:
152
      InsertInstruction(algorithm);
153
      return;
154
    case REMOVE_INSTRUCTION_MUTATION_TYPE:
155
      RemoveInstruction(algorithm);
156
      return;
157
    case TRADE_INSTRUCTION_MUTATION_TYPE:
158
      TradeInstruction(algorithm);
159
      return;
160
    case RANDOMIZE_ALGORITHM_MUTATION_TYPE:
161
      RandomizeAlgorithm(algorithm);
162
      return;
163
    // Do not add a default clause here. All actions should be supported.
164
  }
165
}
166

167
void Mutator::AlterParam(Algorithm* algorithm) {
168
  switch (ComponentFunction()) {
169
    case kSetupComponentFunction: {
170
      if (!algorithm->setup_.empty()) {
171
        InstructionIndexT index = InstructionIndex(algorithm->setup_.size());
172
        algorithm->setup_[index] =
173
            make_shared<const Instruction>(
174
                *algorithm->setup_[index], rand_gen_);
175
      }
176
      return;
177
    }
178
    case kPredictComponentFunction: {
179
      if (!algorithm->predict_.empty()) {
180
        InstructionIndexT index = InstructionIndex(algorithm->predict_.size());
181
        algorithm->predict_[index] =
182
            make_shared<const Instruction>(
183
                *algorithm->predict_[index], rand_gen_);
184
      }
185
      return;
186
    }
187
    case kLearnComponentFunction: {
188
      if (!algorithm->learn_.empty()) {
189
        InstructionIndexT index = InstructionIndex(algorithm->learn_.size());
190
        algorithm->learn_[index] =
191
            make_shared<const Instruction>(
192
                *algorithm->learn_[index], rand_gen_);
193
      }
194
      return;
195
    }
196
  }
197
  LOG(FATAL) << "Control flow should not reach here.";
198
}
199

200
void Mutator::RandomizeInstruction(Algorithm* algorithm) {
201
  switch (ComponentFunction()) {
202
    case kSetupComponentFunction: {
203
      if (!algorithm->setup_.empty()) {
204
        InstructionIndexT index = InstructionIndex(algorithm->setup_.size());
205
        algorithm->setup_[index] =
206
            make_shared<const Instruction>(SetupOp(), rand_gen_);
207
      }
208
      return;
209
    }
210
    case kPredictComponentFunction: {
211
      if (!algorithm->predict_.empty()) {
212
        InstructionIndexT index = InstructionIndex(algorithm->predict_.size());
213
        algorithm->predict_[index] =
214
            make_shared<const Instruction>(PredictOp(), rand_gen_);
215
      }
216
      return;
217
    }
218
    case kLearnComponentFunction: {
219
      if (!algorithm->learn_.empty()) {
220
        InstructionIndexT index = InstructionIndex(algorithm->learn_.size());
221
        algorithm->learn_[index] =
222
            make_shared<const Instruction>(LearnOp(), rand_gen_);
223
      }
224
      return;
225
    }
226
  }
227
  LOG(FATAL) << "Control flow should not reach here.";
228
}
229

230
void Mutator::RandomizeComponentFunction(Algorithm* algorithm) {
231
  switch (ComponentFunction()) {
232
    case kSetupComponentFunction: {
233
      randomizer_.RandomizeSetup(algorithm);
234
      return;
235
    }
236
    case kPredictComponentFunction: {
237
      randomizer_.RandomizePredict(algorithm);
238
      return;
239
    }
240
    case kLearnComponentFunction: {
241
      randomizer_.RandomizeLearn(algorithm);
242
      return;
243
    }
244
  }
245
  LOG(FATAL) << "Control flow should not reach here.";
246
}
247

248
void Mutator::InsertInstruction(Algorithm* algorithm) {
249
  Op op;  // Operation for the new instruction.
250
  vector<shared_ptr<const Instruction>>* component_function;  // To modify.
251
  switch (ComponentFunction()) {
252
    case kSetupComponentFunction: {
253
      if (algorithm->setup_.size() >= setup_size_max_ - 1) return;
254
      op = SetupOp();
255
      component_function = &algorithm->setup_;
256
      break;
257
    }
258
    case kPredictComponentFunction: {
259
      if (algorithm->predict_.size() >= predict_size_max_ - 1) return;
260
      op = PredictOp();
261
      component_function = &algorithm->predict_;
262
      break;
263
    }
264
    case kLearnComponentFunction: {
265
      if (algorithm->learn_.size() >= learn_size_max_ - 1) return;
266
      op = LearnOp();
267
      component_function = &algorithm->learn_;
268
      break;
269
    }
270
  }
271
  InsertInstructionUnconditionally(op, component_function);
272
}
273

274
void Mutator::RemoveInstruction(Algorithm* algorithm) {
275
  vector<shared_ptr<const Instruction>>* component_function;  // To modify.
276
  switch (ComponentFunction()) {
277
    case kSetupComponentFunction: {
278
      if (algorithm->setup_.size() <= setup_size_min_) return;
279
      component_function = &algorithm->setup_;
280
      break;
281
    }
282
    case kPredictComponentFunction: {
283
      if (algorithm->predict_.size() <= predict_size_min_) return;
284
      component_function = &algorithm->predict_;
285
      break;
286
    }
287
    case kLearnComponentFunction: {
288
      if (algorithm->learn_.size() <= learn_size_min_) return;
289
      component_function = &algorithm->learn_;
290
      break;
291
    }
292
  }
293
  RemoveInstructionUnconditionally(component_function);
294
}
295

296
void Mutator::TradeInstruction(Algorithm* algorithm) {
297
  Op op;  // Operation for the new instruction.
298
  vector<shared_ptr<const Instruction>>* component_function;  // To modify.
299
  switch (ComponentFunction()) {
300
    case kSetupComponentFunction: {
301
      op = SetupOp();
302
      component_function = &algorithm->setup_;
303
      break;
304
    }
305
    case kPredictComponentFunction: {
306
      op = PredictOp();
307
      component_function = &algorithm->predict_;
308
      break;
309
    }
310
    case kLearnComponentFunction: {
311
      op = LearnOp();
312
      component_function = &algorithm->learn_;
313
      break;
314
    }
315
  }
316
  InsertInstructionUnconditionally(op, component_function);
317
  RemoveInstructionUnconditionally(component_function);
318
}
319

320
void Mutator::RandomizeAlgorithm(Algorithm* algorithm) {
321
  if (mutate_setup_) {
322
    randomizer_.RandomizeSetup(algorithm);
323
  }
324
  if (mutate_predict_) {
325
    randomizer_.RandomizePredict(algorithm);
326
  }
327
  if (mutate_learn_) {
328
    randomizer_.RandomizeLearn(algorithm);
329
  }
330
}
331

332
void Mutator::InsertInstructionUnconditionally(
333
    const Op op, vector<shared_ptr<const Instruction>>* component_function) {
334
  const InstructionIndexT position =
335
      InstructionIndex(component_function->size() + 1);
336
  component_function->insert(
337
      component_function->begin() + position,
338
      make_shared<const Instruction>(op, rand_gen_));
339
}
340

341
void Mutator::RemoveInstructionUnconditionally(
342
    vector<shared_ptr<const Instruction>>* component_function) {
343
  CHECK_GT(component_function->size(), 0);
344
  const InstructionIndexT position =
345
      InstructionIndex(component_function->size());
346
  component_function->erase(component_function->begin() + position);
347
}
348

349
Op Mutator::SetupOp() {
350
  IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
351
      *bit_gen_, 0, allowed_setup_ops_.size());
352
  return allowed_setup_ops_[op_index];
353
}
354

355
Op Mutator::PredictOp() {
356
  IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
357
      *bit_gen_, 0, allowed_predict_ops_.size());
358
  return allowed_predict_ops_[op_index];
359
}
360

361
Op Mutator::LearnOp() {
362
  IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
363
      *bit_gen_, 0, allowed_learn_ops_.size());
364
  return allowed_learn_ops_[op_index];
365
}
366

367
InstructionIndexT Mutator::InstructionIndex(
368
    const InstructionIndexT component_function_size) {
369
  return absl::Uniform<InstructionIndexT>(
370
      *bit_gen_, 0, component_function_size);
371
}
372

373
ComponentFunctionT Mutator::ComponentFunction() {
374
  vector<ComponentFunctionT> allowed_component_functions;
375
  allowed_component_functions.reserve(4);
376
  if (mutate_setup_) {
377
    allowed_component_functions.push_back(kSetupComponentFunction);
378
  }
379
  if (mutate_predict_) {
380
    allowed_component_functions.push_back(kPredictComponentFunction);
381
  }
382
  if (mutate_learn_) {
383
    allowed_component_functions.push_back(kLearnComponentFunction);
384
  }
385
  CHECK(!allowed_component_functions.empty())
386
      << "Must mutate at least one component function." << endl;
387
  const IntegerT index =
388
      absl::Uniform<IntegerT>(*bit_gen_, 0, allowed_component_functions.size());
389
  return allowed_component_functions[index];
390
}
391

392
}  // namespace automl_zero
393

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.