google-research
105 строк · 4.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines common flags for training RED-ACE models."""
17
18from absl import flags19
20flags.DEFINE_string('train_file', None,21'Path to the tfrecord file for training.')22flags.DEFINE_string('eval_file', None,23'Path to the tfrecord file for evaluation.')24flags.DEFINE_string(25'init_checkpoint',26None,27('Path to a pre-trained BERT checkpoint or a to previously trained model'28' checkpoint that the current training job will further fine-tune.'),29)
30flags.DEFINE_string(31'model_dir',32None,33'Directory where the model weights and summaries are stored.',34)
35flags.DEFINE_integer(36'max_seq_length',37128,38('The maximum total input sequence length after tokenization. '39'Sequences longer than this will be truncated, and sequences shorter '40'than this will be padded.'),41)
42flags.DEFINE_integer('num_train_examples', 32,43'Total size of training dataset.')44flags.DEFINE_integer('num_eval_examples', 32,45'Total size of evaluation dataset.')46flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')47flags.DEFINE_integer('eval_batch_size', 32, 'Total batch size for evaluation.')48flags.DEFINE_integer('num_train_epochs', 100,49'Total number of training epochs to perform.')50flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')51flags.DEFINE_float('warmup_steps', 10000,52'Warmup steps for Adam weight decay optimizer.')53flags.DEFINE_integer('log_steps', 1000,54'Interval of steps between logging of batch level stats.')55flags.DEFINE_integer('steps_per_loop', 1000, 'Steps per loop.')56flags.DEFINE_integer('keep_checkpoint_max', 3,57'How many checkpoints to keep around during training.')58flags.DEFINE_integer(59'mini_epochs_per_epoch', 1,60'Only has an effect for values >= 2. This flag enables more frequent '61'checkpointing + evaluation on the validation set than done by default. '62'This is achieved by reporting to TF an epoch size that is '63'"flag value"-times smaller than the true value.')64
65flags.DEFINE_string('output_file', None, 'Path to the output file.')66flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')67flags.DEFINE_integer(68'predict_batch_size', 32,69'Batch size for the prediction of insertion and tagging models.')70flags.DEFINE_bool(71'split_on_punc',72True,73'Whether to split on punctuation characters during tokenization.',74)
75flags.DEFINE_string('redace_config', None, 'Path to the RED-ACE config file.')76flags.DEFINE_string(77'special_glue_string_for_joining_sources',78' ',79('String that is used to join multiple source strings of a given example'80' into one string. Optional.'),81)
82
83# Prediction flags.
84flags.DEFINE_string(85'predict_input_file', None,86'Path to the input file containing examples for which to'87'compute predictions.')88flags.DEFINE_string('predict_output_file', None,89'Path to the output file for predictions.')90
91# Training flags.
92flags.DEFINE_bool(93'use_weighted_labels', True,94'Whether different labels were given different weights. Primarly used to '95'increase the importance of rare tags.')96
97flags.DEFINE_string('test_file', None, 'Path to the test file.')98
99flags.DEFINE_enum(100'validation_checkpoint_metric',101None,102['bleu', 'exact_match', 'latest', 'tag_accuracy'],103('Which metric should be used when choosing the best checkpoint. If'104' latest,then all checkpoints are saved.'),105)
106