google-research
566 строк · 22.1 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "task_util.h"
16
17#include <cmath>
18#include <ostream>
19#include <string>
20#include <unordered_set>
21#include <vector>
22
23#include "gmock/gmock.h"
24#include "gtest/gtest.h"
25#include "absl/container/node_hash_set.h"
26#include "absl/strings/str_cat.h"
27#include "Eigen/Core"
28#include "algorithm.h"
29#include "definitions.h"
30#include "executor.h"
31#include "generator.h"
32#include "memory.h"
33#include "random_generator.h"
34#include "task.h"
35#include "task.pb.h"
36#include "test_util.h"
37
38namespace automl_zero {
39
40using ::absl::StrCat; // NOLINT
41using ::Eigen::Map;
42using ::std::abs; // NOLINT
43using ::std::function; // NOLINT
44using ::std::make_pair; // NOLINT
45using ::std::pair; // NOLINT
46using ::std::vector; // NOLINT
47using ::std::unique_ptr; // NOLINT
48using ::testing::Test;
49using test_only::GenerateTask;
50
51constexpr IntegerT kNumTrainExamples = 1000;
52constexpr IntegerT kNumValidExamples = 100;
53constexpr double kLargeMaxAbsError = 1000000000.0;
54constexpr IntegerT kNumAllTrainExamples = 1000;
55
56bool ScalarEq(const Scalar& scalar1, const Scalar scalar2) {
57return abs(scalar1 - scalar2) < kDataTolerance;
58}
59
60template <FeatureIndexT F>
61bool VectorEq(const Vector<F>& vector1,
62const ::std::vector<double>& vector2) {
63Map<const Vector<F>> vector2_eigen(vector2.data());
64return (vector1 - vector2_eigen).norm() < kDataTolerance;
65}
66
67// Scalar and Vector are trivially destructible.
68const Vector<4> kZeroVector = Vector<4>::Zero(4, 1);
69const Vector<4> kOnesVector = Vector<4>::Ones(4, 1);
70
71TEST(FillTasksTest, WorksCorrectly) {
72Task<4> expected_task_0 =
73GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
74"num_train_examples: ",
75kNumTrainExamples,
76" "
77"num_valid_examples: ",
78kNumValidExamples,
79" "
80"eval_type: RMS_ERROR "
81"param_seeds: 1000 "
82"data_seeds: 10000 "));
83Task<4> expected_task_1 =
84GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
85"num_train_examples: ",
86kNumTrainExamples,
87" "
88"num_valid_examples: ",
89kNumValidExamples,
90" "
91"eval_type: RMS_ERROR "
92"param_seeds: 1001 "
93"data_seeds: 10001 "));
94
95TaskCollection task_collection;
96TaskSpec* task = task_collection.add_tasks();
97task->set_features_size(4);
98task->set_num_train_examples(kNumTrainExamples);
99task->set_num_valid_examples(kNumValidExamples);
100task->set_num_tasks(2);
101task->set_eval_type(RMS_ERROR);
102task->add_data_seeds(10000);
103task->add_data_seeds(10001);
104task->add_param_seeds(1000);
105task->add_param_seeds(1001);
106task->mutable_scalar_2layer_nn_regression_task();
107vector<unique_ptr<TaskInterface>> owned_tasks;
108FillTasks(task_collection, &owned_tasks);
109vector<Task<4>*> tasks;
110for (const unique_ptr<TaskInterface>& owned_dataset : owned_tasks) {
111tasks.push_back(SafeDowncast<4>(owned_dataset.get()));
112}
113
114EXPECT_EQ(tasks[0]->MaxTrainExamples(), kNumTrainExamples);
115
116EXPECT_LT(
117(tasks[0]->train_features_[478] -
118expected_task_0.train_features_[478]).norm(),
119kDataTolerance);
120EXPECT_LT(
121(tasks[1]->train_features_[478] -
122expected_task_1.train_features_[478]).norm(),
123kDataTolerance);
124EXPECT_LT(
125(tasks[0]->valid_features_[94] -
126expected_task_0.valid_features_[94]).norm(),
127kDataTolerance);
128EXPECT_LT(
129(tasks[1]->valid_features_[94] -
130expected_task_1.valid_features_[94]).norm(),
131kDataTolerance);
132
133EXPECT_LT(
134abs(tasks[0]->train_labels_[478] -
135expected_task_0.train_labels_[478]),
136kDataTolerance);
137EXPECT_LT(
138abs(tasks[1]->train_labels_[478] -
139expected_task_1.train_labels_[478]),
140kDataTolerance);
141EXPECT_LT(
142abs(tasks[0]->valid_labels_[94] -
143expected_task_0.valid_labels_[94]),
144kDataTolerance);
145EXPECT_LT(
146abs(tasks[1]->valid_labels_[94] -
147expected_task_1.valid_labels_[94]),
148kDataTolerance);
149
150EXPECT_EQ(tasks[0]->index_, 0);
151EXPECT_EQ(tasks[1]->index_, 1);
152}
153
154TEST(FillTaskTest, FillsEvalType) {
155std::string task_spec_string =
156StrCat("scalar_linear_regression_task {} "
157"num_train_examples: 1000 "
158"num_valid_examples: 100 "
159"param_seeds: 1000 "
160"data_seeds: 1000 "
161"eval_type: RMS_ERROR");
162Task<4> dataset = GenerateTask<4>(task_spec_string);
163TaskSpec task_spec =
164ParseTextFormat<TaskSpec>(task_spec_string);
165EXPECT_EQ(dataset.eval_type_, task_spec.eval_type());
166}
167
168TEST(FillTaskWithZerosTest, WorksCorrectly) {
169auto dataset =
170GenerateTask<4>(StrCat("unit_test_zeros_task {} "
171"eval_type: ACCURACY "
172"num_train_examples: ",
173kNumTrainExamples,
174" "
175"num_valid_examples: ",
176kNumValidExamples,
177" "
178"num_tasks: 1 "
179"features_size: 4 "));
180for (const Scalar& label : dataset.train_labels_) {
181EXPECT_FLOAT_EQ(label, 0.0);
182}
183for (const Vector<4>& feature : dataset.valid_features_) {
184EXPECT_TRUE(feature.isApprox(kZeroVector));
185}
186for (const Scalar& label : dataset.valid_labels_) {
187EXPECT_FLOAT_EQ(label, 0.0);
188}
189}
190
191TEST(FillTaskWithOnesTest, WorksCorrectly) {
192auto dataset =
193GenerateTask<4>(StrCat("unit_test_ones_task {} "
194"eval_type: ACCURACY "
195"num_train_examples: ",
196kNumTrainExamples,
197" "
198"num_valid_examples: ",
199kNumValidExamples,
200" "
201"num_tasks: 1 "
202"features_size: 4 "));
203for (const Vector<4>& feature : dataset.train_features_) {
204EXPECT_TRUE(feature.isApprox(kOnesVector));
205}
206for (const Scalar& label : dataset.train_labels_) {
207EXPECT_FLOAT_EQ(label, 1.0);
208}
209for (const Vector<4>& feature : dataset.valid_features_) {
210EXPECT_TRUE(feature.isApprox(kOnesVector));
211}
212for (const Scalar& label : dataset.valid_labels_) {
213EXPECT_FLOAT_EQ(label, 1.0);
214}
215}
216
217TEST(FillTaskWithIncrementingIntegersTest, WorksCorrectly) {
218auto dataset =
219GenerateTask<4>(StrCat("unit_test_increment_task {} "
220"eval_type: ACCURACY "
221"num_train_examples: ",
222kNumTrainExamples,
223" "
224"num_valid_examples: ",
225kNumValidExamples,
226" "
227"num_tasks: 1 "
228"features_size: 4 "));
229
230EXPECT_TRUE(dataset.train_features_[0].isApprox(kZeroVector));
231EXPECT_TRUE(
232dataset.train_features_[kNumTrainExamples - 1].isApprox(
233kOnesVector * static_cast<double>(kNumTrainExamples - 1)));
234
235EXPECT_FLOAT_EQ(dataset.train_labels_[0], 0.0);
236EXPECT_FLOAT_EQ(
237dataset.train_labels_[kNumTrainExamples - 1],
238static_cast<double>(kNumTrainExamples - 1));
239
240EXPECT_TRUE(dataset.valid_features_[0].isApprox(kZeroVector));
241EXPECT_TRUE(
242dataset.valid_features_[kNumValidExamples - 1].isApprox(
243kOnesVector * static_cast<double>(kNumValidExamples - 1)));
244
245EXPECT_FLOAT_EQ(dataset.valid_labels_[0], 0.0);
246EXPECT_FLOAT_EQ(
247dataset.valid_labels_[kNumValidExamples - 1],
248static_cast<double>(kNumValidExamples - 1));
249}
250
251TEST(FillTaskWithNonlinearDataTest, DifferentForDifferentSeeds) {
252Task<4> dataset_1000_10000 =
253GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
254"num_train_examples: ",
255kNumTrainExamples,
256" "
257"num_valid_examples: ",
258kNumValidExamples,
259" "
260"eval_type: RMS_ERROR "
261"param_seeds: 1000 "
262"data_seeds: 10000 "));
263Task<4> dataset_1001_10000 =
264GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
265"num_train_examples: ",
266kNumTrainExamples,
267" "
268"num_valid_examples: ",
269kNumValidExamples,
270" "
271"eval_type: RMS_ERROR "
272"param_seeds: 1001 "
273"data_seeds: 10000 "));
274Task<4> dataset_1000_10001 =
275GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
276"num_train_examples: ",
277kNumTrainExamples,
278" "
279"num_valid_examples: ",
280kNumValidExamples,
281" "
282"eval_type: RMS_ERROR "
283"param_seeds: 1000 "
284"data_seeds: 10001 "));
285EXPECT_NE(dataset_1000_10000, dataset_1001_10000);
286EXPECT_NE(dataset_1000_10000, dataset_1000_10001);
287EXPECT_NE(dataset_1001_10000, dataset_1000_10001);
288}
289
290TEST(FillTaskWithNonlinearDataTest, SameForSameSeed) {
291Task<4> dataset_1000_10000_a =
292GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
293"num_train_examples: ",
294kNumTrainExamples,
295" "
296"num_valid_examples: ",
297kNumValidExamples,
298" "
299"eval_type: RMS_ERROR "
300"param_seeds: 1000 "
301"data_seeds: 10000 "));
302Task<4> dataset_1000_10000_b =
303GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
304"num_train_examples: ",
305kNumTrainExamples,
306" "
307"num_valid_examples: ",
308kNumValidExamples,
309" "
310"eval_type: RMS_ERROR "
311"param_seeds: 1000 "
312"data_seeds: 10000 "));
313EXPECT_EQ(dataset_1000_10000_a, dataset_1000_10000_b);
314}
315
316TEST(FillTaskWithNonlinearDataTest, PermanenceTest) {
317Task<4> dataset =
318GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
319"num_train_examples: ",
320kNumTrainExamples,
321" "
322"num_valid_examples: ",
323kNumValidExamples,
324" "
325"eval_type: RMS_ERROR "
326"param_seeds: 1000 "
327"data_seeds: 10000 "));
328EXPECT_TRUE(VectorEq<4>(
329dataset.train_features_[0],
330{1.30836, -0.192507, 0.549877, -0.667065}));
331EXPECT_TRUE(VectorEq<4>(
332dataset.train_features_[994],
333{-0.265714, 1.38325, 0.775253, 1.78923}));
334EXPECT_TRUE(VectorEq<4>(
335dataset.valid_features_[0],
336{1.39658, 0.293097, -0.504938, -1.09144}));
337EXPECT_TRUE(VectorEq<4>(
338dataset.valid_features_[94],
339{-0.224309, 1.78054, 1.24783, 0.54083}));
340EXPECT_TRUE(ScalarEq(dataset.train_labels_[0], 1.508635));
341EXPECT_TRUE(ScalarEq(dataset.train_labels_[994], -2.8410525));
342EXPECT_TRUE(ScalarEq(dataset.valid_labels_[0], 0.0));
343EXPECT_TRUE(ScalarEq(dataset.valid_labels_[98], -0.66133333));
344}
345
346void ClearSeeds(TaskCollection* task_collection) {
347for (TaskSpec& dataset : *task_collection->mutable_tasks()) {
348dataset.clear_param_seeds();
349dataset.clear_data_seeds();
350}
351}
352
353TEST(RandomizeTaskSeedsTest, FillsCorrectNumberOfRandomSeeds) {
354auto task_collection = ParseTextFormat<TaskCollection>(
355"tasks {num_tasks: 8} "
356"tasks {num_tasks: 3} ");
357
358RandomizeTaskSeeds(&task_collection, GenerateRandomSeed());
359EXPECT_EQ(task_collection.tasks_size(), 2);
360EXPECT_EQ(task_collection.tasks(0).param_seeds_size(), 8);
361EXPECT_EQ(task_collection.tasks(0).data_seeds_size(), 8);
362EXPECT_EQ(task_collection.tasks(1).param_seeds_size(), 3);
363EXPECT_EQ(task_collection.tasks(1).data_seeds_size(), 3);
364}
365
366TEST(RandomizeTaskSeedsTest, SameForSameSeed) {
367const RandomSeedT seed = GenerateRandomSeed();
368auto task_collection_1 = ParseTextFormat<TaskCollection>(
369"tasks {num_tasks: 8} "
370"tasks {num_tasks: 3} ");
371TaskCollection task_collection_2 = task_collection_1;
372RandomizeTaskSeeds(&task_collection_1, seed);
373RandomizeTaskSeeds(&task_collection_2, seed);
374EXPECT_EQ(task_collection_1.tasks(0).param_seeds(5),
375task_collection_2.tasks(0).param_seeds(5));
376EXPECT_EQ(task_collection_1.tasks(0).data_seeds(5),
377task_collection_2.tasks(0).data_seeds(5));
378EXPECT_EQ(task_collection_1.tasks(1).param_seeds(1),
379task_collection_2.tasks(1).param_seeds(1));
380EXPECT_EQ(task_collection_1.tasks(1).data_seeds(1),
381task_collection_2.tasks(1).data_seeds(1));
382}
383
384TEST(RandomizeTaskSeedsTest, DifferentForDifferentSeeds) {
385const RandomSeedT seed1 = 519801251;
386const RandomSeedT seed2 = 208594758;
387auto task_collection_1 = ParseTextFormat<TaskCollection>(
388"tasks {num_tasks: 8} "
389"tasks {num_tasks: 3} ");
390TaskCollection task_collection_2 = task_collection_1;
391RandomizeTaskSeeds(&task_collection_1, seed1);
392RandomizeTaskSeeds(&task_collection_2, seed2);
393EXPECT_NE(task_collection_1.tasks(0).param_seeds(5),
394task_collection_2.tasks(0).param_seeds(5));
395EXPECT_NE(task_collection_1.tasks(0).data_seeds(5),
396task_collection_2.tasks(0).data_seeds(5));
397EXPECT_NE(task_collection_1.tasks(1).param_seeds(1),
398task_collection_2.tasks(1).param_seeds(1));
399EXPECT_NE(task_collection_1.tasks(1).data_seeds(1),
400task_collection_2.tasks(1).data_seeds(1));
401}
402
403TEST(RandomizeTaskSeedsTest, CoversParamSeeds) {
404IntegerT num_tasks = 0;
405auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
406const RandomSeedT seed = GenerateRandomSeed();
407EXPECT_TRUE(IsEventually(
408function<RandomSeedT(void)>([&task_collection, &num_tasks, seed]() {
409// We need to keep increasing the number of tasks in order to
410// generate new seeds because the RandomizeTaskSeeds function is
411// deterministic.
412++num_tasks;
413task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
414
415ClearSeeds(&task_collection);
416RandomizeTaskSeeds(&task_collection, seed);
417const RandomSeedT param_seed =
418*task_collection.tasks(0).param_seeds().rbegin();
419return (param_seed % 5);
420}),
421Range<RandomSeedT>(0, 5), Range<RandomSeedT>(0, 5)));
422}
423
424TEST(RandomizeTaskSeedsTest, CoversDataSeeds) {
425IntegerT num_tasks = 0;
426auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
427const RandomSeedT seed = GenerateRandomSeed();
428EXPECT_TRUE(IsEventually(
429function<RandomSeedT(void)>([&task_collection, &num_tasks, seed]() {
430++num_tasks;
431task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
432ClearSeeds(&task_collection);
433
434// Return the last seed.
435RandomizeTaskSeeds(&task_collection, seed);
436const RandomSeedT data_seed =
437*task_collection.tasks(0).data_seeds().rbegin();
438return (data_seed % 5);
439}),
440Range<RandomSeedT>(0, 5), Range<RandomSeedT>(0, 5)));
441}
442
443TEST(RandomizeTaskSeedsTest, ParamAndDataSeedsAreIndependent) {
444IntegerT num_tasks = 0;
445auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
446const RandomSeedT seed = GenerateRandomSeed();
447EXPECT_TRUE(IsEventually(
448function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
449&num_tasks, seed]() {
450++num_tasks;
451task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
452ClearSeeds(&task_collection);
453RandomizeTaskSeeds(&task_collection, seed);
454
455// Return the last data seed and the last param seed.
456const RandomSeedT param_seed =
457*task_collection.tasks(0).param_seeds().rbegin();
458const RandomSeedT data_seed =
459*task_collection.tasks(0).data_seeds().rbegin();
460return (make_pair(param_seed % 3, data_seed % 3));
461}),
462CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
463CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
464}
465
466TEST(RandomizeTaskSeedsTest, ParamSeedsAreIndepdendentWithinTaskSpec) {
467IntegerT num_tasks = 1;
468auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
469const RandomSeedT seed = GenerateRandomSeed();
470EXPECT_TRUE(IsEventually(
471function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
472&num_tasks, seed]() {
473++num_tasks;
474task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
475ClearSeeds(&task_collection);
476RandomizeTaskSeeds(&task_collection, seed);
477
478// Return the last two seeds.
479auto param_seed_it =
480task_collection.tasks(0).param_seeds().rbegin();
481const RandomSeedT param_seed_1 = *param_seed_it;
482++param_seed_it;
483const RandomSeedT param_seed_2 = *param_seed_it;
484return (make_pair(param_seed_1 % 3, param_seed_2 % 3));
485}),
486CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
487CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
488}
489
490TEST(RandomizeTaskSeedsTest, DataSeedsAreIndepdendentWithinTaskSpec) {
491IntegerT num_tasks = 1;
492auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
493const RandomSeedT seed = GenerateRandomSeed();
494EXPECT_TRUE(IsEventually(
495function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
496&num_tasks, seed]() {
497++num_tasks;
498task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
499ClearSeeds(&task_collection);
500RandomizeTaskSeeds(&task_collection, seed);
501
502// Return the last two seeds.
503auto data_seed_it =
504task_collection.tasks(0).data_seeds().rbegin();
505const RandomSeedT data_seed_1 = *data_seed_it;
506++data_seed_it;
507const RandomSeedT data_seed_2 = *data_seed_it;
508return (make_pair(data_seed_1 % 3, data_seed_2 % 3));
509}),
510CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
511CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
512}
513
514TEST(RandomizeTaskSeedsTest, ParamSeedsAreIndepdendentAcrossTaskSpecs) {
515IntegerT num_tasks = 1;
516auto task_collection = ParseTextFormat<TaskCollection>(
517"tasks {} "
518"tasks {} ");
519const RandomSeedT seed = GenerateRandomSeed();
520EXPECT_TRUE(IsEventually(
521function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
522&num_tasks, seed]() {
523++num_tasks;
524task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
525task_collection.mutable_tasks(1)->set_num_tasks(num_tasks);
526ClearSeeds(&task_collection);
527RandomizeTaskSeeds(&task_collection, seed);
528
529// Return the last seed of each TaskSpec.
530const RandomSeedT param_seed_1 =
531*task_collection.tasks(0).param_seeds().rbegin();
532const RandomSeedT param_seed_2 =
533*task_collection.tasks(1).param_seeds().rbegin();
534return (make_pair(param_seed_1 % 3, param_seed_2 % 3));
535}),
536CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
537CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
538}
539
540TEST(RandomizeTaskSeedsTest, DataSeedsAreIndepdendentAcrossTaskSpecs) {
541IntegerT num_tasks = 1;
542auto task_collection = ParseTextFormat<TaskCollection>(
543"tasks {} "
544"tasks {} ");
545const RandomSeedT seed = GenerateRandomSeed();
546EXPECT_TRUE(IsEventually(
547function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
548&num_tasks, seed]() {
549++num_tasks;
550task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
551task_collection.mutable_tasks(1)->set_num_tasks(num_tasks);
552ClearSeeds(&task_collection);
553RandomizeTaskSeeds(&task_collection, seed);
554
555// Return the last seed of each TaskSpec.
556const RandomSeedT data_seed_1 =
557*task_collection.tasks(0).data_seeds().rbegin();
558const RandomSeedT data_seed_2 =
559*task_collection.tasks(1).data_seeds().rbegin();
560return (make_pair(data_seed_1 % 3, data_seed_2 % 3));
561}),
562CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
563CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
564}
565
566} // namespace automl_zero
567