google-research

Форк
0
/
task_util_test.cc 
566 строк · 22.1 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "task_util.h"
16

17
#include <cmath>
18
#include <ostream>
19
#include <string>
20
#include <unordered_set>
21
#include <vector>
22

23
#include "gmock/gmock.h"
24
#include "gtest/gtest.h"
25
#include "absl/container/node_hash_set.h"
26
#include "absl/strings/str_cat.h"
27
#include "Eigen/Core"
28
#include "algorithm.h"
29
#include "definitions.h"
30
#include "executor.h"
31
#include "generator.h"
32
#include "memory.h"
33
#include "random_generator.h"
34
#include "task.h"
35
#include "task.pb.h"
36
#include "test_util.h"
37

38
namespace automl_zero {
39

40
using ::absl::StrCat;  // NOLINT
41
using ::Eigen::Map;
42
using ::std::abs;  // NOLINT
43
using ::std::function;  // NOLINT
44
using ::std::make_pair;  // NOLINT
45
using ::std::pair;  // NOLINT
46
using ::std::vector;  // NOLINT
47
using ::std::unique_ptr;  // NOLINT
48
using ::testing::Test;
49
using test_only::GenerateTask;
50

51
constexpr IntegerT kNumTrainExamples = 1000;
52
constexpr IntegerT kNumValidExamples = 100;
53
constexpr double kLargeMaxAbsError = 1000000000.0;
54
constexpr IntegerT kNumAllTrainExamples = 1000;
55

56
bool ScalarEq(const Scalar& scalar1, const Scalar scalar2) {
57
  return abs(scalar1 - scalar2) < kDataTolerance;
58
}
59

60
template <FeatureIndexT F>
61
bool VectorEq(const Vector<F>& vector1,
62
              const ::std::vector<double>& vector2) {
63
  Map<const Vector<F>> vector2_eigen(vector2.data());
64
  return (vector1 - vector2_eigen).norm() < kDataTolerance;
65
}
66

67
// Scalar and Vector are trivially destructible.
68
const Vector<4> kZeroVector = Vector<4>::Zero(4, 1);
69
const Vector<4> kOnesVector = Vector<4>::Ones(4, 1);
70

71
TEST(FillTasksTest, WorksCorrectly) {
72
  Task<4> expected_task_0 =
73
      GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
74
                                "num_train_examples: ",
75
                                kNumTrainExamples,
76
                                " "
77
                                "num_valid_examples: ",
78
                                kNumValidExamples,
79
                                " "
80
                                "eval_type: RMS_ERROR "
81
                                "param_seeds: 1000 "
82
                                "data_seeds: 10000 "));
83
  Task<4> expected_task_1 =
84
      GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
85
                                "num_train_examples: ",
86
                                kNumTrainExamples,
87
                                " "
88
                                "num_valid_examples: ",
89
                                kNumValidExamples,
90
                                " "
91
                                "eval_type: RMS_ERROR "
92
                                "param_seeds: 1001 "
93
                                "data_seeds: 10001 "));
94

95
  TaskCollection task_collection;
96
  TaskSpec* task = task_collection.add_tasks();
97
  task->set_features_size(4);
98
  task->set_num_train_examples(kNumTrainExamples);
99
  task->set_num_valid_examples(kNumValidExamples);
100
  task->set_num_tasks(2);
101
  task->set_eval_type(RMS_ERROR);
102
  task->add_data_seeds(10000);
103
  task->add_data_seeds(10001);
104
  task->add_param_seeds(1000);
105
  task->add_param_seeds(1001);
106
  task->mutable_scalar_2layer_nn_regression_task();
107
  vector<unique_ptr<TaskInterface>> owned_tasks;
108
  FillTasks(task_collection, &owned_tasks);
109
  vector<Task<4>*> tasks;
110
  for (const unique_ptr<TaskInterface>& owned_dataset : owned_tasks) {
111
    tasks.push_back(SafeDowncast<4>(owned_dataset.get()));
112
  }
113

114
  EXPECT_EQ(tasks[0]->MaxTrainExamples(), kNumTrainExamples);
115

116
  EXPECT_LT(
117
      (tasks[0]->train_features_[478] -
118
       expected_task_0.train_features_[478]).norm(),
119
      kDataTolerance);
120
  EXPECT_LT(
121
      (tasks[1]->train_features_[478] -
122
       expected_task_1.train_features_[478]).norm(),
123
      kDataTolerance);
124
  EXPECT_LT(
125
      (tasks[0]->valid_features_[94] -
126
       expected_task_0.valid_features_[94]).norm(),
127
      kDataTolerance);
128
  EXPECT_LT(
129
      (tasks[1]->valid_features_[94] -
130
       expected_task_1.valid_features_[94]).norm(),
131
      kDataTolerance);
132

133
  EXPECT_LT(
134
      abs(tasks[0]->train_labels_[478] -
135
          expected_task_0.train_labels_[478]),
136
      kDataTolerance);
137
  EXPECT_LT(
138
      abs(tasks[1]->train_labels_[478] -
139
          expected_task_1.train_labels_[478]),
140
      kDataTolerance);
141
  EXPECT_LT(
142
      abs(tasks[0]->valid_labels_[94] -
143
          expected_task_0.valid_labels_[94]),
144
      kDataTolerance);
145
  EXPECT_LT(
146
      abs(tasks[1]->valid_labels_[94] -
147
          expected_task_1.valid_labels_[94]),
148
      kDataTolerance);
149

150
  EXPECT_EQ(tasks[0]->index_, 0);
151
  EXPECT_EQ(tasks[1]->index_, 1);
152
}
153

154
TEST(FillTaskTest, FillsEvalType) {
155
  std::string task_spec_string =
156
      StrCat("scalar_linear_regression_task {} "
157
             "num_train_examples: 1000 "
158
             "num_valid_examples: 100 "
159
             "param_seeds: 1000 "
160
             "data_seeds: 1000 "
161
             "eval_type: RMS_ERROR");
162
  Task<4> dataset = GenerateTask<4>(task_spec_string);
163
  TaskSpec task_spec =
164
      ParseTextFormat<TaskSpec>(task_spec_string);
165
  EXPECT_EQ(dataset.eval_type_, task_spec.eval_type());
166
}
167

168
TEST(FillTaskWithZerosTest, WorksCorrectly) {
169
  auto dataset =
170
      GenerateTask<4>(StrCat("unit_test_zeros_task {} "
171
                                "eval_type: ACCURACY "
172
                                "num_train_examples: ",
173
                                kNumTrainExamples,
174
                                " "
175
                                "num_valid_examples: ",
176
                                kNumValidExamples,
177
                                " "
178
                                "num_tasks: 1 "
179
                                "features_size: 4 "));
180
  for (const Scalar& label : dataset.train_labels_) {
181
    EXPECT_FLOAT_EQ(label, 0.0);
182
  }
183
  for (const Vector<4>& feature : dataset.valid_features_) {
184
    EXPECT_TRUE(feature.isApprox(kZeroVector));
185
  }
186
  for (const Scalar& label : dataset.valid_labels_) {
187
    EXPECT_FLOAT_EQ(label, 0.0);
188
  }
189
}
190

191
TEST(FillTaskWithOnesTest, WorksCorrectly) {
192
  auto dataset =
193
      GenerateTask<4>(StrCat("unit_test_ones_task {} "
194
                                "eval_type: ACCURACY "
195
                                "num_train_examples: ",
196
                                kNumTrainExamples,
197
                                " "
198
                                "num_valid_examples: ",
199
                                kNumValidExamples,
200
                                " "
201
                                "num_tasks: 1 "
202
                                "features_size: 4 "));
203
  for (const Vector<4>& feature : dataset.train_features_) {
204
    EXPECT_TRUE(feature.isApprox(kOnesVector));
205
  }
206
  for (const Scalar& label : dataset.train_labels_) {
207
    EXPECT_FLOAT_EQ(label, 1.0);
208
  }
209
  for (const Vector<4>& feature : dataset.valid_features_) {
210
    EXPECT_TRUE(feature.isApprox(kOnesVector));
211
  }
212
  for (const Scalar& label : dataset.valid_labels_) {
213
    EXPECT_FLOAT_EQ(label, 1.0);
214
  }
215
}
216

217
TEST(FillTaskWithIncrementingIntegersTest, WorksCorrectly) {
218
  auto dataset =
219
      GenerateTask<4>(StrCat("unit_test_increment_task {} "
220
                                "eval_type: ACCURACY "
221
                                "num_train_examples: ",
222
                                kNumTrainExamples,
223
                                " "
224
                                "num_valid_examples: ",
225
                                kNumValidExamples,
226
                                " "
227
                                "num_tasks: 1 "
228
                                "features_size: 4 "));
229

230
  EXPECT_TRUE(dataset.train_features_[0].isApprox(kZeroVector));
231
  EXPECT_TRUE(
232
      dataset.train_features_[kNumTrainExamples - 1].isApprox(
233
          kOnesVector * static_cast<double>(kNumTrainExamples - 1)));
234

235
  EXPECT_FLOAT_EQ(dataset.train_labels_[0], 0.0);
236
  EXPECT_FLOAT_EQ(
237
      dataset.train_labels_[kNumTrainExamples - 1],
238
      static_cast<double>(kNumTrainExamples - 1));
239

240
  EXPECT_TRUE(dataset.valid_features_[0].isApprox(kZeroVector));
241
  EXPECT_TRUE(
242
      dataset.valid_features_[kNumValidExamples - 1].isApprox(
243
          kOnesVector * static_cast<double>(kNumValidExamples - 1)));
244

245
  EXPECT_FLOAT_EQ(dataset.valid_labels_[0], 0.0);
246
  EXPECT_FLOAT_EQ(
247
      dataset.valid_labels_[kNumValidExamples - 1],
248
      static_cast<double>(kNumValidExamples - 1));
249
}
250

251
TEST(FillTaskWithNonlinearDataTest, DifferentForDifferentSeeds) {
252
  Task<4> dataset_1000_10000 =
253
      GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
254
                                "num_train_examples: ",
255
                                kNumTrainExamples,
256
                                " "
257
                                "num_valid_examples: ",
258
                                kNumValidExamples,
259
                                " "
260
                                "eval_type: RMS_ERROR "
261
                                "param_seeds: 1000 "
262
                                "data_seeds: 10000 "));
263
  Task<4> dataset_1001_10000 =
264
      GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
265
                                "num_train_examples: ",
266
                                kNumTrainExamples,
267
                                " "
268
                                "num_valid_examples: ",
269
                                kNumValidExamples,
270
                                " "
271
                                "eval_type: RMS_ERROR "
272
                                "param_seeds: 1001 "
273
                                "data_seeds: 10000 "));
274
  Task<4> dataset_1000_10001 =
275
      GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
276
                                "num_train_examples: ",
277
                                kNumTrainExamples,
278
                                " "
279
                                "num_valid_examples: ",
280
                                kNumValidExamples,
281
                                " "
282
                                "eval_type: RMS_ERROR "
283
                                "param_seeds: 1000 "
284
                                "data_seeds: 10001 "));
285
  EXPECT_NE(dataset_1000_10000, dataset_1001_10000);
286
  EXPECT_NE(dataset_1000_10000, dataset_1000_10001);
287
  EXPECT_NE(dataset_1001_10000, dataset_1000_10001);
288
}
289

290
TEST(FillTaskWithNonlinearDataTest, SameForSameSeed) {
291
  Task<4> dataset_1000_10000_a =
292
      GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
293
                                "num_train_examples: ",
294
                                kNumTrainExamples,
295
                                " "
296
                                "num_valid_examples: ",
297
                                kNumValidExamples,
298
                                " "
299
                                "eval_type: RMS_ERROR "
300
                                "param_seeds: 1000 "
301
                                "data_seeds: 10000 "));
302
  Task<4> dataset_1000_10000_b =
303
      GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
304
                                "num_train_examples: ",
305
                                kNumTrainExamples,
306
                                " "
307
                                "num_valid_examples: ",
308
                                kNumValidExamples,
309
                                " "
310
                                "eval_type: RMS_ERROR "
311
                                "param_seeds: 1000 "
312
                                "data_seeds: 10000 "));
313
  EXPECT_EQ(dataset_1000_10000_a, dataset_1000_10000_b);
314
}
315

316
TEST(FillTaskWithNonlinearDataTest, PermanenceTest) {
317
  Task<4> dataset =
318
      GenerateTask<4>(StrCat("scalar_2layer_nn_regression_task {} "
319
                                "num_train_examples: ",
320
                                kNumTrainExamples,
321
                                " "
322
                                "num_valid_examples: ",
323
                                kNumValidExamples,
324
                                " "
325
                                "eval_type: RMS_ERROR "
326
                                "param_seeds: 1000 "
327
                                "data_seeds: 10000 "));
328
  EXPECT_TRUE(VectorEq<4>(
329
      dataset.train_features_[0],
330
      {1.30836, -0.192507, 0.549877, -0.667065}));
331
  EXPECT_TRUE(VectorEq<4>(
332
      dataset.train_features_[994],
333
      {-0.265714, 1.38325, 0.775253, 1.78923}));
334
  EXPECT_TRUE(VectorEq<4>(
335
      dataset.valid_features_[0],
336
      {1.39658, 0.293097, -0.504938, -1.09144}));
337
  EXPECT_TRUE(VectorEq<4>(
338
      dataset.valid_features_[94],
339
      {-0.224309, 1.78054, 1.24783, 0.54083}));
340
  EXPECT_TRUE(ScalarEq(dataset.train_labels_[0], 1.508635));
341
  EXPECT_TRUE(ScalarEq(dataset.train_labels_[994], -2.8410525));
342
  EXPECT_TRUE(ScalarEq(dataset.valid_labels_[0], 0.0));
343
  EXPECT_TRUE(ScalarEq(dataset.valid_labels_[98], -0.66133333));
344
}
345

346
void ClearSeeds(TaskCollection* task_collection) {
347
  for (TaskSpec& dataset : *task_collection->mutable_tasks()) {
348
    dataset.clear_param_seeds();
349
    dataset.clear_data_seeds();
350
  }
351
}
352

353
TEST(RandomizeTaskSeedsTest, FillsCorrectNumberOfRandomSeeds) {
354
  auto task_collection = ParseTextFormat<TaskCollection>(
355
      "tasks {num_tasks: 8} "
356
      "tasks {num_tasks: 3} ");
357

358
  RandomizeTaskSeeds(&task_collection, GenerateRandomSeed());
359
  EXPECT_EQ(task_collection.tasks_size(), 2);
360
  EXPECT_EQ(task_collection.tasks(0).param_seeds_size(), 8);
361
  EXPECT_EQ(task_collection.tasks(0).data_seeds_size(), 8);
362
  EXPECT_EQ(task_collection.tasks(1).param_seeds_size(), 3);
363
  EXPECT_EQ(task_collection.tasks(1).data_seeds_size(), 3);
364
}
365

366
TEST(RandomizeTaskSeedsTest, SameForSameSeed) {
367
  const RandomSeedT seed = GenerateRandomSeed();
368
  auto task_collection_1 = ParseTextFormat<TaskCollection>(
369
      "tasks {num_tasks: 8} "
370
      "tasks {num_tasks: 3} ");
371
  TaskCollection task_collection_2 = task_collection_1;
372
  RandomizeTaskSeeds(&task_collection_1, seed);
373
  RandomizeTaskSeeds(&task_collection_2, seed);
374
  EXPECT_EQ(task_collection_1.tasks(0).param_seeds(5),
375
            task_collection_2.tasks(0).param_seeds(5));
376
  EXPECT_EQ(task_collection_1.tasks(0).data_seeds(5),
377
            task_collection_2.tasks(0).data_seeds(5));
378
  EXPECT_EQ(task_collection_1.tasks(1).param_seeds(1),
379
            task_collection_2.tasks(1).param_seeds(1));
380
  EXPECT_EQ(task_collection_1.tasks(1).data_seeds(1),
381
            task_collection_2.tasks(1).data_seeds(1));
382
}
383

384
TEST(RandomizeTaskSeedsTest, DifferentForDifferentSeeds) {
385
  const RandomSeedT seed1 = 519801251;
386
  const RandomSeedT seed2 = 208594758;
387
  auto task_collection_1 = ParseTextFormat<TaskCollection>(
388
      "tasks {num_tasks: 8} "
389
      "tasks {num_tasks: 3} ");
390
  TaskCollection task_collection_2 = task_collection_1;
391
  RandomizeTaskSeeds(&task_collection_1, seed1);
392
  RandomizeTaskSeeds(&task_collection_2, seed2);
393
  EXPECT_NE(task_collection_1.tasks(0).param_seeds(5),
394
            task_collection_2.tasks(0).param_seeds(5));
395
  EXPECT_NE(task_collection_1.tasks(0).data_seeds(5),
396
            task_collection_2.tasks(0).data_seeds(5));
397
  EXPECT_NE(task_collection_1.tasks(1).param_seeds(1),
398
            task_collection_2.tasks(1).param_seeds(1));
399
  EXPECT_NE(task_collection_1.tasks(1).data_seeds(1),
400
            task_collection_2.tasks(1).data_seeds(1));
401
}
402

403
TEST(RandomizeTaskSeedsTest, CoversParamSeeds) {
404
  IntegerT num_tasks = 0;
405
  auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
406
  const RandomSeedT seed = GenerateRandomSeed();
407
  EXPECT_TRUE(IsEventually(
408
      function<RandomSeedT(void)>([&task_collection, &num_tasks, seed]() {
409
        // We need to keep increasing the number of tasks in order to
410
        // generate new seeds because the RandomizeTaskSeeds function is
411
        // deterministic.
412
        ++num_tasks;
413
        task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
414

415
        ClearSeeds(&task_collection);
416
        RandomizeTaskSeeds(&task_collection, seed);
417
        const RandomSeedT param_seed =
418
            *task_collection.tasks(0).param_seeds().rbegin();
419
        return (param_seed % 5);
420
      }),
421
      Range<RandomSeedT>(0, 5), Range<RandomSeedT>(0, 5)));
422
}
423

424
TEST(RandomizeTaskSeedsTest, CoversDataSeeds) {
425
  IntegerT num_tasks = 0;
426
  auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
427
  const RandomSeedT seed = GenerateRandomSeed();
428
  EXPECT_TRUE(IsEventually(
429
      function<RandomSeedT(void)>([&task_collection, &num_tasks, seed]() {
430
        ++num_tasks;
431
        task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
432
        ClearSeeds(&task_collection);
433

434
        // Return the last seed.
435
        RandomizeTaskSeeds(&task_collection, seed);
436
        const RandomSeedT data_seed =
437
            *task_collection.tasks(0).data_seeds().rbegin();
438
        return (data_seed % 5);
439
      }),
440
      Range<RandomSeedT>(0, 5), Range<RandomSeedT>(0, 5)));
441
}
442

443
TEST(RandomizeTaskSeedsTest, ParamAndDataSeedsAreIndependent) {
444
  IntegerT num_tasks = 0;
445
  auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
446
  const RandomSeedT seed = GenerateRandomSeed();
447
  EXPECT_TRUE(IsEventually(
448
      function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
449
                                                      &num_tasks, seed]() {
450
        ++num_tasks;
451
        task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
452
        ClearSeeds(&task_collection);
453
        RandomizeTaskSeeds(&task_collection, seed);
454

455
        // Return the last data seed and the last param seed.
456
        const RandomSeedT param_seed =
457
            *task_collection.tasks(0).param_seeds().rbegin();
458
        const RandomSeedT data_seed =
459
            *task_collection.tasks(0).data_seeds().rbegin();
460
        return (make_pair(param_seed % 3, data_seed % 3));
461
      }),
462
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
463
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
464
}
465

466
TEST(RandomizeTaskSeedsTest, ParamSeedsAreIndepdendentWithinTaskSpec) {
467
  IntegerT num_tasks = 1;
468
  auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
469
  const RandomSeedT seed = GenerateRandomSeed();
470
  EXPECT_TRUE(IsEventually(
471
      function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
472
                                                      &num_tasks, seed]() {
473
        ++num_tasks;
474
        task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
475
        ClearSeeds(&task_collection);
476
        RandomizeTaskSeeds(&task_collection, seed);
477

478
        // Return the last two seeds.
479
        auto param_seed_it =
480
            task_collection.tasks(0).param_seeds().rbegin();
481
        const RandomSeedT param_seed_1 = *param_seed_it;
482
        ++param_seed_it;
483
        const RandomSeedT param_seed_2 = *param_seed_it;
484
        return (make_pair(param_seed_1 % 3, param_seed_2 % 3));
485
      }),
486
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
487
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
488
}
489

490
TEST(RandomizeTaskSeedsTest, DataSeedsAreIndepdendentWithinTaskSpec) {
491
  IntegerT num_tasks = 1;
492
  auto task_collection = ParseTextFormat<TaskCollection>("tasks {} ");
493
  const RandomSeedT seed = GenerateRandomSeed();
494
  EXPECT_TRUE(IsEventually(
495
      function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
496
                                                      &num_tasks, seed]() {
497
        ++num_tasks;
498
        task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
499
        ClearSeeds(&task_collection);
500
        RandomizeTaskSeeds(&task_collection, seed);
501

502
        // Return the last two seeds.
503
        auto data_seed_it =
504
            task_collection.tasks(0).data_seeds().rbegin();
505
        const RandomSeedT data_seed_1 = *data_seed_it;
506
        ++data_seed_it;
507
        const RandomSeedT data_seed_2 = *data_seed_it;
508
        return (make_pair(data_seed_1 % 3, data_seed_2 % 3));
509
      }),
510
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
511
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
512
}
513

514
TEST(RandomizeTaskSeedsTest, ParamSeedsAreIndepdendentAcrossTaskSpecs) {
515
  IntegerT num_tasks = 1;
516
  auto task_collection = ParseTextFormat<TaskCollection>(
517
      "tasks {} "
518
      "tasks {} ");
519
  const RandomSeedT seed = GenerateRandomSeed();
520
  EXPECT_TRUE(IsEventually(
521
      function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
522
                                                      &num_tasks, seed]() {
523
        ++num_tasks;
524
        task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
525
        task_collection.mutable_tasks(1)->set_num_tasks(num_tasks);
526
        ClearSeeds(&task_collection);
527
        RandomizeTaskSeeds(&task_collection, seed);
528

529
        // Return the last seed of each TaskSpec.
530
        const RandomSeedT param_seed_1 =
531
            *task_collection.tasks(0).param_seeds().rbegin();
532
        const RandomSeedT param_seed_2 =
533
            *task_collection.tasks(1).param_seeds().rbegin();
534
        return (make_pair(param_seed_1 % 3, param_seed_2 % 3));
535
      }),
536
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
537
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
538
}
539

540
TEST(RandomizeTaskSeedsTest, DataSeedsAreIndepdendentAcrossTaskSpecs) {
541
  IntegerT num_tasks = 1;
542
  auto task_collection = ParseTextFormat<TaskCollection>(
543
      "tasks {} "
544
      "tasks {} ");
545
  const RandomSeedT seed = GenerateRandomSeed();
546
  EXPECT_TRUE(IsEventually(
547
      function<pair<RandomSeedT, RandomSeedT>(void)>([&task_collection,
548
                                                      &num_tasks, seed]() {
549
        ++num_tasks;
550
        task_collection.mutable_tasks(0)->set_num_tasks(num_tasks);
551
        task_collection.mutable_tasks(1)->set_num_tasks(num_tasks);
552
        ClearSeeds(&task_collection);
553
        RandomizeTaskSeeds(&task_collection, seed);
554

555
        // Return the last seed of each TaskSpec.
556
        const RandomSeedT data_seed_1 =
557
            *task_collection.tasks(0).data_seeds().rbegin();
558
        const RandomSeedT data_seed_2 =
559
            *task_collection.tasks(1).data_seeds().rbegin();
560
        return (make_pair(data_seed_1 % 3, data_seed_2 % 3));
561
      }),
562
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3)),
563
      CartesianProduct(Range<RandomSeedT>(0, 3), Range<RandomSeedT>(0, 3))));
564
}
565

566
}  // namespace automl_zero
567

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.