google-research

Форк
0
181 строка · 5.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generic models for building blocks.
17

18
Can be used either for the main prediction task or as a building blocks
19
for pretext and main models.
20

21
All batch norm attributes (e.g., self.batch_norm_layers = ...) should contain
22
'batch_norm' in the object name, because this string is matched for weight
23
decay purposes (batch norm parameters are excluded.)
24
"""
25

26
from typing import Sequence
27

28
from flax import linen as nn
29
import jax
30

31

32
class BaseMLP(nn.Module):
33
  """Base MLP with BatchNorm on inputs.
34

35
  Uses BatchNorm with running average before features are passes through
36
  the model.
37

38
  Relu after every layer except the last.
39
  """
40

41
  layer_sizes: Sequence[int]
42
  training: bool
43

44
  def setup(self):
45
    self.layers = [nn.Dense(layer_size) for layer_size in self.layer_sizes]
46
    self.initial_batch_norm_layer = nn.BatchNorm(use_running_average=True)
47
    self.batch_norm_layers = [
48
        nn.BatchNorm(use_running_average=(not self.training))
49
        for _ in self.layer_sizes
50
    ]
51

52
  def __call__(self, inputs):
53
    x = inputs
54
    x = self.initial_batch_norm_layer(x)
55
    for i, layer in enumerate(self.layers):
56
      x = self.batch_norm_layers[i](x)
57
      x = layer(x)
58
      if i != len(self.layers) - 1:
59
        x = nn.relu(x)
60
    return x
61

62

63
class ImixMLP(nn.Module):
64
  """Base MLP used in iMix.
65

66
  A replica of the iMix mode architecture.
67

68
  See page 17 here: https://arxiv.org/pdf/2010.08887.pdf
69

70
  'The output dimensions of layers are (2048-2048-4096-4096-8192),
71
  where all layers have batch normalization followed by ReLU except
72
  for the last layer. The last layer activation is maxout (Goodfellow et al.,
73
  2013) with 4 sets, such that the output dimension is 2048.
74
  On top of this five-layer MLP, we attach two-layer
75
  MLP (2048-128)  as a projection head.'
76

77
  See the code here: https://github.com/kibok90/imix/blob/main/models/mlp.py
78
  """
79

80
  training: bool
81

82
  def setup(self):
83
    self.dense_layers = [nn.Dense(2048, use_bias=False),
84
                         nn.Dense(2048, use_bias=False),
85
                         nn.Dense(4096, use_bias=False),
86
                         nn.Dense(4096, use_bias=False),
87
                         nn.Dense(8192, use_bias=True)]
88
    self.initial_batch_norm_layer = nn.BatchNorm(use_running_average=True)
89
    self.batch_norm_layers = [
90
        nn.BatchNorm(use_running_average=(not self.training))
91
        for _ in self.dense_layers
92
    ]
93

94
    self.projection_head = nn.Dense(128)
95

96
  def __call__(self, inputs):
97
    x = inputs
98
    x = self.initial_batch_norm_layer(x)
99
    for i, layer in enumerate(self.dense_layers):
100
      x = layer(x)
101
      if i != len(self.dense_layers) - 1:
102
        x = self.batch_norm_layers[i](x)
103
        x = nn.relu(x)
104
      else:
105
        x = x.reshape(x.shape[:-1] + (x.shape[-1]//4, 4))
106
        x = jax.numpy.max(x, axis=-1)
107
    x = self.projection_head(x)
108
    return x
109

110

111
class MLP(nn.Module):
112
  """Base MLP with BatchNorm.
113

114
  Relu after every layer except the last.
115
  """
116

117
  layer_sizes: Sequence[int]
118
  training: bool
119

120
  def setup(self):
121
    self.layers = [nn.Dense(layer_size) for layer_size in self.layer_sizes]
122
    self.batch_norm_layers = [
123
        nn.BatchNorm(use_running_average=(not self.training))
124
        for _ in self.layer_sizes
125
    ]
126

127
  def __call__(self, inputs):
128
    x = inputs
129
    for i, layer in enumerate(self.layers):
130
      x = layer(x)
131
      if i != len(self.layers) - 1:
132
        x = self.batch_norm_layers[i](x)
133
        x = nn.relu(x)
134
    return x
135

136

137
class Resnet(nn.Module):
138
  """Resnet architecure as defined in Revisiting DL Models for Tabular Data.
139

140
  https://arxiv.org/pdf/2106.11959.pdf
141
  """
142

143
  training: bool
144
  num_blocks: int = 12
145
  factor: float = 3.402
146
  dropout_rate_1: float = 0.3612
147
  dropout_rate_2: float = 0.0
148
  layer_size: int = 235
149

150
  @nn.compact
151
  def __call__(self, inputs):
152
    first_projection = nn.Dense(self.layer_size)
153
    hidden_size = int(self.layer_size * self.factor)
154
    batch_norm_layers = [
155
        nn.BatchNorm(use_running_average=(not self.training))
156
        for _ in range(self.num_blocks)
157
    ]
158
    linear_first = [
159
        nn.Dense(hidden_size, use_bias=False)
160
        for _ in range(self.num_blocks)
161
    ]
162
    linear_second = [nn.Dense(self.layer_size) for _ in range(self.num_blocks)]
163

164
    dropout_first = [
165
        nn.Dropout(self.dropout_rate_1, deterministic=(not self.training))
166
        for _ in range(self.num_blocks)
167
    ]
168

169
    dropout_second = [
170
        nn.Dropout(self.dropout_rate_2, deterministic=(not self.training))
171
        for _ in range(self.num_blocks)
172
    ]
173

174
    x = inputs
175
    x = first_projection(x)
176
    for i in range(self.num_blocks):
177
      x += dropout_second[i](
178
          linear_second[i](nn.relu(dropout_first[i](
179
              linear_first[i](batch_norm_layers[i](x))))))
180

181
    return x
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.