google-research
440 строк · 15.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# Copyright 2020 Google Inc. All Rights Reserved.
17#
18# Licensed under the Apache License, Version 2.0 (the "License");
19# you may not use this file except in compliance with the License.
20# You may obtain a copy of the License at
21#
22# http://www.apache.org/licenses/LICENSE-2.0
23#
24# Unless required by applicable law or agreed to in writing, software
25# distributed under the License is distributed on an "AS IS" BASIS,
26# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27# See the License for the specific language governing permissions and
28# limitations under the License.
29# ==============================================================================
30
31"""Utility functions for training the MentorNet models."""
32
33import collections
34
35import numpy as np
36import tensorflow as tf
37import tensorflow.contrib.slim as slim
38import tensorflow_probability as tfp
39
40
41MentorNetNetHParams = collections.namedtuple(
42'MentorNetNetHParams', 'label_embedding_size, epoch_embedding_size, '
43'num_label_embedding, num_fc_nodes')
44
45
46def summarize_data_utilization(v, tf_global_step, batch_size, epsilon=0.001):
47"""Summarizes the samples of non-zero weights during training.
48
49Args:
50v: a tensor [batch_size, 1] represents the sample weights.
510: loss, 1: loss difference to the moving average, 2: label and 3: epoch,
52where epoch is an integer between 0 and 99 (the first and the last epoch).
53tf_global_step: the tensor of the current global step.
54batch_size: an integer batch_size
55epsilon: the rounding error. If the weight is smaller than epsilon then set
56it to zero.
57Returns:
58data_util: a tensor of data utilization.
59"""
60nonzero_v = tf.get_variable('data_util/nonzero_v', [],
61initializer=tf.zeros_initializer(),
62trainable=False,
63dtype=tf.float32)
64
65rounded_v = tf.maximum(v - epsilon, tf.to_float(0))
66
67# Log data utilization
68nonzero_v = tf.assign_add(nonzero_v, tf.count_nonzero(
69rounded_v, dtype=tf.float32))
70
71data_util = (nonzero_v) / tf.to_float(batch_size) / (
72tf.to_float(tf_global_step) + 2)
73data_util = tf.minimum(data_util, 1)
74tf.stop_gradient(data_util)
75
76slim.summaries.add_scalar_summary(data_util, 'data_util/data_util')
77slim.summaries.add_scalar_summary(tf.reduce_sum(v), 'data_util/batch_sum_v')
78return data_util
79
80
81def parse_dropout_rate_list(str_list):
82"""Parse a comma-separated string to a list.
83
84The format follows [dropout rate, epoch_num]+ and the result is a list of 100
85dropout rate.
86
87Args:
88str_list: the input string.
89Returns:
90result: the converted list
91"""
92str_list = np.array(str_list)
93values = str_list[np.arange(0, len(str_list), 2)]
94indexes = str_list[np.arange(1, len(str_list), 2)]
95
96values = [float(t) for t in values]
97indexes = [int(t) for t in indexes]
98
99assert len(values) == len(indexes) and np.sum(indexes) == 100
100for t in values:
101assert t >= 0.0 and t <= 1.0
102
103result = []
104for t in range(len(str_list) // 2):
105result.extend([values[t]] * indexes[t])
106return result
107
108
109def mentornet_nn(input_features,
110label_embedding_size=2,
111epoch_embedding_size=5,
112num_label_embedding=2,
113num_fc_nodes=20):
114"""The neural network form of the MentorNet.
115
116An implementation of the mentornet. The details are in:
117Jiang, Lu, et al. "MentorNet: Learning Data-Driven Curriculum for Very Deep
118Neural Networks on Corrupted Labels." ICML. 2018.
119
120Args:
121input_features: a [batch_size, 4] tensor. Each dimension corresponds to
1220: loss, 1: loss difference to the moving average, 2: label and 3: epoch,
123where epoch is an integer between 0 and 99 (the first and the last epoch).
124label_embedding_size: the embedding size for the label feature.
125epoch_embedding_size: the embedding size for the epoch feature.
126num_label_embedding: the number of different labels.
127num_fc_nodes: number of hidden nodes in the fc layer.
128Returns:
129v: [batch_size, 1] weight vector.
130"""
131batch_size = int(input_features.get_shape()[0])
132
133losses = tf.reshape(input_features[:, 0], [-1, 1])
134loss_diffs = tf.reshape(input_features[:, 1], [-1, 1])
135labels = tf.to_int32(tf.reshape(input_features[:, 2], [-1, 1]))
136epochs = tf.to_int32(tf.reshape(input_features[:, 3], [-1, 1]))
137epochs = tf.minimum(epochs, tf.ones([batch_size, 1], dtype=tf.int32) * 99)
138
139if len(losses.get_shape()) <= 1:
140num_steps = 1
141else:
142num_steps = int(losses.get_shape()[1])
143
144with tf.variable_scope('mentornet', reuse=tf.AUTO_REUSE):
145label_embedding = tf.get_variable(
146'label_embedding', [num_label_embedding, label_embedding_size])
147epoch_embedding = tf.get_variable(
148'epoch_embedding', [100, epoch_embedding_size])
149
150lstm_inputs = tf.stack([losses, loss_diffs], axis=1)
151lstm_inputs = tf.squeeze(lstm_inputs)
152lstm_inputs = [lstm_inputs]
153
154forward_cell = tf.contrib.rnn.BasicLSTMCell(1, forget_bias=0.0)
155backward_cell = tf.contrib.rnn.BasicLSTMCell(1, forget_bias=0.0)
156
157_, out_state_fw, out_state_bw = tf.contrib.rnn.static_bidirectional_rnn(
158forward_cell,
159backward_cell,
160inputs=lstm_inputs,
161dtype=tf.float32,
162sequence_length=np.ones(batch_size) * num_steps)
163
164label_inputs = tf.squeeze(tf.nn.embedding_lookup(label_embedding, labels))
165epoch_inputs = tf.squeeze(tf.nn.embedding_lookup(epoch_embedding, epochs))
166
167h = tf.concat([out_state_fw[0], out_state_bw[0]], 1)
168feat = tf.concat([label_inputs, epoch_inputs, h], 1)
169feat_dim = int(feat.get_shape()[1])
170
171fc_1 = tf.add(
172tf.matmul(
173feat,
174tf.get_variable(
175'Variable',
176initializer=tf.random_normal([feat_dim, num_fc_nodes]))),
177tf.get_variable(
178'Variable_1', initializer=tf.random_normal([num_fc_nodes])))
179fc_1 = tf.nn.tanh(fc_1)
180# Output layer with linear activation
181out_layer = tf.matmul(
182fc_1,
183tf.get_variable(
184'Variable_2', initializer=tf.random_normal([num_fc_nodes, 1])) +
185tf.get_variable('Variable_3', initializer=tf.random_normal([1])))
186return out_layer
187
188
189def loss_thresholding_function(loss, para_lambda=0.75):
190"""The simplest MentorNet is a loss thresholding function.
191
192Args:
193loss: [batch_size, 1] the loss vector.
194para_lambda: the age parameter, in [0,1], indicates the percentile in a
195mini-batch, where 1 indicates selecting all samples and 0 is selecting
196zero samples.
197Returns:
198v: [batch_size, 1] weight vector.
199"""
200assert para_lambda >= 0 and para_lambda <= 1
201with tf.variable_scope('mentornet/thresholding'):
202
203one_weights = tf.ones(tf.shape(loss), tf.float32)
204zero_weights = tf.zeros(tf.shape(loss), tf.float32)
205
206percentile_loss = tfp.stats.percentile(
207loss, int(para_lambda * 100))
208# Replace this with loss moving average to get better results.
209percentile_loss = tf.reshape(percentile_loss, [1])
210
211weights = tf.where(loss >= percentile_loss, zero_weights, one_weights)
212v = tf.reshape(weights, [-1, 1], name='v')
213return v
214
215
216def mentornet(epoch,
217loss,
218labels,
219loss_p_percentile,
220example_dropout_rates,
221burn_in_epoch=18,
222fixed_epoch_after_burn_in=False,
223loss_moving_average_decay=0.5,
224mentornet_net_hparams=None,
225avg_name='cumulative',
226debug=False):
227"""The MentorNet to train with the StudentNet.
228
229The details are in:
230Jiang, Lu, et al. "MentorNet: Learning Data-Driven Curriculum for Very Deep
231Neural Networks on Corrupted Labels." ICML. 2018.
232
233Args:
234epoch: a tensor [batch_size, 1] representing the training percentage. Each
235epoch is an integer between 0 and 99.
236loss: a tensor [batch_size, 1] representing the sample loss.
237labels: a tensor [batch_size, 1] representing the label. Every label is set
238to 0 in the current version.
239loss_p_percentile: a 1-d tensor of size 100, where each element is the
240p-percentile at that epoch to compute the moving average.
241example_dropout_rates: a 1-d tensor of size 100, where each element is the
242dropout rate at that epoch. Dropping out means the probability of setting
243sample weights to zeros proposed in Liang, Junwei, et al. "Learning to
244Detect Concepts from Webly-Labeled Video Data." IJCAI. 2016.
245burn_in_epoch: the number of burn_in_epoch. In the first burn_in_epoch, all
246samples have 1.0 weights.
247fixed_epoch_after_burn_in: whether to fix the epoch after the burn-in.
248loss_moving_average_decay: the decay factor to compute the moving average.
249mentornet_net_hparams: mentornet hyperparameters.
250avg_name: name of the loss moving average variable.
251debug: whether to print the weight information for debugging purposes.
252
253Returns:
254v: [batch_size, 1] weight vector.
255"""
256with tf.variable_scope('mentor_inputs'):
257loss_moving_avg = tf.get_variable(
258avg_name, [], initializer=tf.zeros_initializer(), trainable=False)
259
260if not fixed_epoch_after_burn_in:
261cur_epoch = epoch
262else:
263cur_epoch = tf.to_int32(tf.minimum(epoch, burn_in_epoch))
264
265v_ones = tf.ones(tf.shape(loss), tf.float32)
266v_zeros = tf.zeros(tf.shape(loss), tf.float32)
267upper_bound = tf.cond(cur_epoch < (burn_in_epoch - 1), lambda: v_ones,
268lambda: v_zeros)
269
270this_dropout_rate = tf.squeeze(
271tf.nn.embedding_lookup(example_dropout_rates, cur_epoch))
272this_percentile = tf.squeeze(
273tf.nn.embedding_lookup(loss_p_percentile, cur_epoch))
274
275percentile_loss = tf.contrib.distributions.percentile(
276loss, this_percentile * 100)
277percentile_loss = tf.convert_to_tensor(percentile_loss)
278
279loss_moving_avg = loss_moving_avg.assign(
280loss_moving_avg * loss_moving_average_decay +
281(1 - loss_moving_average_decay) * percentile_loss)
282
283slim.summaries.add_scalar_summary(percentile_loss,
284'{}/percentile_loss'.format(avg_name))
285slim.summaries.add_scalar_summary(cur_epoch,
286'{}/percentile_loss'.format(avg_name))
287slim.summaries.add_scalar_summary(loss_moving_avg,
288'{}/percentile_loss'.format(avg_name))
289
290ones = tf.ones([tf.shape(loss)[0], 1], tf.float32)
291
292epoch_vec = tf.scalar_mul(tf.to_float(cur_epoch), ones)
293lossdiff = loss - tf.scalar_mul(loss_moving_avg, ones)
294
295input_data = tf.squeeze(tf.stack([loss, lossdiff, labels, epoch_vec], 1))
296hparams = mentornet_net_hparams
297if hparams:
298v = tf.nn.sigmoid(
299mentornet_nn(
300input_data,
301label_embedding_size=hparams.label_embedding_size,
302epoch_embedding_size=hparams.epoch_embedding_size,
303num_label_embedding=hparams.num_label_embedding,
304num_fc_nodes=hparams.num_fc_nodes),
305name='v')
306else:
307v = tf.nn.sigmoid(mentornet_nn(input_data), name='v')
308# Force select all samples in the first burn_in_epochs
309v = tf.maximum(v, upper_bound, 'v_bound')
310
311v_dropout = tf.py_func(probabilistic_sample,
312[v, this_dropout_rate, 'random'], tf.float32)
313v_dropout = tf.reshape(v_dropout, [-1, 1], name='v_dropout')
314
315# Print information in the debug mode.
316if debug:
317v_dropout = tf.Print(
318v_dropout,
319data=[cur_epoch, loss_moving_avg, percentile_loss],
320summarize=64,
321message='epoch, loss_moving_avg, percentile_loss')
322v_dropout = tf.Print(
323v_dropout, data=[lossdiff], summarize=64, message='loss_diff')
324v_dropout = tf.Print(v_dropout, data=[v], summarize=64, message='v')
325v_dropout = tf.Print(
326v_dropout, data=[v_dropout], summarize=64, message='v_dropout')
327return v_dropout
328
329
330def mentor_mix_up(x, l, v, beta_param):
331"""MentorMix method.
332
333Args:
334x: the input image batch [batch_size, H, W, C]
335l: the label batch [batch_size, num_of_class]
336v: mentornet weights
337beta_param: the parameter to sample the weight average.
338Returns:
339result: The mixed images and label batches.
340"""
341if beta_param <= 0:
342return x, l
343
344v_flat = tf.reshape(v, [-1])
345dist = tfp.distributions.Categorical(probs=tf.nn.softmax(v_flat))
346idx = dist.sample(tf.shape(x)[0])
347
348x2 = tf.gather(x, idx)
349l2 = tf.gather(l, idx)
350
351mix = tf.distributions.Beta(beta_param,
352beta_param).sample([tf.shape(x)[0], 1, 1, 1])
353
354mix = tf.maximum(mix, 1 - mix)
355mix = tf.where(v_flat >= 0.5, mix, 1 - mix)
356
357xmix = x * mix + x2 * (1 - mix)
358lmix = l * mix[:, :, 0, 0] + l2 * (1 - mix[:, :, 0, 0])
359v = tf.stop_gradient(v)
360v_flat = tf.stop_gradient(v_flat)
361xmix = tf.stop_gradient(xmix)
362lmix = tf.stop_gradient(lmix)
363return xmix, lmix
364
365
366def probabilistic_sample(v, rate=0.5, mode='binary'):
367"""Implement the sampling techniques.
368
369Args:
370v: [batch_size, 1] the weight column vector.
371rate: in [0,1]. 0 indicates using all samples and 1 indicates
372using zero samples.
373mode: a string. One of the following 1) actual returns the actual sampling;
3742) binary returns binary weights; 3) random performs random sampling.
375Returns:
376v: [batch_size, 1] weight vector.
377"""
378assert rate >= 0 and rate <= 1
379epsilon = 1e-5
380with tf.variable_scope('mentornet/prob_sampling'):
381p = np.copy(v)
382p = np.reshape(p, -1)
383if mode == 'random':
384ids = np.random.choice(
385p.shape[0], int(p.shape[0] * (1 - rate)), replace=False)
386else:
387# Avoid 1) all zero loss and 2) zero loss are never selected.
388p += epsilon
389p /= np.sum(p)
390ids = np.random.choice(
391p.shape[0], int(p.shape[0] * (1 - rate)), p=p, replace=False)
392result = np.zeros(v.shape, dtype=np.float32)
393if mode == 'binary':
394result[ids, 0] = 1
395else:
396result[ids, 0] = v[ids, 0]
397return result
398
399
400def get_mentornet_network_hyperparameter(checkpoint_path):
401"""Get MentorNet network configuration from the checkpoint file.
402
403Args:
404checkpoint_path: the file path to restore MentorNet.
405
406Returns:
407a named tuple MentorNetNetHParams.
408"""
409if checkpoint_path and tf.gfile.IsDirectory(checkpoint_path):
410checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
411
412reader = tf.train.load_checkpoint(checkpoint_path)
413var_to_shape_map = reader.get_variable_to_shape_map()
414
415label_embedding_size = 2
416epoch_embedding_size = 5
417num_label_embedding = 2
418num_fc_nodes = 20
419
420key = 'mentornet/epoch_embedding'
421if key in var_to_shape_map:
422epoch_embedding_size = reader.get_tensor(key).shape[1]
423
424key = 'mentornet/label_embedding'
425if key in var_to_shape_map:
426num_label_embedding = reader.get_tensor(key).shape[0]
427label_embedding_size = reader.get_tensor(key).shape[1]
428
429# FC layer.
430key = 'mentornet/Variable'
431if key in var_to_shape_map:
432num_fc_nodes = reader.get_tensor(key).shape[1]
433
434hparams = MentorNetNetHParams(
435label_embedding_size=label_embedding_size,
436epoch_embedding_size=epoch_embedding_size,
437num_label_embedding=num_label_embedding,
438num_fc_nodes=num_fc_nodes)
439
440return hparams
441