google-research
1044 строки · 38.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Functions for the forward pass (symbolic and decimal) of a neural network.
17
18Given an image and a trained neural network this code does an smt encoding of
19the forward pass of the neural network and further, employs z3 solver to
20learn a mask for the inputs given the weights.
21"""
22import collections23import io24import math25import matplotlib.patches as patches26import matplotlib.pyplot as plt27import numpy as np28from PIL import Image29import skimage.draw as draw30import sklearn.metrics as metrics31import tensorflow.compat.v1 as tf32import tensorflow_datasets as tfds33import z334
35
36tf.disable_eager_execution()37
38
39class OptimizerBase:40"""Creates a solver by using z3 solver.41
42Attributes:
43z3_mask: list, contains mask bits as z3 vars.
44mask_sum: z3.ExprRef, sum of boolean mask bits.
45minimal_mask_sum: int, the minimum value of mask_sum which satisfying the
46smt constraints.
47solver: z3.Optimize, minimizes a mask_sum wrt smt constraints.
48
49Subclasses should define the generate_mask method.
50"""
51
52def __init__(self, z3_mask):53"""Initializer.54
55Args:
56z3_mask: list, contains mask bits as z3 vars.
57"""
58self.z3_mask = z3_mask59self.mask_sum = 060self.solver = z3.Optimize()61for mask in self.z3_mask:62self.solver.add(z3.Or(mask == 1, mask == 0))63self.mask_sum += mask64self.minimal_mask_sum = self.solver.minimize(self.mask_sum)65
66def _optimize(self):67"""Solves the SMT constraints and returns the solution as a numpy array.68
69Returns:
70z3_mask: float numpy array with shape (num_mask_variables,).
71result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
72"""
73result = str(self.solver.check())74z3_mask = np.zeros(len(self.z3_mask))75if result != 'unknown':76z3_assignment = self.solver.model()77for var in z3_assignment.decls():78z3_mask[int(str(var).split('_')[1])] = int(str(z3_assignment[var]))79
80# Block the currently found solution so that for every call of optimize,81# a unique mask is found.82block = [var() != z3_assignment[var] for var in z3_assignment]83self.solver.add(z3.Or(block))84return z3_mask, result85
86def generate_mask(self):87"""Constructs the mask with the same shape as that of data.88
89Returns:
90mask: float numpy array.
91result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
92"""
93raise NotImplementedError('Must be implemented by subclass.')94
95def generator(self, num_unique_solutions):96"""Generates solutions from the optimizer.97
98If the number of unique solutions is smaller than num_unique_solutions,
99the rest of the solutions are unsat.
100
101Args:
102num_unique_solutions: int, number of unique solutions you want to sample.
103
104Yields:
105mask: float numpy array.
106result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
107"""
108for _ in range(num_unique_solutions):109yield self.generate_mask()110
111
112class TextOptimizer(OptimizerBase):113"""Creates a solver for text by using z3 solver.114"""
115
116def __init__(self, z3_mask):117"""Initializer.118
119Args:
120z3_mask: list, contains mask bits as z3 vars.
121"""
122super().__init__(z3_mask=z3_mask)123
124def generate_mask(self):125"""Constructs the mask with the same shape as that of data.126
127Returns:
128mask: float numpy array with shape (num_mask_variables,).
129result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
130"""
131# This method explicitly requires a masking variable for each input word132# to the neural network. If a mask bit covers multiple words, then the133# function has to be appropriately modified.134return self._optimize()135
136
137class ImageOptimizer(OptimizerBase):138"""Creates a solver by using z3 solver.139
140Attributes:
141edge_length: int, side length of the 2D array (image) whose pixels are to
142be masked.
143window_size: int, side length of the square mask.
144"""
145
146def __init__(self, z3_mask, window_size, edge_length):147"""Initializer.148
149Args:
150z3_mask: list, contains mask bits as z3 vars.
151window_size: int, side length of the square mask.
152edge_length: int, side length of the 2D array (image) whose pixels are to
153be masked.
154"""
155super().__init__(z3_mask=z3_mask)156self.edge_length = edge_length157self.window_size = window_size158
159def generate_mask(self):160"""Constructs a 2D mask with the same shape as that of image.161
162Returns:
163mask: float numpy array with shape (edge_length, edge_length).
164result: string, returns one of the following: 'sat', 'unsat' or 'unknown'.
165"""
166z3_mask, result = self._optimize()167mask = np.zeros((self.edge_length, self.edge_length))168num_masks_along_row = math.ceil(self.edge_length / self.window_size)169for row in range(self.edge_length):170for column in range(self.edge_length):171mask_id = (172num_masks_along_row * (row // self.window_size)) + (173column // self.window_size)174mask[row][column] = z3_mask[mask_id]175return mask, result176
177
178def restore_model(model_path):179"""Restores a frozen tensorflow model into a tf session and returns it.180
181Args:
182model_path: string, path to a tensorflow frozen graph.
183
184Returns:
185A tensorflow session.
186"""
187session = tf.Session()188tf.saved_model.loader.load(session, ['serve'], model_path)189return session190
191
192def zero_pad(activation_map, padding):193"""Appends vectors of zeros on all the 4 sides of the image.194
195Args:
196activation_map: list of list of z3.ExprRef, activation map to be 0-padded.
197padding: tuple, number of layers 0 padded vectors on top / left side of the
198image, number of layers 0 padded vectors on bottom / right side of the
199image.
200
201Returns:
202list of list of z3.ExrRef, 0 padded activation map.
203"""
204num_rows = len(activation_map)205num_columns = len(activation_map[0])206
207# padded_activation_map has a shape - (num_padded_rows, num_padded_columns)208padded_activation_map = []209for _ in range(num_rows + padding[0] + padding[1]):210padded_activation_map.append([0] * (num_columns + padding[0] + padding[1]))211
212for i in range(num_rows):213for j in range(num_columns):214padded_activation_map[padding[0] + i][padding[0] +215j] = activation_map[i][j]216return padded_activation_map217
218
219def dot_product(input_activation_map, input_activation_map_row,220input_activation_map_column, sliced_kernel):221"""Convolution operation for a convolution kernel and a patch in the image.222
223Performs convolution on a square patch of the input_activation_map with
224(input_activation_map_row, input_activation_map_column) and
225(input_activation_map_row + kernel_rows - 1,
226input_activation_map_column + kernel_columns - 1) as the diagonal vertices.
227
228Args:
229input_activation_map: list of list of z3.ExprRef with dimensions
230(input_activation_map_size, input_activation_map_size).
231input_activation_map_row: int, row in the activation map for which the
232convolution is being performed.
233input_activation_map_column: int, column in the activation map for which
234convolution is being performed.
235sliced_kernel: numpy array with shape (kernel_rows, kernel_columns),
2362d slice of a kernel along input_channel.
237
238Returns:
239z3.ExprRef, dot product of the convolution kernel and a patch in the image.
240"""
241convolution = 0242for i in range(sliced_kernel.shape[0]):243for j in range(sliced_kernel.shape[1]):244convolution += (245input_activation_map
246[input_activation_map_row + i][input_activation_map_column + j]247* sliced_kernel[i][j])248return convolution249
250
251def smt_convolution(input_activation_maps, kernels, kernel_biases, padding,252strides):253"""Performs convolution on symbolic inputs.254
255Args:
256input_activation_maps: list of list of z3.ExprRef with dimensions
257(input_channels, input_activation_map_size, input_activation_map_size),
258input activation maps.
259kernels: numpy array with shape
260(kernel_size, kernel_size, input_channels, output_channels),
261weights of the convolution layer.
262kernel_biases: numpy array with shape (output_channels,), biases of the
263convolution layer.
264padding: tuple, number of layers 0 padded vectors on top/left side of the
265image.
266strides: int, number of pixel shifts over the input matrix.
267
268Returns:
269list of list of list of z3.ExprRef with dimensions (output_channels,
270output_activation_map_size, output_activation_map_size), convolutions.
271
272Raises:
273ValueError: If input_channels is inconsistent across
274input_activation_maps and kernels, or output_channels is inconsistent
275across kernels and kernel_biases, or padding is not a tuple, or padding
276isn't a tuple of size 2.
277"""
278if len(input_activation_maps) != kernels.shape[2]:279raise ValueError(280'Input channels in inputs and kernels are not equal. Number of input '281'channels in input: %d and kernels: %d' % (282len(input_activation_maps), kernels.shape[2]))283if not isinstance(padding, tuple) or len(padding) != 2:284raise ValueError(285'Padding should be a tuple with 2 dimensions. Input padding: %s' %286padding)287if kernels.shape[3] != kernel_biases.shape[0]:288raise ValueError(289'Output channels in kernels and biases are not equal. Number of output '290'channels in kernels: %d and biases: %d' % (291kernels.shape[3], kernel_biases.shape[0]))292padded_input_activation_maps = []293
294# reshape the kernels to295# (output_channels, kernel_size, kernel_size, input_channels)296kernels = np.moveaxis(kernels, -1, 0)297for input_activation_map in input_activation_maps:298padded_input_activation_maps.append(299zero_pad(300# (input_activation_map_size, input_activation_map_size)301activation_map=input_activation_map,302padding=padding))303output_activation_maps = []304output_activation_map_size = len(input_activation_maps[0]) // strides305# Iterate over output_channels.306for kernel, kernel_bias in zip(kernels, kernel_biases):307output_activation_map = np.full(308(output_activation_map_size, output_activation_map_size),309kernel_bias).tolist()310for i in range(output_activation_map_size):311for j in range(output_activation_map_size):312for channel_in in range(kernel.shape[-1]):313output_activation_map[i][j] += dot_product(314input_activation_map=padded_input_activation_maps[channel_in],315input_activation_map_row=strides * i,316input_activation_map_column=strides * j,317sliced_kernel=kernel[:, :, channel_in])318output_activation_maps.append(output_activation_map)319return output_activation_maps320
321
322def flatten_nested_lists(activation_maps):323"""Flattens a nested list of depth 3 in a row major order.324
325Args:
326activation_maps: list of list of list of z3.ExprRef with dimensions
327(channels, activation_map_size, activation_map_size), activation_maps.
328
329Returns:
330list of z3.ExprRef.
331"""
332flattened_activation_maps = []333for activation_map in activation_maps:334for activation_map_row in activation_map:335flattened_activation_maps.extend(activation_map_row)336return flattened_activation_maps337
338
339def z3_relu(x):340"""Relu activation function.341
342max(0, x).
343
344Args:
345x: z3.ExprRef, z3 Expression.
346
347Returns:
348z3.ExprRef.
349"""
350return z3.If(x > 0, x, 0)351
352
353def _verify_lengths(weights, biases, activations):354"""Verifies the lengths of the weights, biases, and activations are equal.355
356Args:
357weights: list of float numpy array with shape (output_dim, input_dim) and
358length num_layers, weights of the neural network.
359biases: list of float numpy array with shape (output_dim,) and length
360num_layers, biases of the neural network.
361activations: list of string with length num_layers, activations for each
362hidden layer.
363
364Raises:
365ValueError: If lengths of weights, biases, and activations are not equal.
366"""
367if not len(weights) == len(biases) == len(activations):368raise ValueError('Lengths of weights, biases and activations should be the '369'same, but got weights with length %d biases with length '370'%d activations with length %d' % (371len(weights), len(biases), len(activations)))372
373
374def smt_forward(features, weights, biases, activations):375"""Forward pass of a neural network with the inputs being symbolic.376
377Computes the forward pass of a neural network by looping through the weights
378and the biases in a layerwise manner.
379
380Args:
381features: list of z3.ExprRef, contains a z3 instance corresponding
382to each pixel of a flattened image.
383weights: list of float numpy array with shape (output_dim, input_dim) and
384length num_layers, weights of the neural network.
385biases: list of float numpy array with shape (output_dim,) and length
386num_layers, biases of the neural network.
387activations: list of string with length num_layers, activations for each
388hidden layer.
389
390Returns:
391logits: list of z3.ExprRef, output logits.
392hidden_nodes: list of list of list of z3.ExprRef with dimensions
393(num_layers, output_dim, input_dim),
394weighted sum at every hidden neuron.
395"""
396_verify_lengths(weights, biases, activations)397layer_features = [i for i in features]398hidden_nodes = []399for layer_weights, layer_bias, layer_activation in zip(400weights, biases, activations):401# Values of hidden nodes after activation.402layer_output = []403# Values of hidden nodes before activation.404layer_weighted_sums = []405for weight_row, bias in zip(layer_weights, layer_bias):406# Iterating over output_dim407intermediate_sum = bias408for x, weight in zip(layer_features, weight_row):409# Iterating over input_dim410intermediate_sum += weight * x411layer_weighted_sums.append(intermediate_sum)412# Apply relu or linear activation function413if layer_activation == 'relu':414layer_output.append(z3_relu(intermediate_sum))415else:416layer_output.append(intermediate_sum)417hidden_nodes.append(layer_weighted_sums)418layer_features = layer_output419return layer_features, hidden_nodes420
421
422def nn_forward(features, weights, biases, activations):423"""Forward pass of a neural network using matrix multiplication.424
425Computes the forward pas of a neural network using matrix multiplication and
426addition by looping through the weights and the biases.
427
428Args:
429features: float numpy array with shape (num_input_features,),
430image flattened as a 1D vector.
431weights: list of float numpy array with shape (output_dim, input_dim) and
432length num_layers, weights of the neural network .
433biases: list of float numpy array with shape (output_dim,) and length
434num_layers, biases of the neural network.
435activations: list of strings with length num_layers,
436activations for each hidden layer.
437
438Returns:
439logits: float numpy array with shape (num_labels,).
440hidden_nodes: list of numpy array with shape (output_dim,) and
441length num_layers.
442"""
443_verify_lengths(weights, biases, activations)444hidden_nodes = []445layer_features = np.copy(features)446for layer_weights, layer_bias, layer_activation in zip(447weights, biases, activations):448layer_output = np.matmul(449layer_features, layer_weights.transpose()) + layer_bias450hidden_nodes.append(layer_output)451if layer_activation == 'relu':452layer_output = layer_output * (layer_output > 0)453layer_features = layer_output454return layer_features, hidden_nodes455
456
457def convert_pixel_to_2d_indices(edge_length, flattened_pixel_index):458"""Maps an index of an array to its reshaped 2D matrix's rows and columns.459
460This function maps the index of an array with length edge_length ** 2 to the
461rows and columns of its reshaped 2D matrix with shape
462(edge_length, edge_length).
463
464Args:
465edge_length: int, side length of the 2D array (image) whose pixels are to be
466masked.
467flattened_pixel_index: int, flattened pixel index in the image in
468a row major order.
469Returns:
470row_index: int, row index of the 2D array
471column_index: int, column index of the 2D array
472"""
473return (474flattened_pixel_index // edge_length, flattened_pixel_index % edge_length)475
476
477def convert_pixel_to_mask_index(478edge_length, window_size, flattened_pixel_index):479"""Maps flattened pixel index to the flattened index of its mask.480
481Args:
482edge_length: int, side length of the 2D array (image).
483window_size: int, side length of the square mask.
484flattened_pixel_index: int, flattened pixel index in the image in
485a row major order.
486
487Returns:
488int, the index of the mask bit in the flattened mask array.
489"""
490num_masks_along_row = edge_length // window_size491num_pixels_per_mask_row = edge_length * window_size492return (493num_masks_along_row * (flattened_pixel_index // num_pixels_per_mask_row)494+ (flattened_pixel_index % edge_length) // window_size)495
496
497def calculate_auc_score(ground_truth, attribution_map):498"""Calculates the auc of roc curve of the attribution map wrt ground truth.499
500Args:
501ground_truth: float numpy array, ground truth values.
502attribution_map: float numpy array, attribution map.
503
504Returns:
505float, AUC of the roc curve.
506"""
507return metrics.roc_auc_score(ground_truth, attribution_map)508
509
510def calculate_min_mae_score(ground_truth, attribution_map):511"""Calculates the mean absolute error of the attribution map wrt ground truth.512
513Converts the continuous valued attribution maps to binary valued by
514choosing multiple thresholds. Entries above the threshold are set to 1 and
515below are set to 0. Then, it computes MAE for each such mask and returns
516the best score.
517
518Args:
519ground_truth: int numpy array, ground truth values.
520attribution_map: float numpy array, attribution map.
521
522Returns:
523float, the mean absolute error.
524"""
525thresholds = np.unique(attribution_map)526thresholds = np.append(527thresholds[::max(int(round(len(thresholds) / 1000)), 1)], thresholds[-1])528mae_score = np.inf529for threshold in thresholds:530thresholded_attributions = np.zeros_like(attribution_map, dtype=np.int8)531thresholded_attributions[attribution_map >= threshold] = 1532mae_score = min(533mae_score,534metrics.mean_absolute_error(ground_truth, thresholded_attributions))535return mae_score536
537
538def calculate_max_f1_score(ground_truth, attribution_map):539"""Calculates the F1 score of the attribution map wrt the ground truth.540
541Computes f1 score for a continuous valued attribution map. First,
542it computes precision and recall at multiple thresholds using
543sklearn.precision_recall_curve(). Then it computes f1 scores for each
544precision and recall score and returns the max.
545
546Args:
547ground_truth: int numpy array, ground truth values.
548attribution_map: float numpy array, attribution map.
549
550Returns:
551float, the F1 score.
552"""
553precision, recall, _ = metrics.precision_recall_curve(554ground_truth, attribution_map)555# Sklearn's f1_score metric requires both the ground_truth and the556# attribution_map to be binary valued. So, we compute the precision and557# recall scores at multiple thresholds and report the best f1 score.558return np.nanmax(list(559map(lambda p, r: 2 * (p * r) / (p + r), precision, recall)))560
561
562
563
564def get_mnist_dataset(num_datapoints, split='test'):565"""Loads the MNIST dataset.566
567Args:
568num_datapoints: int, number of images to load.
569split: str, One of {'train', 'test'} representing train and test data
570respectively.
571
572Returns:
573dict,
574* image_ids: list of int, the serial number of each image serialised
575accoriding to its position in the dataset.
576* labels: list of int, inception logit indices of each image.
577* images: list of float numpy array with shape (28, 28, 1),
578MNIST images with values between [0, 1].
579"""
580builder = tfds.builder('mnist')581builder.download_and_prepare()582dataset = builder.as_dataset()583data = collections.defaultdict(list)584for image_id, datapoint in enumerate(tfds.as_numpy(dataset[split])):585data['images'].append(datapoint['image'] / 255.0)586data['labels'].append(datapoint['label'])587data['image_ids'].append(image_id)588if image_id == num_datapoints - 1:589break590return data591
592
593def _get_tightest_crop(saliency_map, threshold):594"""Finds the tightest bounding box for a given saliency map.595
596For a continuous valued saliency map, finds the tightest bounding box by
597all the attributions outside the bounding box have a score less than the
598threshold.
599
600Args:
601saliency_map: float numpy array with shape (rows, columns), saliency map.
602threshold: float, attribution threshold.
603
604Returns:
605crop parameters: dict,
606* left: int, index of the left most column of the bounding box.
607* right: int, index of the right most column of the bounding box + 1.
608* top: int, index of the top most row of the bounding box.
609* bottom: int, index of the bottom most row of the bounding box + 1.
610cropped mask: int numpy array with shape (rows, columns), the values within
611the bounding set to 1.
612"""
613non_zero_rows, non_zero_columns = np.asarray(614saliency_map > threshold).nonzero()615top = np.min(non_zero_rows)616bottom = np.max(non_zero_rows) + 1617left = np.min(non_zero_columns)618right = np.max(non_zero_columns) + 1619cropped_mask = np.zeros_like(saliency_map)620cropped_mask[top: bottom, left: right] = 1621return {622'left': left,623'right': right,624'top': top,625'bottom': bottom,626}, cropped_mask627
628
629def _check_dimensions(image, saliency_map, model_type):630"""Verifies the image and saliency map dimensions have proper dimensions.631
632Args:
633image: If model_type = 'cnn', float numpy array with shape (rows, columns,
634channels), image. Otherwise, float numpy array with shape
635(num_zero_padded_words,), text.
636saliency_map: If model_type = 'cnn', float numpy array with shape (rows,
637columns, channels). Otherwise, float numpy array with shape
638(num_zero_padded_words,), saliency_map.
639model_type: str, One of {'cnn', 'text_cnn'}, model type.
640
641Raises:
642ValueError:
643If model_type is 'text_cnn' and image isn't a 3D array or the saliency map
644isn't a 2D array. Or,
645if the model_type is 'cnn' and the image isn't a 1D array or the saliency
646map isn't a 1D array.
647"""
648if model_type == 'text_cnn':649if image.ndim != 1:650raise ValueError('The text input should be a 1D numpy array. '651'Shape of the supplied image: {}'.format(image.shape))652if saliency_map.ndim != 1:653raise ValueError(654'The text saliency map should be a 1D numpy array. '655'Shape of the supplied Saliency map: {}'.format(saliency_map.shape))656else:657if image.ndim != 3:658raise ValueError(659'Image should have 3 dimensions. '660'Shape of the supplied image: {}'.format(image.shape))661if saliency_map.ndim != 2:662raise ValueError(663'Saliency map should have 2 dimensions. '664'Shape of the supplied Saliency map: {}'.format(saliency_map.shape))665
666
667def calculate_saliency_score(668run_params, image, saliency_map, area_threshold=0.05, session=None):669"""Computes the score for an image using the saliency metric.670
671For a continuous valued saliency map, tighest bounding box is found at
672multiple threhsolds and the best score is returned.
673The saliency metric is defined as score(a, p) = log(a') - log(p),
674where a = fraction of the image area occupied by the mask,
675p = confidence of the classifier on the cropped and rescaled image.
676a' = max(area_threshold, a)
677Reference: https://arxiv.org/pdf/1705.07857.pdf
678
679Args:
680run_params: RunParams with model_path, model_type and tensor_names.
681image: If model_type = 'cnn', float numpy array with shape (rows, columns,
682channels) with pixel values between [0, 255], image. Otherwise, float
683numpy array with shape (num_zero_padded_words,), text.
684saliency_map: If model_type = 'cnn', float numpy array with shape (rows,
685columns, channels). Otherwise, float numpy array with shape
686(num_zero_padded_words,), saliency_map.
687area_threshold: float, area_threshold used in the metric.
688session: (default: None) tensorflow session.
689
690Returns:
691if a the saliency_map has all 0s returns None
692else dict,
693* true_label: int, True label of the image.
694* true_confidence: float, Confidence of the classifier on the image.
695* cropped_label: int, Predicted label of the classifier on the cropped
696image.
697* cropped_confidence: float, Confidence of the classifier on the cropped
698image for the true label.
699* crop_mask: int numpy array with shape (rows, columns), the values
700within the bounding set to 1.
701* saliency_map: float numpy array with shape (rows, columns),
702saliency map.
703* image: float numpy array with shape (rows, columns), image.
704* saliency_score: float, saliency score.
705"""
706_check_dimensions(image=image, saliency_map=saliency_map,707model_type=run_params.model_type)708if session is None:709session = restore_model(run_params.model_path)710# Sometimes the saliency map consists of all 1s. Hence, a threshold = 0711# should be present.712thresholds = np.append(0, np.unique(saliency_map))713min_score = None714record = None715steps = max(int(round(thresholds.size / 100)), 1)716if run_params.model_type == 'text_cnn':717steps = 1718for threshold in thresholds[::steps]:719if np.sum(saliency_map > threshold) == 0:720# A bounding box doesn't exist.721continue722crop_mask, processed_image = _crop_and_process_image(723image=image,724saliency_map=saliency_map,725threshold=threshold,726model_type=run_params.model_type)727eval_record = _evaluate_cropped_image(728session=session,729run_params=run_params,730crop_mask=crop_mask,731image=image,732processed_image=processed_image,733saliency_map=saliency_map,734area_threshold=area_threshold)735if min_score is None or eval_record['saliency_score'] < min_score:736min_score = eval_record['saliency_score']737record = eval_record738session.close()739return record740
741
742def _crop_and_process_image(image, saliency_map, threshold, model_type):743"""Crops the image and returns the processed image.744
745Args:
746image: If model_type = 'cnn', float numpy array with shape (rows, columns,
747channels) with pixel values between [0, 255], image. Otherwise, float
748numpy array with shape (num_zero_padded_words,), text.
749saliency_map: If model_type = 'cnn', float numpy array with shape (rows,
750columns, channels). Otherwise, float numpy array with shape
751(num_zero_padded_words,), saliency_map.
752threshold: float, saliency threshold.
753model_type: str, One of 'cnn' for image or 'text_cnn' for text input.
754
755Returns:
756crop_mask: If model_type = 'cnn',
757float numpy array with shape (rows, columns, channels), image.
758Otherwise,
759float numpy array with shape (num_zero_padded_words,), text.
760processed_image: If model_type = 'cnn',
761float numpy array with shape (rows, columns, channels), image.
762Otherwise,
763float numpy array with shape (num_zero_padded_words,), text.
764"""
765if model_type == 'text_cnn':766crop_mask = (saliency_map > threshold).astype(int)767return crop_mask, saliency_map * crop_mask768else:769image_shape_original = (image.shape[0], image.shape[1])770crop_params, crop_mask = _get_tightest_crop(saliency_map=saliency_map,771threshold=threshold)772cropped_image = image[crop_params['top']:crop_params['bottom'],773crop_params['left']:crop_params['right'], :]774return crop_mask, np.array(775Image.fromarray(cropped_image.astype(np.uint8)).resize(776image_shape_original, resample=Image.Resampling.BILINEAR777)778)779
780
781def process_model_input(image, pixel_range):782"""Scales the input image's pixels to make it within pixel_range."""783# pixel values are between [0, 1]784image = normalize_array(image, percentile=100)785min_pixel_value, max_pixel_value = pixel_range786# pixel values are within pixel_range787return image * (max_pixel_value - min_pixel_value) + min_pixel_value788
789
790def _evaluate_cropped_image(session, run_params, crop_mask, image,791processed_image, saliency_map, area_threshold):792"""Computes the saliency metric for a given resized image.793
794Args:
795session: tf.Session, tensorflow session.
796run_params: RunParams with tensor_names and pixel_range.
797crop_mask: int numpy array with shape (rows, columns), the values within the
798bounding set to 1.
799image: If model_type = 'cnn', float numpy array with shape (rows, columns,
800channels) with pixel values between [0, 255], image. Otherwise, float
801numpy array with shape (num_zero_padded_words,), text.
802processed_image: float numpy array with shape (cropped_rows,
803cropped_columns, channels), cropped image.
804saliency_map:
805* None if brute_force_fast_saliency_evaluate_masks is using this function.
806* otherwise, float numpy array with shape (rows, columns), saliency map.
807area_threshold: float, area threshold in the saliency metric.
808
809Returns:
810dict,
811* true_label: int, True label of the image.
812* true_confidence: float, Confidence of the classifier on the image.
813* cropped_label: int, Predicted label of the classifier on the cropped
814image.
815* cropped_confidence: float, Confidence of the classifier on the cropped
816image for the true label.
817* crop_mask: int numpy array with shape (rows, columns), the values
818within the bounding set to 1.
819saliency_map:
820* None if brute_force_fast_saliency_evaluate_masks is using this
821function.
822* otherwise, float numpy array with shape (rows, columns), saliency map.
823* image: float numpy array with shape (rows, columns), image.
824* saliency_score: float, saliency score.
825"""
826if run_params.model_type == 'cnn':827image = process_model_input(image, run_params.pixel_range)828processed_image = process_model_input(processed_image,829run_params.pixel_range)830true_softmax, cropped_softmax = session.run(831run_params.tensor_names,832feed_dict={833run_params.tensor_names['input']: [image, processed_image]}834)['softmax']835true_label = np.argmax(true_softmax)836cropped_confidence = cropped_softmax[true_label]837if run_params.model_type == 'text_cnn':838# Sparsity is defined as words in the mask / words in the sentence.839# Hence, to ignore zero padding we only account for non-zero entries in the840# input.841sparsity = np.sum(crop_mask) / np.sum(image != 0)842else:843sparsity = np.sum(crop_mask) / crop_mask.size844score = np.log(max(area_threshold, sparsity)) - np.log(cropped_confidence)845return {846'true_label': true_label,847'true_confidence': np.max(true_softmax),848'cropped_label': np.argmax(cropped_softmax),849'cropped_confidence': cropped_confidence,850'crop_mask': crop_mask,851'saliency_map': saliency_map,852'image': image,853'saliency_score': score,854}855
856
857def _generate_cropped_image(image, grid_size):858"""Generates crop mask and cropped images by dividing the image into a grid.859
860Args:
861image: float numpy array with shape (rows, columns, channels), image.
862grid_size: int, size of the grid.
863
864Yields:
865crop_mask: int numpy array with shape (rows, columns), the values
866within the bounding set to 1.
867image: float numpy array with shape (cropped_rows, cropped_columns,
868channels), cropped image.
869"""
870image_edge_length = image.shape[0]871scale = image_edge_length / grid_size872for row_top in range(grid_size):873for column_left in range(grid_size):874for row_bottom in range(row_top + 2, grid_size + 1):875# row_bottom starts from row_top + 2 so that while slicing, we don't876# end up with a null array.877for column_right in range(column_left + 2, grid_size + 1):878crop_mask = np.zeros((image_edge_length, image_edge_length))879row_slice = slice(int(scale * row_top), int(scale * row_bottom))880column_slice = slice(int(scale * column_left),881int(scale * column_right))882crop_mask[row_slice, column_slice] = 1883yield crop_mask, image[row_slice, column_slice, :]884
885
886def brute_force_fast_saliency_evaluate_masks(run_params,887image,888grid_size=10,889area_threshold=0.05,890session=None):891"""Finds the best bounding box in an image that optimizes the saliency metric.892
893Divides the image into (grid_size x grid_size) grid. Then evaluates all
894possible bounding boxes formed by choosing any 2 grid points as opposite
895ends of its diagonal.
896
897Args:
898run_params: RunParams with model_path and tensor_names.
899image: float numpy array with shape (rows, columns, channels) and pixel
900values between [0, 255], image.
901grid_size: int, size of the grid.
902area_threshold: float, area_threshold used in the saliency metric.
903session: tf.Session, (default None) tensorflow session with the loaded
904neural network.
905
906Returns:
907dict,
908* true_label: int, True label of the image.
909* true_confidence: float, Confidence of the classifier on the image.
910* cropped_label: int, Predicted label of the classifier on the cropped
911image.
912* cropped_confidence: float, Confidence of the classifier on the cropped
913image for the true label.
914* crop_mask: int numpy array with shape (rows, columns), the values
915within the bounding set to 1.
916* saliency_map: None.
917* image: float numpy array with shape (rows, columns), image.
918* saliency_score: float, saliency score.
919"""
920if session is None:921session = restore_model(run_params.model_path)922min_score = None923for crop_mask, cropped_image in _generate_cropped_image(image, grid_size):924eval_record = _evaluate_cropped_image(925session=session,926run_params=run_params,927crop_mask=crop_mask,928image=image,929processed_image=np.array(930Image.fromarray(cropped_image.astype(np.uint8)).resize(931run_params.image_placeholder_shape[1:-1],932resample=Image.Resampling.BILINEAR,933)934),935saliency_map=None,936area_threshold=area_threshold,937)938if min_score is None or eval_record['saliency_score'] < min_score:939min_score = eval_record['saliency_score']940record = eval_record941session.close()942return record943
944
945def remove_ticks():946"""Removes ticks from the axes."""947plt.tick_params(948axis='both', # changes apply to the x-axis949which='both', # both major and minor ticks are affected950bottom=False, # ticks along the bottom edge are off951top=False, # ticks along the top edge are off952left=False, # ticks along the left edge are off953right=False, # ticks along the right edge are off954labelbottom=False,955labelleft=False)956
957
958def show_bounding_box(mask, left_offset=0, top_offset=0, linewidth=3,959edgecolor='lime'):960"""Given a mask, shows the tightest rectangle capturing it.961
962Args:
963mask: numpy array with shape (rows, columns), a mask.
964left_offset: int, shift the bounding box left by these many pixels.
965top_offset: int, shift the bounding box top by these many pixels.
966linewidth: int, line width the of the bounding box.
967edgecolor: string, color of the bounding box.
968"""
969ax = plt.gca()970params, _ = _get_tightest_crop(mask, 0)971ax.add_patch(patches.Rectangle(972(params['left'] - left_offset, params['top'] - top_offset),973params['right'] - params['left'],974params['bottom'] - params['top'],975linewidth=linewidth, edgecolor=edgecolor, facecolor='none'))976
977
978def normalize_array(array, percentile=99):979"""Normalizes saliency maps for visualization.980
981Args:
982array: numpy array, a saliency map.
983percentile: int, the minimum value and the value with this percentile in x
984are scaled between 0 and 1.
985
986Returns:
987numpy array with same shape as input array, the normalized saliency map.
988"""
989return (array - array.min()) / (990np.percentile(array, percentile) - array.min())991
992
993def _verify_saliency_map_shape(saliency_map):994"""Checks if the shape of the saliency map is a 2D array.995
996Args:
997saliency_map: numpy array with shape (rows, columns), a saliency map.
998
999Raises:
1000ValueError: If the saliency map isn't a 2D array.
1001"""
1002if saliency_map.ndim != 2:1003raise ValueError('The saliency map should be a 2D numpy array '1004'but the received shape is {}'.format(saliency_map.shape))1005
1006
1007def scale_saliency_map(saliency_map, method):1008"""Scales saliency maps for visualization.1009
1010For smug and smug base the saliency map is scaled such that the positive
1011scores are scaled between 0.5 and 1 (99th percentile maps to 1).
1012For other methods the saliency map is scaled between 0 and 1
1013(99th percentile maps to 1).
1014
1015Args:
1016saliency_map: numpy array with shape (rows, columns), a saliency map.
1017method: str, saliency method.
1018
1019Returns:
1020numpy array with shape (rows, columns), the normalized saliency map.
1021"""
1022_verify_saliency_map_shape(saliency_map)1023saliency_map = normalize_array(saliency_map)1024if 'smug' in method:1025# For better visualization, the smug_saliency_map and the1026# no_minimization_saliency_map are scaled between [0.5, 1] instead of the1027# usual [0, 1]. Note that doing such a scaling doesn't affect the1028# saliency score in any way as the relative ordering between the pixels1029# is preserved.1030saliency_map[saliency_map > 0] = 0.5 + 0.5 * saliency_map[saliency_map > 0]1031return saliency_map1032
1033
1034def visualize_saliency_map(saliency_map, title=''):1035"""Grayscale visualization of the saliency map.1036
1037Args:
1038saliency_map: numpy array with shape (rows, columns), a saliency map.
1039title: str, title of the saliency map.
1040"""
1041_verify_saliency_map_shape(saliency_map)1042plt.imshow(saliency_map, cmap=plt.cm.gray, vmin=0, vmax=1) # pytype: disable=module-attr1043plt.title(title)1044remove_ticks()1045