google-research
60 строк · 1.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implement a 2-Layer MLP."""
17
18import tensorflow as tf19
20from extreme_memorization import alignment21
22
23class MLP(tf.keras.Model):24"""Simple 2-Layer MLP."""25
26def __init__(self,27num_units,28stddev,29activation_fn=tf.nn.relu,30custom_init=False,31num_labels=10):32super(MLP, self).__init__(name="MLP")33self.custom_init = custom_init34self.num_labels = num_labels35if custom_init:36self.hidden = tf.keras.layers.Dense(37num_units,38activation=activation_fn,39kernel_initializer=tf.keras.initializers.RandomNormal(40mean=0.0, stddev=stddev),41use_bias=False,42name="Dense")43else:44self.hidden = tf.keras.layers.Dense(45num_units, activation=activation_fn, use_bias=False, name="Dense")46self.top = tf.keras.layers.Dense(num_labels, use_bias=False, name="Top")47
48def call(self, input_, labels, training=False, step=0):49x = tf.keras.layers.Flatten()(input_)50x = self.hidden(x)51
52alignment.plot_class_alignment(53x,54labels,55self.num_labels,56step,57tf_summary_key="representation_alignment")58
59x = self.top(x)60return x61