google-research

Форк
0
/
redace_flags.py 
105 строк · 4.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Defines common flags for training RED-ACE models."""
17

18
from absl import flags
19

20
flags.DEFINE_string('train_file', None,
21
                    'Path to the tfrecord file for training.')
22
flags.DEFINE_string('eval_file', None,
23
                    'Path to the tfrecord file for evaluation.')
24
flags.DEFINE_string(
25
    'init_checkpoint',
26
    None,
27
    ('Path to a pre-trained BERT checkpoint or a to previously trained model'
28
     ' checkpoint that the current training job will further fine-tune.'),
29
)
30
flags.DEFINE_string(
31
    'model_dir',
32
    None,
33
    'Directory where the model weights and summaries are stored.',
34
)
35
flags.DEFINE_integer(
36
    'max_seq_length',
37
    128,
38
    ('The maximum total input sequence length after tokenization. '
39
     'Sequences longer than this will be truncated, and sequences shorter '
40
     'than this will be padded.'),
41
)
42
flags.DEFINE_integer('num_train_examples', 32,
43
                     'Total size of training dataset.')
44
flags.DEFINE_integer('num_eval_examples', 32,
45
                     'Total size of evaluation dataset.')
46
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
47
flags.DEFINE_integer('eval_batch_size', 32, 'Total batch size for evaluation.')
48
flags.DEFINE_integer('num_train_epochs', 100,
49
                     'Total number of training epochs to perform.')
50
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
51
flags.DEFINE_float('warmup_steps', 10000,
52
                   'Warmup steps for Adam weight decay optimizer.')
53
flags.DEFINE_integer('log_steps', 1000,
54
                     'Interval of steps between logging of batch level stats.')
55
flags.DEFINE_integer('steps_per_loop', 1000, 'Steps per loop.')
56
flags.DEFINE_integer('keep_checkpoint_max', 3,
57
                     'How many checkpoints to keep around during training.')
58
flags.DEFINE_integer(
59
    'mini_epochs_per_epoch', 1,
60
    'Only has an effect for values >= 2. This flag enables more frequent '
61
    'checkpointing + evaluation on the validation set than done by default. '
62
    'This is achieved by reporting to TF an epoch size that is '
63
    '"flag value"-times smaller than the true value.')
64

65
flags.DEFINE_string('output_file', None, 'Path to the output file.')
66
flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
67
flags.DEFINE_integer(
68
    'predict_batch_size', 32,
69
    'Batch size for the prediction of insertion and tagging models.')
70
flags.DEFINE_bool(
71
    'split_on_punc',
72
    True,
73
    'Whether to split on punctuation characters during tokenization.',
74
)
75
flags.DEFINE_string('redace_config', None, 'Path to the RED-ACE config file.')
76
flags.DEFINE_string(
77
    'special_glue_string_for_joining_sources',
78
    ' ',
79
    ('String that is used to join multiple source strings of a given example'
80
     ' into one string. Optional.'),
81
)
82

83
# Prediction flags.
84
flags.DEFINE_string(
85
    'predict_input_file', None,
86
    'Path to the input file containing examples for which to'
87
    'compute predictions.')
88
flags.DEFINE_string('predict_output_file', None,
89
                    'Path to the output file for predictions.')
90

91
# Training flags.
92
flags.DEFINE_bool(
93
    'use_weighted_labels', True,
94
    'Whether different labels were given different weights. Primarly used to '
95
    'increase the importance of rare tags.')
96

97
flags.DEFINE_string('test_file', None, 'Path to the test file.')
98

99
flags.DEFINE_enum(
100
    'validation_checkpoint_metric',
101
    None,
102
    ['bleu', 'exact_match', 'latest', 'tag_accuracy'],
103
    ('Which metric should be used when choosing the best checkpoint. If'
104
     ' latest,then all checkpoints are saved.'),
105
)
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.