google-research
214 строк · 7.8 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15
16syntax = "proto2";
17
18package automl_zero;
19
20
21
22// The tasks to evaluate.
23message TaskCollection {
24repeated TaskSpec tasks = 1; // Non-empty.
25}
26
27enum EvalType {
28INVALID_EVAL_TYPE = 0;
29RMS_ERROR = 1;
30ACCURACY = 4;
31}
32
33// Encodes information about a task of a given kind.
34message TaskSpec {
35// Size of each features vector. This also sets the size of all vectors and
36// matrices in the memory.
37optional int32 features_size = 13;
38
39// Number of unique training examples.
40optional int32 num_train_examples = 1; // Required.
41
42// Number of times the training examples will be repeated to mimic multiple
43// epochs over a fixed training set.
44optional int32 num_train_epochs = 21 [default = 1];
45
46optional int32 num_valid_examples = 2; // Required.
47
48// Number of tasks with this specification.
49optional int32 num_tasks = 3;
50
51// Seeds for the features. If data_seeds have n elements (n > 0), they will
52// be used as the seeds for the first n tasks, the seeds for the rest will
53// be incrementing from the last seed in data_seeds. If data_seeds is empty,
54// the default seeds will be used. See FillTasks function in
55// task_util.cc for more details.
56// TODO(crazydonkey): make sure the random seed is never 0.
57repeated uint32 data_seeds = 4;
58
59// Seeds for the parameters that determine the labels function. Same rules
60// as for data_seeds apply.
61// TODO(crazydonkey): make sure the random seed is never 0.
62repeated uint32 param_seeds = 5;
63
64// See task_type case for allowed EvalType values.
65optional EvalType eval_type = 28; // Required.
66
67oneof task_type {
68// Linear regression task.
69ScalarLinearRegressionTaskSpec scalar_linear_regression_task = 6;
70
71// Non-linear regression task generated by 2 layer NN.
72Scalar2LayerNNRegressionTaskSpec scalar_2layer_nn_regression_task = 7;
73
74// Binary classification task generated by randomly projecting an
75// MNIST and CIFAR-10 to lower dimensions.
76ProjectedBinaryClassificationTask projected_binary_classification_task = 24;
77
78// Useful tasks for tests.
79UnitTestFixedTask unit_test_fixed_task = 40;
80UnitTestZerosTaskSpec unit_test_zeros_task = 45;
81UnitTestOnesTaskSpec unit_test_ones_task = 46;
82UnitTestIncrementTaskSpec unit_test_increment_task = 47;
83}
84
85// Used for final evaluation.
86optional int32 num_test_examples = 18;
87}
88
89enum ActivationType {
90RELU = 0;
91TANH = 1;
92}
93
94message ScalarLinearRegressionTaskSpec {}
95
96message Scalar2LayerNNRegressionTaskSpec {}
97
98// A projected binary classification task. These use pre-generated datasets.
99// The following TaskSpec fields are restricted to the given values:
100// eval_type: ACCURACY.
101// num_train_examples: the value should be an integer in (0, 8000]
102// num_valid_examples: the value should be an integer in (0, 1000]
103// num_test_examples: the value should be an integer in (0, 1000]
104// param_seeds: the param_seeds are not used so doesn't matter.
105// Below are the supported choices for dataset_name, features_size,
106// min_supported_data_seed, max_supported_data_seed and use_downsampling:
107// |dataset_name|features_size|min/max_supported_data_seed|
108// |------------|-------------|---------------------------|
109// |mnist |16 |0 / 100 |
110// |cifar10 |16 |0 / 100 |
111//
112// Meta-train / meta-validation / meta-test split:
113// Since some positive-negative pairs are heldout,
114// you can use all the seeds during search (meta-train) and use the heldout
115// pairs in model selection and evaluation (meta-validation and meta-test).
116// Among all 45 possible pairs, we recommend that the following
117// 9 randomly selected pairs be held out for meta-validation and meta-test:
118// (4, 6), (3, 5), (8, 9), (3, 8), (0, 9), (2, 9), (1, 8), (3, 6), (0, 5).
119// If transferring to the original feature size is used as final evaluation
120// (meta-test), you can use all the heldout pairs as meta-validation.
121// If no transferring is used, you can use the first 4 pairs as
122// meta-validation and the rest 5 pairs as meta-test.
123message ProjectedBinaryClassificationTask {
124// Below are the IDs for the positive and negative classes, you should
125// either specify:
126// (1) both of them, in this case, the given positive and negative classes
127// will be used;
128// (2) none of them, in this case, the positive and negative classes will
129// be randomly chosen based on the data_seed.
130//
131// Both values should be integers in [0, 9] with the `positive_class` smaller
132// than the `negative_class`.
133optional int32 positive_class = 1;
134optional int32 negative_class = 2;
135
136// Name to specify the dataset to use, currently supporting "mnist" and
137// "cifar10".
138optional string dataset_name = 3;
139
140// There are two possible sources to get the projected data:
141// (1) the data is saved in this proto, i.e., all the features and
142// labels of the train/validation/test set are saved in the
143// `dataset` field.
144// (2) If `path` is set, it will be used as the path to the folder
145// containing all the serialized data.
146oneof task_source {
147string path = 4;
148ScalarLabelDataset dataset = 5;
149}
150
151// Pairs to hold out when randomizing the dataset.
152repeated ClassPair held_out_pairs = 6;
153
154// Minimum (incl.) and maximum (excl.) data seeds supported in the sstable
155// that saves the dumped projected dataset.
156//
157// Only seeds in the range specified in the table above are supported.
158// The seed is obtained by mapping `data_seed` into the range with
159// seed = (data_seed % (max_supported_data_seed-min_supported_data_seed) +
160// min_supported_data_seed)
161// (1) when the dataset is not randomized, the specified `positive_class` and
162// `negative_class` will be used;
163// (2) when the dataset is randomized, i.e., when `positive_class` and
164// `negative_classess` are not set, the `data_seed` is also used
165// to randomly select a the positive and negative classes.
166// TODO(crazydonkey): make sure the random seed is never 0.
167optional int32 min_supported_data_seed = 7 [default = 0];
168optional int32 max_supported_data_seed = 8 [default = 10];
169}
170
171message ClassPair {
172optional int32 positive_class = 1;
173optional int32 negative_class = 2;
174}
175
176message ScalarLabelDataset {
177// Training, validation and test examples.
178repeated FeatureVector train_features = 1;
179repeated float train_labels = 2;
180repeated FeatureVector valid_features = 3;
181repeated float valid_labels = 4;
182repeated FeatureVector test_features = 5;
183repeated float test_labels = 6;
184}
185
186message FeatureVector {
187repeated float features = 1;
188}
189
190// A task where the data is specified explicitly during construction.
191// Useful for unit tests.
192message UnitTestFixedTask {
193// Training, validation and test examples.
194repeated UnitTestFixedTaskVector train_features = 1;
195repeated UnitTestFixedTaskVector train_labels = 2;
196repeated UnitTestFixedTaskVector valid_features = 3;
197repeated UnitTestFixedTaskVector valid_labels = 4;
198repeated UnitTestFixedTaskVector test_features = 5;
199repeated UnitTestFixedTaskVector test_labels = 6;
200
201reserved 7;
202}
203
204message UnitTestFixedTaskVector {
205repeated double elements = 1;
206}
207
208message UnitTestZerosTaskSpec {}
209
210message UnitTestOnesTaskSpec {}
211
212message UnitTestIncrementTaskSpec {
213optional double increment = 1 [default = 1.0];
214}
215