CelestialSurveyor
114 строк · 5.2 Кб
1# Slow Fast Model implementation taken from https://github.com/xuzheyuan624/slowfast-keras with some modifications
2from keras.layers import (Conv3D, BatchNormalization, ReLU, Add, MaxPool3D, GlobalAveragePooling3D,3Concatenate, Dropout, Dense, Lambda, Input)4from keras.models import Model5from keras import Sequential6
7
8def Conv_BN_ReLU(planes, kernel_size, strides=(1, 1, 1), padding='same', use_bias=False):9return Sequential([10Conv3D(planes, kernel_size, strides=strides, padding=padding, use_bias=use_bias),11BatchNormalization(),12ReLU()13])14
15
16def bottleneck(x, planes, stride=1, downsample=None, head_conv=1, use_bias=False):17residual = x18if head_conv == 1:19x = Conv_BN_ReLU(planes, kernel_size=1, use_bias=use_bias)(x)20elif head_conv == 3:21x = Conv_BN_ReLU(planes, kernel_size=(3, 1, 1), use_bias=use_bias)(x)22else:23raise ValueError('Unsupported head_conv!!!')24x = Conv_BN_ReLU(planes, kernel_size=(1, 3, 3), strides=(1, stride, stride), use_bias=use_bias)(x)25x = Conv3D(planes*4, kernel_size=1, use_bias=use_bias)(x)26x = BatchNormalization()(x)27if downsample is not None:28residual = downsample(residual)29x = Add()([x, residual])30x = ReLU()(x)31return x32
33
34def datalayer(x, stride):35return x[:, ::stride, :, :, :]36
37
38def SlowFast_body(layers, block, dropout=0.5):39inputs = Input(shape=(None, 64, 64, 1))40inputs_fast = Lambda(datalayer, name='data_fast', arguments={'stride':1})(inputs)41inputs_slow = Lambda(datalayer, name='data_slow', arguments={'stride':4})(inputs)42fast, lateral = Fast_body(inputs_fast, layers, block)43slow = Slow_body(inputs_slow, lateral, layers, block)44x = Concatenate()([slow, fast])45x = Dropout(dropout)(x)46x = Dense(64, activation='relu')(x)47x = Dense(32, activation='relu')(x)48out = Dense(1, activation='sigmoid')(x)49return Model(inputs, out)50
51
52def Fast_body(x, layers, block):53fast_inplanes = 854lateral = []55x = Conv_BN_ReLU(8, kernel_size=(5, 3, 3), strides=(1, 2, 2))(x)56x = MaxPool3D(pool_size=(1, 3, 3), strides=(1, 2, 2), padding='same')(x)57lateral_p1 = Conv3D(8*2, kernel_size=(5, 1, 1), strides=(4, 1, 1), padding='same', use_bias=False)(x)58lateral.append(lateral_p1)59x, fast_inplanes = make_layer_fast(x, block, 8, layers[0], head_conv=3, fast_inplanes=fast_inplanes)60lateral_res2 = Conv3D(32*2, kernel_size=(5, 1, 1), strides=(4, 1, 1), padding='same', use_bias=False)(x)61lateral.append(lateral_res2)62x, fast_inplanes = make_layer_fast(x, block, 16, layers[1], stride=2, head_conv=3, fast_inplanes=fast_inplanes)63lateral_res3 = Conv3D(64*2, kernel_size=(5, 1, 1), strides=(4, 1, 1), padding='same', use_bias=False)(x)64lateral.append(lateral_res3)65x, fast_inplanes = make_layer_fast(x, block, 32, layers[2], stride=2, head_conv=3, fast_inplanes=fast_inplanes)66lateral_res4 = Conv3D(128*2, kernel_size=(5, 1, 1), strides=(4, 1, 1), padding='same', use_bias=False)(x)67lateral.append(lateral_res4)68x, fast_inplanes = make_layer_fast(x, block, 64, layers[3], stride=2, head_conv=3, fast_inplanes=fast_inplanes)69x = GlobalAveragePooling3D()(x)70return x, lateral71
72
73def Slow_body(x, lateral, layers, block):74slow_inplanes = 64 + 64//8*275x = Conv_BN_ReLU(32, kernel_size=(1, 3, 3), strides=(1, 2, 2))(x)76x = MaxPool3D(pool_size=(1, 3, 3), strides=(1, 2, 2), padding='same')(x)77x = Concatenate()([x, lateral[0]])78x, slow_inplanes = make_layer_slow(x, block, 32, layers[0], head_conv=1, slow_inplanes=slow_inplanes)79x = Concatenate()([x, lateral[1]])80x, slow_inplanes = make_layer_slow(x, block, 64, layers[1], stride=2, head_conv=1, slow_inplanes=slow_inplanes)81x = Concatenate()([x, lateral[2]])82x, slow_inplanes = make_layer_slow(x, block, 128, layers[2], stride=2, head_conv=1, slow_inplanes=slow_inplanes)83x = Concatenate()([x, lateral[3]])84x, slow_inplanes = make_layer_slow(x, block, 256, layers[3], stride=2, head_conv=1, slow_inplanes=slow_inplanes)85x = GlobalAveragePooling3D()(x)86return x87
88
89def make_layer_fast(x, block, planes, blocks, stride=1, head_conv=1, fast_inplanes=8, block_expansion=4):90downsample = None91if stride != 1 or fast_inplanes != planes * block_expansion:92downsample = Sequential([93Conv3D(planes*block_expansion, kernel_size=1, strides=(1, stride, stride), use_bias=False),94BatchNormalization()95])96fast_inplanes = planes * block_expansion97x = block(x, planes, stride, downsample=downsample, head_conv=head_conv)98for _ in range(1, blocks):99x = block(x, planes, head_conv=head_conv)100return x, fast_inplanes101
102
103def make_layer_slow(x, block, planes, blocks, stride=1, head_conv=1, slow_inplanes=80, block_expansion=4):104downsample = None105if stride != 1 or slow_inplanes != planes * block_expansion:106downsample = Sequential([107Conv3D(planes*block_expansion, kernel_size=1, strides = (1, stride, stride), use_bias=False),108BatchNormalization()109])110x = block(x, planes, stride, downsample, head_conv=head_conv)111for _ in range(1, blocks):112x = block(x, planes, head_conv=head_conv)113slow_inplanes = planes * block_expansion + planes * block_expansion//8*2114return x, slow_inplanes115