google-research
202 строки · 7.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main program to run student-mentor training."""
17
18from absl import app
19from absl import flags
20import tensorflow as tf
21
22from student_mentor_dataset_cleaning.training.loss.triplet_loss import TripletLoss
23from student_mentor_dataset_cleaning.training.trainers import trainer
24from student_mentor_dataset_cleaning.training.trainers import trainer_triplet
25
26FLAGS = flags.FLAGS
27
28flags.DEFINE_integer('mini_batch_size', 32, 'Mini-batch size')
29flags.DEFINE_integer('max_iteration_count', 20, 'Maximum iteration count')
30flags.DEFINE_integer('student_epoch_count', 30,
31'Maximum number of training epochs for the student.')
32flags.DEFINE_integer('mentor_epoch_count', 30,
33'Maximum number of training epochs for the mentor.')
34flags.DEFINE_string(
35'mode', 'softmax',
36'Training mode to use. Options are "softmax" (default) and "triplet". '
37'Softmax mode will always use the MNIST dataset and ignore "csv_path".'
38)
39flags.DEFINE_string('save_dir', '', 'Path to model save dir.')
40flags.DEFINE_string('tensorboard_log_dir', '', 'Path to tensorboard log dir.')
41flags.DEFINE_string('train_dataset_dir', '',
42'Path to the training dataset dir.')
43flags.DEFINE_string('csv_path', '', 'Path to the training dataset dataframe.')
44flags.DEFINE_string('student_initial_model', '',
45'Path to the student model initialization.')
46flags.DEFINE_integer(
47'delg_embedding_layer_dim', 2048,
48'Size of the FC whitening layer (embedding layer). Used only if'
49'delg_global_features:True.')
50
51
52def verify_arguments():
53"""Verifies the validity of the command-line arguments."""
54
55assert FLAGS.mini_batch_size > 0, '`mini_batch_size` must be positive.'
56assert FLAGS.max_iteration_count > 0, ('`max_iteration_count` must be '
57'positive.')
58assert FLAGS.student_epoch_count > 0, ('`student_epoch_count` must be '
59'positive.')
60assert FLAGS.mentor_epoch_count > 0, '`mentor_epoch_count` must be positive.'
61assert FLAGS.mode in ['softmax', 'triplet'], ('`mode` must be either '
62'`softmax` or `triplet`')
63
64
65def run_softmax():
66"""Runs the program in softmax mode using the MNIST dataset."""
67
68tf.compat.v1.enable_eager_execution()
69
70student = tf.keras.models.Sequential([
71tf.keras.layers.Flatten(input_shape=(28, 28), name='student_flatten'),
72tf.keras.layers.Dense(128, activation='relu', name='student_hidden0'),
73tf.keras.layers.Dense(10, name='student_output')
74])
75student_optimizer = tf.keras.optimizers.Adam(
76learning_rate=0.001,
77beta_1=0.9,
78beta_2=0.999,
79epsilon=1e-07,
80amsgrad=False,
81name='StudentAdam')
82student.compile(
83optimizer=student_optimizer,
84loss=tf.keras.losses.SparseCategoricalCrossentropy(
85from_logits=True, name='student_categorical_crossentropy_loss'),
86metrics=[
87tf.keras.metrics.SparseTopKCategoricalAccuracy(
88k=1, name='student_top_1_categorical_accuracy'),
89tf.keras.metrics.SparseTopKCategoricalAccuracy(
90k=2, name='student_top_2_categorical_accuracy'),
91tf.keras.metrics.SparseTopKCategoricalAccuracy(
92k=3, name='student_top_3_categorical_accuracy'),
93tf.keras.metrics.SparseTopKCategoricalAccuracy(
94k=4, name='student_top_4_categorical_accuracy'),
95tf.keras.metrics.SparseCategoricalCrossentropy(
96name='student_categorical_crossentropy', from_logits=True)
97])
98
99mentor = tf.keras.models.Sequential([
100tf.keras.layers.Flatten(input_shape=(101770,), name='mentor_flatten'),
101tf.keras.layers.Dense(50, activation='relu', name='mentor_hidden0'),
102tf.keras.layers.Dense(1, activation='sigmoid', name='mentor_output')
103])
104mentor_optimizer = tf.keras.optimizers.Adam(
105learning_rate=0.001,
106beta_1=0.9,
107beta_2=0.999,
108epsilon=1e-07,
109amsgrad=False,
110name='StudentAdam')
111mentor.compile(
112optimizer=mentor_optimizer,
113loss=tf.keras.losses.BinaryCrossentropy(
114name='mentor_binary_crossentropy_loss'),
115metrics=[
116tf.keras.metrics.BinaryAccuracy(
117name='mentor_binary_accuracy', threshold=0.5),
118tf.keras.metrics.BinaryCrossentropy(
119name='mentor_binary_crossentropy'),
120tf.keras.metrics.FalseNegatives(name='mentor_false_negatives'),
121tf.keras.metrics.FalsePositives(name='mentor_false_positives'),
122tf.keras.metrics.TrueNegatives(name='mentor_true_negatives'),
123tf.keras.metrics.TruePositives(name='mentor_true_positives')
124])
125
126student, mentor = trainer.train(student, mentor, FLAGS.mini_batch_size,
127FLAGS.max_iteration_count,
128FLAGS.student_epoch_count,
129FLAGS.mentor_epoch_count, FLAGS.save_dir,
130FLAGS.tensorboard_log_dir)
131
132
133def run_triplet():
134"""Runs the program in triplet mode on the provided CSV dataset."""
135
136tf.compat.v1.enable_eager_execution()
137
138student = tf.keras.applications.ResNet152V2(
139include_top=False,
140weights='imagenet',
141input_shape=[321, 321, 3],
142pooling='avg')
143
144student_optimizer = tf.keras.optimizers.Adam(
145learning_rate=0.001,
146beta_1=0.9,
147beta_2=0.999,
148epsilon=1e-07,
149amsgrad=False,
150name='StudentAdam')
151student.compile(
152optimizer=student_optimizer,
153loss=TripletLoss(
154embedding_size=FLAGS.delg_embedding_layer_dim, train_ratio=1.0))
155
156mentor = tf.keras.models.Sequential([
157tf.keras.layers.Flatten(input_shape=(104000,), name='mentor_flatten'),
158tf.keras.layers.Dense(50, activation='relu', name='mentor_hidden0'),
159tf.keras.layers.Dense(1, activation='sigmoid', name='mentor_output')
160])
161mentor_optimizer = tf.keras.optimizers.Adam(
162learning_rate=0.001,
163beta_1=0.9,
164beta_2=0.999,
165epsilon=1e-07,
166amsgrad=False,
167name='StudentAdam')
168mentor.compile(
169optimizer=mentor_optimizer,
170loss=tf.keras.losses.BinaryCrossentropy(
171name='mentor_binary_crossentropy_loss'),
172metrics=[
173tf.keras.metrics.BinaryAccuracy(
174name='mentor_binary_accuracy', threshold=0.5),
175tf.keras.metrics.BinaryCrossentropy(
176name='mentor_binary_crossentropy'),
177tf.keras.metrics.FalseNegatives(name='mentor_false_negatives'),
178tf.keras.metrics.FalsePositives(name='mentor_false_positives'),
179tf.keras.metrics.TrueNegatives(name='mentor_true_negatives'),
180tf.keras.metrics.TruePositives(name='mentor_true_positives')
181])
182
183student, mentor = trainer_triplet.train(
184student, mentor, FLAGS.mini_batch_size, FLAGS.max_iteration_count,
185FLAGS.student_epoch_count, FLAGS.mentor_epoch_count,
186FLAGS.train_dataset_dir, FLAGS.csv_path, FLAGS.save_dir,
187FLAGS.tensorboard_log_dir)
188
189
190def main(argv):
191if len(argv) > 1:
192raise app.UsageError('Too many command-line arguments.')
193verify_arguments()
194
195if FLAGS.mode == 'softmax':
196run_softmax()
197elif FLAGS.mode == 'triplet':
198run_triplet()
199
200
201if __name__ == '__main__':
202app.run(main)
203