google-research

Форк
0
/
train_inner.py 
292 строки · 9.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Train a task with a given optimizer and monitor results when done.
17

18
Example usage:
19
```
20
  binary_path/train_inner --task_name="mlp_family_seed12" \
21
  --optimizer_name="adam8p_wide_grid_seed21" \
22
  --output_directory="/disk2/tmp/optfolder" \
23
  --alsologtostderr
24
```
25
"""
26
import json
27
import os
28
import time
29
from typing import Dict, Text, Tuple
30

31
from absl import app
32
from absl import flags
33
from absl import logging
34

35
import dataclasses
36

37
import numpy as np
38

39
from task_set import datasets
40
from task_set import registry
41
from task_set.optimizers import all_optimizers  # pylint: disable=unused-import
42
from task_set.tasks import all_tasks  # pylint: disable=unused-import
43
from task_set.tasks import base
44
import tensorflow.compat.v1 as tf
45

46

47
FLAGS = flags.FLAGS
48

49
flags.DEFINE_string("optimizer_name", None, "Name of optimizer to run.")
50
flags.DEFINE_string("task_name", None, "Name of task to run.")
51
flags.DEFINE_integer("training_steps", 10000,
52
                     "Number of training steps to run.")
53
flags.DEFINE_integer("eval_every_n", 200, "Number of steps between each eval.")
54
flags.DEFINE_integer("replica", 0, "Replica of run.")
55

56
flags.DEFINE_string("output_directory", None,
57
                    "Training directory to save summaries/checkpoints.")
58

59
NamedTensorDict = Dict[Text, tf.Tensor]
60
FourTensorTuple = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]
61
FourDictTensorTuple = Tuple[NamedTensorDict, NamedTensorDict, NamedTensorDict,
62
                            NamedTensorDict]
63

64

65
def compute_averaged_loss(
66
    task,
67
    params,
68
    num_batches = 2,
69
    with_metrics = False):
70
  """Computes inner-task loss and metrics with num_batches mini-batches.
71

72
  Returns values for each of the 4 splits (train, valid-inner, valid-outer,
73
  test.) For each split, perform num_batches evaluations averaging the results
74
  (both losses, and optionally metrics).
75

76
  Args:
77
    task: Task used to compute the loss.
78
    params: Parameters for the task.
79
    num_batches: Number of batches to compute averages over.
80
    with_metrics: bool additionally compute and return averages over aux
81
      metrics.
82

83
  Returns:
84
    losses: len 4 tuple containing loss values for each split
85
    metrics: len 4 tuple containing dictionaries of metrics for each split.
86
      If with_metrics is false, these dictionaries are empty.
87
  """
88

89
  def compute_loss(split):
90
    inner_loss_and_maybe_aux = task.call_split(
91
        params, split, with_metrics=with_metrics)
92
    if not with_metrics:
93
      inner_loss_and_maybe_aux = inner_loss_and_maybe_aux, {}
94

95
    inner_loss_and_maybe_aux = inner_loss_and_maybe_aux  #  type: Tuple[tf.Tensor, Dict[Text, tf.Tensor]]
96
    return inner_loss_and_maybe_aux
97

98
  # Run a forward pass to get a dictionary with metrics of the right dtype.
99
  # This is needed ahead of time before tf.map_fn is called.
100
  # Because we are in graph mode, this does not incur overhead.
101
  _, tmp_aux = compute_loss(datasets.Split.TRAIN)
102
  dummy_aux = {
103
      k: tf.zeros(shape=[num_batches], dtype=v.dtype)
104
      for k, v in tmp_aux.items()
105
  }
106

107
  splits = [
108
      datasets.Split.TRAIN, datasets.Split.VALID_INNER,
109
      datasets.Split.VALID_OUTER, datasets.Split.TEST
110
  ]
111

112
  return_losses = []
113
  return_metrics = []
114
  for split in splits:
115
    # pylint: disable=cell-var-from-loop
116
    losses, metrics = tf.map_fn(lambda _: compute_loss(split),
117
                                (tf.to_float(tf.range(num_batches)), dummy_aux))
118
    # pylint: enable=cell-var-from-loop
119
    avg_loss = tf.reduce_mean(losses)
120
    avg_metric = {k: tf.reduce_mean(v) for k, v in metrics.items()}
121

122
    return_losses.append(avg_loss)
123
    return_metrics.append(avg_metric)
124

125
  return tuple(return_losses), tuple(return_metrics)
126

127

128
@dataclasses.dataclass(frozen=True)
129
class GraphEndpoints:
130
  """Class containing endpoints used for inner-training."""
131
  train_op: tf.Operation
132
  global_step: tf.Tensor
133
  init_op: tf.Operation
134
  test_loss: tf.Tensor
135
  valid_inner_loss: tf.Tensor
136
  valid_outer_loss: tf.Tensor
137
  train_loss: tf.Tensor
138

139

140
def build_training_graph(task_name,
141
                         optimizer_name,
142
                         num_batchs_per_evaluation = 5):
143
  """Build the tensorflow graph.
144

145
  Args:
146
    task_name: Name of task to build.
147
    optimizer_name: Name of the optimizer to use.
148
    num_batchs_per_evaluation: Number of batches to use when running a
149
      single evaluation op. Note, this op is run multiple times per evauation by
150
      training code.
151

152
  Returns:
153
    A dict containing TensorFlow tensors and operations used for training.
154
  """
155

156
  global_step = tf.train.get_or_create_global_step()
157

158
  task_mod = registry.task_registry.get_instance(task_name)
159
  params = task_mod.current_params()
160
  loss = task_mod.call_split(params, datasets.Split.TRAIN)
161
  opt = registry.optimizers_registry.get_instance(optimizer_name)
162

163
  train_op = opt.minimize(
164
      loss, var_list=list(params.values()), global_step=global_step)
165

166
  train_op = tf.group(train_op, name="train_op")
167

168
  (train_loss, valid_inner_loss, valid_outer_loss,
169
   test_loss), _ = compute_averaged_loss(task_mod, params,
170
                                         num_batchs_per_evaluation)
171

172
  init_op = tf.initialize_variables(task_mod.get_variables())
173

174
  return GraphEndpoints(
175
      train_op=train_op,
176
      global_step=global_step,
177
      init_op=init_op,
178
      test_loss=test_loss,
179
      valid_inner_loss=valid_inner_loss,
180
      valid_outer_loss=valid_outer_loss,
181
      train_loss=train_loss)
182

183

184
def train(
185
    train_log_dir,
186
    task_name,
187
    optimizer_name,
188
    training_steps = 10000,
189
    eval_every_n = 200,
190
    minibatch_per_evaluation = 50,
191
    parallel_evaluations = 5,
192
):
193
  """Train a model and monitor results.
194

195
  This function trains a task specified by the task_name using the optimizer
196
  from optimizer_name. It logs out 2 files, result and time_per_step to the
197
  train_log_dir for later processing.
198

199
  Args:
200
    train_log_dir: str Directory to write summaries out to.
201
    task_name: str Name of task to train.
202
    optimizer_name: Name of the optimizer to train with.
203
    training_steps: Number of training steps to perform.
204
    eval_every_n: Number of steps to run between each evaluation.
205
    minibatch_per_evaluation: Number of minibatches to run per evalulation
206
    parallel_evaluations: Number of minibatches to run in parallel in graph.
207
      Must cleanly devide into minibatch_per_evaluation.
208

209
  Returns:
210
    The resulting learning curves encoded as a json string.
211
  """
212

213
  if minibatch_per_evaluation % parallel_evaluations != 0:
214
    raise ValueError("minibatch_per_evaluation must be divisible by"
215
                     "parallel_evaluations")
216
  tf.gfile.MakeDirs(train_log_dir)
217

218
  g = build_training_graph(task_name, optimizer_name, parallel_evaluations)
219

220
  state = {"losses": {}, "time_per_step": []}
221

222
  config = tf.ConfigProto(
223
      intra_op_parallelism_threads=20, inter_op_parallelism_threads=20)
224

225
  with tf.Session(config=config) as sess:
226
    sess.run(tf.initialize_all_variables())
227
    sess.run(tf.get_collection(tf.GraphKeys.LOCAL_INIT_OP))
228

229
    step = sess.run(g.global_step)
230

231
    logging.info("Running init op")
232
    sess.run(g.init_op)
233

234
    while step <= training_steps:
235
      if step % eval_every_n == 0 or step == training_steps:
236
        logging.info("Evaluating %d", step)
237
        losses = []
238
        for _ in range(minibatch_per_evaluation // parallel_evaluations):
239
          tr, vai, vao, te = sess.run([
240
              g.train_loss, g.valid_inner_loss, g.valid_outer_loss, g.test_loss
241
          ])
242
          losses.append((tr, vai, vao, te))
243
        state["losses"][str(step)] = [float(np.mean(x)) for x in zip(*losses)]
244

245
      # Only log 10 steps to not flood info logs.
246
      if step < 10:
247
        logging.info("Running train_op %d (only logging first 10)", step)
248
      start_time = time.time()
249
      _, step = sess.run([g.train_op, g.global_step])
250
      state["time_per_step"].append(time.time() - start_time)
251

252
  # compute aggregate values for timing information.
253
  mean_time = np.mean(state["time_per_step"])
254
  num_steps = len(state["time_per_step"])
255
  mean_time_last_half = np.mean(state["time_per_step"][num_steps // 2:])
256
  median_time = np.median(state["time_per_step"])
257
  time_per_step = json.dumps({
258
      "mean_time": mean_time,
259
      "median_time": median_time,
260
      "mean_last_half": mean_time_last_half,
261
  })
262

263
  result = json.dumps(state["losses"])
264
  with tf.gfile.GFile(os.path.join(train_log_dir, "result"), "w") as f:
265
    f.write(result.encode("utf-8"))
266

267
  with tf.gfile.GFile(os.path.join(train_log_dir, "time_per_step"), "w") as f:
268
    f.write(time_per_step.encode("utf-8"))
269

270
  return result
271

272

273
def main(_):
274
  if not FLAGS.optimizer_name:
275
    raise ValueError("Must pass `optimizer_name`")
276
  if not FLAGS.task_name:
277
    raise ValueError("Must pass `task_name`")
278
  if not FLAGS.output_directory:
279
    raise ValueError("Must pass `output_directory`")
280
  train_log_dir = os.path.join(FLAGS.output_directory, FLAGS.task_name,
281
                               FLAGS.optimizer_name, str(FLAGS.training_steps),
282
                               str(FLAGS.replica))
283
  train(
284
      train_log_dir=train_log_dir,
285
      optimizer_name=FLAGS.optimizer_name,
286
      task_name=FLAGS.task_name,
287
      training_steps=FLAGS.training_steps,
288
      eval_every_n=FLAGS.eval_every_n)
289

290

291
if __name__ == "__main__":
292
  app.run(main)
293

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.