google-research

Форк
0
60 строк · 1.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implement a 2-Layer MLP."""
17

18
import tensorflow as tf
19

20
from extreme_memorization import alignment
21

22

23
class MLP(tf.keras.Model):
24
  """Simple 2-Layer MLP."""
25

26
  def __init__(self,
27
               num_units,
28
               stddev,
29
               activation_fn=tf.nn.relu,
30
               custom_init=False,
31
               num_labels=10):
32
    super(MLP, self).__init__(name="MLP")
33
    self.custom_init = custom_init
34
    self.num_labels = num_labels
35
    if custom_init:
36
      self.hidden = tf.keras.layers.Dense(
37
          num_units,
38
          activation=activation_fn,
39
          kernel_initializer=tf.keras.initializers.RandomNormal(
40
              mean=0.0, stddev=stddev),
41
          use_bias=False,
42
          name="Dense")
43
    else:
44
      self.hidden = tf.keras.layers.Dense(
45
          num_units, activation=activation_fn, use_bias=False, name="Dense")
46
    self.top = tf.keras.layers.Dense(num_labels, use_bias=False, name="Top")
47

48
  def call(self, input_, labels, training=False, step=0):
49
    x = tf.keras.layers.Flatten()(input_)
50
    x = self.hidden(x)
51

52
    alignment.plot_class_alignment(
53
        x,
54
        labels,
55
        self.num_labels,
56
        step,
57
        tf_summary_key="representation_alignment")
58

59
    x = self.top(x)
60
    return x
61

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.