google-research

Форк
0
241 строка · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main script for dense/sparse inference."""
17
import sys
18
import time
19
from absl import app
20
from absl import flags
21
from absl import logging
22
import numpy as np
23
import tensorflow.compat.v1 as tf
24

25
from sgk import driver
26
from sgk.mbv1 import config
27
from sgk.mbv1 import mobilenet_builder
28

29
# Crop padding for ImageNet preprocessing.
30
CROP_PADDING = 32
31

32
# Mean & stddev for ImageNet preprocessing.
33
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
34
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
35

36
FLAGS = flags.FLAGS
37

38
flags.DEFINE_string("runmode", "examples",
39
                    "Running mode: examples or imagenet.")
40

41
flags.DEFINE_string("ckpt_dir", "/tmp/ckpt/", "Checkpoint folders")
42

43
flags.DEFINE_integer("num_images", 5000, "Number of images to eval.")
44

45
flags.DEFINE_string("imagenet_glob", None, "ImageNet eval image glob.")
46

47
flags.DEFINE_string("imagenet_label", None, "ImageNet eval label file path.")
48

49
flags.DEFINE_float("width", 1.0, "Width for MobileNetV1 model.")
50

51
flags.DEFINE_float("sparsity", 0.0, "Sparsity for MobileNetV1 model.")
52

53
flags.DEFINE_bool("fuse_bnbr", False, "Whether to fuse batch norm, bias, relu.")
54

55
flags.DEFINE_integer("inner_steps", 1000, "Benchmark steps for inner loop.")
56

57
flags.DEFINE_integer("outer_steps", 100, "Benchmark steps for outer loop.")
58

59
# Disable TF2.
60
tf.disable_v2_behavior()
61

62

63
class InferenceDriver(driver.Driver):
64
  """Custom inference driver for MBV1."""
65

66
  def __init__(self, cfg):
67
    super(InferenceDriver, self).__init__(batch_size=1, image_size=224)
68
    self.num_classes = 1000
69
    self.cfg = cfg
70

71
  def build_model(self, features):
72
    with tf.device("gpu"):
73
      # Transpose the input features from NHWC to NCHW.
74
      features = tf.transpose(features, [0, 3, 1, 2])
75

76
      # Apply image preprocessing.
77
      features -= tf.constant(MEAN_RGB, shape=[3, 1, 1], dtype=features.dtype)
78
      features /= tf.constant(STDDEV_RGB, shape=[3, 1, 1], dtype=features.dtype)
79

80
      logits = mobilenet_builder.build_model(features, cfg=self.cfg)
81
      probs = tf.nn.softmax(logits)
82
      return tf.squeeze(probs)
83

84
  def preprocess_fn(self, image_bytes, image_size):
85
    """Preprocesses the given image for evaluation.
86

87
    Args:
88
      image_bytes: `Tensor` representing an image binary of arbitrary size.
89
      image_size: image size.
90

91
    Returns:
92
      A preprocessed image `Tensor`.
93
    """
94
    shape = tf.image.extract_jpeg_shape(image_bytes)
95
    image_height = shape[0]
96
    image_width = shape[1]
97

98
    padded_center_crop_size = tf.cast(
99
        ((image_size / (image_size + CROP_PADDING)) *
100
         tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
101

102
    offset_height = ((image_height - padded_center_crop_size) + 1) // 2
103
    offset_width = ((image_width - padded_center_crop_size) + 1) // 2
104
    crop_window = tf.stack([
105
        offset_height, offset_width, padded_center_crop_size,
106
        padded_center_crop_size
107
    ])
108

109
    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
110
    image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
111

112
    image = tf.reshape(image, [image_size, image_size, 3])
113
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
114
    return image
115

116
  def run_inference(self, ckpt_dir, image_files, labels):
117
    with tf.Graph().as_default(), tf.Session() as sess:
118
      images, labels = self.build_dataset(image_files, labels)
119
      probs = self.build_model(images)
120
      if isinstance(probs, tuple):
121
        probs = probs[0]
122

123
      self.restore_model(sess, ckpt_dir)
124

125
      prediction_idx = []
126
      prediction_prob = []
127
      for i in range(len(image_files)):
128
        # Run inference.
129
        out_probs = sess.run(probs)
130

131
        idx = np.argsort(out_probs)[::-1]
132
        prediction_idx.append(idx[:5])
133
        prediction_prob.append([out_probs[pid] for pid in idx[:5]])
134

135
        if i % 1000 == 0:
136
          logging.error("Processed %d images.", i)
137

138
      # Return the top 5 predictions (idx and prob) for each image.
139
      return prediction_idx, prediction_prob
140

141
  def imagenet(self, ckpt_dir, imagenet_eval_glob, imagenet_eval_label,
142
               num_images):
143
    """Eval ImageNet images and report top1/top5 accuracy.
144

145
    Args:
146
      ckpt_dir: str. Checkpoint directory path.
147
      imagenet_eval_glob: str. File path glob for all eval images.
148
      imagenet_eval_label: str. File path for eval label.
149
      num_images: int. Number of images to eval: -1 means eval the whole
150
        dataset.
151

152
    Returns:
153
      A tuple (top1, top5) for top1 and top5 accuracy.
154
    """
155
    imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)]
156
    imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob))
157
    if num_images < 0:
158
      num_images = len(imagenet_filenames)
159
    image_files = imagenet_filenames[:num_images]
160
    labels = imagenet_val_labels[:num_images]
161

162
    pred_idx, _ = self.run_inference(ckpt_dir, image_files, labels)
163
    top1_cnt, top5_cnt = 0.0, 0.0
164
    for i, label in enumerate(labels):
165
      top1_cnt += label in pred_idx[i][:1]
166
      top5_cnt += label in pred_idx[i][:5]
167
      if i % 100 == 0:
168
        print("Step {}: top1_acc = {:4.2f}%  top5_acc = {:4.2f}%".format(
169
            i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1)))
170
        sys.stdout.flush()
171
    top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images
172
    print("Final: top1_acc = {:4.2f}%  top5_acc = {:4.2f}%".format(top1, top5))
173
    return top1, top5
174

175
  def benchmark(self, ckpt_dir, outer_steps=100, inner_steps=1000):
176
    """Run repeatedly on dummy data to benchmark inference."""
177
    # Turn off Grappler optimizations.
178
    options = {"disable_meta_optimizer": True}
179
    tf.config.optimizer.set_experimental_options(options)
180

181
    # Run only the model body (no data pipeline) on device.
182
    features = tf.zeros([1, 3, self.image_size, self.image_size],
183
                        dtype=tf.float32)
184

185
    # Create the model outside the loop body.
186
    model = mobilenet_builder.mobilenet_generator(self.cfg)
187

188
    # Call the model once to initialize the variables. Note that
189
    # this should never execute.
190
    dummy_iteration = model(features)
191

192
    # Run the function body in a loop to amortize session overhead.
193
    loop_index = tf.zeros([], dtype=tf.int32)
194
    initial_probs = tf.zeros([self.num_classes])
195

196
    def loop_cond(idx, _):
197
      return tf.less(idx, tf.constant(inner_steps, dtype=tf.int32))
198

199
    def loop_body(idx, _):
200
      logits = model(features)
201
      probs = tf.squeeze(tf.nn.softmax(logits))
202
      return idx + 1, probs
203

204
    benchmark_op = tf.while_loop(
205
        loop_cond,
206
        loop_body, [loop_index, initial_probs],
207
        parallel_iterations=1,
208
        back_prop=False)
209

210
    with tf.Session() as sess:
211
      self.restore_model(sess, ckpt_dir)
212
      fps = []
213
      for idx in range(outer_steps):
214
        start_time = time.time()
215
        sess.run(benchmark_op)
216
        elapsed_time = time.time() - start_time
217
        fps.append(inner_steps / elapsed_time)
218
        logging.error("Iterations %d processed %f FPS.", idx, fps[-1])
219
      # Skip the first iteration where all the setup and allocation happens.
220
      fps = np.asarray(fps[1:])
221
      logging.error("Mean, Std, Max, Min throughput = %f, %f, %f, %f",
222
                    np.mean(fps), np.std(fps), fps.max(), fps.min())
223

224

225
def main(_):
226
  logging.set_verbosity(logging.ERROR)
227
  cfg_cls = config.get_config(FLAGS.width, FLAGS.sparsity)
228
  cfg = cfg_cls(FLAGS.fuse_bnbr)
229
  drv = InferenceDriver(cfg)
230

231
  if FLAGS.runmode == "imagenet":
232
    drv.imagenet(FLAGS.ckpt_dir, FLAGS.imagenet_glob, FLAGS.imagenet_label,
233
                 FLAGS.num_images)
234
  elif FLAGS.runmode == "benchmark":
235
    drv.benchmark(FLAGS.ckpt_dir, FLAGS.outer_steps, FLAGS.inner_steps)
236
  else:
237
    logging.error("Must specify runmode: 'benchmark' or 'imagenet'")
238

239

240
if __name__ == "__main__":
241
  app.run(main)
242

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.