google-research
292 строки · 9.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Train a task with a given optimizer and monitor results when done.
17
18Example usage:
19```
20binary_path/train_inner --task_name="mlp_family_seed12" \
21--optimizer_name="adam8p_wide_grid_seed21" \
22--output_directory="/disk2/tmp/optfolder" \
23--alsologtostderr
24```
25"""
26import json27import os28import time29from typing import Dict, Text, Tuple30
31from absl import app32from absl import flags33from absl import logging34
35import dataclasses36
37import numpy as np38
39from task_set import datasets40from task_set import registry41from task_set.optimizers import all_optimizers # pylint: disable=unused-import42from task_set.tasks import all_tasks # pylint: disable=unused-import43from task_set.tasks import base44import tensorflow.compat.v1 as tf45
46
47FLAGS = flags.FLAGS48
49flags.DEFINE_string("optimizer_name", None, "Name of optimizer to run.")50flags.DEFINE_string("task_name", None, "Name of task to run.")51flags.DEFINE_integer("training_steps", 10000,52"Number of training steps to run.")53flags.DEFINE_integer("eval_every_n", 200, "Number of steps between each eval.")54flags.DEFINE_integer("replica", 0, "Replica of run.")55
56flags.DEFINE_string("output_directory", None,57"Training directory to save summaries/checkpoints.")58
59NamedTensorDict = Dict[Text, tf.Tensor]60FourTensorTuple = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]61FourDictTensorTuple = Tuple[NamedTensorDict, NamedTensorDict, NamedTensorDict,62NamedTensorDict]63
64
65def compute_averaged_loss(66task,67params,68num_batches = 2,69with_metrics = False):70"""Computes inner-task loss and metrics with num_batches mini-batches.71
72Returns values for each of the 4 splits (train, valid-inner, valid-outer,
73test.) For each split, perform num_batches evaluations averaging the results
74(both losses, and optionally metrics).
75
76Args:
77task: Task used to compute the loss.
78params: Parameters for the task.
79num_batches: Number of batches to compute averages over.
80with_metrics: bool additionally compute and return averages over aux
81metrics.
82
83Returns:
84losses: len 4 tuple containing loss values for each split
85metrics: len 4 tuple containing dictionaries of metrics for each split.
86If with_metrics is false, these dictionaries are empty.
87"""
88
89def compute_loss(split):90inner_loss_and_maybe_aux = task.call_split(91params, split, with_metrics=with_metrics)92if not with_metrics:93inner_loss_and_maybe_aux = inner_loss_and_maybe_aux, {}94
95inner_loss_and_maybe_aux = inner_loss_and_maybe_aux # type: Tuple[tf.Tensor, Dict[Text, tf.Tensor]]96return inner_loss_and_maybe_aux97
98# Run a forward pass to get a dictionary with metrics of the right dtype.99# This is needed ahead of time before tf.map_fn is called.100# Because we are in graph mode, this does not incur overhead.101_, tmp_aux = compute_loss(datasets.Split.TRAIN)102dummy_aux = {103k: tf.zeros(shape=[num_batches], dtype=v.dtype)104for k, v in tmp_aux.items()105}106
107splits = [108datasets.Split.TRAIN, datasets.Split.VALID_INNER,109datasets.Split.VALID_OUTER, datasets.Split.TEST110]111
112return_losses = []113return_metrics = []114for split in splits:115# pylint: disable=cell-var-from-loop116losses, metrics = tf.map_fn(lambda _: compute_loss(split),117(tf.to_float(tf.range(num_batches)), dummy_aux))118# pylint: enable=cell-var-from-loop119avg_loss = tf.reduce_mean(losses)120avg_metric = {k: tf.reduce_mean(v) for k, v in metrics.items()}121
122return_losses.append(avg_loss)123return_metrics.append(avg_metric)124
125return tuple(return_losses), tuple(return_metrics)126
127
128@dataclasses.dataclass(frozen=True)129class GraphEndpoints:130"""Class containing endpoints used for inner-training."""131train_op: tf.Operation132global_step: tf.Tensor133init_op: tf.Operation134test_loss: tf.Tensor135valid_inner_loss: tf.Tensor136valid_outer_loss: tf.Tensor137train_loss: tf.Tensor138
139
140def build_training_graph(task_name,141optimizer_name,142num_batchs_per_evaluation = 5):143"""Build the tensorflow graph.144
145Args:
146task_name: Name of task to build.
147optimizer_name: Name of the optimizer to use.
148num_batchs_per_evaluation: Number of batches to use when running a
149single evaluation op. Note, this op is run multiple times per evauation by
150training code.
151
152Returns:
153A dict containing TensorFlow tensors and operations used for training.
154"""
155
156global_step = tf.train.get_or_create_global_step()157
158task_mod = registry.task_registry.get_instance(task_name)159params = task_mod.current_params()160loss = task_mod.call_split(params, datasets.Split.TRAIN)161opt = registry.optimizers_registry.get_instance(optimizer_name)162
163train_op = opt.minimize(164loss, var_list=list(params.values()), global_step=global_step)165
166train_op = tf.group(train_op, name="train_op")167
168(train_loss, valid_inner_loss, valid_outer_loss,169test_loss), _ = compute_averaged_loss(task_mod, params,170num_batchs_per_evaluation)171
172init_op = tf.initialize_variables(task_mod.get_variables())173
174return GraphEndpoints(175train_op=train_op,176global_step=global_step,177init_op=init_op,178test_loss=test_loss,179valid_inner_loss=valid_inner_loss,180valid_outer_loss=valid_outer_loss,181train_loss=train_loss)182
183
184def train(185train_log_dir,186task_name,187optimizer_name,188training_steps = 10000,189eval_every_n = 200,190minibatch_per_evaluation = 50,191parallel_evaluations = 5,192):193"""Train a model and monitor results.194
195This function trains a task specified by the task_name using the optimizer
196from optimizer_name. It logs out 2 files, result and time_per_step to the
197train_log_dir for later processing.
198
199Args:
200train_log_dir: str Directory to write summaries out to.
201task_name: str Name of task to train.
202optimizer_name: Name of the optimizer to train with.
203training_steps: Number of training steps to perform.
204eval_every_n: Number of steps to run between each evaluation.
205minibatch_per_evaluation: Number of minibatches to run per evalulation
206parallel_evaluations: Number of minibatches to run in parallel in graph.
207Must cleanly devide into minibatch_per_evaluation.
208
209Returns:
210The resulting learning curves encoded as a json string.
211"""
212
213if minibatch_per_evaluation % parallel_evaluations != 0:214raise ValueError("minibatch_per_evaluation must be divisible by"215"parallel_evaluations")216tf.gfile.MakeDirs(train_log_dir)217
218g = build_training_graph(task_name, optimizer_name, parallel_evaluations)219
220state = {"losses": {}, "time_per_step": []}221
222config = tf.ConfigProto(223intra_op_parallelism_threads=20, inter_op_parallelism_threads=20)224
225with tf.Session(config=config) as sess:226sess.run(tf.initialize_all_variables())227sess.run(tf.get_collection(tf.GraphKeys.LOCAL_INIT_OP))228
229step = sess.run(g.global_step)230
231logging.info("Running init op")232sess.run(g.init_op)233
234while step <= training_steps:235if step % eval_every_n == 0 or step == training_steps:236logging.info("Evaluating %d", step)237losses = []238for _ in range(minibatch_per_evaluation // parallel_evaluations):239tr, vai, vao, te = sess.run([240g.train_loss, g.valid_inner_loss, g.valid_outer_loss, g.test_loss241])242losses.append((tr, vai, vao, te))243state["losses"][str(step)] = [float(np.mean(x)) for x in zip(*losses)]244
245# Only log 10 steps to not flood info logs.246if step < 10:247logging.info("Running train_op %d (only logging first 10)", step)248start_time = time.time()249_, step = sess.run([g.train_op, g.global_step])250state["time_per_step"].append(time.time() - start_time)251
252# compute aggregate values for timing information.253mean_time = np.mean(state["time_per_step"])254num_steps = len(state["time_per_step"])255mean_time_last_half = np.mean(state["time_per_step"][num_steps // 2:])256median_time = np.median(state["time_per_step"])257time_per_step = json.dumps({258"mean_time": mean_time,259"median_time": median_time,260"mean_last_half": mean_time_last_half,261})262
263result = json.dumps(state["losses"])264with tf.gfile.GFile(os.path.join(train_log_dir, "result"), "w") as f:265f.write(result.encode("utf-8"))266
267with tf.gfile.GFile(os.path.join(train_log_dir, "time_per_step"), "w") as f:268f.write(time_per_step.encode("utf-8"))269
270return result271
272
273def main(_):274if not FLAGS.optimizer_name:275raise ValueError("Must pass `optimizer_name`")276if not FLAGS.task_name:277raise ValueError("Must pass `task_name`")278if not FLAGS.output_directory:279raise ValueError("Must pass `output_directory`")280train_log_dir = os.path.join(FLAGS.output_directory, FLAGS.task_name,281FLAGS.optimizer_name, str(FLAGS.training_steps),282str(FLAGS.replica))283train(284train_log_dir=train_log_dir,285optimizer_name=FLAGS.optimizer_name,286task_name=FLAGS.task_name,287training_steps=FLAGS.training_steps,288eval_every_n=FLAGS.eval_every_n)289
290
291if __name__ == "__main__":292app.run(main)293