google-research
159 строк · 5.2 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "task_util.h"16
17#include <algorithm>18#include <cstddef>19#include <iostream>20#include <memory>21#include <ostream>22#include <type_traits>23#include <utility>24#include <vector>25
26#include "algorithm.h"27#include "task.h"28#include "task.pb.h"29#include "definitions.h"30#include "executor.h"31#include "generator.h"32#include "memory.h"33#include "random_generator.h"34#include "google/protobuf/text_format.h"35#include "absl/flags/flag.h"36#include "absl/memory/memory.h"37
38namespace automl_zero {39
40using ::absl::make_unique; // NOLINT41using ::std::enable_if; // NOLINT42using ::std::endl; // NOLINT43using ::std::max; // NOLINT44using ::std::mt19937; // NOLINT45using ::std::is_same; // NOLINT46using ::std::unique_ptr; // NOLINT47using ::std::vector; // NOLINT48using ::std::pair; // NOLINT49using ::std::set; // NOLINT50
51// The values of the seeds below were chosen so that they span tasks of
52// varying difficulties (the difficulties are for the nonlinear tasks).
53vector<RandomSeedT> DefaultFirstParamSeeds() {54return {551001, // Easy.561012, // Medium (on easier side).571010, // Medium (on harder side).581000, // Hard.591006, // Easy.601008, // Medium (on easier side).611007, // Medium (on harder side).621003, // Hard.63};64}
65
66vector<RandomSeedT> DefaultFirstDataSeeds() {67return {11001, 11012, 11010, 11000, 11006, 11008, 11007, 11003};68}
69
70void FillTasksFromTaskSpec(71const TaskSpec& task_spec,72vector<unique_ptr<TaskInterface>>* return_tasks) {73const IntegerT num_tasks = task_spec.num_tasks();74CHECK_GT(num_tasks, 0);75vector<RandomSeedT> first_param_seeds =76task_spec.param_seeds_size() == 077? DefaultFirstParamSeeds()78: vector<RandomSeedT>(task_spec.param_seeds().begin(),79task_spec.param_seeds().end());80vector<RandomSeedT> first_data_seeds =81task_spec.data_seeds_size() == 082? DefaultFirstDataSeeds()83: vector<RandomSeedT>(task_spec.data_seeds().begin(),84task_spec.data_seeds().end());85CHECK(!first_param_seeds.empty());86CHECK(!first_data_seeds.empty());87
88RandomSeedT param_seed;89RandomSeedT data_seed;90for (IntegerT i = 0; i < num_tasks; ++i) {91param_seed =92i < first_param_seeds.size() ? first_param_seeds[i] : param_seed + 1;93data_seed =94i < first_data_seeds.size() ? first_data_seeds[i] : data_seed + 1;95
96const IntegerT task_index = return_tasks->size();97switch (task_spec.features_size()) {98case 2:99return_tasks->push_back(CreateTask<2>(task_index, param_seed,100data_seed, task_spec));101break;102case 4:103return_tasks->push_back(CreateTask<4>(task_index, param_seed,104data_seed, task_spec));105break;106case 8:107return_tasks->push_back(CreateTask<8>(task_index, param_seed,108data_seed, task_spec));109break;110case 16:111return_tasks->push_back(CreateTask<16>(task_index, param_seed,112data_seed, task_spec));113break;114case 32:115return_tasks->push_back(CreateTask<32>(task_index, param_seed,116data_seed, task_spec));117break;118default:119LOG(FATAL) << "Unsupported features size: "120<< task_spec.features_size() << std::endl;121}122}123}
124
125void FillTasks(126const TaskCollection& task_collection,127vector<unique_ptr<TaskInterface>>* return_tasks) {128// Check return targets are empty.129CHECK(return_tasks->empty());130for (const TaskSpec& task_spec : task_collection.tasks()) {131FillTasksFromTaskSpec(task_spec, return_tasks);132}133}
134
135void RandomizeTaskSeeds(TaskCollection* task_collection,136const RandomSeedT seed) {137RandomSeedT base_param_seed =138HashMix(static_cast<RandomSeedT>(85652777), seed);139mt19937 param_seed_bit_gen(base_param_seed);140RandomGenerator param_seed_gen(¶m_seed_bit_gen);141
142RandomSeedT base_data_seed =143HashMix(static_cast<RandomSeedT>(38272328), seed);144mt19937 data_seed_bit_gen(base_data_seed);145RandomGenerator data_seed_gen(&data_seed_bit_gen);146
147for (TaskSpec& task : *task_collection->mutable_tasks()) {148task.clear_param_seeds();149task.clear_data_seeds();150for (IntegerT i = 0; i < task.num_tasks(); i++) {151task.add_param_seeds(param_seed_gen.UniformRandomSeed());152}153for (IntegerT i = 0; i < task.num_tasks(); i++) {154task.add_data_seeds(data_seed_gen.UniformRandomSeed());155}156}157}
158
159} // namespace automl_zero160