google-research

Форк
0
71 строка · 2.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implement ConvNet architecture."""
17

18
import tensorflow as tf
19

20
from extreme_memorization import alignment
21

22

23
class ConvNet(tf.keras.Model):
24
  """Neural network used to train CIFAR10.
25

26
  Uses Keras to define convolutional and fully connected layers.
27
  """
28

29
  def __init__(self, num_labels=10):
30
    super(ConvNet, self).__init__()
31
    self.num_labels = num_labels
32
    self.conv1 = tf.keras.layers.Conv2D(
33
        32, 5, padding='same', activation='relu')
34
    self.max_pool1 = tf.keras.layers.MaxPooling2D((3, 3), (2, 2),
35
                                                  padding='valid')
36
    self.conv2 = tf.keras.layers.Conv2D(
37
        64, 5, padding='valid', activation='relu')
38
    self.max_pool2 = tf.keras.layers.MaxPooling2D((3, 3), (2, 2),
39
                                                  padding='valid')
40
    self.flatten = tf.keras.layers.Flatten()
41
    self.fc1 = tf.keras.layers.Dense(1024, activation='relu', name='hidden')
42
    self.fc2 = tf.keras.layers.Dense(num_labels, name='top')
43

44
  def call(self, x, labels, training=False, step=0):
45
    """Used to perform a forward pass."""
46
    # Assume channels_last
47
    input_shape = [32, 32, 3]
48
    x = tf.keras.layers.Reshape(
49
        target_shape=input_shape, input_shape=(32 * 32 * 3,))(
50
            x)
51

52
    if tf.keras.backend.image_data_format() == 'channels_first':
53
      # Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
54
      # This provides a large performance boost on GPU. See
55
      # https://www.tensorflow.org/performance/performance_guide#data_formats
56
      x = tf.transpose(a=x, perm=[0, 3, 1, 2])
57

58
    x = self.max_pool1(self.conv1(x))
59
    x = self.max_pool2(self.conv2(x))
60
    x = self.flatten(x)
61
    x = self.fc1(x)
62

63
    alignment.plot_class_alignment(
64
        x,
65
        labels,
66
        self.num_labels,
67
        step,
68
        tf_summary_key='representation_alignment')
69

70
    x = self.fc2(x)
71
    return x
72

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.