google-research

Форк
0
170 строк · 6.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Common flags for trainer and decoder."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
from absl import flags
22

23
FLAGS = flags.FLAGS
24

25
# Flags for hparam sweeps
26
flags.DEFINE_float("l0_norm_weight", None, "weight for l0 norm.")
27
flags.DEFINE_integer("l0_weight_start", None, "weight start for l0 norm.")
28
flags.DEFINE_integer("l0_weight_diff", None, "weight diff for l0 norm.")
29
flags.DEFINE_float("dkl_weight", None, "weight for dkl norm.")
30
flags.DEFINE_integer("dkl_weight_start", None, "weight start for dkl norm.")
31
flags.DEFINE_integer("dkl_weight_diff", None, "weight diff for dkl norm.")
32
flags.DEFINE_string("dkl_weight_fn", None, "dkl weight curve.")
33
flags.DEFINE_float("target_sparsity", None, "sparsity for mp.")
34
flags.DEFINE_integer("begin_pruning_step", None, "start step for mp.")
35
flags.DEFINE_integer("end_pruning_step", None, "end step for mp.")
36
flags.DEFINE_integer("pruning_frequency", None, "frequency of mp steps.")
37
flags.DEFINE_string("regularization", None, "what regularization to use.")
38
flags.DEFINE_float("clip_log_alpha", None, "clip limit for log alphas.")
39
flags.DEFINE_integer("nbins", None, "number of bins for mp histogram.")
40

41
# For scratch-e, scratch-b, and lottery ticket experiments
42
flags.DEFINE_string(
43
    "load_masks_from",
44
    None,
45
    "Checkpoint to load trained mask from.")
46
flags.DEFINE_string(
47
    "load_weights_from",
48
    None,
49
    "Checkpoint to load trained non-mask from.")
50
flags.DEFINE_float(
51
    "initial_sparsity",
52
    None,
53
    "Initial sparsity for scratch-* experiments.")
54

55
# For constant parameter curves
56
flags.DEFINE_integer(
57
    "hidden_size",
58
    None,
59
    "Hidden size of the Transformer.")
60
flags.DEFINE_integer(
61
    "filter_size",
62
    None,
63
    "Filter size of the Transformer.")
64
flags.DEFINE_integer(
65
    "num_heads",
66
    None,
67
    "Number of heads in the Transformer.")
68

69
# For imbalanced pruning experiments
70
flags.DEFINE_float(
71
    "embedding_sparsity",
72
    None,
73
    "Sparsity fraction for embedding matrix",
74
)
75

76

77
def update_argv(argv):
78
  """Update the arguments."""
79
  if FLAGS.l0_norm_weight is not None:
80
    argv.append("--hp_l0_norm_weight")
81
    argv.append("{}".format(FLAGS.l0_norm_weight))
82
  if FLAGS.l0_weight_start is not None:
83
    argv.append("--hp_l0_weight_start")
84
    argv.append("{}".format(FLAGS.l0_weight_start))
85
  if FLAGS.l0_weight_diff is not None:
86
    argv.append("--hp_l0_weight_diff")
87
    argv.append("{}".format(FLAGS.l0_weight_diff))
88
  if FLAGS.dkl_weight is not None:
89
    argv.append("--hp_dkl_weight")
90
    argv.append("{}".format(FLAGS.dkl_weight))
91
  if FLAGS.dkl_weight_start is not None:
92
    argv.append("--hp_dkl_weight_start")
93
    argv.append("{}".format(FLAGS.dkl_weight_start))
94
  if FLAGS.dkl_weight_diff is not None:
95
    argv.append("--hp_dkl_weight_diff")
96
    argv.append("{}".format(FLAGS.dkl_weight_diff))
97
  if FLAGS.dkl_weight_fn is not None:
98
    argv.append("--hp_dkl_weight_fn")
99
    argv.append("{}".format(FLAGS.dkl_weight_fn))
100
  if FLAGS.target_sparsity is not None:
101
    argv.append("--hp_target_sparsity")
102
    argv.append("{}".format(FLAGS.target_sparsity))
103
  if FLAGS.begin_pruning_step is not None:
104
    argv.append("--hp_begin_pruning_step")
105
    argv.append("{}".format(FLAGS.begin_pruning_step))
106
  if FLAGS.end_pruning_step is not None:
107
    argv.append("--hp_end_pruning_step")
108
    argv.append("{}".format(FLAGS.end_pruning_step))
109
  if FLAGS.pruning_frequency is not None:
110
    argv.append("--hp_pruning_frequency")
111
    argv.append("{}".format(FLAGS.pruning_frequency))
112
  if FLAGS.regularization is not None:
113
    if FLAGS.regularization == "none":
114
      argv.append("--hp_layer_prepostprocess_dropout")
115
      argv.append("0.0")
116
      argv.append("--hp_attention_dropout")
117
      argv.append("0.0")
118
      argv.append("--hp_relu_dropout")
119
      argv.append("0.0")
120
      argv.append("--hp_label_smoothing")
121
      argv.append("0.0")
122
    elif FLAGS.regularization == "label_smoothing":
123
      argv.append("--hp_layer_prepostprocess_dropout")
124
      argv.append("0.0")
125
      argv.append("--hp_attention_dropout")
126
      argv.append("0.0")
127
      argv.append("--hp_relu_dropout")
128
      argv.append("0.0")
129
    elif FLAGS.regularization == "dropout+label_smoothing":
130
      # Don't need to do anything
131
      pass
132
    elif FLAGS.regularization == "moredropout+label_smoothing":
133
      # crank up the prepostprocess dropout (like transformer_big)
134
      argv.append("--hp_layer_prepostprocess_dropout")
135
      argv.append("0.3")
136
    elif FLAGS.regularization == "muchmoredropout+label_smoothing":
137
      # crank up the prepostprocess dropout a lot
138
      argv.append("--hp_layer_prepostprocess_dropout")
139
      argv.append("0.5")
140
    else:
141
      raise ValueError("Invalid value of regularization flags: {}"
142
                       .format(FLAGS.regularization))
143
  if FLAGS.clip_log_alpha is not None:
144
    argv.append("--hp_clip_log_alpha")
145
    argv.append("{}".format(FLAGS.clip_log_alpha))
146
  if FLAGS.nbins is not None:
147
    argv.append("--hp_nbins")
148
    argv.append("{}".format(FLAGS.nbins))
149
  if FLAGS.load_masks_from is not None:
150
    argv.append("--hp_load_masks_from")
151
    argv.append("{}".format(FLAGS.load_masks_from))
152
  if FLAGS.load_weights_from is not None:
153
    argv.append("--hp_load_weights_from")
154
    argv.append("{}".format(FLAGS.load_weights_from))
155
  if FLAGS.initial_sparsity is not None:
156
    argv.append("--hp_initial_sparsity")
157
    argv.append("{}".format(FLAGS.initial_sparsity))
158
  if FLAGS.hidden_size is not None:
159
    argv.append("--hp_hidden_size")
160
    argv.append("{}".format(FLAGS.hidden_size))
161
  if FLAGS.filter_size is not None:
162
    argv.append("--hp_filter_size")
163
    argv.append("{}".format(FLAGS.filter_size))
164
  if FLAGS.num_heads is not None:
165
    argv.append("--hp_num_heads")
166
    argv.append("{}".format(FLAGS.num_heads))
167
  if FLAGS.embedding_sparsity is not None:
168
    argv.append("--hp_embedding_sparsity")
169
    argv.append("{}".format(FLAGS.embedding_sparsity))
170
  return argv
171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.