google-research

Форк
0
214 строк · 7.8 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15

16
syntax = "proto2";
17

18
package automl_zero;
19

20

21

22
// The tasks to evaluate.
23
message TaskCollection {
24
  repeated TaskSpec tasks = 1;  // Non-empty.
25
}
26

27
enum EvalType {
28
  INVALID_EVAL_TYPE = 0;
29
  RMS_ERROR = 1;
30
  ACCURACY = 4;
31
}
32

33
// Encodes information about a task of a given kind.
34
message TaskSpec {
35
  // Size of each features vector. This also sets the size of all vectors and
36
  // matrices in the memory.
37
  optional int32 features_size = 13;
38

39
  // Number of unique training examples.
40
  optional int32 num_train_examples = 1;  // Required.
41

42
  // Number of times the training examples will be repeated to mimic multiple
43
  // epochs over a fixed training set.
44
  optional int32 num_train_epochs = 21 [default = 1];
45

46
  optional int32 num_valid_examples = 2;  // Required.
47

48
  // Number of tasks with this specification.
49
  optional int32 num_tasks = 3;
50

51
  // Seeds for the features. If data_seeds have n elements (n > 0), they will
52
  // be used as the seeds for the first n tasks, the seeds for the rest will
53
  // be incrementing from the last seed in data_seeds. If data_seeds is empty,
54
  // the default seeds will be used. See FillTasks function in
55
  // task_util.cc for more details.
56
  // TODO(crazydonkey): make sure the random seed is never 0.
57
  repeated uint32 data_seeds = 4;
58

59
  // Seeds for the parameters that determine the labels function. Same rules
60
  // as for data_seeds apply.
61
  // TODO(crazydonkey): make sure the random seed is never 0.
62
  repeated uint32 param_seeds = 5;
63

64
  // See task_type case for allowed EvalType values.
65
  optional EvalType eval_type = 28;  // Required.
66

67
  oneof task_type {
68
    // Linear regression task.
69
    ScalarLinearRegressionTaskSpec scalar_linear_regression_task = 6;
70

71
    // Non-linear regression task generated by 2 layer NN.
72
    Scalar2LayerNNRegressionTaskSpec scalar_2layer_nn_regression_task = 7;
73

74
    // Binary classification task generated by randomly projecting an
75
    // MNIST and CIFAR-10 to lower dimensions.
76
    ProjectedBinaryClassificationTask projected_binary_classification_task = 24;
77

78
    // Useful tasks for tests.
79
    UnitTestFixedTask unit_test_fixed_task = 40;
80
    UnitTestZerosTaskSpec unit_test_zeros_task = 45;
81
    UnitTestOnesTaskSpec unit_test_ones_task = 46;
82
    UnitTestIncrementTaskSpec unit_test_increment_task = 47;
83
  }
84

85
  // Used for final evaluation.
86
  optional int32 num_test_examples = 18;
87
}
88

89
enum ActivationType {
90
  RELU = 0;
91
  TANH = 1;
92
}
93

94
message ScalarLinearRegressionTaskSpec {}
95

96
message Scalar2LayerNNRegressionTaskSpec {}
97

98
// A projected binary classification task. These use pre-generated datasets.
99
// The following TaskSpec fields are restricted to the given values:
100
//   eval_type: ACCURACY.
101
//   num_train_examples: the value should be an integer in (0, 8000]
102
//   num_valid_examples: the value should be an integer in (0, 1000]
103
//   num_test_examples: the value should be an integer in (0, 1000]
104
//   param_seeds: the param_seeds are not used so doesn't matter.
105
// Below are the supported choices for dataset_name, features_size,
106
// min_supported_data_seed, max_supported_data_seed and use_downsampling:
107
// |dataset_name|features_size|min/max_supported_data_seed|
108
// |------------|-------------|---------------------------|
109
// |mnist       |16           |0 / 100                    |
110
// |cifar10     |16           |0 / 100                    |
111
//
112
// Meta-train / meta-validation / meta-test split:
113
// Since some positive-negative pairs are heldout,
114
// you can use all the seeds during search (meta-train) and use the heldout
115
// pairs in model selection and evaluation (meta-validation and meta-test).
116
// Among all 45 possible pairs, we recommend that the following
117
// 9 randomly selected pairs be held out for meta-validation and meta-test:
118
// (4, 6), (3, 5), (8, 9), (3, 8), (0, 9), (2, 9), (1, 8), (3, 6), (0, 5).
119
// If transferring to the original feature size is used as final evaluation
120
// (meta-test), you can use all the heldout pairs as meta-validation.
121
// If no transferring is used, you can use the first 4 pairs as
122
// meta-validation and the rest 5 pairs as meta-test.
123
message ProjectedBinaryClassificationTask {
124
  // Below are the IDs for the positive and negative classes, you should
125
  // either specify:
126
  // (1) both of them, in this case, the given positive and negative classes
127
  // will be used;
128
  // (2) none of them, in this case, the positive and negative classes will
129
  // be randomly chosen based on the data_seed.
130
  //
131
  // Both values should be integers in [0, 9] with the `positive_class` smaller
132
  // than the `negative_class`.
133
  optional int32 positive_class = 1;
134
  optional int32 negative_class = 2;
135

136
  // Name to specify the dataset to use, currently supporting "mnist" and
137
  // "cifar10".
138
  optional string dataset_name = 3;
139

140
  // There are two possible sources to get the projected data:
141
  // (1) the data is saved in this proto, i.e., all the features and
142
  //     labels of the train/validation/test set are saved in the
143
  //     `dataset` field.
144
  // (2) If `path` is set, it will be used as the path to the folder
145
  // containing all the serialized data.
146
  oneof task_source {
147
    string path = 4;
148
    ScalarLabelDataset dataset = 5;
149
  }
150

151
  // Pairs to hold out when randomizing the dataset.
152
  repeated ClassPair held_out_pairs = 6;
153

154
  // Minimum (incl.) and maximum (excl.) data seeds supported in the sstable
155
  // that saves the dumped projected dataset.
156
  //
157
  // Only seeds in the range specified in the table above are supported.
158
  // The seed is obtained by mapping `data_seed` into the range with
159
  //     seed = (data_seed % (max_supported_data_seed-min_supported_data_seed) +
160
  //             min_supported_data_seed)
161
  // (1) when the dataset is not randomized, the specified `positive_class` and
162
  // `negative_class` will be used;
163
  // (2) when the dataset is randomized, i.e., when `positive_class` and
164
  // `negative_classess` are not set, the `data_seed` is also used
165
  // to randomly select a the positive and negative classes.
166
  // TODO(crazydonkey): make sure the random seed is never 0.
167
  optional int32 min_supported_data_seed = 7 [default = 0];
168
  optional int32 max_supported_data_seed = 8 [default = 10];
169
}
170

171
message ClassPair {
172
  optional int32 positive_class = 1;
173
  optional int32 negative_class = 2;
174
}
175

176
message ScalarLabelDataset {
177
  // Training, validation and test examples.
178
  repeated FeatureVector train_features = 1;
179
  repeated float train_labels = 2;
180
  repeated FeatureVector valid_features = 3;
181
  repeated float valid_labels = 4;
182
  repeated FeatureVector test_features = 5;
183
  repeated float test_labels = 6;
184
}
185

186
message FeatureVector {
187
  repeated float features = 1;
188
}
189

190
// A task where the data is specified explicitly during construction.
191
// Useful for unit tests.
192
message UnitTestFixedTask {
193
  // Training, validation and test examples.
194
  repeated UnitTestFixedTaskVector train_features = 1;
195
  repeated UnitTestFixedTaskVector train_labels = 2;
196
  repeated UnitTestFixedTaskVector valid_features = 3;
197
  repeated UnitTestFixedTaskVector valid_labels = 4;
198
  repeated UnitTestFixedTaskVector test_features = 5;
199
  repeated UnitTestFixedTaskVector test_labels = 6;
200

201
  reserved 7;
202
}
203

204
message UnitTestFixedTaskVector {
205
  repeated double elements = 1;
206
}
207

208
message UnitTestZerosTaskSpec {}
209

210
message UnitTestOnesTaskSpec {}
211

212
message UnitTestIncrementTaskSpec {
213
  optional double increment = 1 [default = 1.0];
214
}
215

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.