google-research
340 строк · 13.3 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "generator.h"
16
17#include <functional>
18#include <limits>
19#include <random>
20#include <sstream>
21
22#include "algorithm_test_util.h"
23#include "task.h"
24#include "task_util.h"
25#include "definitions.h"
26#include "instruction.pb.h"
27#include "evaluator.h"
28#include "executor.h"
29#include "random_generator.h"
30#include "test_util.h"
31#include "util.h"
32#include "gmock/gmock.h"
33#include "gtest/gtest.h"
34#include "absl/strings/str_cat.h"
35
36namespace automl_zero {
37
38using ::absl::StrCat;
39using ::std::function;
40using ::std::mt19937;
41using test_only::GenerateTask;
42
43constexpr IntegerT kNumTrainExamples = 1000;
44constexpr IntegerT kNumValidExamples = 100;
45constexpr double kLargeMaxAbsError = 1000000000.0;
46
47TEST(GeneratorTest, NoOpHasNoOpInstructions) {
48Generator generator(
49NO_OP_ALGORITHM, // Irrelevant.
5010, // setup_size_init
5112, // predict_size_init
5213, // learn_size_init
53{}, // allowed_setup_ops, irrelevant.
54{}, // allowed_predict_ops, irrelevant.
55{}, // allowed_learn_ops, irrelevant.
56nullptr, // bit_gen, irrelevant.
57nullptr); // rand_gen, irrelevant.
58const InstructionIndexT setup_instruction_index = 2;
59const InstructionIndexT predict_instruction_index = 1;
60const InstructionIndexT learn_instruction_index = 3;
61Algorithm algorithm = generator.NoOp();
62EXPECT_EQ(algorithm.setup_[setup_instruction_index]->op_, NO_OP);
63EXPECT_EQ(algorithm.setup_[setup_instruction_index]->in1_, 0);
64EXPECT_EQ(algorithm.setup_[setup_instruction_index]->in2_, 0);
65EXPECT_EQ(algorithm.setup_[setup_instruction_index]->out_, 0);
66EXPECT_EQ(algorithm.setup_[setup_instruction_index]->GetActivationData(),
670.0);
68EXPECT_EQ(algorithm.setup_[setup_instruction_index]->GetFloatData0(), 0.0);
69EXPECT_EQ(algorithm.setup_[setup_instruction_index]->GetFloatData1(), 0.0);
70EXPECT_EQ(algorithm.setup_[setup_instruction_index]->GetFloatData2(), 0.0);
71EXPECT_EQ(algorithm.predict_[predict_instruction_index]->op_, NO_OP);
72EXPECT_EQ(algorithm.predict_[predict_instruction_index]->in1_, 0);
73EXPECT_EQ(algorithm.predict_[predict_instruction_index]->in2_, 0);
74EXPECT_EQ(algorithm.predict_[predict_instruction_index]->out_, 0);
75EXPECT_EQ(algorithm.predict_[predict_instruction_index]->GetActivationData(),
760.0);
77EXPECT_EQ(algorithm.predict_[predict_instruction_index]->GetFloatData0(),
780.0);
79EXPECT_EQ(algorithm.predict_[predict_instruction_index]->GetFloatData1(),
800.0);
81EXPECT_EQ(algorithm.predict_[predict_instruction_index]->GetFloatData2(),
820.0);
83EXPECT_EQ(algorithm.learn_[learn_instruction_index]->op_, NO_OP);
84EXPECT_EQ(algorithm.learn_[learn_instruction_index]->in1_, 0);
85EXPECT_EQ(algorithm.learn_[learn_instruction_index]->in2_, 0);
86EXPECT_EQ(algorithm.learn_[learn_instruction_index]->out_, 0);
87EXPECT_EQ(algorithm.learn_[learn_instruction_index]->GetActivationData(),
880.0);
89EXPECT_EQ(algorithm.learn_[learn_instruction_index]->GetFloatData0(), 0.0);
90EXPECT_EQ(algorithm.learn_[learn_instruction_index]->GetFloatData1(), 0.0);
91EXPECT_EQ(algorithm.learn_[learn_instruction_index]->GetFloatData2(), 0.0);
92}
93
94TEST(GeneratorTest, NoOpProducesCorrectComponentFunctionSize) {
95Generator generator(
96NO_OP_ALGORITHM, // Irrelevant.
9710, // setup_size_init
9812, // predict_size_init
9913, // learn_size_init
100{}, // allowed_setup_ops, irrelevant.
101{}, // allowed_predict_ops, irrelevant.
102{}, // allowed_learn_ops, irrelevant.
103nullptr, // bit_gen, irrelevant.
104nullptr); // rand_gen, irrelevant.
105Algorithm algorithm = generator.NoOp();
106EXPECT_EQ(algorithm.setup_.size(), 10);
107EXPECT_EQ(algorithm.predict_.size(), 12);
108EXPECT_EQ(algorithm.learn_.size(), 13);
109}
110
111TEST(GeneratorTest, Gz_Learns) {
112Generator generator(
113NO_OP_ALGORITHM, // Irrelevant.
11410, // setup_size_init, irrelevant
11512, // predict_size_init, irrelevant
11613, // learn_size_init, irrelevant
117{}, // allowed_setup_ops, irrelevant.
118{}, // allowed_predict_ops, irrelevant.
119{}, // allowed_learn_ops, irrelevant.
120nullptr, // bit_gen, irrelevant.
121nullptr); // rand_gen, irrelevant.
122Task<4> dataset =
123GenerateTask<4>(StrCat("scalar_linear_regression_task {} "
124"num_train_examples: ",
125kNumTrainExamples,
126" "
127"num_valid_examples: ",
128kNumValidExamples,
129" "
130"eval_type: RMS_ERROR "
131"param_seeds: 100 "
132"data_seeds: 1000 "));
133Algorithm algorithm = generator.LinearModel(kDefaultLearningRate);
134mt19937 bit_gen(10000);
135RandomGenerator rand_gen(&bit_gen);
136Executor<4> executor(algorithm, dataset, kNumTrainExamples, kNumValidExamples,
137&rand_gen, kLargeMaxAbsError);
138double fitness = executor.Execute();
139std::cout << "Gz_Learns fitness = " << fitness << std::endl;
140EXPECT_GE(fitness, 0.0);
141EXPECT_LE(fitness, 1.0);
142EXPECT_GT(fitness, 0.999);
143}
144
145TEST(GeneratorTest, LinearModel_Learns) {
146Generator generator(
147NO_OP_ALGORITHM, // Irrelevant.
14810, // setup_size_init, irrelevant
14912, // predict_size_init, irrelevant
15013, // learn_size_init, irrelevant
151{}, // allowed_setup_ops, irrelevant.
152{}, // allowed_predict_ops, irrelevant.
153{}, // allowed_learn_ops, irrelevant.
154nullptr, // bit_gen, irrelevant.
155nullptr); // rand_gen, irrelevant.
156Task<4> dataset =
157GenerateTask<4>(StrCat("scalar_linear_regression_task {} "
158"num_train_examples: ",
159kNumTrainExamples,
160" "
161"num_valid_examples: ",
162kNumValidExamples,
163" "
164"eval_type: RMS_ERROR "
165"param_seeds: 100 "
166"data_seeds: 1000 "));
167Algorithm algorithm = generator.LinearModel(kDefaultLearningRate);
168mt19937 bit_gen(10000);
169RandomGenerator rand_gen(&bit_gen);
170Executor<4> executor(algorithm, dataset, kNumTrainExamples, kNumValidExamples,
171&rand_gen, kLargeMaxAbsError);
172double fitness = executor.Execute();
173std::cout << "Gz_Learns fitness = " << fitness << std::endl;
174EXPECT_GE(fitness, 0.0);
175EXPECT_LE(fitness, 1.0);
176EXPECT_GT(fitness, 0.999);
177}
178
179TEST(GeneratorTest, GrTildeGrWithBias_PermanenceTest) {
180Generator generator(
181NO_OP_ALGORITHM, // Irrelevant.
1820, // setup_size_init, irrelevant.
1830, // predict_size_init, irrelevant.
1840, // learn_size_init, irrelevant.
185{}, // allowed_setup_ops, irrelevant.
186{}, // allowed_predict_ops, irrelevant.
187{}, // allowed_learn_ops, irrelevant.
188nullptr, // bit_gen, irrelevant.
189nullptr); // rand_gen, irrelevant.
190Task<4> dataset = GenerateTask<4>(StrCat(
191"scalar_2layer_nn_regression_task {} "
192"num_train_examples: ", kNumTrainExamples, " "
193"num_valid_examples: ", kNumValidExamples, " "
194"num_tasks: 1 "
195"eval_type: RMS_ERROR "
196"param_seeds: 1000 "
197"data_seeds: 10000 "));
198Algorithm algorithm = generator.NeuralNet(
199kDefaultLearningRate, kDefaultInitScale, kDefaultInitScale);
200mt19937 bit_gen(10000);
201RandomGenerator rand_gen(&bit_gen);
202Executor<4> executor(algorithm, dataset, kNumTrainExamples, kNumValidExamples,
203&rand_gen, kLargeMaxAbsError);
204double fitness = executor.Execute();
205std::cout << "GrTildeGrWithBias_PermanenceTest fitness = " << fitness
206<< std::endl;
207EXPECT_FLOAT_EQ(fitness, 0.80256736);
208}
209
210TEST(GeneratorTest, RandomInstructions) {
211mt19937 bit_gen;
212RandomGenerator rand_gen(&bit_gen);
213Generator generator(
214NO_OP_ALGORITHM, // Irrelevant.
2152, // setup_size_init
2164, // predict_size_init
2175, // learn_size_init
218{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
219{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
220{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
221&bit_gen, // bit_gen
222&rand_gen); // rand_gen
223const Algorithm no_op_algorithm = generator.NoOp();
224const IntegerT total_instructions = 2 + 4 + 5;
225EXPECT_TRUE(IsEventually(
226function<IntegerT(void)>([&](){
227Algorithm random_algorithm = generator.Random();
228return CountDifferentInstructions(random_algorithm, no_op_algorithm);
229}),
230Range<IntegerT>(0, total_instructions + 1), {total_instructions}));
231}
232
233TEST(GeneratorTest, RandomInstructionsProducesCorrectComponentFunctionSizes) {
234mt19937 bit_gen;
235RandomGenerator rand_gen(&bit_gen);
236Generator generator(
237NO_OP_ALGORITHM, // Irrelevant.
2382, // setup_size_init
2394, // predict_size_init
2405, // learn_size_init
241{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
242{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
243{NO_OP, SCALAR_SUM_OP, MATRIX_VECTOR_PRODUCT_OP, VECTOR_MEAN_OP},
244&bit_gen, // bit_gen
245&rand_gen); // rand_gen
246Algorithm algorithm = generator.Random();
247EXPECT_EQ(algorithm.setup_.size(), 2);
248EXPECT_EQ(algorithm.predict_.size(), 4);
249EXPECT_EQ(algorithm.learn_.size(), 5);
250}
251
252TEST(GeneratorTest, GzHasCorrectComponentFunctionSizes) {
253Generator generator(
254NO_OP_ALGORITHM, // Irrelevant.
2550, // setup_size_init, no padding.
2560, // predict_size_init, no padding.
2570, // learn_size_init, no padding.
258{}, // allowed_setup_ops, irrelevant.
259{}, // allowed_predict_ops, irrelevant.
260{}, // allowed_learn_ops, irrelevant.
261nullptr, // bit_gen, irrelevant.
262nullptr); // rand_gen, irrelevant.
263Algorithm algorithm = generator.LinearModel(kDefaultLearningRate);
264EXPECT_EQ(algorithm.setup_.size(), 1);
265EXPECT_EQ(algorithm.predict_.size(), 1);
266EXPECT_EQ(algorithm.learn_.size(), 4);
267}
268
269TEST(GeneratorTest, GzTildeGzHasCorrectComponentFunctionSizes) {
270Generator generator(
271NO_OP_ALGORITHM, // Irrelevant.
2720, // setup_size_init, no padding.
2730, // predict_size_init, no padding.
2740, // learn_size_init, no padding.
275{}, // allowed_setup_ops, irrelevant.
276{}, // allowed_predict_ops, irrelevant.
277{}, // allowed_learn_ops, irrelevant.
278nullptr, // bit_gen, irrelevant.
279nullptr); // rand_gen, irrelevant.
280Algorithm algorithm =
281generator.UnitTestNeuralNetNoBiasNoGradient(kDefaultLearningRate);
282EXPECT_EQ(algorithm.setup_.size(), 1);
283EXPECT_EQ(algorithm.predict_.size(), 3);
284EXPECT_EQ(algorithm.learn_.size(), 9);
285}
286
287TEST(GeneratorTest, GzTildeGzPadsComponentFunctionSizesCorrectly) {
288Generator generator(
289NO_OP_ALGORITHM, // Irrelevant.
29010, // setup_size_init
29112, // predict_size_init
29213, // learn_size_init
293{}, // allowed_setup_ops, irrelevant.
294{}, // allowed_predict_ops, irrelevant.
295{}, // allowed_learn_ops, irrelevant.
296nullptr, // bit_gen, irrelevant.
297nullptr); // rand_gen, irrelevant.
298Algorithm algorithm =
299generator.UnitTestNeuralNetNoBiasNoGradient(kDefaultLearningRate);
300EXPECT_EQ(algorithm.setup_.size(), 10);
301EXPECT_EQ(algorithm.predict_.size(), 12);
302EXPECT_EQ(algorithm.learn_.size(), 13);
303}
304
305TEST(GeneratorTest, GrTildeGrPadsComponentFunctionSizesCorrectly) {
306Generator generator(
307NO_OP_ALGORITHM, // Irrelevant.
30816, // setup_size_init
30918, // predict_size_init
31019, // learn_size_init
311{}, // allowed_setup_ops, irrelevant.
312{}, // allowed_predict_ops, irrelevant.
313{}, // allowed_learn_ops, irrelevant.
314nullptr, // bit_gen, irrelevant.
315nullptr); // rand_gen, irrelevant.
316Algorithm algorithm = generator.NeuralNet(
317kDefaultLearningRate, kDefaultInitScale, kDefaultInitScale);
318EXPECT_EQ(algorithm.setup_.size(), 16);
319EXPECT_EQ(algorithm.predict_.size(), 18);
320EXPECT_EQ(algorithm.learn_.size(), 19);
321}
322
323TEST(GeneratorTest, GzPadsComponentFunctionSizesCorrectly) {
324Generator generator(
325NO_OP_ALGORITHM, // Irrelevant.
32610, // setup_size_init
32712, // predict_size_init
32813, // learn_size_init
329{}, // allowed_setup_ops, irrelevant.
330{}, // allowed_predict_ops, irrelevant.
331{}, // allowed_learn_ops, irrelevant.
332nullptr, // bit_gen, irrelevant.
333nullptr); // rand_gen, irrelevant.
334Algorithm algorithm = generator.LinearModel(kDefaultLearningRate);
335EXPECT_EQ(algorithm.setup_.size(), 10);
336EXPECT_EQ(algorithm.predict_.size(), 12);
337EXPECT_EQ(algorithm.learn_.size(), 13);
338}
339
340} // namespace automl_zero
341