google-research

Форк
0
95 строк · 3.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Trains a feature-column model on Criteo Kaggle data."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
from absl import app
23
from absl import flags
24
from absl import logging
25

26
import tensorflow.compat.v2 as tf
27
from uq_benchmark_2019 import experiment_utils
28
from uq_benchmark_2019.criteo import data_lib
29
from uq_benchmark_2019.criteo import hparams_lib
30
from uq_benchmark_2019.criteo import models_lib
31

32
FLAGS = flags.FLAGS
33

34

35
def _declare_flags():
36
  """Declare flags; not invoked when this module is imported as a library."""
37
  flags.DEFINE_enum('method', None, models_lib.METHODS,
38
                    'Name of modeling method.')
39
  flags.DEFINE_string('output_dir', None, 'Output directory.')
40
  flags.DEFINE_integer('test_level', 0, 'Testing level.')
41
  flags.DEFINE_integer('train_epochs', 1, 'Number of epochs for training.')
42
  flags.DEFINE_integer('task', 0, 'Task number.')
43

44

45
def run(method, output_dir, num_epochs, fake_data=False, fake_training=False):
46
  """Trains a model and records its predictions on configured datasets.
47

48
  Args:
49
    method: Modeling method to experiment with.
50
    output_dir: Directory to record the trained model and output stats.
51
    num_epochs: Number of training epochs.
52
    fake_data: If true, use fake data.
53
    fake_training: If true, train for a trivial number of steps.
54
  Returns:
55
    Trained Keras model.
56
  """
57
  tf.io.gfile.makedirs(output_dir)
58
  data_config_train = data_lib.DataConfig(split='train', fake_data=fake_data)
59
  data_config_valid = data_lib.DataConfig(split='valid', fake_data=fake_data)
60

61
  hparams = hparams_lib.get_tuned_hparams(method, parameterization='C')
62
  model_opts = hparams_lib.model_opts_from_hparams(hparams, method,
63
                                                   parameterization='C',
64
                                                   fake_training=fake_training)
65

66
  experiment_utils.record_config(model_opts, output_dir+'/model_options.json')
67

68
  model = models_lib.build_and_train_model(
69
      model_opts, data_config_train, data_config_valid,
70
      output_dir=output_dir,
71
      num_epochs=num_epochs,
72
      fake_training=fake_training)
73

74
  logging.info('Saving model to output_dir.')
75
  model.save_weights(output_dir + '/model.ckpt')
76
  # TODO(yovadia): Looks like Keras save_model does not work with Python3.
77
  # (e.g. see b/129323565).
78
  # experiment_utils.save_model(model, output_dir)
79
  return model
80

81

82
def main(argv):
83
  if len(argv) > 1:
84
    raise app.UsageError('Too many command-line arguments.')
85
  run(FLAGS.method,
86
      FLAGS.output_dir,
87
      num_epochs=FLAGS.train_epochs,
88
      fake_data=FLAGS.test_level > 1,
89
      fake_training=FLAGS.test_level > 0)
90

91
if __name__ == '__main__':
92

93
  tf.enable_v2_behavior()
94
  _declare_flags()
95
  app.run(main)
96

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.