google-research
70 строк · 2.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Sign language sequence tagging keras model."""
17
18import tensorflow as tf19
20from sign_language_detection.args import FLAGS21
22
23def input_size():24"""Calculate the size of the input pose by desired components."""25points = 026if 'pose_keypoints_2d' in FLAGS.input_components:27points += 2528if 'face_keypoints_2d' in FLAGS.input_components:29points += 7030if 'hand_left_keypoints_2d' in FLAGS.input_components:31points += 2132if 'hand_right_keypoints_2d' in FLAGS.input_components:33points += 2134return points35
36
37def get_model():38"""Create keras sequential model following the hyperparameters."""39
40model = tf.keras.Sequential(name='tgt')41
42# model.add(SequenceMasking()) # Mask padded sequences43model.add(tf.keras.layers.Dropout(44FLAGS.input_dropout)) # Random feature dropout45
46# Add LSTM47for _ in range(FLAGS.encoder_layers):48rnn = tf.keras.layers.LSTM(FLAGS.hidden_size, return_sequences=True)49if FLAGS.encoder_bidirectional:50rnn = tf.keras.layers.Bidirectional(rnn)51model.add(rnn)52
53# Project and normalize to labels space54model.add(tf.keras.layers.Dense(2, activation='softmax'))55
56return model57
58
59def build_model():60"""Apply input shape, loss, optimizer, and metric to the model."""61model = get_model()62model.build(input_shape=(None, None, input_size()))63model.compile(64loss='sparse_categorical_crossentropy',65optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate),66metrics=['accuracy'],67)68model.summary()69
70return model71