google-research
170 строк · 6.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Common flags for trainer and decoder."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl import flags
22
23FLAGS = flags.FLAGS
24
25# Flags for hparam sweeps
26flags.DEFINE_float("l0_norm_weight", None, "weight for l0 norm.")
27flags.DEFINE_integer("l0_weight_start", None, "weight start for l0 norm.")
28flags.DEFINE_integer("l0_weight_diff", None, "weight diff for l0 norm.")
29flags.DEFINE_float("dkl_weight", None, "weight for dkl norm.")
30flags.DEFINE_integer("dkl_weight_start", None, "weight start for dkl norm.")
31flags.DEFINE_integer("dkl_weight_diff", None, "weight diff for dkl norm.")
32flags.DEFINE_string("dkl_weight_fn", None, "dkl weight curve.")
33flags.DEFINE_float("target_sparsity", None, "sparsity for mp.")
34flags.DEFINE_integer("begin_pruning_step", None, "start step for mp.")
35flags.DEFINE_integer("end_pruning_step", None, "end step for mp.")
36flags.DEFINE_integer("pruning_frequency", None, "frequency of mp steps.")
37flags.DEFINE_string("regularization", None, "what regularization to use.")
38flags.DEFINE_float("clip_log_alpha", None, "clip limit for log alphas.")
39flags.DEFINE_integer("nbins", None, "number of bins for mp histogram.")
40
41# For scratch-e, scratch-b, and lottery ticket experiments
42flags.DEFINE_string(
43"load_masks_from",
44None,
45"Checkpoint to load trained mask from.")
46flags.DEFINE_string(
47"load_weights_from",
48None,
49"Checkpoint to load trained non-mask from.")
50flags.DEFINE_float(
51"initial_sparsity",
52None,
53"Initial sparsity for scratch-* experiments.")
54
55# For constant parameter curves
56flags.DEFINE_integer(
57"hidden_size",
58None,
59"Hidden size of the Transformer.")
60flags.DEFINE_integer(
61"filter_size",
62None,
63"Filter size of the Transformer.")
64flags.DEFINE_integer(
65"num_heads",
66None,
67"Number of heads in the Transformer.")
68
69# For imbalanced pruning experiments
70flags.DEFINE_float(
71"embedding_sparsity",
72None,
73"Sparsity fraction for embedding matrix",
74)
75
76
77def update_argv(argv):
78"""Update the arguments."""
79if FLAGS.l0_norm_weight is not None:
80argv.append("--hp_l0_norm_weight")
81argv.append("{}".format(FLAGS.l0_norm_weight))
82if FLAGS.l0_weight_start is not None:
83argv.append("--hp_l0_weight_start")
84argv.append("{}".format(FLAGS.l0_weight_start))
85if FLAGS.l0_weight_diff is not None:
86argv.append("--hp_l0_weight_diff")
87argv.append("{}".format(FLAGS.l0_weight_diff))
88if FLAGS.dkl_weight is not None:
89argv.append("--hp_dkl_weight")
90argv.append("{}".format(FLAGS.dkl_weight))
91if FLAGS.dkl_weight_start is not None:
92argv.append("--hp_dkl_weight_start")
93argv.append("{}".format(FLAGS.dkl_weight_start))
94if FLAGS.dkl_weight_diff is not None:
95argv.append("--hp_dkl_weight_diff")
96argv.append("{}".format(FLAGS.dkl_weight_diff))
97if FLAGS.dkl_weight_fn is not None:
98argv.append("--hp_dkl_weight_fn")
99argv.append("{}".format(FLAGS.dkl_weight_fn))
100if FLAGS.target_sparsity is not None:
101argv.append("--hp_target_sparsity")
102argv.append("{}".format(FLAGS.target_sparsity))
103if FLAGS.begin_pruning_step is not None:
104argv.append("--hp_begin_pruning_step")
105argv.append("{}".format(FLAGS.begin_pruning_step))
106if FLAGS.end_pruning_step is not None:
107argv.append("--hp_end_pruning_step")
108argv.append("{}".format(FLAGS.end_pruning_step))
109if FLAGS.pruning_frequency is not None:
110argv.append("--hp_pruning_frequency")
111argv.append("{}".format(FLAGS.pruning_frequency))
112if FLAGS.regularization is not None:
113if FLAGS.regularization == "none":
114argv.append("--hp_layer_prepostprocess_dropout")
115argv.append("0.0")
116argv.append("--hp_attention_dropout")
117argv.append("0.0")
118argv.append("--hp_relu_dropout")
119argv.append("0.0")
120argv.append("--hp_label_smoothing")
121argv.append("0.0")
122elif FLAGS.regularization == "label_smoothing":
123argv.append("--hp_layer_prepostprocess_dropout")
124argv.append("0.0")
125argv.append("--hp_attention_dropout")
126argv.append("0.0")
127argv.append("--hp_relu_dropout")
128argv.append("0.0")
129elif FLAGS.regularization == "dropout+label_smoothing":
130# Don't need to do anything
131pass
132elif FLAGS.regularization == "moredropout+label_smoothing":
133# crank up the prepostprocess dropout (like transformer_big)
134argv.append("--hp_layer_prepostprocess_dropout")
135argv.append("0.3")
136elif FLAGS.regularization == "muchmoredropout+label_smoothing":
137# crank up the prepostprocess dropout a lot
138argv.append("--hp_layer_prepostprocess_dropout")
139argv.append("0.5")
140else:
141raise ValueError("Invalid value of regularization flags: {}"
142.format(FLAGS.regularization))
143if FLAGS.clip_log_alpha is not None:
144argv.append("--hp_clip_log_alpha")
145argv.append("{}".format(FLAGS.clip_log_alpha))
146if FLAGS.nbins is not None:
147argv.append("--hp_nbins")
148argv.append("{}".format(FLAGS.nbins))
149if FLAGS.load_masks_from is not None:
150argv.append("--hp_load_masks_from")
151argv.append("{}".format(FLAGS.load_masks_from))
152if FLAGS.load_weights_from is not None:
153argv.append("--hp_load_weights_from")
154argv.append("{}".format(FLAGS.load_weights_from))
155if FLAGS.initial_sparsity is not None:
156argv.append("--hp_initial_sparsity")
157argv.append("{}".format(FLAGS.initial_sparsity))
158if FLAGS.hidden_size is not None:
159argv.append("--hp_hidden_size")
160argv.append("{}".format(FLAGS.hidden_size))
161if FLAGS.filter_size is not None:
162argv.append("--hp_filter_size")
163argv.append("{}".format(FLAGS.filter_size))
164if FLAGS.num_heads is not None:
165argv.append("--hp_num_heads")
166argv.append("{}".format(FLAGS.num_heads))
167if FLAGS.embedding_sparsity is not None:
168argv.append("--hp_embedding_sparsity")
169argv.append("{}".format(FLAGS.embedding_sparsity))
170return argv
171