google-research

Форк
0
202 строки · 7.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main program to run student-mentor training."""
17

18
from absl import app
19
from absl import flags
20
import tensorflow as tf
21

22
from student_mentor_dataset_cleaning.training.loss.triplet_loss import TripletLoss
23
from student_mentor_dataset_cleaning.training.trainers import trainer
24
from student_mentor_dataset_cleaning.training.trainers import trainer_triplet
25

26
FLAGS = flags.FLAGS
27

28
flags.DEFINE_integer('mini_batch_size', 32, 'Mini-batch size')
29
flags.DEFINE_integer('max_iteration_count', 20, 'Maximum iteration count')
30
flags.DEFINE_integer('student_epoch_count', 30,
31
                     'Maximum number of training epochs for the student.')
32
flags.DEFINE_integer('mentor_epoch_count', 30,
33
                     'Maximum number of training epochs for the mentor.')
34
flags.DEFINE_string(
35
    'mode', 'softmax',
36
    'Training mode to use. Options are "softmax" (default) and "triplet". '
37
    'Softmax mode will always use the MNIST dataset and ignore "csv_path".'
38
)
39
flags.DEFINE_string('save_dir', '', 'Path to model save dir.')
40
flags.DEFINE_string('tensorboard_log_dir', '', 'Path to tensorboard log dir.')
41
flags.DEFINE_string('train_dataset_dir', '',
42
                    'Path to the training dataset dir.')
43
flags.DEFINE_string('csv_path', '', 'Path to the training dataset dataframe.')
44
flags.DEFINE_string('student_initial_model', '',
45
                    'Path to the student model initialization.')
46
flags.DEFINE_integer(
47
    'delg_embedding_layer_dim', 2048,
48
    'Size of the FC whitening layer (embedding layer). Used only if'
49
    'delg_global_features:True.')
50

51

52
def verify_arguments():
53
  """Verifies the validity of the command-line arguments."""
54

55
  assert FLAGS.mini_batch_size > 0, '`mini_batch_size` must be positive.'
56
  assert FLAGS.max_iteration_count > 0, ('`max_iteration_count` must be '
57
                                         'positive.')
58
  assert FLAGS.student_epoch_count > 0, ('`student_epoch_count` must be '
59
                                         'positive.')
60
  assert FLAGS.mentor_epoch_count > 0, '`mentor_epoch_count` must be positive.'
61
  assert FLAGS.mode in ['softmax', 'triplet'], ('`mode` must be either '
62
                                                '`softmax` or `triplet`')
63

64

65
def run_softmax():
66
  """Runs the program in softmax mode using the MNIST dataset."""
67

68
  tf.compat.v1.enable_eager_execution()
69

70
  student = tf.keras.models.Sequential([
71
      tf.keras.layers.Flatten(input_shape=(28, 28), name='student_flatten'),
72
      tf.keras.layers.Dense(128, activation='relu', name='student_hidden0'),
73
      tf.keras.layers.Dense(10, name='student_output')
74
  ])
75
  student_optimizer = tf.keras.optimizers.Adam(
76
      learning_rate=0.001,
77
      beta_1=0.9,
78
      beta_2=0.999,
79
      epsilon=1e-07,
80
      amsgrad=False,
81
      name='StudentAdam')
82
  student.compile(
83
      optimizer=student_optimizer,
84
      loss=tf.keras.losses.SparseCategoricalCrossentropy(
85
          from_logits=True, name='student_categorical_crossentropy_loss'),
86
      metrics=[
87
          tf.keras.metrics.SparseTopKCategoricalAccuracy(
88
              k=1, name='student_top_1_categorical_accuracy'),
89
          tf.keras.metrics.SparseTopKCategoricalAccuracy(
90
              k=2, name='student_top_2_categorical_accuracy'),
91
          tf.keras.metrics.SparseTopKCategoricalAccuracy(
92
              k=3, name='student_top_3_categorical_accuracy'),
93
          tf.keras.metrics.SparseTopKCategoricalAccuracy(
94
              k=4, name='student_top_4_categorical_accuracy'),
95
          tf.keras.metrics.SparseCategoricalCrossentropy(
96
              name='student_categorical_crossentropy', from_logits=True)
97
      ])
98

99
  mentor = tf.keras.models.Sequential([
100
      tf.keras.layers.Flatten(input_shape=(101770,), name='mentor_flatten'),
101
      tf.keras.layers.Dense(50, activation='relu', name='mentor_hidden0'),
102
      tf.keras.layers.Dense(1, activation='sigmoid', name='mentor_output')
103
  ])
104
  mentor_optimizer = tf.keras.optimizers.Adam(
105
      learning_rate=0.001,
106
      beta_1=0.9,
107
      beta_2=0.999,
108
      epsilon=1e-07,
109
      amsgrad=False,
110
      name='StudentAdam')
111
  mentor.compile(
112
      optimizer=mentor_optimizer,
113
      loss=tf.keras.losses.BinaryCrossentropy(
114
          name='mentor_binary_crossentropy_loss'),
115
      metrics=[
116
          tf.keras.metrics.BinaryAccuracy(
117
              name='mentor_binary_accuracy', threshold=0.5),
118
          tf.keras.metrics.BinaryCrossentropy(
119
              name='mentor_binary_crossentropy'),
120
          tf.keras.metrics.FalseNegatives(name='mentor_false_negatives'),
121
          tf.keras.metrics.FalsePositives(name='mentor_false_positives'),
122
          tf.keras.metrics.TrueNegatives(name='mentor_true_negatives'),
123
          tf.keras.metrics.TruePositives(name='mentor_true_positives')
124
      ])
125

126
  student, mentor = trainer.train(student, mentor, FLAGS.mini_batch_size,
127
                                  FLAGS.max_iteration_count,
128
                                  FLAGS.student_epoch_count,
129
                                  FLAGS.mentor_epoch_count, FLAGS.save_dir,
130
                                  FLAGS.tensorboard_log_dir)
131

132

133
def run_triplet():
134
  """Runs the program in triplet mode on the provided CSV dataset."""
135

136
  tf.compat.v1.enable_eager_execution()
137

138
  student = tf.keras.applications.ResNet152V2(
139
      include_top=False,
140
      weights='imagenet',
141
      input_shape=[321, 321, 3],
142
      pooling='avg')
143

144
  student_optimizer = tf.keras.optimizers.Adam(
145
      learning_rate=0.001,
146
      beta_1=0.9,
147
      beta_2=0.999,
148
      epsilon=1e-07,
149
      amsgrad=False,
150
      name='StudentAdam')
151
  student.compile(
152
      optimizer=student_optimizer,
153
      loss=TripletLoss(
154
          embedding_size=FLAGS.delg_embedding_layer_dim, train_ratio=1.0))
155

156
  mentor = tf.keras.models.Sequential([
157
      tf.keras.layers.Flatten(input_shape=(104000,), name='mentor_flatten'),
158
      tf.keras.layers.Dense(50, activation='relu', name='mentor_hidden0'),
159
      tf.keras.layers.Dense(1, activation='sigmoid', name='mentor_output')
160
  ])
161
  mentor_optimizer = tf.keras.optimizers.Adam(
162
      learning_rate=0.001,
163
      beta_1=0.9,
164
      beta_2=0.999,
165
      epsilon=1e-07,
166
      amsgrad=False,
167
      name='StudentAdam')
168
  mentor.compile(
169
      optimizer=mentor_optimizer,
170
      loss=tf.keras.losses.BinaryCrossentropy(
171
          name='mentor_binary_crossentropy_loss'),
172
      metrics=[
173
          tf.keras.metrics.BinaryAccuracy(
174
              name='mentor_binary_accuracy', threshold=0.5),
175
          tf.keras.metrics.BinaryCrossentropy(
176
              name='mentor_binary_crossentropy'),
177
          tf.keras.metrics.FalseNegatives(name='mentor_false_negatives'),
178
          tf.keras.metrics.FalsePositives(name='mentor_false_positives'),
179
          tf.keras.metrics.TrueNegatives(name='mentor_true_negatives'),
180
          tf.keras.metrics.TruePositives(name='mentor_true_positives')
181
      ])
182

183
  student, mentor = trainer_triplet.train(
184
      student, mentor, FLAGS.mini_batch_size, FLAGS.max_iteration_count,
185
      FLAGS.student_epoch_count, FLAGS.mentor_epoch_count,
186
      FLAGS.train_dataset_dir, FLAGS.csv_path, FLAGS.save_dir,
187
      FLAGS.tensorboard_log_dir)
188

189

190
def main(argv):
191
  if len(argv) > 1:
192
    raise app.UsageError('Too many command-line arguments.')
193
  verify_arguments()
194

195
  if FLAGS.mode == 'softmax':
196
    run_softmax()
197
  elif FLAGS.mode == 'triplet':
198
    run_triplet()
199

200

201
if __name__ == '__main__':
202
  app.run(main)
203

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.