google-research
162 строки · 5.1 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "random_generator.h"
16
17#include "definitions.h"
18#include "absl/memory/memory.h"
19#include "absl/random/distributions.h"
20#include "absl/time/clock.h"
21#include "absl/time/time.h"
22
23namespace automl_zero {
24
25using ::absl::GetCurrentTimeNanos;
26using ::absl::make_unique;
27using ::std::mt19937;
28using ::std::numeric_limits;
29using ::std::string;
30
31RandomGenerator::RandomGenerator(mt19937* bit_gen) : bit_gen_(bit_gen) {}
32
33float RandomGenerator::GaussianFloat(float mean, float stdev) {
34return ::absl::Gaussian<float>(*bit_gen_, mean, stdev);
35}
36
37IntegerT RandomGenerator::UniformInteger(IntegerT low, IntegerT high) {
38// TODO(ereal): change this to IntegerT and change the values provided by
39// LeanClient::PutGetAndCount. Probably affects random number generation.
40CHECK_GE(low, std::numeric_limits<int32_t>::min());
41CHECK_LE(high, std::numeric_limits<int32_t>::max());
42return ::absl::Uniform<int32_t>(*bit_gen_, low, high);
43}
44
45RandomSeedT RandomGenerator::UniformRandomSeed() {
46return absl::Uniform<RandomSeedT>(
47absl::IntervalOpen, *bit_gen_,
481, std::numeric_limits<RandomSeedT>::max());
49}
50
51double RandomGenerator::UniformDouble(double low, double high) {
52return ::absl::Uniform<double>(*bit_gen_, low, high);
53}
54
55float RandomGenerator::UniformFloat(float low, float high) {
56return ::absl::Uniform<float>(*bit_gen_, low, high);
57}
58
59ProbabilityT RandomGenerator::UniformProbability(
60const ProbabilityT low, const ProbabilityT high) {
61return ::absl::Uniform<ProbabilityT>(*bit_gen_, low, high);
62}
63
64string RandomGenerator::UniformString(const size_t size) {
65string random_string;
66for (size_t i = 0; i < size; ++i) {
67char random_char;
68const IntegerT char_index = UniformInteger(0, 64);
69if (char_index < 26) {
70random_char = char_index + 97; // Maps 0-25 to 'a'-'z'.
71} else if (char_index < 52) {
72random_char = char_index - 26 + 65; // Maps 26-51 to 'A'-'Z'.
73} else if (char_index < 62) {
74random_char = char_index - 52 + 48; // Maps 52-61 to '0'-'9'.
75} else if (char_index == 62) {
76random_char = '_';
77} else if (char_index == 63) {
78random_char = '~';
79} else {
80LOG(FATAL) << "Code should not get here." << std::endl;
81}
82random_string.push_back(random_char);
83}
84return random_string;
85}
86
87FeatureIndexT RandomGenerator::FeatureIndex(
88const FeatureIndexT features_size) {
89// TODO(ereal): below should have FeatureIndexT instead of InstructionIndexT;
90// affects random number generation.
91return absl::Uniform<InstructionIndexT>(*bit_gen_, 0, features_size);
92}
93
94AddressT RandomGenerator::ScalarInAddress() {
95return absl::Uniform<AddressT>(*bit_gen_, 0, kMaxScalarAddresses);
96}
97
98AddressT RandomGenerator::VectorInAddress() {
99return absl::Uniform<AddressT>(*bit_gen_, 0, kMaxVectorAddresses);
100}
101
102AddressT RandomGenerator::MatrixInAddress() {
103return absl::Uniform<AddressT>(*bit_gen_, 0, kMaxMatrixAddresses);
104}
105
106AddressT RandomGenerator::ScalarOutAddress() {
107return absl::Uniform<AddressT>(
108*bit_gen_, kFirstOutScalarAddress, kMaxScalarAddresses);
109}
110
111AddressT RandomGenerator::VectorOutAddress() {
112return absl::Uniform<AddressT>(
113*bit_gen_, kFirstOutVectorAddress, kMaxVectorAddresses);
114}
115
116AddressT RandomGenerator::MatrixOutAddress() {
117return absl::Uniform<AddressT>(
118*bit_gen_, kFirstOutMatrixAddress, kMaxMatrixAddresses);
119}
120
121Choice2T RandomGenerator::Choice2() {
122return static_cast<Choice2T>(absl::Uniform<IntegerT>(*bit_gen_, 0, 2));
123}
124
125Choice3T RandomGenerator::Choice3() {
126return static_cast<Choice3T>(absl::Uniform<IntegerT>(*bit_gen_, 0, 3));
127}
128
129IntegerT RandomGenerator::UniformPopulationSize(
130IntegerT high) {
131return static_cast<IntegerT>(absl::Uniform<uint32_t>(*bit_gen_, 0, high));
132}
133
134double RandomGenerator::UniformActivation(
135double low, double high) {
136return absl::Uniform<double>(absl::IntervalOpen, *bit_gen_, low, high);
137}
138
139double RandomGenerator::GaussianActivation(
140const double mean, const double stdev) {
141return ::absl::Gaussian<double>(*bit_gen_, mean, stdev);
142}
143
144double RandomGenerator::BetaActivation(
145const double alpha, const double beta) {
146return ::absl::Beta<double>(*bit_gen_, alpha, beta);
147}
148
149RandomGenerator::RandomGenerator()
150: bit_gen_owned_(make_unique<mt19937>(GenerateRandomSeed())),
151bit_gen_(bit_gen_owned_.get()) {}
152
153RandomSeedT GenerateRandomSeed() {
154RandomSeedT seed = 0;
155while (seed == 0) {
156seed = static_cast<RandomSeedT>(
157GetCurrentTimeNanos() % numeric_limits<RandomSeedT>::max());
158}
159return seed;
160}
161
162} // namespace automl_zero
163