google-research

Форк
0
70 строк · 2.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Sign language sequence tagging keras model."""
17

18
import tensorflow as tf
19

20
from sign_language_detection.args import FLAGS
21

22

23
def input_size():
24
  """Calculate the size of the input pose by desired components."""
25
  points = 0
26
  if 'pose_keypoints_2d' in FLAGS.input_components:
27
    points += 25
28
  if 'face_keypoints_2d' in FLAGS.input_components:
29
    points += 70
30
  if 'hand_left_keypoints_2d' in FLAGS.input_components:
31
    points += 21
32
  if 'hand_right_keypoints_2d' in FLAGS.input_components:
33
    points += 21
34
  return points
35

36

37
def get_model():
38
  """Create keras sequential model following the hyperparameters."""
39

40
  model = tf.keras.Sequential(name='tgt')
41

42
  # model.add(SequenceMasking())  # Mask padded sequences
43
  model.add(tf.keras.layers.Dropout(
44
      FLAGS.input_dropout))  # Random feature dropout
45

46
  # Add LSTM
47
  for _ in range(FLAGS.encoder_layers):
48
    rnn = tf.keras.layers.LSTM(FLAGS.hidden_size, return_sequences=True)
49
    if FLAGS.encoder_bidirectional:
50
      rnn = tf.keras.layers.Bidirectional(rnn)
51
    model.add(rnn)
52

53
  # Project and normalize to labels space
54
  model.add(tf.keras.layers.Dense(2, activation='softmax'))
55

56
  return model
57

58

59
def build_model():
60
  """Apply input shape, loss, optimizer, and metric to the model."""
61
  model = get_model()
62
  model.build(input_shape=(None, None, input_size()))
63
  model.compile(
64
      loss='sparse_categorical_crossentropy',
65
      optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate),
66
      metrics=['accuracy'],
67
  )
68
  model.summary()
69

70
  return model
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.