google-research
68 строк · 2.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generate a json file containing a dictionary mapping names to configs.
17
18This manages sampling configs for both optimizers and tasks.
19
20The tasks are not directly part of the task dataset of tasks as not all configs
21will be feasible and/or will train.
22"""
23import collections24from absl import app25from absl import flags26
27from task_set import registry28from task_set.optimizers import all_optimizers # pylint: disable=unused-import29from task_set.tasks import all_tasks # pylint: disable=unused-import30from task_set.tasks import utils31import tensorflow.compat.v1 as tf32
33flags.DEFINE_string("task_sampler", None, "Module to use for sampling tasks")34flags.DEFINE_string("optimizer_sampler", None,35"Module to use for sampling optimizers")36
37flags.DEFINE_string("output_file", None, "Output location for sampling.")38flags.mark_flag_as_required("output_file")39
40flags.DEFINE_integer("num_samples", 100, "Number of samples so select.")41FLAGS = flags.FLAGS42
43
44def main(_):45if FLAGS.task_sampler and FLAGS.optimizer_sampler:46raise ValueError("Only specify one sampler!")47if not FLAGS.task_sampler and not FLAGS.optimizer_sampler:48raise ValueError("Must specify either task_sampler or optimizer_sampler!")49
50if FLAGS.task_sampler:51sampler = registry.task_registry.get_sampler(FLAGS.task_sampler)52sampler_name = FLAGS.task_sampler53else:54sampler = registry.optimizers_registry.get_sampler(FLAGS.optimizer_sampler)55sampler_name = FLAGS.optimizer_sampler56
57samples = collections.OrderedDict()58for i in range(FLAGS.num_samples):59cfg = sampler(i)60task_name = "%s_seed%d" % (sampler_name, i)61samples[task_name] = cfg, sampler_name62
63with tf.gfile.GFile(FLAGS.output_file, "w") as f:64f.write(utils.pretty_json_dumps(samples).encode("utf-8"))65
66
67if __name__ == "__main__":68app.run(main)69