google-research

Форк
0
92 строки · 3.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Utility functions."""
17

18
import numpy as np
19
import tensorflow as tf
20

21

22
def maybe_one_hot(labels, depth):
23
  """Convert categorical labels to one-hot, if needed.
24

25
  Args:
26
    labels: A `Tensor` containing labels.
27
    depth: An integer specifying the depth of one-hot represention (number of
28
      classes).
29

30
  Returns:
31
    One-hot labels.
32
  """
33
  if len(labels.shape) > 1:
34
    return labels
35
  else:
36
    return tf.one_hot(labels, depth=depth)
37

38

39
def get_smoothed_labels(labels, preds, smoothing_weights):
40
  """Smoothen the labels."""
41
  smoothing_weights = tf.reshape(smoothing_weights, [-1, 1])
42
  return labels * smoothing_weights + preds * (1. - smoothing_weights)
43

44

45
def mixup(images,
46
          labels,
47
          num_classes,
48
          mixup_alpha,
49
          mixing_weights=None,
50
          mixing_probs=None):
51
  """Mixup with mixing weights and probabilities.
52

53
  Args:
54
    images: A `Tensor` containing batch of images.
55
    labels: A `Tensor` containing batch of labels.
56
    num_classes: Number of classes.
57
    mixup_alpha: Parameter of Beta distribution for sampling mixing ratio
58
      (applicable for regular mixup).
59
    mixing_weights: A `Tensor` of size [batch_size] specifying mixing weights.
60
    mixing_probs: A `Tensor` of size [batch_size] specifying probabilities for
61
      sampling images for imixing.
62

63
  Returns:
64
    Minibatch of mixed up images and labels.
65
  """
66

67
  images = images.numpy()
68
  labels = maybe_one_hot(labels, num_classes).numpy()
69
  num_examples = images.shape[0]
70
  mixing_ratios_im = np.random.beta(
71
      mixup_alpha, mixup_alpha, size=(num_examples, 1, 1, 1))
72
  mixing_ratios_lab = np.reshape(mixing_ratios_im, [num_examples, 1])
73
  if mixing_probs is None:
74
    mixing_indices = np.random.permutation(num_examples)
75
  else:
76
    mixing_probs = np.round(mixing_probs, 5)
77
    mixing_probs = mixing_probs / np.sum(mixing_probs)
78
    mixing_indices = np.random.choice(
79
        num_examples, size=num_examples, replace=True, p=mixing_probs)
80
  if mixing_weights is not None:
81
    mixing_ratios_im = mixing_weights / (
82
        mixing_weights + mixing_weights[mixing_indices])
83
    mixing_ratios_im = np.reshape(mixing_ratios_im, [-1, 1, 1, 1])
84
    # mix labels in same proportions
85
    mixing_ratios_lab = np.reshape(mixing_ratios_im, [num_examples, 1])
86
  images = (
87
      images * mixing_ratios_im + images[mixing_indices] *
88
      (1. - mixing_ratios_im))
89
  labels = (
90
      labels * mixing_ratios_lab + labels[mixing_indices] *
91
      (1. - mixing_ratios_lab))
92
  return images, labels
93

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.