google-research
241 строка · 8.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main script for dense/sparse inference."""
17import sys
18import time
19from absl import app
20from absl import flags
21from absl import logging
22import numpy as np
23import tensorflow.compat.v1 as tf
24
25from sgk import driver
26from sgk.mbv1 import config
27from sgk.mbv1 import mobilenet_builder
28
29# Crop padding for ImageNet preprocessing.
30CROP_PADDING = 32
31
32# Mean & stddev for ImageNet preprocessing.
33MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
34STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
35
36FLAGS = flags.FLAGS
37
38flags.DEFINE_string("runmode", "examples",
39"Running mode: examples or imagenet.")
40
41flags.DEFINE_string("ckpt_dir", "/tmp/ckpt/", "Checkpoint folders")
42
43flags.DEFINE_integer("num_images", 5000, "Number of images to eval.")
44
45flags.DEFINE_string("imagenet_glob", None, "ImageNet eval image glob.")
46
47flags.DEFINE_string("imagenet_label", None, "ImageNet eval label file path.")
48
49flags.DEFINE_float("width", 1.0, "Width for MobileNetV1 model.")
50
51flags.DEFINE_float("sparsity", 0.0, "Sparsity for MobileNetV1 model.")
52
53flags.DEFINE_bool("fuse_bnbr", False, "Whether to fuse batch norm, bias, relu.")
54
55flags.DEFINE_integer("inner_steps", 1000, "Benchmark steps for inner loop.")
56
57flags.DEFINE_integer("outer_steps", 100, "Benchmark steps for outer loop.")
58
59# Disable TF2.
60tf.disable_v2_behavior()
61
62
63class InferenceDriver(driver.Driver):
64"""Custom inference driver for MBV1."""
65
66def __init__(self, cfg):
67super(InferenceDriver, self).__init__(batch_size=1, image_size=224)
68self.num_classes = 1000
69self.cfg = cfg
70
71def build_model(self, features):
72with tf.device("gpu"):
73# Transpose the input features from NHWC to NCHW.
74features = tf.transpose(features, [0, 3, 1, 2])
75
76# Apply image preprocessing.
77features -= tf.constant(MEAN_RGB, shape=[3, 1, 1], dtype=features.dtype)
78features /= tf.constant(STDDEV_RGB, shape=[3, 1, 1], dtype=features.dtype)
79
80logits = mobilenet_builder.build_model(features, cfg=self.cfg)
81probs = tf.nn.softmax(logits)
82return tf.squeeze(probs)
83
84def preprocess_fn(self, image_bytes, image_size):
85"""Preprocesses the given image for evaluation.
86
87Args:
88image_bytes: `Tensor` representing an image binary of arbitrary size.
89image_size: image size.
90
91Returns:
92A preprocessed image `Tensor`.
93"""
94shape = tf.image.extract_jpeg_shape(image_bytes)
95image_height = shape[0]
96image_width = shape[1]
97
98padded_center_crop_size = tf.cast(
99((image_size / (image_size + CROP_PADDING)) *
100tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
101
102offset_height = ((image_height - padded_center_crop_size) + 1) // 2
103offset_width = ((image_width - padded_center_crop_size) + 1) // 2
104crop_window = tf.stack([
105offset_height, offset_width, padded_center_crop_size,
106padded_center_crop_size
107])
108
109image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
110image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
111
112image = tf.reshape(image, [image_size, image_size, 3])
113image = tf.image.convert_image_dtype(image, dtype=tf.float32)
114return image
115
116def run_inference(self, ckpt_dir, image_files, labels):
117with tf.Graph().as_default(), tf.Session() as sess:
118images, labels = self.build_dataset(image_files, labels)
119probs = self.build_model(images)
120if isinstance(probs, tuple):
121probs = probs[0]
122
123self.restore_model(sess, ckpt_dir)
124
125prediction_idx = []
126prediction_prob = []
127for i in range(len(image_files)):
128# Run inference.
129out_probs = sess.run(probs)
130
131idx = np.argsort(out_probs)[::-1]
132prediction_idx.append(idx[:5])
133prediction_prob.append([out_probs[pid] for pid in idx[:5]])
134
135if i % 1000 == 0:
136logging.error("Processed %d images.", i)
137
138# Return the top 5 predictions (idx and prob) for each image.
139return prediction_idx, prediction_prob
140
141def imagenet(self, ckpt_dir, imagenet_eval_glob, imagenet_eval_label,
142num_images):
143"""Eval ImageNet images and report top1/top5 accuracy.
144
145Args:
146ckpt_dir: str. Checkpoint directory path.
147imagenet_eval_glob: str. File path glob for all eval images.
148imagenet_eval_label: str. File path for eval label.
149num_images: int. Number of images to eval: -1 means eval the whole
150dataset.
151
152Returns:
153A tuple (top1, top5) for top1 and top5 accuracy.
154"""
155imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)]
156imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob))
157if num_images < 0:
158num_images = len(imagenet_filenames)
159image_files = imagenet_filenames[:num_images]
160labels = imagenet_val_labels[:num_images]
161
162pred_idx, _ = self.run_inference(ckpt_dir, image_files, labels)
163top1_cnt, top5_cnt = 0.0, 0.0
164for i, label in enumerate(labels):
165top1_cnt += label in pred_idx[i][:1]
166top5_cnt += label in pred_idx[i][:5]
167if i % 100 == 0:
168print("Step {}: top1_acc = {:4.2f}% top5_acc = {:4.2f}%".format(
169i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1)))
170sys.stdout.flush()
171top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images
172print("Final: top1_acc = {:4.2f}% top5_acc = {:4.2f}%".format(top1, top5))
173return top1, top5
174
175def benchmark(self, ckpt_dir, outer_steps=100, inner_steps=1000):
176"""Run repeatedly on dummy data to benchmark inference."""
177# Turn off Grappler optimizations.
178options = {"disable_meta_optimizer": True}
179tf.config.optimizer.set_experimental_options(options)
180
181# Run only the model body (no data pipeline) on device.
182features = tf.zeros([1, 3, self.image_size, self.image_size],
183dtype=tf.float32)
184
185# Create the model outside the loop body.
186model = mobilenet_builder.mobilenet_generator(self.cfg)
187
188# Call the model once to initialize the variables. Note that
189# this should never execute.
190dummy_iteration = model(features)
191
192# Run the function body in a loop to amortize session overhead.
193loop_index = tf.zeros([], dtype=tf.int32)
194initial_probs = tf.zeros([self.num_classes])
195
196def loop_cond(idx, _):
197return tf.less(idx, tf.constant(inner_steps, dtype=tf.int32))
198
199def loop_body(idx, _):
200logits = model(features)
201probs = tf.squeeze(tf.nn.softmax(logits))
202return idx + 1, probs
203
204benchmark_op = tf.while_loop(
205loop_cond,
206loop_body, [loop_index, initial_probs],
207parallel_iterations=1,
208back_prop=False)
209
210with tf.Session() as sess:
211self.restore_model(sess, ckpt_dir)
212fps = []
213for idx in range(outer_steps):
214start_time = time.time()
215sess.run(benchmark_op)
216elapsed_time = time.time() - start_time
217fps.append(inner_steps / elapsed_time)
218logging.error("Iterations %d processed %f FPS.", idx, fps[-1])
219# Skip the first iteration where all the setup and allocation happens.
220fps = np.asarray(fps[1:])
221logging.error("Mean, Std, Max, Min throughput = %f, %f, %f, %f",
222np.mean(fps), np.std(fps), fps.max(), fps.min())
223
224
225def main(_):
226logging.set_verbosity(logging.ERROR)
227cfg_cls = config.get_config(FLAGS.width, FLAGS.sparsity)
228cfg = cfg_cls(FLAGS.fuse_bnbr)
229drv = InferenceDriver(cfg)
230
231if FLAGS.runmode == "imagenet":
232drv.imagenet(FLAGS.ckpt_dir, FLAGS.imagenet_glob, FLAGS.imagenet_label,
233FLAGS.num_images)
234elif FLAGS.runmode == "benchmark":
235drv.benchmark(FLAGS.ckpt_dir, FLAGS.outer_steps, FLAGS.inner_steps)
236else:
237logging.error("Must specify runmode: 'benchmark' or 'imagenet'")
238
239
240if __name__ == "__main__":
241app.run(main)
242