google-research

Форк
0
1044 строки · 38.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Functions for the forward pass (symbolic and decimal) of a neural network.
17

18
Given an image and a trained neural network this code does an smt encoding of
19
the forward pass of the neural network and further, employs z3 solver to
20
learn a mask for the inputs given the weights.
21
"""
22
import collections
23
import io
24
import math
25
import matplotlib.patches as patches
26
import matplotlib.pyplot as plt
27
import numpy as np
28
from PIL import Image
29
import skimage.draw as draw
30
import sklearn.metrics as metrics
31
import tensorflow.compat.v1 as tf
32
import tensorflow_datasets as tfds
33
import z3
34

35

36
tf.disable_eager_execution()
37

38

39
class OptimizerBase:
40
  """Creates a solver by using z3 solver.
41

42
  Attributes:
43
    z3_mask: list, contains mask bits as z3 vars.
44
    mask_sum: z3.ExprRef, sum of boolean mask bits.
45
    minimal_mask_sum: int, the minimum value of mask_sum which satisfying the
46
        smt constraints.
47
    solver: z3.Optimize, minimizes a mask_sum wrt smt constraints.
48

49
  Subclasses should define the generate_mask method.
50
  """
51

52
  def __init__(self, z3_mask):
53
    """Initializer.
54

55
    Args:
56
      z3_mask: list, contains mask bits as z3 vars.
57
    """
58
    self.z3_mask = z3_mask
59
    self.mask_sum = 0
60
    self.solver = z3.Optimize()
61
    for mask in self.z3_mask:
62
      self.solver.add(z3.Or(mask == 1, mask == 0))
63
      self.mask_sum += mask
64
    self.minimal_mask_sum = self.solver.minimize(self.mask_sum)
65

66
  def _optimize(self):
67
    """Solves the SMT constraints and returns the solution as a numpy array.
68

69
    Returns:
70
      z3_mask: float numpy array with shape (num_mask_variables,).
71
      result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
72
    """
73
    result = str(self.solver.check())
74
    z3_mask = np.zeros(len(self.z3_mask))
75
    if result != 'unknown':
76
      z3_assignment = self.solver.model()
77
      for var in z3_assignment.decls():
78
        z3_mask[int(str(var).split('_')[1])] = int(str(z3_assignment[var]))
79

80
      # Block the currently found solution so that for every call of optimize,
81
      # a unique mask is found.
82
      block = [var() != z3_assignment[var] for var in z3_assignment]
83
      self.solver.add(z3.Or(block))
84
    return z3_mask, result
85

86
  def generate_mask(self):
87
    """Constructs the mask with the same shape as that of data.
88

89
    Returns:
90
      mask: float numpy array.
91
      result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
92
    """
93
    raise NotImplementedError('Must be implemented by subclass.')
94

95
  def generator(self, num_unique_solutions):
96
    """Generates solutions from the optimizer.
97

98
    If the number of unique solutions is smaller than num_unique_solutions,
99
    the rest of the solutions are unsat.
100

101
    Args:
102
      num_unique_solutions: int, number of unique solutions you want to sample.
103

104
    Yields:
105
      mask: float numpy array.
106
      result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
107
    """
108
    for _ in range(num_unique_solutions):
109
      yield self.generate_mask()
110

111

112
class TextOptimizer(OptimizerBase):
113
  """Creates a solver for text by using z3 solver.
114
  """
115

116
  def __init__(self, z3_mask):
117
    """Initializer.
118

119
    Args:
120
      z3_mask: list, contains mask bits as z3 vars.
121
    """
122
    super().__init__(z3_mask=z3_mask)
123

124
  def generate_mask(self):
125
    """Constructs the mask with the same shape as that of data.
126

127
    Returns:
128
      mask: float numpy array with shape (num_mask_variables,).
129
      result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
130
    """
131
    # This method explicitly requires a masking variable for each input word
132
    # to the neural network. If a mask bit covers multiple words, then the
133
    # function has to be appropriately modified.
134
    return self._optimize()
135

136

137
class ImageOptimizer(OptimizerBase):
138
  """Creates a solver by using z3 solver.
139

140
  Attributes:
141
    edge_length: int, side length of the 2D array (image) whose pixels are to
142
        be masked.
143
    window_size: int, side length of the square mask.
144
  """
145

146
  def __init__(self, z3_mask, window_size, edge_length):
147
    """Initializer.
148

149
    Args:
150
      z3_mask: list, contains mask bits as z3 vars.
151
      window_size: int, side length of the square mask.
152
      edge_length: int, side length of the 2D array (image) whose pixels are to
153
          be masked.
154
    """
155
    super().__init__(z3_mask=z3_mask)
156
    self.edge_length = edge_length
157
    self.window_size = window_size
158

159
  def generate_mask(self):
160
    """Constructs a 2D mask with the same shape as that of image.
161

162
    Returns:
163
      mask: float numpy array with shape (edge_length, edge_length).
164
      result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
165
    """
166
    z3_mask, result = self._optimize()
167
    mask = np.zeros((self.edge_length, self.edge_length))
168
    num_masks_along_row = math.ceil(self.edge_length / self.window_size)
169
    for row in range(self.edge_length):
170
      for column in range(self.edge_length):
171
        mask_id = (
172
            num_masks_along_row * (row // self.window_size)) + (
173
                column // self.window_size)
174
        mask[row][column] = z3_mask[mask_id]
175
    return mask, result
176

177

178
def restore_model(model_path):
179
  """Restores a frozen tensorflow model into a tf session and returns it.
180

181
  Args:
182
    model_path: string, path to a tensorflow frozen graph.
183

184
  Returns:
185
    A tensorflow session.
186
  """
187
  session = tf.Session()
188
  tf.saved_model.loader.load(session, ['serve'], model_path)
189
  return session
190

191

192
def zero_pad(activation_map, padding):
193
  """Appends vectors of zeros on all the 4 sides of the image.
194

195
  Args:
196
    activation_map: list of list of z3.ExprRef, activation map to be 0-padded.
197
    padding: tuple, number of layers 0 padded vectors on top / left side of the
198
        image, number of layers 0 padded vectors on bottom / right side of the
199
        image.
200

201
  Returns:
202
    list of list of z3.ExrRef, 0 padded activation map.
203
  """
204
  num_rows = len(activation_map)
205
  num_columns = len(activation_map[0])
206

207
  # padded_activation_map has a shape - (num_padded_rows, num_padded_columns)
208
  padded_activation_map = []
209
  for _ in range(num_rows + padding[0] + padding[1]):
210
    padded_activation_map.append([0] * (num_columns + padding[0] + padding[1]))
211

212
  for i in range(num_rows):
213
    for j in range(num_columns):
214
      padded_activation_map[padding[0] + i][padding[0] +
215
                                            j] = activation_map[i][j]
216
  return padded_activation_map
217

218

219
def dot_product(input_activation_map, input_activation_map_row,
220
                input_activation_map_column, sliced_kernel):
221
  """Convolution operation for a convolution kernel and a patch in the image.
222

223
  Performs convolution on a square patch of the input_activation_map with
224
  (input_activation_map_row, input_activation_map_column) and
225
  (input_activation_map_row + kernel_rows - 1,
226
   input_activation_map_column + kernel_columns - 1) as the diagonal vertices.
227

228
  Args:
229
    input_activation_map: list of list of z3.ExprRef with dimensions
230
        (input_activation_map_size, input_activation_map_size).
231
    input_activation_map_row: int, row in the activation map for which the
232
        convolution is being performed.
233
    input_activation_map_column: int, column in the activation map for which
234
        convolution is being performed.
235
    sliced_kernel: numpy array with shape (kernel_rows, kernel_columns),
236
        2d slice of a kernel along input_channel.
237

238
  Returns:
239
    z3.ExprRef, dot product of the convolution kernel and a patch in the image.
240
  """
241
  convolution = 0
242
  for i in range(sliced_kernel.shape[0]):
243
    for j in range(sliced_kernel.shape[1]):
244
      convolution += (
245
          input_activation_map
246
          [input_activation_map_row + i][input_activation_map_column + j]
247
          * sliced_kernel[i][j])
248
  return convolution
249

250

251
def smt_convolution(input_activation_maps, kernels, kernel_biases, padding,
252
                    strides):
253
  """Performs convolution on symbolic inputs.
254

255
  Args:
256
    input_activation_maps: list of list of z3.ExprRef with dimensions
257
        (input_channels, input_activation_map_size, input_activation_map_size),
258
        input activation maps.
259
    kernels: numpy array with shape
260
        (kernel_size, kernel_size, input_channels, output_channels),
261
        weights of the convolution layer.
262
    kernel_biases: numpy array with shape (output_channels,), biases of the
263
        convolution layer.
264
    padding: tuple, number of layers 0 padded vectors on top/left side of the
265
        image.
266
    strides: int, number of pixel shifts over the input matrix.
267

268
  Returns:
269
    list of list of list of z3.ExprRef with dimensions (output_channels,
270
        output_activation_map_size, output_activation_map_size), convolutions.
271

272
  Raises:
273
    ValueError: If input_channels is inconsistent across
274
        input_activation_maps and kernels, or output_channels is inconsistent
275
        across kernels and kernel_biases, or padding is not a tuple, or padding
276
        isn't a tuple of size 2.
277
  """
278
  if len(input_activation_maps) != kernels.shape[2]:
279
    raise ValueError(
280
        'Input channels in inputs and kernels are not equal. Number of input '
281
        'channels in input: %d and kernels: %d' % (
282
            len(input_activation_maps), kernels.shape[2]))
283
  if not isinstance(padding, tuple) or len(padding) != 2:
284
    raise ValueError(
285
        'Padding should be a tuple with 2 dimensions. Input padding: %s' %
286
        padding)
287
  if kernels.shape[3] != kernel_biases.shape[0]:
288
    raise ValueError(
289
        'Output channels in kernels and biases are not equal. Number of output '
290
        'channels in kernels: %d and biases: %d' % (
291
            kernels.shape[3], kernel_biases.shape[0]))
292
  padded_input_activation_maps = []
293

294
  # reshape the kernels to
295
  # (output_channels, kernel_size, kernel_size, input_channels)
296
  kernels = np.moveaxis(kernels, -1, 0)
297
  for input_activation_map in input_activation_maps:
298
    padded_input_activation_maps.append(
299
        zero_pad(
300
            # (input_activation_map_size, input_activation_map_size)
301
            activation_map=input_activation_map,
302
            padding=padding))
303
  output_activation_maps = []
304
  output_activation_map_size = len(input_activation_maps[0]) // strides
305
  # Iterate over output_channels.
306
  for kernel, kernel_bias in zip(kernels, kernel_biases):
307
    output_activation_map = np.full(
308
        (output_activation_map_size, output_activation_map_size),
309
        kernel_bias).tolist()
310
    for i in range(output_activation_map_size):
311
      for j in range(output_activation_map_size):
312
        for channel_in in range(kernel.shape[-1]):
313
          output_activation_map[i][j] += dot_product(
314
              input_activation_map=padded_input_activation_maps[channel_in],
315
              input_activation_map_row=strides * i,
316
              input_activation_map_column=strides * j,
317
              sliced_kernel=kernel[:, :, channel_in])
318
    output_activation_maps.append(output_activation_map)
319
  return output_activation_maps
320

321

322
def flatten_nested_lists(activation_maps):
323
  """Flattens a nested list of depth 3 in a row major order.
324

325
  Args:
326
    activation_maps: list of list of list of z3.ExprRef with dimensions
327
        (channels, activation_map_size, activation_map_size), activation_maps.
328

329
  Returns:
330
    list of z3.ExprRef.
331
  """
332
  flattened_activation_maps = []
333
  for activation_map in activation_maps:
334
    for activation_map_row in activation_map:
335
      flattened_activation_maps.extend(activation_map_row)
336
  return flattened_activation_maps
337

338

339
def z3_relu(x):
340
  """Relu activation function.
341

342
  max(0, x).
343

344
  Args:
345
    x: z3.ExprRef, z3 Expression.
346

347
  Returns:
348
    z3.ExprRef.
349
  """
350
  return z3.If(x > 0, x, 0)
351

352

353
def _verify_lengths(weights, biases, activations):
354
  """Verifies the lengths of the weights, biases, and activations are equal.
355

356
  Args:
357
    weights: list of float numpy array with shape (output_dim, input_dim) and
358
        length num_layers, weights of the neural network.
359
    biases: list of float numpy array with shape (output_dim,) and length
360
        num_layers, biases of the neural network.
361
    activations: list of string with length num_layers, activations for each
362
        hidden layer.
363

364
  Raises:
365
    ValueError: If lengths of weights, biases, and activations are not equal.
366
  """
367
  if not len(weights) == len(biases) == len(activations):
368
    raise ValueError('Lengths of weights, biases and activations should be the '
369
                     'same, but got weights with length %d biases with length '
370
                     '%d activations with length %d' % (
371
                         len(weights), len(biases), len(activations)))
372

373

374
def smt_forward(features, weights, biases, activations):
375
  """Forward pass of a neural network with the inputs being symbolic.
376

377
  Computes the forward pass of a neural network by looping through the weights
378
  and the biases in a layerwise manner.
379

380
  Args:
381
    features: list of z3.ExprRef, contains a z3 instance corresponding
382
        to each pixel of a flattened image.
383
    weights: list of float numpy array with shape (output_dim, input_dim) and
384
        length num_layers, weights of the neural network.
385
    biases: list of float numpy array with shape (output_dim,) and length
386
        num_layers, biases of the neural network.
387
    activations: list of string with length num_layers, activations for each
388
        hidden layer.
389

390
  Returns:
391
    logits: list of z3.ExprRef, output logits.
392
    hidden_nodes: list of list of list of z3.ExprRef with dimensions
393
        (num_layers, output_dim, input_dim),
394
        weighted sum at every hidden neuron.
395
  """
396
  _verify_lengths(weights, biases, activations)
397
  layer_features = [i for i in features]
398
  hidden_nodes = []
399
  for layer_weights, layer_bias, layer_activation in zip(
400
      weights, biases, activations):
401
    # Values of hidden nodes after activation.
402
    layer_output = []
403
    # Values of hidden nodes before activation.
404
    layer_weighted_sums = []
405
    for weight_row, bias in zip(layer_weights, layer_bias):
406
      # Iterating over output_dim
407
      intermediate_sum = bias
408
      for x, weight in zip(layer_features, weight_row):
409
        # Iterating over input_dim
410
        intermediate_sum += weight * x
411
      layer_weighted_sums.append(intermediate_sum)
412
      # Apply relu or linear activation function
413
      if layer_activation == 'relu':
414
        layer_output.append(z3_relu(intermediate_sum))
415
      else:
416
        layer_output.append(intermediate_sum)
417
    hidden_nodes.append(layer_weighted_sums)
418
    layer_features = layer_output
419
  return layer_features, hidden_nodes
420

421

422
def nn_forward(features, weights, biases, activations):
423
  """Forward pass of a neural network using matrix multiplication.
424

425
  Computes the forward pas of a neural network using matrix multiplication and
426
  addition by looping through the weights and the biases.
427

428
  Args:
429
    features: float numpy array with shape (num_input_features,),
430
        image flattened as a 1D vector.
431
    weights: list of float numpy array with shape (output_dim, input_dim) and
432
        length num_layers, weights of the neural network .
433
    biases: list of float numpy array with shape (output_dim,) and length
434
        num_layers, biases of the neural network.
435
    activations: list of strings with length num_layers,
436
        activations for each hidden layer.
437

438
  Returns:
439
    logits: float numpy array with shape (num_labels,).
440
    hidden_nodes: list of numpy array with shape (output_dim,) and
441
        length num_layers.
442
  """
443
  _verify_lengths(weights, biases, activations)
444
  hidden_nodes = []
445
  layer_features = np.copy(features)
446
  for layer_weights, layer_bias, layer_activation in zip(
447
      weights, biases, activations):
448
    layer_output = np.matmul(
449
        layer_features, layer_weights.transpose()) + layer_bias
450
    hidden_nodes.append(layer_output)
451
    if layer_activation == 'relu':
452
      layer_output = layer_output * (layer_output > 0)
453
    layer_features = layer_output
454
  return layer_features, hidden_nodes
455

456

457
def convert_pixel_to_2d_indices(edge_length, flattened_pixel_index):
458
  """Maps an index of an array to its reshaped 2D matrix's rows and columns.
459

460
  This function maps the index of an array with length edge_length ** 2 to the
461
  rows and columns of its reshaped 2D matrix with shape
462
  (edge_length, edge_length).
463

464
  Args:
465
    edge_length: int, side length of the 2D array (image) whose pixels are to be
466
        masked.
467
    flattened_pixel_index: int, flattened pixel index in the image in
468
        a row major order.
469
  Returns:
470
    row_index: int, row index of the 2D array
471
    column_index: int, column index of the 2D array
472
  """
473
  return (
474
      flattened_pixel_index // edge_length, flattened_pixel_index % edge_length)
475

476

477
def convert_pixel_to_mask_index(
478
    edge_length, window_size, flattened_pixel_index):
479
  """Maps flattened pixel index to the flattened index of its mask.
480

481
  Args:
482
    edge_length: int, side length of the 2D array (image).
483
    window_size: int, side length of the square mask.
484
    flattened_pixel_index: int, flattened pixel index in the image in
485
        a row major order.
486

487
  Returns:
488
    int, the index of the mask bit in the flattened mask array.
489
  """
490
  num_masks_along_row = edge_length // window_size
491
  num_pixels_per_mask_row = edge_length * window_size
492
  return (
493
      num_masks_along_row * (flattened_pixel_index // num_pixels_per_mask_row)
494
      + (flattened_pixel_index % edge_length) // window_size)
495

496

497
def calculate_auc_score(ground_truth, attribution_map):
498
  """Calculates the auc of roc curve of the attribution map wrt ground truth.
499

500
  Args:
501
    ground_truth: float numpy array, ground truth values.
502
    attribution_map: float numpy array, attribution map.
503

504
  Returns:
505
    float, AUC of the roc curve.
506
  """
507
  return metrics.roc_auc_score(ground_truth, attribution_map)
508

509

510
def calculate_min_mae_score(ground_truth, attribution_map):
511
  """Calculates the mean absolute error of the attribution map wrt ground truth.
512

513
  Converts the continuous valued attribution maps to binary valued by
514
  choosing multiple thresholds. Entries above the threshold are set to 1 and
515
  below are set to 0. Then, it computes MAE for each such mask and returns
516
  the best score.
517

518
  Args:
519
    ground_truth: int numpy array, ground truth values.
520
    attribution_map: float numpy array, attribution map.
521

522
  Returns:
523
    float, the mean absolute error.
524
  """
525
  thresholds = np.unique(attribution_map)
526
  thresholds = np.append(
527
      thresholds[::max(int(round(len(thresholds) / 1000)), 1)], thresholds[-1])
528
  mae_score = np.inf
529
  for threshold in thresholds:
530
    thresholded_attributions = np.zeros_like(attribution_map, dtype=np.int8)
531
    thresholded_attributions[attribution_map >= threshold] = 1
532
    mae_score = min(
533
        mae_score,
534
        metrics.mean_absolute_error(ground_truth, thresholded_attributions))
535
  return mae_score
536

537

538
def calculate_max_f1_score(ground_truth, attribution_map):
539
  """Calculates the F1 score of the attribution map wrt the ground truth.
540

541
  Computes f1 score for a continuous valued attribution map. First,
542
  it computes precision and recall at multiple thresholds using
543
  sklearn.precision_recall_curve(). Then it computes f1 scores for each
544
  precision and recall score and returns the max.
545

546
  Args:
547
    ground_truth: int numpy array, ground truth values.
548
    attribution_map: float numpy array, attribution map.
549

550
  Returns:
551
    float, the F1 score.
552
  """
553
  precision, recall, _ = metrics.precision_recall_curve(
554
      ground_truth, attribution_map)
555
  # Sklearn's f1_score metric requires both the ground_truth and the
556
  # attribution_map to be binary valued. So, we compute the precision and
557
  # recall scores at multiple thresholds and report the best f1 score.
558
  return np.nanmax(list(
559
      map(lambda p, r: 2 * (p * r) / (p + r), precision, recall)))
560

561

562

563

564
def get_mnist_dataset(num_datapoints, split='test'):
565
  """Loads the MNIST dataset.
566

567
  Args:
568
    num_datapoints: int, number of images to load.
569
    split: str, One of {'train', 'test'} representing train and test data
570
      respectively.
571

572
  Returns:
573
    dict,
574
      * image_ids: list of int, the serial number of each image serialised
575
          accoriding to its position in the dataset.
576
      * labels: list of int, inception logit indices of each image.
577
      * images: list of float numpy array with shape (28, 28, 1),
578
          MNIST images with values between [0, 1].
579
  """
580
  builder = tfds.builder('mnist')
581
  builder.download_and_prepare()
582
  dataset = builder.as_dataset()
583
  data = collections.defaultdict(list)
584
  for image_id, datapoint in enumerate(tfds.as_numpy(dataset[split])):
585
    data['images'].append(datapoint['image'] / 255.0)
586
    data['labels'].append(datapoint['label'])
587
    data['image_ids'].append(image_id)
588
    if image_id == num_datapoints - 1:
589
      break
590
  return data
591

592

593
def _get_tightest_crop(saliency_map, threshold):
594
  """Finds the tightest bounding box for a given saliency map.
595

596
  For a continuous valued saliency map, finds the tightest bounding box by
597
  all the attributions outside the bounding box have a score less than the
598
  threshold.
599

600
  Args:
601
    saliency_map: float numpy array with shape (rows, columns), saliency map.
602
    threshold: float, attribution threshold.
603

604
  Returns:
605
    crop parameters: dict,
606
      * left: int, index of the left most column of the bounding box.
607
      * right: int, index of the right most column of the bounding box + 1.
608
      * top: int, index of the top most row of the bounding box.
609
      * bottom: int, index of the bottom most row of the bounding box + 1.
610
    cropped mask: int numpy array with shape (rows, columns), the values within
611
      the bounding set to 1.
612
  """
613
  non_zero_rows, non_zero_columns = np.asarray(
614
      saliency_map > threshold).nonzero()
615
  top = np.min(non_zero_rows)
616
  bottom = np.max(non_zero_rows) + 1
617
  left = np.min(non_zero_columns)
618
  right = np.max(non_zero_columns) + 1
619
  cropped_mask = np.zeros_like(saliency_map)
620
  cropped_mask[top: bottom, left: right] = 1
621
  return {
622
      'left': left,
623
      'right': right,
624
      'top': top,
625
      'bottom': bottom,
626
  }, cropped_mask
627

628

629
def _check_dimensions(image, saliency_map, model_type):
630
  """Verifies the image and saliency map dimensions have proper dimensions.
631

632
  Args:
633
    image: If model_type = 'cnn', float numpy array with shape (rows, columns,
634
      channels), image. Otherwise, float numpy array with shape
635
      (num_zero_padded_words,), text.
636
    saliency_map: If model_type = 'cnn', float numpy array with shape (rows,
637
      columns, channels). Otherwise, float numpy array with shape
638
      (num_zero_padded_words,), saliency_map.
639
    model_type: str, One of {'cnn', 'text_cnn'}, model type.
640

641
  Raises:
642
    ValueError:
643
      If model_type is 'text_cnn' and image isn't a 3D array or the saliency map
644
      isn't a 2D array. Or,
645
      if the model_type is 'cnn' and the image isn't a 1D array or the saliency
646
      map isn't a 1D array.
647
  """
648
  if model_type == 'text_cnn':
649
    if image.ndim != 1:
650
      raise ValueError('The text input should be a 1D numpy array. '
651
                       'Shape of the supplied image: {}'.format(image.shape))
652
    if saliency_map.ndim != 1:
653
      raise ValueError(
654
          'The text saliency map should be a 1D numpy array. '
655
          'Shape of the supplied Saliency map: {}'.format(saliency_map.shape))
656
  else:
657
    if image.ndim != 3:
658
      raise ValueError(
659
          'Image should have 3 dimensions. '
660
          'Shape of the supplied image: {}'.format(image.shape))
661
    if saliency_map.ndim != 2:
662
      raise ValueError(
663
          'Saliency map should have 2 dimensions. '
664
          'Shape of the supplied Saliency map: {}'.format(saliency_map.shape))
665

666

667
def calculate_saliency_score(
668
    run_params, image, saliency_map, area_threshold=0.05, session=None):
669
  """Computes the score for an image using the saliency metric.
670

671
  For a continuous valued saliency map, tighest bounding box is found at
672
  multiple threhsolds and the best score is returned.
673
  The saliency metric is defined as score(a, p) = log(a') - log(p),
674
  where a = fraction of the image area occupied by the mask,
675
        p = confidence of the classifier on the cropped and rescaled image.
676
        a' = max(area_threshold, a)
677
  Reference: https://arxiv.org/pdf/1705.07857.pdf
678

679
  Args:
680
    run_params: RunParams with model_path, model_type and tensor_names.
681
    image: If model_type = 'cnn', float numpy array with shape (rows, columns,
682
      channels) with pixel values between [0, 255], image. Otherwise, float
683
      numpy array with shape (num_zero_padded_words,), text.
684
    saliency_map: If model_type = 'cnn', float numpy array with shape (rows,
685
      columns, channels). Otherwise, float numpy array with shape
686
      (num_zero_padded_words,), saliency_map.
687
    area_threshold: float, area_threshold used in the metric.
688
    session: (default: None) tensorflow session.
689

690
  Returns:
691
    if a the saliency_map has all 0s returns None
692
    else dict,
693
      * true_label: int, True label of the image.
694
      * true_confidence: float, Confidence of the classifier on the image.
695
      * cropped_label: int, Predicted label of the classifier on the cropped
696
          image.
697
      * cropped_confidence: float, Confidence of the classifier on the cropped
698
          image for the true label.
699
      * crop_mask: int numpy array with shape (rows, columns), the values
700
          within the bounding set to 1.
701
      * saliency_map: float numpy array with shape (rows, columns),
702
          saliency map.
703
      * image: float numpy array with shape (rows, columns), image.
704
      * saliency_score: float, saliency score.
705
  """
706
  _check_dimensions(image=image, saliency_map=saliency_map,
707
                    model_type=run_params.model_type)
708
  if session is None:
709
    session = restore_model(run_params.model_path)
710
  # Sometimes the saliency map consists of all 1s. Hence, a threshold = 0
711
  # should be present.
712
  thresholds = np.append(0, np.unique(saliency_map))
713
  min_score = None
714
  record = None
715
  steps = max(int(round(thresholds.size / 100)), 1)
716
  if run_params.model_type == 'text_cnn':
717
    steps = 1
718
  for threshold in thresholds[::steps]:
719
    if np.sum(saliency_map > threshold) == 0:
720
      # A bounding box doesn't exist.
721
      continue
722
    crop_mask, processed_image = _crop_and_process_image(
723
        image=image,
724
        saliency_map=saliency_map,
725
        threshold=threshold,
726
        model_type=run_params.model_type)
727
    eval_record = _evaluate_cropped_image(
728
        session=session,
729
        run_params=run_params,
730
        crop_mask=crop_mask,
731
        image=image,
732
        processed_image=processed_image,
733
        saliency_map=saliency_map,
734
        area_threshold=area_threshold)
735
    if min_score is None or eval_record['saliency_score'] < min_score:
736
      min_score = eval_record['saliency_score']
737
      record = eval_record
738
  session.close()
739
  return record
740

741

742
def _crop_and_process_image(image, saliency_map, threshold, model_type):
743
  """Crops the image and returns the processed image.
744

745
  Args:
746
    image: If model_type = 'cnn', float numpy array with shape (rows, columns,
747
      channels) with pixel values between [0, 255], image. Otherwise, float
748
      numpy array with shape (num_zero_padded_words,), text.
749
    saliency_map: If model_type = 'cnn', float numpy array with shape (rows,
750
      columns, channels). Otherwise, float numpy array with shape
751
      (num_zero_padded_words,), saliency_map.
752
    threshold: float, saliency threshold.
753
    model_type: str, One of 'cnn' for image or 'text_cnn' for text input.
754

755
  Returns:
756
    crop_mask: If model_type = 'cnn',
757
        float numpy array with shape (rows, columns, channels), image.
758
        Otherwise,
759
        float numpy array with shape (num_zero_padded_words,), text.
760
    processed_image: If model_type = 'cnn',
761
        float numpy array with shape (rows, columns, channels), image.
762
        Otherwise,
763
        float numpy array with shape (num_zero_padded_words,), text.
764
  """
765
  if model_type == 'text_cnn':
766
    crop_mask = (saliency_map > threshold).astype(int)
767
    return crop_mask, saliency_map * crop_mask
768
  else:
769
    image_shape_original = (image.shape[0], image.shape[1])
770
    crop_params, crop_mask = _get_tightest_crop(saliency_map=saliency_map,
771
                                                threshold=threshold)
772
    cropped_image = image[crop_params['top']:crop_params['bottom'],
773
                          crop_params['left']:crop_params['right'], :]
774
    return crop_mask, np.array(
775
        Image.fromarray(cropped_image.astype(np.uint8)).resize(
776
            image_shape_original, resample=Image.Resampling.BILINEAR
777
        )
778
    )
779

780

781
def process_model_input(image, pixel_range):
782
  """Scales the input image's pixels to make it within pixel_range."""
783
  # pixel values are between [0, 1]
784
  image = normalize_array(image, percentile=100)
785
  min_pixel_value, max_pixel_value = pixel_range
786
  # pixel values are within pixel_range
787
  return image * (max_pixel_value - min_pixel_value) + min_pixel_value
788

789

790
def _evaluate_cropped_image(session, run_params, crop_mask, image,
791
                            processed_image, saliency_map, area_threshold):
792
  """Computes the saliency metric for a given resized image.
793

794
  Args:
795
    session: tf.Session, tensorflow session.
796
    run_params: RunParams with tensor_names and pixel_range.
797
    crop_mask: int numpy array with shape (rows, columns), the values within the
798
      bounding set to 1.
799
    image: If model_type = 'cnn', float numpy array with shape (rows, columns,
800
      channels) with pixel values between [0, 255], image. Otherwise, float
801
      numpy array with shape (num_zero_padded_words,), text.
802
    processed_image: float numpy array with shape (cropped_rows,
803
        cropped_columns, channels), cropped image.
804
    saliency_map:
805
      * None if brute_force_fast_saliency_evaluate_masks is using this function.
806
      * otherwise, float numpy array with shape (rows, columns), saliency map.
807
    area_threshold: float, area threshold in the saliency metric.
808

809
  Returns:
810
    dict,
811
      * true_label: int, True label of the image.
812
      * true_confidence: float, Confidence of the classifier on the image.
813
      * cropped_label: int, Predicted label of the classifier on the cropped
814
          image.
815
      * cropped_confidence: float, Confidence of the classifier on the cropped
816
          image for the true label.
817
      * crop_mask: int numpy array with shape (rows, columns), the values
818
          within the bounding set to 1.
819
      saliency_map:
820
        * None if brute_force_fast_saliency_evaluate_masks is using this
821
            function.
822
        * otherwise, float numpy array with shape (rows, columns), saliency map.
823
      * image: float numpy array with shape (rows, columns), image.
824
      * saliency_score: float, saliency score.
825
  """
826
  if run_params.model_type == 'cnn':
827
    image = process_model_input(image, run_params.pixel_range)
828
    processed_image = process_model_input(processed_image,
829
                                          run_params.pixel_range)
830
  true_softmax, cropped_softmax = session.run(
831
      run_params.tensor_names,
832
      feed_dict={
833
          run_params.tensor_names['input']: [image, processed_image]}
834
      )['softmax']
835
  true_label = np.argmax(true_softmax)
836
  cropped_confidence = cropped_softmax[true_label]
837
  if run_params.model_type == 'text_cnn':
838
    # Sparsity is defined as words in the mask / words in the sentence.
839
    # Hence, to ignore zero padding we only account for non-zero entries in the
840
    # input.
841
    sparsity = np.sum(crop_mask) / np.sum(image != 0)
842
  else:
843
    sparsity = np.sum(crop_mask) / crop_mask.size
844
  score = np.log(max(area_threshold, sparsity)) - np.log(cropped_confidence)
845
  return {
846
      'true_label': true_label,
847
      'true_confidence': np.max(true_softmax),
848
      'cropped_label': np.argmax(cropped_softmax),
849
      'cropped_confidence': cropped_confidence,
850
      'crop_mask': crop_mask,
851
      'saliency_map': saliency_map,
852
      'image': image,
853
      'saliency_score': score,
854
  }
855

856

857
def _generate_cropped_image(image, grid_size):
858
  """Generates crop mask and cropped images by dividing the image into a grid.
859

860
  Args:
861
    image: float numpy array with shape (rows, columns, channels), image.
862
    grid_size: int, size of the grid.
863

864
  Yields:
865
    crop_mask: int numpy array with shape (rows, columns), the values
866
        within the bounding set to 1.
867
    image: float numpy array with shape (cropped_rows, cropped_columns,
868
        channels), cropped image.
869
  """
870
  image_edge_length = image.shape[0]
871
  scale = image_edge_length / grid_size
872
  for row_top in range(grid_size):
873
    for column_left in range(grid_size):
874
      for row_bottom in range(row_top + 2, grid_size + 1):
875
        # row_bottom starts from row_top + 2 so that while slicing, we don't
876
        # end up with a null array.
877
        for column_right in range(column_left + 2, grid_size + 1):
878
          crop_mask = np.zeros((image_edge_length, image_edge_length))
879
          row_slice = slice(int(scale * row_top), int(scale * row_bottom))
880
          column_slice = slice(int(scale * column_left),
881
                               int(scale * column_right))
882
          crop_mask[row_slice, column_slice] = 1
883
          yield crop_mask, image[row_slice, column_slice, :]
884

885

886
def brute_force_fast_saliency_evaluate_masks(run_params,
887
                                             image,
888
                                             grid_size=10,
889
                                             area_threshold=0.05,
890
                                             session=None):
891
  """Finds the best bounding box in an image that optimizes the saliency metric.
892

893
  Divides the image into (grid_size x grid_size) grid. Then evaluates all
894
  possible bounding boxes formed by choosing any 2 grid points as opposite
895
  ends of its diagonal.
896

897
  Args:
898
    run_params: RunParams with model_path and tensor_names.
899
    image: float numpy array with shape (rows, columns, channels) and pixel
900
        values between [0, 255], image.
901
    grid_size: int, size of the grid.
902
    area_threshold: float, area_threshold used in the saliency metric.
903
    session: tf.Session, (default None) tensorflow session with the loaded
904
        neural network.
905

906
  Returns:
907
    dict,
908
      * true_label: int, True label of the image.
909
      * true_confidence: float, Confidence of the classifier on the image.
910
      * cropped_label: int, Predicted label of the classifier on the cropped
911
          image.
912
      * cropped_confidence: float, Confidence of the classifier on the cropped
913
          image for the true label.
914
      * crop_mask: int numpy array with shape (rows, columns), the values
915
          within the bounding set to 1.
916
      * saliency_map: None.
917
      * image: float numpy array with shape (rows, columns), image.
918
      * saliency_score: float, saliency score.
919
  """
920
  if session is None:
921
    session = restore_model(run_params.model_path)
922
  min_score = None
923
  for crop_mask, cropped_image in _generate_cropped_image(image, grid_size):
924
    eval_record = _evaluate_cropped_image(
925
        session=session,
926
        run_params=run_params,
927
        crop_mask=crop_mask,
928
        image=image,
929
        processed_image=np.array(
930
            Image.fromarray(cropped_image.astype(np.uint8)).resize(
931
                run_params.image_placeholder_shape[1:-1],
932
                resample=Image.Resampling.BILINEAR,
933
            )
934
        ),
935
        saliency_map=None,
936
        area_threshold=area_threshold,
937
    )
938
    if min_score is None or eval_record['saliency_score'] < min_score:
939
      min_score = eval_record['saliency_score']
940
      record = eval_record
941
  session.close()
942
  return record
943

944

945
def remove_ticks():
946
  """Removes ticks from the axes."""
947
  plt.tick_params(
948
      axis='both',  # changes apply to the x-axis
949
      which='both',  # both major and minor ticks are affected
950
      bottom=False,  # ticks along the bottom edge are off
951
      top=False,  # ticks along the top edge are off
952
      left=False,  # ticks along the left edge are off
953
      right=False,  # ticks along the right edge are off
954
      labelbottom=False,
955
      labelleft=False)
956

957

958
def show_bounding_box(mask, left_offset=0, top_offset=0, linewidth=3,
959
                      edgecolor='lime'):
960
  """Given a mask, shows the tightest rectangle capturing it.
961

962
  Args:
963
    mask: numpy array with shape (rows, columns), a mask.
964
    left_offset: int, shift the bounding box left by these many pixels.
965
    top_offset: int, shift the bounding box top by these many pixels.
966
    linewidth: int, line width the of the bounding box.
967
    edgecolor: string, color of the bounding box.
968
  """
969
  ax = plt.gca()
970
  params, _ = _get_tightest_crop(mask, 0)
971
  ax.add_patch(patches.Rectangle(
972
      (params['left'] - left_offset, params['top'] - top_offset),
973
      params['right'] - params['left'],
974
      params['bottom'] - params['top'],
975
      linewidth=linewidth, edgecolor=edgecolor, facecolor='none'))
976

977

978
def normalize_array(array, percentile=99):
979
  """Normalizes saliency maps for visualization.
980

981
  Args:
982
    array: numpy array, a saliency map.
983
    percentile: int, the minimum value and the value with this percentile in x
984
      are scaled between 0 and 1.
985

986
  Returns:
987
    numpy array with same shape as input array, the normalized saliency map.
988
  """
989
  return (array - array.min()) / (
990
      np.percentile(array, percentile) - array.min())
991

992

993
def _verify_saliency_map_shape(saliency_map):
994
  """Checks if the shape of the saliency map is a 2D array.
995

996
  Args:
997
    saliency_map: numpy array with shape (rows, columns), a saliency map.
998

999
  Raises:
1000
    ValueError: If the saliency map isn't a 2D array.
1001
  """
1002
  if saliency_map.ndim != 2:
1003
    raise ValueError('The saliency map should be a 2D numpy array '
1004
                     'but the received shape is {}'.format(saliency_map.shape))
1005

1006

1007
def scale_saliency_map(saliency_map, method):
1008
  """Scales saliency maps for visualization.
1009

1010
  For smug and smug base the saliency map is scaled such that the positive
1011
  scores are scaled between 0.5 and 1 (99th percentile maps to 1).
1012
  For other methods the saliency map is scaled between 0 and 1
1013
  (99th percentile maps to 1).
1014

1015
  Args:
1016
    saliency_map: numpy array with shape (rows, columns), a saliency map.
1017
    method: str, saliency method.
1018

1019
  Returns:
1020
    numpy array with shape (rows, columns), the normalized saliency map.
1021
  """
1022
  _verify_saliency_map_shape(saliency_map)
1023
  saliency_map = normalize_array(saliency_map)
1024
  if 'smug' in method:
1025
    # For better visualization, the smug_saliency_map and the
1026
    # no_minimization_saliency_map are scaled between [0.5, 1] instead of the
1027
    # usual [0, 1]. Note that doing such a scaling doesn't affect the
1028
    # saliency score in any way as the relative ordering between the pixels
1029
    # is preserved.
1030
    saliency_map[saliency_map > 0] = 0.5 + 0.5 * saliency_map[saliency_map > 0]
1031
  return saliency_map
1032

1033

1034
def visualize_saliency_map(saliency_map, title=''):
1035
  """Grayscale visualization of the saliency map.
1036

1037
  Args:
1038
    saliency_map: numpy array with shape (rows, columns), a saliency map.
1039
    title: str, title of the saliency map.
1040
  """
1041
  _verify_saliency_map_shape(saliency_map)
1042
  plt.imshow(saliency_map, cmap=plt.cm.gray, vmin=0, vmax=1)  # pytype: disable=module-attr
1043
  plt.title(title)
1044
  remove_ticks()
1045

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.