google-research

Форк
0
122 строки · 4.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Script to run AssembleNets without object input."""
17
import json
18

19
from absl import app
20
from absl import flags
21

22
import numpy as np
23

24
import tensorflow as tf  # tf
25

26
from assemblenet import assemblenet
27
from assemblenet import assemblenet_plus
28
from assemblenet import assemblenet_plus_lite
29

30
from assemblenet import model_structures
31

32

33
flags.DEFINE_string('precision', 'float32',
34
                    'Precision to use; one of: {bfloat16, float32}.')
35
flags.DEFINE_integer('num_frames', 64, 'Number of frames to use.')
36

37
flags.DEFINE_integer('num_classes', 157,
38
                     'Number of classes. 157 is for Charades')
39

40

41
flags.DEFINE_string('assemblenet_mode', 'assemblenet',
42
                    '"assemblenet" or "assemblenet_plus" or "assemblenet_plus_lite"')  # pylint: disable=line-too-long
43

44
flags.DEFINE_string('model_structure', '[-1,1]',
45
                    'AssembleNet model structure in the string format.')
46
flags.DEFINE_string(
47
    'model_edge_weights', '[]',
48
    'AssembleNet model structure connection weights in the string format.')
49

50
flags.DEFINE_string('attention_mode', None, '"peer" or "self" or None')
51

52
flags.DEFINE_float('dropout_keep_prob', None, 'Keep ratio for dropout.')
53
flags.DEFINE_bool(
54
    'max_pool_preditions', False,
55
    'Use max-pooling on predictions instead of mean pooling on features. It helps if you have more than 32 frames.')  # pylint: disable=line-too-long
56

57
flags.DEFINE_bool('use_object_input', False,
58
                  'Whether to use object input for AssembleNet++ or not')  # pylint: disable=line-too-long
59
flags.DEFINE_integer('num_object_classes', 151,
60
                     'Number of object classes, when using object inputs. 151 is for ADE-20k')  # pylint: disable=line-too-long
61

62

63
FLAGS = flags.FLAGS
64

65

66
def main(_):
67
  # Create model.
68

69
  batch_size = 2
70
  image_size = 256
71

72
  vid_placeholder = tf.placeholder(tf.float32,
73
                                   (batch_size, FLAGS.num_frames, image_size, image_size, 3))  # pylint: disable=line-too-long
74

75
  if FLAGS.assemblenet_mode == 'assemblenet_plus_lite':
76
    FLAGS.model_structure = json.dumps(model_structures.asnp_lite_structure)
77
    FLAGS.model_edge_weights = json.dumps(model_structures.asnp_lite_structure_weights)  # pylint: disable=line-too-long
78

79
    network = assemblenet_plus_lite.assemblenet_plus_lite(
80
        num_layers=[3, 5, 11, 7],
81
        num_classes=FLAGS.num_classes,
82
        data_format='channels_last')
83
  else:
84
    vid_placeholder = tf.reshape(vid_placeholder,
85
                                 [batch_size*FLAGS.num_frames, image_size, image_size, 3])  # pylint: disable=line-too-long
86

87
    if FLAGS.assemblenet_mode == 'assemblenet_plus':
88
      # Here, we are using model_structures.asn50_structure for AssembleNet++
89
      # instead of full_asnp50_structure. By using asn50_structure, it
90
      # essentially becomes AssembleNet++ without objects, only requiring RGB
91
      # inputs (and optical flow to be computed inside the model).
92
      FLAGS.model_structure = json.dumps(model_structures.asn50_structure)
93
      FLAGS.model_edge_weights = json.dumps(model_structures.asn_structure_weights)  # pylint: disable=line-too-long
94

95
      network = assemblenet_plus.assemblenet_plus(
96
          assemblenet_depth=50,
97
          num_classes=FLAGS.num_classes,
98
          data_format='channels_last')
99
    else:
100
      FLAGS.model_structure = json.dumps(model_structures.asn50_structure)
101
      FLAGS.model_edge_weights = json.dumps(model_structures.asn_structure_weights)  # pylint: disable=line-too-long
102

103
      network = assemblenet.assemblenet_v1(
104
          assemblenet_depth=50,
105
          num_classes=FLAGS.num_classes,
106
          data_format='channels_last')
107

108
  # The model function takes the inputs and is_training.
109
  outputs = network(vid_placeholder, False)
110

111
  with tf.Session() as sess:
112
    # Generate a random video to run on.
113
    # This should be replaced by a real video.
114
    vid = np.random.rand(*vid_placeholder.shape)
115
    sess.run(tf.global_variables_initializer())
116
    logits = sess.run(outputs, feed_dict={vid_placeholder: vid})
117
    print(logits)
118
    print(np.argmax(logits, axis=1))
119

120

121
if __name__ == '__main__':
122
  app.run(main)
123

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.