google-research

Форк
0
440 строк · 15.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Copyright 2020 Google Inc. All Rights Reserved.
17
#
18
# Licensed under the Apache License, Version 2.0 (the "License");
19
# you may not use this file except in compliance with the License.
20
# You may obtain a copy of the License at
21
#
22
# http://www.apache.org/licenses/LICENSE-2.0
23
#
24
# Unless required by applicable law or agreed to in writing, software
25
# distributed under the License is distributed on an "AS IS" BASIS,
26
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27
# See the License for the specific language governing permissions and
28
# limitations under the License.
29
# ==============================================================================
30

31
"""Utility functions for training the MentorNet models."""
32

33
import collections
34

35
import numpy as np
36
import tensorflow as tf
37
import tensorflow.contrib.slim as slim
38
import tensorflow_probability as tfp
39

40

41
MentorNetNetHParams = collections.namedtuple(
42
    'MentorNetNetHParams', 'label_embedding_size, epoch_embedding_size, '
43
    'num_label_embedding, num_fc_nodes')
44

45

46
def summarize_data_utilization(v, tf_global_step, batch_size, epsilon=0.001):
47
  """Summarizes the samples of non-zero weights during training.
48

49
  Args:
50
    v: a tensor [batch_size, 1] represents the sample weights.
51
      0: loss, 1: loss difference to the moving average, 2: label and 3: epoch,
52
      where epoch is an integer between 0 and 99 (the first and the last epoch).
53
    tf_global_step: the tensor of the current global step.
54
    batch_size: an integer batch_size
55
    epsilon: the rounding error. If the weight is smaller than epsilon then set
56
      it to zero.
57
  Returns:
58
    data_util: a tensor of data utilization.
59
  """
60
  nonzero_v = tf.get_variable('data_util/nonzero_v', [],
61
                              initializer=tf.zeros_initializer(),
62
                              trainable=False,
63
                              dtype=tf.float32)
64

65
  rounded_v = tf.maximum(v - epsilon, tf.to_float(0))
66

67
  # Log data utilization
68
  nonzero_v = tf.assign_add(nonzero_v, tf.count_nonzero(
69
      rounded_v, dtype=tf.float32))
70

71
  data_util = (nonzero_v) / tf.to_float(batch_size) / (
72
      tf.to_float(tf_global_step) + 2)
73
  data_util = tf.minimum(data_util, 1)
74
  tf.stop_gradient(data_util)
75

76
  slim.summaries.add_scalar_summary(data_util, 'data_util/data_util')
77
  slim.summaries.add_scalar_summary(tf.reduce_sum(v), 'data_util/batch_sum_v')
78
  return data_util
79

80

81
def parse_dropout_rate_list(str_list):
82
  """Parse a comma-separated string to a list.
83

84
  The format follows [dropout rate, epoch_num]+ and the result is a list of 100
85
  dropout rate.
86

87
  Args:
88
    str_list: the input string.
89
  Returns:
90
    result: the converted list
91
  """
92
  str_list = np.array(str_list)
93
  values = str_list[np.arange(0, len(str_list), 2)]
94
  indexes = str_list[np.arange(1, len(str_list), 2)]
95

96
  values = [float(t) for t in values]
97
  indexes = [int(t) for t in indexes]
98

99
  assert len(values) == len(indexes) and np.sum(indexes) == 100
100
  for t in values:
101
    assert t >= 0.0 and t <= 1.0
102

103
  result = []
104
  for t in range(len(str_list) // 2):
105
    result.extend([values[t]] * indexes[t])
106
  return result
107

108

109
def mentornet_nn(input_features,
110
                 label_embedding_size=2,
111
                 epoch_embedding_size=5,
112
                 num_label_embedding=2,
113
                 num_fc_nodes=20):
114
  """The neural network form of the MentorNet.
115

116
  An implementation of the mentornet. The details are in:
117
  Jiang, Lu, et al. "MentorNet: Learning Data-Driven Curriculum for Very Deep
118
  Neural Networks on Corrupted Labels." ICML. 2018.
119

120
  Args:
121
    input_features: a [batch_size, 4] tensor. Each dimension corresponds to
122
      0: loss, 1: loss difference to the moving average, 2: label and 3: epoch,
123
      where epoch is an integer between 0 and 99 (the first and the last epoch).
124
    label_embedding_size: the embedding size for the label feature.
125
    epoch_embedding_size: the embedding size for the epoch feature.
126
    num_label_embedding: the number of different labels.
127
    num_fc_nodes: number of hidden nodes in the fc layer.
128
  Returns:
129
    v: [batch_size, 1] weight vector.
130
  """
131
  batch_size = int(input_features.get_shape()[0])
132

133
  losses = tf.reshape(input_features[:, 0], [-1, 1])
134
  loss_diffs = tf.reshape(input_features[:, 1], [-1, 1])
135
  labels = tf.to_int32(tf.reshape(input_features[:, 2], [-1, 1]))
136
  epochs = tf.to_int32(tf.reshape(input_features[:, 3], [-1, 1]))
137
  epochs = tf.minimum(epochs, tf.ones([batch_size, 1], dtype=tf.int32) * 99)
138

139
  if len(losses.get_shape()) <= 1:
140
    num_steps = 1
141
  else:
142
    num_steps = int(losses.get_shape()[1])
143

144
  with tf.variable_scope('mentornet', reuse=tf.AUTO_REUSE):
145
    label_embedding = tf.get_variable(
146
        'label_embedding', [num_label_embedding, label_embedding_size])
147
    epoch_embedding = tf.get_variable(
148
        'epoch_embedding', [100, epoch_embedding_size])
149

150
    lstm_inputs = tf.stack([losses, loss_diffs], axis=1)
151
    lstm_inputs = tf.squeeze(lstm_inputs)
152
    lstm_inputs = [lstm_inputs]
153

154
    forward_cell = tf.contrib.rnn.BasicLSTMCell(1, forget_bias=0.0)
155
    backward_cell = tf.contrib.rnn.BasicLSTMCell(1, forget_bias=0.0)
156

157
    _, out_state_fw, out_state_bw = tf.contrib.rnn.static_bidirectional_rnn(
158
        forward_cell,
159
        backward_cell,
160
        inputs=lstm_inputs,
161
        dtype=tf.float32,
162
        sequence_length=np.ones(batch_size) * num_steps)
163

164
    label_inputs = tf.squeeze(tf.nn.embedding_lookup(label_embedding, labels))
165
    epoch_inputs = tf.squeeze(tf.nn.embedding_lookup(epoch_embedding, epochs))
166

167
    h = tf.concat([out_state_fw[0], out_state_bw[0]], 1)
168
    feat = tf.concat([label_inputs, epoch_inputs, h], 1)
169
    feat_dim = int(feat.get_shape()[1])
170

171
    fc_1 = tf.add(
172
        tf.matmul(
173
            feat,
174
            tf.get_variable(
175
                'Variable',
176
                initializer=tf.random_normal([feat_dim, num_fc_nodes]))),
177
        tf.get_variable(
178
            'Variable_1', initializer=tf.random_normal([num_fc_nodes])))
179
    fc_1 = tf.nn.tanh(fc_1)
180
    # Output layer with linear activation
181
    out_layer = tf.matmul(
182
        fc_1,
183
        tf.get_variable(
184
            'Variable_2', initializer=tf.random_normal([num_fc_nodes, 1])) +
185
        tf.get_variable('Variable_3', initializer=tf.random_normal([1])))
186
    return out_layer
187

188

189
def loss_thresholding_function(loss, para_lambda=0.75):
190
  """The simplest MentorNet is a loss thresholding function.
191

192
  Args:
193
    loss: [batch_size, 1] the loss vector.
194
    para_lambda: the age parameter, in [0,1], indicates the percentile in a
195
      mini-batch, where 1 indicates selecting all samples and 0 is selecting
196
      zero samples.
197
  Returns:
198
    v: [batch_size, 1] weight vector.
199
  """
200
  assert para_lambda >= 0 and para_lambda <= 1
201
  with tf.variable_scope('mentornet/thresholding'):
202

203
    one_weights = tf.ones(tf.shape(loss), tf.float32)
204
    zero_weights = tf.zeros(tf.shape(loss), tf.float32)
205

206
    percentile_loss = tfp.stats.percentile(
207
        loss, int(para_lambda * 100))
208
    # Replace this with loss moving average to get better results.
209
    percentile_loss = tf.reshape(percentile_loss, [1])
210

211
    weights = tf.where(loss >= percentile_loss, zero_weights, one_weights)
212
    v = tf.reshape(weights, [-1, 1], name='v')
213
  return v
214

215

216
def mentornet(epoch,
217
              loss,
218
              labels,
219
              loss_p_percentile,
220
              example_dropout_rates,
221
              burn_in_epoch=18,
222
              fixed_epoch_after_burn_in=False,
223
              loss_moving_average_decay=0.5,
224
              mentornet_net_hparams=None,
225
              avg_name='cumulative',
226
              debug=False):
227
  """The MentorNet to train with the StudentNet.
228

229
     The details are in:
230
    Jiang, Lu, et al. "MentorNet: Learning Data-Driven Curriculum for Very Deep
231
    Neural Networks on Corrupted Labels." ICML. 2018.
232

233
  Args:
234
    epoch: a tensor [batch_size, 1] representing the training percentage. Each
235
      epoch is an integer between 0 and 99.
236
    loss: a tensor [batch_size, 1] representing the sample loss.
237
    labels: a tensor [batch_size, 1] representing the label. Every label is set
238
      to 0 in the current version.
239
    loss_p_percentile: a 1-d tensor of size 100, where each element is the
240
      p-percentile at that epoch to compute the moving average.
241
    example_dropout_rates: a 1-d tensor of size 100, where each element is the
242
      dropout rate at that epoch. Dropping out means the probability of setting
243
      sample weights to zeros proposed in Liang, Junwei, et al. "Learning to
244
      Detect Concepts from Webly-Labeled Video Data." IJCAI. 2016.
245
    burn_in_epoch: the number of burn_in_epoch. In the first burn_in_epoch, all
246
      samples have 1.0 weights.
247
    fixed_epoch_after_burn_in: whether to fix the epoch after the burn-in.
248
    loss_moving_average_decay: the decay factor to compute the moving average.
249
    mentornet_net_hparams: mentornet hyperparameters.
250
    avg_name: name of the loss moving average variable.
251
    debug: whether to print the weight information for debugging purposes.
252

253
  Returns:
254
    v: [batch_size, 1] weight vector.
255
  """
256
  with tf.variable_scope('mentor_inputs'):
257
    loss_moving_avg = tf.get_variable(
258
        avg_name, [], initializer=tf.zeros_initializer(), trainable=False)
259

260
    if not fixed_epoch_after_burn_in:
261
      cur_epoch = epoch
262
    else:
263
      cur_epoch = tf.to_int32(tf.minimum(epoch, burn_in_epoch))
264

265
    v_ones = tf.ones(tf.shape(loss), tf.float32)
266
    v_zeros = tf.zeros(tf.shape(loss), tf.float32)
267
    upper_bound = tf.cond(cur_epoch < (burn_in_epoch - 1), lambda: v_ones,
268
                          lambda: v_zeros)
269

270
    this_dropout_rate = tf.squeeze(
271
        tf.nn.embedding_lookup(example_dropout_rates, cur_epoch))
272
    this_percentile = tf.squeeze(
273
        tf.nn.embedding_lookup(loss_p_percentile, cur_epoch))
274

275
    percentile_loss = tf.contrib.distributions.percentile(
276
        loss, this_percentile * 100)
277
    percentile_loss = tf.convert_to_tensor(percentile_loss)
278

279
    loss_moving_avg = loss_moving_avg.assign(
280
        loss_moving_avg * loss_moving_average_decay +
281
        (1 - loss_moving_average_decay) * percentile_loss)
282

283
    slim.summaries.add_scalar_summary(percentile_loss,
284
                                      '{}/percentile_loss'.format(avg_name))
285
    slim.summaries.add_scalar_summary(cur_epoch,
286
                                      '{}/percentile_loss'.format(avg_name))
287
    slim.summaries.add_scalar_summary(loss_moving_avg,
288
                                      '{}/percentile_loss'.format(avg_name))
289

290
    ones = tf.ones([tf.shape(loss)[0], 1], tf.float32)
291

292
    epoch_vec = tf.scalar_mul(tf.to_float(cur_epoch), ones)
293
    lossdiff = loss - tf.scalar_mul(loss_moving_avg, ones)
294

295
  input_data = tf.squeeze(tf.stack([loss, lossdiff, labels, epoch_vec], 1))
296
  hparams = mentornet_net_hparams
297
  if hparams:
298
    v = tf.nn.sigmoid(
299
        mentornet_nn(
300
            input_data,
301
            label_embedding_size=hparams.label_embedding_size,
302
            epoch_embedding_size=hparams.epoch_embedding_size,
303
            num_label_embedding=hparams.num_label_embedding,
304
            num_fc_nodes=hparams.num_fc_nodes),
305
        name='v')
306
  else:
307
    v = tf.nn.sigmoid(mentornet_nn(input_data), name='v')
308
  # Force select all samples in the first burn_in_epochs
309
  v = tf.maximum(v, upper_bound, 'v_bound')
310

311
  v_dropout = tf.py_func(probabilistic_sample,
312
                         [v, this_dropout_rate, 'random'], tf.float32)
313
  v_dropout = tf.reshape(v_dropout, [-1, 1], name='v_dropout')
314

315
  # Print information in the debug mode.
316
  if debug:
317
    v_dropout = tf.Print(
318
        v_dropout,
319
        data=[cur_epoch, loss_moving_avg, percentile_loss],
320
        summarize=64,
321
        message='epoch, loss_moving_avg, percentile_loss')
322
    v_dropout = tf.Print(
323
        v_dropout, data=[lossdiff], summarize=64, message='loss_diff')
324
    v_dropout = tf.Print(v_dropout, data=[v], summarize=64, message='v')
325
    v_dropout = tf.Print(
326
        v_dropout, data=[v_dropout], summarize=64, message='v_dropout')
327
  return v_dropout
328

329

330
def mentor_mix_up(x, l, v, beta_param):
331
  """MentorMix method.
332

333
  Args:
334
    x: the input image batch [batch_size, H, W, C]
335
    l: the label batch  [batch_size, num_of_class]
336
    v: mentornet weights
337
    beta_param: the parameter to sample the weight average.
338
  Returns:
339
    result: The mixed images and label batches.
340
  """
341
  if beta_param <= 0:
342
    return x, l
343

344
  v_flat = tf.reshape(v, [-1])
345
  dist = tfp.distributions.Categorical(probs=tf.nn.softmax(v_flat))
346
  idx = dist.sample(tf.shape(x)[0])
347

348
  x2 = tf.gather(x, idx)
349
  l2 = tf.gather(l, idx)
350

351
  mix = tf.distributions.Beta(beta_param,
352
                              beta_param).sample([tf.shape(x)[0], 1, 1, 1])
353

354
  mix = tf.maximum(mix, 1 - mix)
355
  mix = tf.where(v_flat >= 0.5, mix, 1 - mix)
356

357
  xmix = x * mix + x2 * (1 - mix)
358
  lmix = l * mix[:, :, 0, 0] + l2 * (1 - mix[:, :, 0, 0])
359
  v = tf.stop_gradient(v)
360
  v_flat = tf.stop_gradient(v_flat)
361
  xmix = tf.stop_gradient(xmix)
362
  lmix = tf.stop_gradient(lmix)
363
  return xmix, lmix
364

365

366
def probabilistic_sample(v, rate=0.5, mode='binary'):
367
  """Implement the sampling techniques.
368

369
  Args:
370
    v: [batch_size, 1] the weight column vector.
371
    rate: in [0,1]. 0 indicates using all samples and 1 indicates
372
      using zero samples.
373
    mode: a string. One of the following 1) actual returns the actual sampling;
374
      2) binary returns binary weights; 3) random performs random sampling.
375
  Returns:
376
    v: [batch_size, 1] weight vector.
377
  """
378
  assert rate >= 0 and rate <= 1
379
  epsilon = 1e-5
380
  with tf.variable_scope('mentornet/prob_sampling'):
381
    p = np.copy(v)
382
    p = np.reshape(p, -1)
383
    if mode == 'random':
384
      ids = np.random.choice(
385
          p.shape[0], int(p.shape[0] * (1 - rate)), replace=False)
386
    else:
387
      # Avoid 1) all zero loss and 2) zero loss are never selected.
388
      p += epsilon
389
      p /= np.sum(p)
390
      ids = np.random.choice(
391
          p.shape[0], int(p.shape[0] * (1 - rate)), p=p, replace=False)
392
    result = np.zeros(v.shape, dtype=np.float32)
393
    if mode == 'binary':
394
      result[ids, 0] = 1
395
    else:
396
      result[ids, 0] = v[ids, 0]
397
    return result
398

399

400
def get_mentornet_network_hyperparameter(checkpoint_path):
401
  """Get MentorNet network configuration from the checkpoint file.
402

403
  Args:
404
    checkpoint_path: the file path to restore MentorNet.
405

406
  Returns:
407
    a named tuple MentorNetNetHParams.
408
  """
409
  if checkpoint_path and tf.gfile.IsDirectory(checkpoint_path):
410
    checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
411

412
  reader = tf.train.load_checkpoint(checkpoint_path)
413
  var_to_shape_map = reader.get_variable_to_shape_map()
414

415
  label_embedding_size = 2
416
  epoch_embedding_size = 5
417
  num_label_embedding = 2
418
  num_fc_nodes = 20
419

420
  key = 'mentornet/epoch_embedding'
421
  if key in var_to_shape_map:
422
    epoch_embedding_size = reader.get_tensor(key).shape[1]
423

424
  key = 'mentornet/label_embedding'
425
  if key in var_to_shape_map:
426
    num_label_embedding = reader.get_tensor(key).shape[0]
427
    label_embedding_size = reader.get_tensor(key).shape[1]
428

429
  # FC layer.
430
  key = 'mentornet/Variable'
431
  if key in var_to_shape_map:
432
    num_fc_nodes = reader.get_tensor(key).shape[1]
433

434
  hparams = MentorNetNetHParams(
435
      label_embedding_size=label_embedding_size,
436
      epoch_embedding_size=epoch_embedding_size,
437
      num_label_embedding=num_label_embedding,
438
      num_fc_nodes=num_fc_nodes)
439

440
  return hparams
441

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.