google-research
71 строка · 2.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implement ConvNet architecture."""
17
18import tensorflow as tf19
20from extreme_memorization import alignment21
22
23class ConvNet(tf.keras.Model):24"""Neural network used to train CIFAR10.25
26Uses Keras to define convolutional and fully connected layers.
27"""
28
29def __init__(self, num_labels=10):30super(ConvNet, self).__init__()31self.num_labels = num_labels32self.conv1 = tf.keras.layers.Conv2D(3332, 5, padding='same', activation='relu')34self.max_pool1 = tf.keras.layers.MaxPooling2D((3, 3), (2, 2),35padding='valid')36self.conv2 = tf.keras.layers.Conv2D(3764, 5, padding='valid', activation='relu')38self.max_pool2 = tf.keras.layers.MaxPooling2D((3, 3), (2, 2),39padding='valid')40self.flatten = tf.keras.layers.Flatten()41self.fc1 = tf.keras.layers.Dense(1024, activation='relu', name='hidden')42self.fc2 = tf.keras.layers.Dense(num_labels, name='top')43
44def call(self, x, labels, training=False, step=0):45"""Used to perform a forward pass."""46# Assume channels_last47input_shape = [32, 32, 3]48x = tf.keras.layers.Reshape(49target_shape=input_shape, input_shape=(32 * 32 * 3,))(50x)51
52if tf.keras.backend.image_data_format() == 'channels_first':53# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).54# This provides a large performance boost on GPU. See55# https://www.tensorflow.org/performance/performance_guide#data_formats56x = tf.transpose(a=x, perm=[0, 3, 1, 2])57
58x = self.max_pool1(self.conv1(x))59x = self.max_pool2(self.conv2(x))60x = self.flatten(x)61x = self.fc1(x)62
63alignment.plot_class_alignment(64x,65labels,66self.num_labels,67step,68tf_summary_key='representation_alignment')69
70x = self.fc2(x)71return x72