google-research
92 строки · 3.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Utility functions."""
17
18import numpy as np
19import tensorflow as tf
20
21
22def maybe_one_hot(labels, depth):
23"""Convert categorical labels to one-hot, if needed.
24
25Args:
26labels: A `Tensor` containing labels.
27depth: An integer specifying the depth of one-hot represention (number of
28classes).
29
30Returns:
31One-hot labels.
32"""
33if len(labels.shape) > 1:
34return labels
35else:
36return tf.one_hot(labels, depth=depth)
37
38
39def get_smoothed_labels(labels, preds, smoothing_weights):
40"""Smoothen the labels."""
41smoothing_weights = tf.reshape(smoothing_weights, [-1, 1])
42return labels * smoothing_weights + preds * (1. - smoothing_weights)
43
44
45def mixup(images,
46labels,
47num_classes,
48mixup_alpha,
49mixing_weights=None,
50mixing_probs=None):
51"""Mixup with mixing weights and probabilities.
52
53Args:
54images: A `Tensor` containing batch of images.
55labels: A `Tensor` containing batch of labels.
56num_classes: Number of classes.
57mixup_alpha: Parameter of Beta distribution for sampling mixing ratio
58(applicable for regular mixup).
59mixing_weights: A `Tensor` of size [batch_size] specifying mixing weights.
60mixing_probs: A `Tensor` of size [batch_size] specifying probabilities for
61sampling images for imixing.
62
63Returns:
64Minibatch of mixed up images and labels.
65"""
66
67images = images.numpy()
68labels = maybe_one_hot(labels, num_classes).numpy()
69num_examples = images.shape[0]
70mixing_ratios_im = np.random.beta(
71mixup_alpha, mixup_alpha, size=(num_examples, 1, 1, 1))
72mixing_ratios_lab = np.reshape(mixing_ratios_im, [num_examples, 1])
73if mixing_probs is None:
74mixing_indices = np.random.permutation(num_examples)
75else:
76mixing_probs = np.round(mixing_probs, 5)
77mixing_probs = mixing_probs / np.sum(mixing_probs)
78mixing_indices = np.random.choice(
79num_examples, size=num_examples, replace=True, p=mixing_probs)
80if mixing_weights is not None:
81mixing_ratios_im = mixing_weights / (
82mixing_weights + mixing_weights[mixing_indices])
83mixing_ratios_im = np.reshape(mixing_ratios_im, [-1, 1, 1, 1])
84# mix labels in same proportions
85mixing_ratios_lab = np.reshape(mixing_ratios_im, [num_examples, 1])
86images = (
87images * mixing_ratios_im + images[mixing_indices] *
88(1. - mixing_ratios_im))
89labels = (
90labels * mixing_ratios_lab + labels[mixing_indices] *
91(1. - mixing_ratios_lab))
92return images, labels
93