google-research
181 строка · 5.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generic models for building blocks.
17
18Can be used either for the main prediction task or as a building blocks
19for pretext and main models.
20
21All batch norm attributes (e.g., self.batch_norm_layers = ...) should contain
22'batch_norm' in the object name, because this string is matched for weight
23decay purposes (batch norm parameters are excluded.)
24"""
25
26from typing import Sequence27
28from flax import linen as nn29import jax30
31
32class BaseMLP(nn.Module):33"""Base MLP with BatchNorm on inputs.34
35Uses BatchNorm with running average before features are passes through
36the model.
37
38Relu after every layer except the last.
39"""
40
41layer_sizes: Sequence[int]42training: bool43
44def setup(self):45self.layers = [nn.Dense(layer_size) for layer_size in self.layer_sizes]46self.initial_batch_norm_layer = nn.BatchNorm(use_running_average=True)47self.batch_norm_layers = [48nn.BatchNorm(use_running_average=(not self.training))49for _ in self.layer_sizes50]51
52def __call__(self, inputs):53x = inputs54x = self.initial_batch_norm_layer(x)55for i, layer in enumerate(self.layers):56x = self.batch_norm_layers[i](x)57x = layer(x)58if i != len(self.layers) - 1:59x = nn.relu(x)60return x61
62
63class ImixMLP(nn.Module):64"""Base MLP used in iMix.65
66A replica of the iMix mode architecture.
67
68See page 17 here: https://arxiv.org/pdf/2010.08887.pdf
69
70'The output dimensions of layers are (2048-2048-4096-4096-8192),
71where all layers have batch normalization followed by ReLU except
72for the last layer. The last layer activation is maxout (Goodfellow et al.,
732013) with 4 sets, such that the output dimension is 2048.
74On top of this five-layer MLP, we attach two-layer
75MLP (2048-128) as a projection head.'
76
77See the code here: https://github.com/kibok90/imix/blob/main/models/mlp.py
78"""
79
80training: bool81
82def setup(self):83self.dense_layers = [nn.Dense(2048, use_bias=False),84nn.Dense(2048, use_bias=False),85nn.Dense(4096, use_bias=False),86nn.Dense(4096, use_bias=False),87nn.Dense(8192, use_bias=True)]88self.initial_batch_norm_layer = nn.BatchNorm(use_running_average=True)89self.batch_norm_layers = [90nn.BatchNorm(use_running_average=(not self.training))91for _ in self.dense_layers92]93
94self.projection_head = nn.Dense(128)95
96def __call__(self, inputs):97x = inputs98x = self.initial_batch_norm_layer(x)99for i, layer in enumerate(self.dense_layers):100x = layer(x)101if i != len(self.dense_layers) - 1:102x = self.batch_norm_layers[i](x)103x = nn.relu(x)104else:105x = x.reshape(x.shape[:-1] + (x.shape[-1]//4, 4))106x = jax.numpy.max(x, axis=-1)107x = self.projection_head(x)108return x109
110
111class MLP(nn.Module):112"""Base MLP with BatchNorm.113
114Relu after every layer except the last.
115"""
116
117layer_sizes: Sequence[int]118training: bool119
120def setup(self):121self.layers = [nn.Dense(layer_size) for layer_size in self.layer_sizes]122self.batch_norm_layers = [123nn.BatchNorm(use_running_average=(not self.training))124for _ in self.layer_sizes125]126
127def __call__(self, inputs):128x = inputs129for i, layer in enumerate(self.layers):130x = layer(x)131if i != len(self.layers) - 1:132x = self.batch_norm_layers[i](x)133x = nn.relu(x)134return x135
136
137class Resnet(nn.Module):138"""Resnet architecure as defined in Revisiting DL Models for Tabular Data.139
140https://arxiv.org/pdf/2106.11959.pdf
141"""
142
143training: bool144num_blocks: int = 12145factor: float = 3.402146dropout_rate_1: float = 0.3612147dropout_rate_2: float = 0.0148layer_size: int = 235149
150@nn.compact151def __call__(self, inputs):152first_projection = nn.Dense(self.layer_size)153hidden_size = int(self.layer_size * self.factor)154batch_norm_layers = [155nn.BatchNorm(use_running_average=(not self.training))156for _ in range(self.num_blocks)157]158linear_first = [159nn.Dense(hidden_size, use_bias=False)160for _ in range(self.num_blocks)161]162linear_second = [nn.Dense(self.layer_size) for _ in range(self.num_blocks)]163
164dropout_first = [165nn.Dropout(self.dropout_rate_1, deterministic=(not self.training))166for _ in range(self.num_blocks)167]168
169dropout_second = [170nn.Dropout(self.dropout_rate_2, deterministic=(not self.training))171for _ in range(self.num_blocks)172]173
174x = inputs175x = first_projection(x)176for i in range(self.num_blocks):177x += dropout_second[i](178linear_second[i](nn.relu(dropout_first[i](179linear_first[i](batch_norm_layers[i](x))))))180
181return x182