google-research
83 строки · 2.8 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "generator_test_util.h"
16
17#include "random_generator.h"
18#include "absl/memory/memory.h"
19
20namespace automl_zero {
21
22using ::absl::make_unique;
23using ::std::mt19937; // NOLINT
24using ::std::shared_ptr;
25using ::std::vector;
26
27Algorithm SimpleNoOpAlgorithm() {
28Generator generator(NO_OP_ALGORITHM, // Irrelevant.
296, // setup_size_init
303, // predict_size_init
319, // learn_size_init
32{}, {}, {}, nullptr, nullptr); // Irrelevant.
33return generator.NoOp();
34}
35
36Algorithm SimpleRandomAlgorithm() {
37mt19937 bit_gen;
38RandomGenerator rand_gen(&bit_gen);
39Generator generator(
40RANDOM_ALGORITHM, // Irrelevant.
416, // setup_size_init
423, // predict_size_init
439, // learn_size_init
44{SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
45{SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
46{SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
47&bit_gen,
48&rand_gen);
49return generator.Random();
50}
51
52Algorithm SimpleGz() {
53Generator generator(NO_OP_ALGORITHM, 0, 0, 0, {}, {}, {}, nullptr, nullptr);
54return generator.LinearModel(kDefaultLearningRate);
55}
56
57Algorithm SimpleGrTildeGrWithBias() {
58Generator generator(NO_OP_ALGORITHM, 0, 0, 0, {}, {}, {}, nullptr, nullptr);
59return generator.NeuralNet(
60kDefaultLearningRate, kDefaultInitScale, kDefaultInitScale);
61}
62
63void SetIncreasingDataInComponentFunction(
64vector<shared_ptr<const Instruction>>* component_function) {
65for (IntegerT position = 0;
66position < component_function->size();
67++position) {
68auto instruction =
69make_unique<Instruction>(*(*component_function)[position]);
70instruction->SetIntegerData(position);
71(*component_function)[position].reset(instruction.release());
72}
73}
74
75Algorithm SimpleIncreasingDataAlgorithm() {
76Algorithm algorithm = SimpleNoOpAlgorithm();
77SetIncreasingDataInComponentFunction(&algorithm.setup_);
78SetIncreasingDataInComponentFunction(&algorithm.predict_);
79SetIncreasingDataInComponentFunction(&algorithm.learn_);
80return algorithm;
81}
82
83} // namespace automl_zero
84