google-research

Форк
0
89 строк · 2.6 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "randomizer.h"
16

17
#include <memory>
18

19
#include "algorithm.h"
20
#include "random_generator.h"
21

22
namespace automl_zero {
23

24
using ::std::make_shared;  // NOLINT
25
using ::std::mt19937;  // NOLINT
26
using ::std::shared_ptr;  // NOLINT
27
using ::std::vector;  // NOLINT
28

29
Randomizer::Randomizer(
30
    vector<Op> allowed_setup_ops,
31
    vector<Op> allowed_predict_ops,
32
    vector<Op> allowed_learn_ops,
33
    mt19937* bit_gen,
34
    RandomGenerator* rand_gen)
35
    : allowed_setup_ops_(allowed_setup_ops),
36
      allowed_predict_ops_(allowed_predict_ops),
37
      allowed_learn_ops_(allowed_learn_ops),
38
      bit_gen_(bit_gen),
39
      rand_gen_(rand_gen) {}
40

41
void Randomizer::Randomize(Algorithm* algorithm) {
42
  if (!allowed_setup_ops_.empty()) {
43
    RandomizeSetup(algorithm);
44
  }
45
  if (!allowed_predict_ops_.empty()) {
46
    RandomizePredict(algorithm);
47
  }
48
  if (!allowed_learn_ops_.empty()) {
49
    RandomizeLearn(algorithm);
50
  }
51
}
52

53
void Randomizer::RandomizeSetup(Algorithm* algorithm) {
54
  for (shared_ptr<const Instruction>& instruction : algorithm->setup_) {
55
    instruction = make_shared<const Instruction>(SetupOp(), rand_gen_);
56
  }
57
}
58

59
void Randomizer::RandomizePredict(Algorithm* algorithm) {
60
  for (shared_ptr<const Instruction>& instruction : algorithm->predict_) {
61
    instruction = make_shared<const Instruction>(PredictOp(), rand_gen_);
62
  }
63
}
64

65
void Randomizer::RandomizeLearn(Algorithm* algorithm) {
66
  for (shared_ptr<const Instruction>& instruction : algorithm->learn_) {
67
    instruction = make_shared<const Instruction>(LearnOp(), rand_gen_);
68
  }
69
}
70

71
Op Randomizer::SetupOp() {
72
  IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
73
      *bit_gen_, 0, allowed_setup_ops_.size());
74
  return allowed_setup_ops_[op_index];
75
}
76

77
Op Randomizer::PredictOp() {
78
  IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
79
      *bit_gen_, 0, allowed_predict_ops_.size());
80
  return allowed_predict_ops_[op_index];
81
}
82

83
Op Randomizer::LearnOp() {
84
  IntegerT op_index = absl::Uniform<DeprecatedOpIndexT>(
85
      *bit_gen_, 0, allowed_learn_ops_.size());
86
  return allowed_learn_ops_[op_index];
87
}
88

89
}  // namespace automl_zero
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.