google-research
100 строк · 3.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Script to run AssembleNet++ with objects."""
17import json18
19from absl import app20from absl import flags21
22import numpy as np23
24import tensorflow as tf # tf25
26from assemblenet import assemblenet_plus27
28from assemblenet import model_structures29
30
31flags.DEFINE_string('precision', 'float32',32'Precision to use; one of: {bfloat16, float32}.')33flags.DEFINE_integer('num_frames', 64, 'Number of frames to use.')34
35flags.DEFINE_integer('num_classes', 157,36'Number of classes. 157 is for Charades')37
38
39flags.DEFINE_string('assemblenet_mode', 'assemblenet_plus',40'"assemblenet" or "assemblenet_plus" or "assemblenet_plus_lite"') # pylint: disable=line-too-long41
42flags.DEFINE_string('model_structure', '[-1,1]',43'AssembleNet model structure in the string format.')44flags.DEFINE_string(45'model_edge_weights', '[]',46'AssembleNet model structure connection weights in the string format.')47
48flags.DEFINE_string('attention_mode', 'peer', '"peer" or "self" or None')49
50flags.DEFINE_float('dropout_keep_prob', None, 'Keep ratio for dropout.')51flags.DEFINE_bool(52'max_pool_preditions', False,53'Use max-pooling on predictions instead of mean pooling on features. It helps if you have more than 32 frames.') # pylint: disable=line-too-long54
55flags.DEFINE_bool('use_object_input', True,56'Whether to use object input for AssembleNet++ or not') # pylint: disable=line-too-long57flags.DEFINE_integer('num_object_classes', 151,58'Number of object classes, when using object inputs. 151 is for ADE-20k') # pylint: disable=line-too-long59
60
61FLAGS = flags.FLAGS62
63
64def main(_):65# Create model.66
67batch_size = 268image_size = 25669
70vid_placeholder = tf.placeholder(tf.float32,71(batch_size*FLAGS.num_frames, image_size, image_size, 3)) # pylint: disable=line-too-long72object_placeholder = tf.placeholder(tf.float32,73(batch_size*FLAGS.num_frames, image_size, image_size, FLAGS.num_object_classes)) # pylint: disable=line-too-long74input_placeholder = (vid_placeholder, object_placeholder)75
76# We are using the full_asnp50_structure, since we feed both video and object.77FLAGS.model_structure = json.dumps(model_structures.full_asnp50_structure) # pylint: disable=line-too-long78FLAGS.model_edge_weights = json.dumps(model_structures.full_asnp_structure_weights) # pylint: disable=line-too-long79
80network = assemblenet_plus.assemblenet_plus(81assemblenet_depth=50,82num_classes=FLAGS.num_classes,83data_format='channels_last')84
85# The model function takes the inputs and is_training.86outputs = network(input_placeholder, False)87
88with tf.Session() as sess:89# Generate a random video to run on.90# This should be replaced by a real video.91sess.run(tf.global_variables_initializer())92vid = np.random.rand(*vid_placeholder.shape)93obj = np.random.rand(*object_placeholder.shape)94logits = sess.run(outputs, feed_dict={input_placeholder: (vid, obj)})95print(logits)96print(np.argmax(logits, axis=1))97
98
99if __name__ == '__main__':100app.run(main)101