google-research
95 строк · 3.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Trains a feature-column model on Criteo Kaggle data."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22from absl import app23from absl import flags24from absl import logging25
26import tensorflow.compat.v2 as tf27from uq_benchmark_2019 import experiment_utils28from uq_benchmark_2019.criteo import data_lib29from uq_benchmark_2019.criteo import hparams_lib30from uq_benchmark_2019.criteo import models_lib31
32FLAGS = flags.FLAGS33
34
35def _declare_flags():36"""Declare flags; not invoked when this module is imported as a library."""37flags.DEFINE_enum('method', None, models_lib.METHODS,38'Name of modeling method.')39flags.DEFINE_string('output_dir', None, 'Output directory.')40flags.DEFINE_integer('test_level', 0, 'Testing level.')41flags.DEFINE_integer('train_epochs', 1, 'Number of epochs for training.')42flags.DEFINE_integer('task', 0, 'Task number.')43
44
45def run(method, output_dir, num_epochs, fake_data=False, fake_training=False):46"""Trains a model and records its predictions on configured datasets.47
48Args:
49method: Modeling method to experiment with.
50output_dir: Directory to record the trained model and output stats.
51num_epochs: Number of training epochs.
52fake_data: If true, use fake data.
53fake_training: If true, train for a trivial number of steps.
54Returns:
55Trained Keras model.
56"""
57tf.io.gfile.makedirs(output_dir)58data_config_train = data_lib.DataConfig(split='train', fake_data=fake_data)59data_config_valid = data_lib.DataConfig(split='valid', fake_data=fake_data)60
61hparams = hparams_lib.get_tuned_hparams(method, parameterization='C')62model_opts = hparams_lib.model_opts_from_hparams(hparams, method,63parameterization='C',64fake_training=fake_training)65
66experiment_utils.record_config(model_opts, output_dir+'/model_options.json')67
68model = models_lib.build_and_train_model(69model_opts, data_config_train, data_config_valid,70output_dir=output_dir,71num_epochs=num_epochs,72fake_training=fake_training)73
74logging.info('Saving model to output_dir.')75model.save_weights(output_dir + '/model.ckpt')76# TODO(yovadia): Looks like Keras save_model does not work with Python3.77# (e.g. see b/129323565).78# experiment_utils.save_model(model, output_dir)79return model80
81
82def main(argv):83if len(argv) > 1:84raise app.UsageError('Too many command-line arguments.')85run(FLAGS.method,86FLAGS.output_dir,87num_epochs=FLAGS.train_epochs,88fake_data=FLAGS.test_level > 1,89fake_training=FLAGS.test_level > 0)90
91if __name__ == '__main__':92
93tf.enable_v2_behavior()94_declare_flags()95app.run(main)96