google-research
122 строки · 4.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Script to run AssembleNets without object input."""
17import json18
19from absl import app20from absl import flags21
22import numpy as np23
24import tensorflow as tf # tf25
26from assemblenet import assemblenet27from assemblenet import assemblenet_plus28from assemblenet import assemblenet_plus_lite29
30from assemblenet import model_structures31
32
33flags.DEFINE_string('precision', 'float32',34'Precision to use; one of: {bfloat16, float32}.')35flags.DEFINE_integer('num_frames', 64, 'Number of frames to use.')36
37flags.DEFINE_integer('num_classes', 157,38'Number of classes. 157 is for Charades')39
40
41flags.DEFINE_string('assemblenet_mode', 'assemblenet',42'"assemblenet" or "assemblenet_plus" or "assemblenet_plus_lite"') # pylint: disable=line-too-long43
44flags.DEFINE_string('model_structure', '[-1,1]',45'AssembleNet model structure in the string format.')46flags.DEFINE_string(47'model_edge_weights', '[]',48'AssembleNet model structure connection weights in the string format.')49
50flags.DEFINE_string('attention_mode', None, '"peer" or "self" or None')51
52flags.DEFINE_float('dropout_keep_prob', None, 'Keep ratio for dropout.')53flags.DEFINE_bool(54'max_pool_preditions', False,55'Use max-pooling on predictions instead of mean pooling on features. It helps if you have more than 32 frames.') # pylint: disable=line-too-long56
57flags.DEFINE_bool('use_object_input', False,58'Whether to use object input for AssembleNet++ or not') # pylint: disable=line-too-long59flags.DEFINE_integer('num_object_classes', 151,60'Number of object classes, when using object inputs. 151 is for ADE-20k') # pylint: disable=line-too-long61
62
63FLAGS = flags.FLAGS64
65
66def main(_):67# Create model.68
69batch_size = 270image_size = 25671
72vid_placeholder = tf.placeholder(tf.float32,73(batch_size, FLAGS.num_frames, image_size, image_size, 3)) # pylint: disable=line-too-long74
75if FLAGS.assemblenet_mode == 'assemblenet_plus_lite':76FLAGS.model_structure = json.dumps(model_structures.asnp_lite_structure)77FLAGS.model_edge_weights = json.dumps(model_structures.asnp_lite_structure_weights) # pylint: disable=line-too-long78
79network = assemblenet_plus_lite.assemblenet_plus_lite(80num_layers=[3, 5, 11, 7],81num_classes=FLAGS.num_classes,82data_format='channels_last')83else:84vid_placeholder = tf.reshape(vid_placeholder,85[batch_size*FLAGS.num_frames, image_size, image_size, 3]) # pylint: disable=line-too-long86
87if FLAGS.assemblenet_mode == 'assemblenet_plus':88# Here, we are using model_structures.asn50_structure for AssembleNet++89# instead of full_asnp50_structure. By using asn50_structure, it90# essentially becomes AssembleNet++ without objects, only requiring RGB91# inputs (and optical flow to be computed inside the model).92FLAGS.model_structure = json.dumps(model_structures.asn50_structure)93FLAGS.model_edge_weights = json.dumps(model_structures.asn_structure_weights) # pylint: disable=line-too-long94
95network = assemblenet_plus.assemblenet_plus(96assemblenet_depth=50,97num_classes=FLAGS.num_classes,98data_format='channels_last')99else:100FLAGS.model_structure = json.dumps(model_structures.asn50_structure)101FLAGS.model_edge_weights = json.dumps(model_structures.asn_structure_weights) # pylint: disable=line-too-long102
103network = assemblenet.assemblenet_v1(104assemblenet_depth=50,105num_classes=FLAGS.num_classes,106data_format='channels_last')107
108# The model function takes the inputs and is_training.109outputs = network(vid_placeholder, False)110
111with tf.Session() as sess:112# Generate a random video to run on.113# This should be replaced by a real video.114vid = np.random.rand(*vid_placeholder.shape)115sess.run(tf.global_variables_initializer())116logits = sess.run(outputs, feed_dict={vid_placeholder: vid})117print(logits)118print(np.argmax(logits, axis=1))119
120
121if __name__ == '__main__':122app.run(main)123