google-research
89 строк · 2.6 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "randomizer.h"
16
17#include <memory>
18
19#include "algorithm.h"
20#include "random_generator.h"
21
22namespace automl_zero {
23
24using ::std::make_shared; // NOLINT
25using ::std::mt19937; // NOLINT
26using ::std::shared_ptr; // NOLINT
27using ::std::vector; // NOLINT
28
29Randomizer::Randomizer(
30vector<Op> allowed_setup_ops,
31vector<Op> allowed_predict_ops,
32vector<Op> allowed_learn_ops,
33mt19937* bit_gen,
34RandomGenerator* rand_gen)
35: allowed_setup_ops_(allowed_setup_ops),
36allowed_predict_ops_(allowed_predict_ops),
37allowed_learn_ops_(allowed_learn_ops),
38bit_gen_(bit_gen),
39rand_gen_(rand_gen) {}
40
41void Randomizer::Randomize(Algorithm* algorithm) {
42if (!allowed_setup_ops_.empty()) {
43RandomizeSetup(algorithm);
44}
45if (!allowed_predict_ops_.empty()) {
46RandomizePredict(algorithm);
47}
48if (!allowed_learn_ops_.empty()) {
49RandomizeLearn(algorithm);
50}
51}
52
53void Randomizer::RandomizeSetup(Algorithm* algorithm) {
54for (shared_ptr<const Instruction>& instruction : algorithm->setup_) {
55instruction = make_shared<const Instruction>(SetupOp(), rand_gen_);
56}
57}
58
59void Randomizer::RandomizePredict(Algorithm* algorithm) {
60for (shared_ptr<const Instruction>& instruction : algorithm->predict_) {
61instruction = make_shared<const Instruction>(PredictOp(), rand_gen_);
62}
63}
64
65void Randomizer::RandomizeLearn(Algorithm* algorithm) {
66for (shared_ptr<const Instruction>& instruction : algorithm->learn_) {
67instruction = make_shared<const Instruction>(LearnOp(), rand_gen_);
68}
69}
70
71Op Randomizer::SetupOp() {
72IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
73*bit_gen_, 0, allowed_setup_ops_.size());
74return allowed_setup_ops_[op_index];
75}
76
77Op Randomizer::PredictOp() {
78IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
79*bit_gen_, 0, allowed_predict_ops_.size());
80return allowed_predict_ops_[op_index];
81}
82
83Op Randomizer::LearnOp() {
84IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
85*bit_gen_, 0, allowed_learn_ops_.size());
86return allowed_learn_ops_[op_index];
87}
88
89} // namespace automl_zero
90