google-research

Форк
0
159 строк · 5.2 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "task_util.h"
16

17
#include <algorithm>
18
#include <cstddef>
19
#include <iostream>
20
#include <memory>
21
#include <ostream>
22
#include <type_traits>
23
#include <utility>
24
#include <vector>
25

26
#include "algorithm.h"
27
#include "task.h"
28
#include "task.pb.h"
29
#include "definitions.h"
30
#include "executor.h"
31
#include "generator.h"
32
#include "memory.h"
33
#include "random_generator.h"
34
#include "google/protobuf/text_format.h"
35
#include "absl/flags/flag.h"
36
#include "absl/memory/memory.h"
37

38
namespace automl_zero {
39

40
using ::absl::make_unique;  // NOLINT
41
using ::std::enable_if;  // NOLINT
42
using ::std::endl;  // NOLINT
43
using ::std::max;  // NOLINT
44
using ::std::mt19937;  // NOLINT
45
using ::std::is_same;  // NOLINT
46
using ::std::unique_ptr;  // NOLINT
47
using ::std::vector;  // NOLINT
48
using ::std::pair;  // NOLINT
49
using ::std::set;  // NOLINT
50

51
// The values of the seeds below were chosen so that they span tasks of
52
// varying difficulties (the difficulties are for the nonlinear tasks).
53
vector<RandomSeedT> DefaultFirstParamSeeds() {
54
  return {
55
      1001,  // Easy.
56
      1012,  // Medium (on easier side).
57
      1010,  // Medium (on harder side).
58
      1000,  // Hard.
59
      1006,  // Easy.
60
      1008,  // Medium (on easier side).
61
      1007,  // Medium (on harder side).
62
      1003,  // Hard.
63
  };
64
}
65

66
vector<RandomSeedT> DefaultFirstDataSeeds() {
67
  return {11001, 11012, 11010, 11000, 11006, 11008, 11007, 11003};
68
}
69

70
void FillTasksFromTaskSpec(
71
    const TaskSpec& task_spec,
72
    vector<unique_ptr<TaskInterface>>* return_tasks) {
73
  const IntegerT num_tasks = task_spec.num_tasks();
74
  CHECK_GT(num_tasks, 0);
75
  vector<RandomSeedT> first_param_seeds =
76
      task_spec.param_seeds_size() == 0
77
          ? DefaultFirstParamSeeds()
78
          : vector<RandomSeedT>(task_spec.param_seeds().begin(),
79
                                task_spec.param_seeds().end());
80
  vector<RandomSeedT> first_data_seeds =
81
      task_spec.data_seeds_size() == 0
82
          ? DefaultFirstDataSeeds()
83
          : vector<RandomSeedT>(task_spec.data_seeds().begin(),
84
                                task_spec.data_seeds().end());
85
  CHECK(!first_param_seeds.empty());
86
  CHECK(!first_data_seeds.empty());
87

88
  RandomSeedT param_seed;
89
  RandomSeedT data_seed;
90
  for (IntegerT i = 0; i < num_tasks; ++i) {
91
    param_seed =
92
        i < first_param_seeds.size() ? first_param_seeds[i] : param_seed + 1;
93
    data_seed =
94
        i < first_data_seeds.size() ? first_data_seeds[i] : data_seed + 1;
95

96
    const IntegerT task_index = return_tasks->size();
97
    switch (task_spec.features_size()) {
98
      case 2:
99
        return_tasks->push_back(CreateTask<2>(task_index, param_seed,
100
                                                    data_seed, task_spec));
101
        break;
102
      case 4:
103
        return_tasks->push_back(CreateTask<4>(task_index, param_seed,
104
                                                    data_seed, task_spec));
105
        break;
106
      case 8:
107
        return_tasks->push_back(CreateTask<8>(task_index, param_seed,
108
                                                    data_seed, task_spec));
109
        break;
110
      case 16:
111
        return_tasks->push_back(CreateTask<16>(task_index, param_seed,
112
                                                     data_seed, task_spec));
113
        break;
114
      case 32:
115
        return_tasks->push_back(CreateTask<32>(task_index, param_seed,
116
                                                     data_seed, task_spec));
117
        break;
118
      default:
119
        LOG(FATAL) << "Unsupported features size: "
120
                   << task_spec.features_size() << std::endl;
121
    }
122
  }
123
}
124

125
void FillTasks(
126
    const TaskCollection& task_collection,
127
    vector<unique_ptr<TaskInterface>>* return_tasks) {
128
  // Check return targets are empty.
129
  CHECK(return_tasks->empty());
130
  for (const TaskSpec& task_spec : task_collection.tasks()) {
131
    FillTasksFromTaskSpec(task_spec, return_tasks);
132
  }
133
}
134

135
void RandomizeTaskSeeds(TaskCollection* task_collection,
136
                        const RandomSeedT seed) {
137
  RandomSeedT base_param_seed =
138
      HashMix(static_cast<RandomSeedT>(85652777), seed);
139
  mt19937 param_seed_bit_gen(base_param_seed);
140
  RandomGenerator param_seed_gen(&param_seed_bit_gen);
141

142
  RandomSeedT base_data_seed =
143
      HashMix(static_cast<RandomSeedT>(38272328), seed);
144
  mt19937 data_seed_bit_gen(base_data_seed);
145
  RandomGenerator data_seed_gen(&data_seed_bit_gen);
146

147
  for (TaskSpec& task : *task_collection->mutable_tasks()) {
148
    task.clear_param_seeds();
149
    task.clear_data_seeds();
150
    for (IntegerT i = 0; i < task.num_tasks(); i++) {
151
      task.add_param_seeds(param_seed_gen.UniformRandomSeed());
152
    }
153
    for (IntegerT i = 0; i < task.num_tasks(); i++) {
154
      task.add_data_seeds(data_seed_gen.UniformRandomSeed());
155
    }
156
  }
157
}
158

159
}  // namespace automl_zero
160

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.