google-research

Форк
0
/
run_asn_with_object.py 
100 строк · 3.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Script to run AssembleNet++ with objects."""
17
import json
18

19
from absl import app
20
from absl import flags
21

22
import numpy as np
23

24
import tensorflow as tf  # tf
25

26
from assemblenet import assemblenet_plus
27

28
from assemblenet import model_structures
29

30

31
flags.DEFINE_string('precision', 'float32',
32
                    'Precision to use; one of: {bfloat16, float32}.')
33
flags.DEFINE_integer('num_frames', 64, 'Number of frames to use.')
34

35
flags.DEFINE_integer('num_classes', 157,
36
                     'Number of classes. 157 is for Charades')
37

38

39
flags.DEFINE_string('assemblenet_mode', 'assemblenet_plus',
40
                    '"assemblenet" or "assemblenet_plus" or "assemblenet_plus_lite"')  # pylint: disable=line-too-long
41

42
flags.DEFINE_string('model_structure', '[-1,1]',
43
                    'AssembleNet model structure in the string format.')
44
flags.DEFINE_string(
45
    'model_edge_weights', '[]',
46
    'AssembleNet model structure connection weights in the string format.')
47

48
flags.DEFINE_string('attention_mode', 'peer', '"peer" or "self" or None')
49

50
flags.DEFINE_float('dropout_keep_prob', None, 'Keep ratio for dropout.')
51
flags.DEFINE_bool(
52
    'max_pool_preditions', False,
53
    'Use max-pooling on predictions instead of mean pooling on features. It helps if you have more than 32 frames.')  # pylint: disable=line-too-long
54

55
flags.DEFINE_bool('use_object_input', True,
56
                  'Whether to use object input for AssembleNet++ or not')  # pylint: disable=line-too-long
57
flags.DEFINE_integer('num_object_classes', 151,
58
                     'Number of object classes, when using object inputs. 151 is for ADE-20k')  # pylint: disable=line-too-long
59

60

61
FLAGS = flags.FLAGS
62

63

64
def main(_):
65
  # Create model.
66

67
  batch_size = 2
68
  image_size = 256
69

70
  vid_placeholder = tf.placeholder(tf.float32,
71
                                   (batch_size*FLAGS.num_frames, image_size, image_size, 3))  # pylint: disable=line-too-long
72
  object_placeholder = tf.placeholder(tf.float32,
73
                                      (batch_size*FLAGS.num_frames, image_size, image_size, FLAGS.num_object_classes))  # pylint: disable=line-too-long
74
  input_placeholder = (vid_placeholder, object_placeholder)
75

76
  # We are using the full_asnp50_structure, since we feed both video and object.
77
  FLAGS.model_structure = json.dumps(model_structures.full_asnp50_structure)  # pylint: disable=line-too-long
78
  FLAGS.model_edge_weights = json.dumps(model_structures.full_asnp_structure_weights)  # pylint: disable=line-too-long
79

80
  network = assemblenet_plus.assemblenet_plus(
81
      assemblenet_depth=50,
82
      num_classes=FLAGS.num_classes,
83
      data_format='channels_last')
84

85
  # The model function takes the inputs and is_training.
86
  outputs = network(input_placeholder, False)
87

88
  with tf.Session() as sess:
89
    # Generate a random video to run on.
90
    # This should be replaced by a real video.
91
    sess.run(tf.global_variables_initializer())
92
    vid = np.random.rand(*vid_placeholder.shape)
93
    obj = np.random.rand(*object_placeholder.shape)
94
    logits = sess.run(outputs, feed_dict={input_placeholder: (vid, obj)})
95
    print(logits)
96
    print(np.argmax(logits, axis=1))
97

98

99
if __name__ == '__main__':
100
  app.run(main)
101

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.