google-research
845 строк · 34.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for third_party.google_research.google_research.smug_saliency.masking."""
17import os18import shutil19import tempfile20from unittest import mock21
22from absl import flags23from absl.testing import absltest24from absl.testing import parameterized25import numpy as np26import tensorflow.compat.v1 as tf27from tensorflow.compat.v1 import gfile28import z329
30from smug_saliency import masking31from smug_saliency import utils32
33FLAGS = flags.FLAGS34
35tf.disable_eager_execution()36
37
38def _get_z3_var(index):39return z3.Int('z3var_' + str(index))40
41
42def _create_temporary_tf_graph_fully_connected(test_model_path):43# Create a test model with 1 hiddenlayer with 4 nodes, and an output layer44# with softmax activation function and 2 nodes.45test_model = tf.keras.Sequential([46tf.keras.layers.Dense(units=4, input_shape=(4,), activation='relu'),47tf.keras.layers.Dense(units=2, activation='softmax')])48test_model.compile(49optimizer=tf.keras.optimizers.Adam(1e-3),50loss=tf.keras.losses.CategoricalCrossentropy(),51metrics=['accuracy'])52# Name of the tensors in the graph:53# input- dense_input:054# weights_first_layer - dense/MatMul/ReadVariableOp:055# biases_first_layer - dense/BiasAdd/ReadVariableOp:056# first_layer_input- dense/BiasAdd:057# first_layer_relu_output - dense/Relu:058# final_layer_input - dense_1/BiasAdd:059# final_layer_softmax_output - dense_1/Softmax:060tf.saved_model.simple_save(61session=tf.keras.backend.get_session(),62export_dir=test_model_path,63inputs={'input': test_model.inputs[0]},64outputs={'output': test_model.outputs[0]})65
66weights = []67biases = []68weights_and_biases = test_model.get_weights()69for i in range(len(weights_and_biases) // 2):70weights.append(weights_and_biases[2 * i])71biases.append(weights_and_biases[2 * i + 1])72return weights, biases73
74
75def _create_temporary_tf_graph_cnn(test_model_path):76# Create a test model with 1 conv layer with 3 kernels shaped (2, 2), and an77# output layer with 4 nodes and softmax activation function.78
79test_model = tf.keras.Sequential([80tf.keras.layers.Conv2D(813,82kernel_size=(2, 2),83strides=(1, 1),84padding='same',85activation='relu',86input_shape=(4, 4, 3)),87tf.keras.layers.Flatten(),88tf.keras.layers.Dense(4, activation='softmax')])89test_model.compile(90optimizer=tf.keras.optimizers.Adam(1e-3),91loss=tf.keras.losses.CategoricalCrossentropy(),92metrics=['accuracy'])93# Name of the tensors in the graph:94# input - conv2d_input:095# weights_first_layer - conv2d/Conv2D/ReadVariableOp:096# biases_first_layer - conv2d/bias/Read/ReadVariableOp:097# first_layer_input - conv2d/BiasAdd:098# first_layer_relu_output - conv2d/Relu:099# final_layer_input - dense/BiasAdd:0100# final_layer_softmax_output - dense/Softmax:0101tf.saved_model.simple_save(102session=tf.keras.backend.get_session(),103export_dir=test_model_path,104inputs={'input': test_model.inputs[0]},105outputs={'output': test_model.outputs[0]})106
107
108def _create_temporary_tf_graph_text_cnn(test_model_path):109# Create a test model with 1 conv layer with 4 kernels shaped (3, 10),110# and an output layer with 1 nodes and sigmoid activation function.111test_model = tf.keras.Sequential([112tf.keras.Input(shape=(5,)), # max words = 5113tf.keras.layers.Embedding(10, 10), # num top words = 10114tf.keras.layers.Conv1D(115filters=4, strides=1, kernel_size=3, activation='relu'),116tf.keras.layers.GlobalMaxPooling1D(),117tf.keras.layers.Dense(1, activation='sigmoid')])118test_model.compile(119optimizer=tf.keras.optimizers.Adam(1e-3),120loss=tf.keras.losses.CategoricalCrossentropy(),121metrics=['accuracy'])122
123# Name of the tensors in the graph:124# input - input_1:0125# embedding - embedding/embedding_lookup/Identity:0126# weights_first_layer - conv1d/Conv1D/ExpandDims_1:0127# biases_first_layer - conv1d/BiasAdd/ReadVariableOp:0128# first_layer_input - conv1d/BiasAdd:0129# first_layer_relu_output - conv1d/Relu:0130# final_layer_input - dense/BiasAdd:0131# final_layer_softmax_output - dense/Sigmoid:0132tf.saved_model.simple_save(133session=tf.keras.backend.get_session(),134export_dir=test_model_path,135inputs={'input': test_model.inputs[0]},136outputs={'output': test_model.outputs[0]})137
138
139class MaskingLibTest(parameterized.TestCase, tf.test.TestCase):140
141def setUp(self):142super().setUp()143self.test_model_path = os.path.join(144tempfile.mkdtemp(dir=FLAGS.test_tmpdir), 'checkpoint')145gfile.MakeDirs(self.test_model_path)146
147def tearDown(self):148shutil.rmtree(self.test_model_path)149super().tearDown()150
151@parameterized.parameters(1, 3)152def test_encode_input(self, image_channels):153# Creates a random image and checks if the encoded image, after multiplying154# the mask bits (set to 1) is the same as the original image.155image_edge_length = 2156image = np.random.rand(image_edge_length, image_edge_length, image_channels)157z3_var = _get_z3_var(index=0)158
159# encoded_image has dimensions160# (image_channels, image_edge_length, image_edge_length)161encoded_image = masking._encode_input(162image=image,163z3_mask=[z3_var for _ in range(image_edge_length ** 2)])164solver = z3.Solver()165solver.add(z3_var == 1)166# Swap the axes of the image so that it has the same dimensions as the167# encoded image.168image = masking._reorder(image).reshape(-1)169encoded_image = utils.flatten_nested_lists(encoded_image)170for i in range(image_channels * image_edge_length ** 2):171solver.add(encoded_image[i] == image[i])172
173self.assertEqual(str(solver.check()), 'sat')174
175def test_formulate_smt_constraints_convolution_layer(self):176with self.test_session():177# Temporary graphs should be created inside a session. Notice multiple178# graphs are being created in this particular code. So, if each graph179# isn't created inside a separate session, the tensor names will have180# unwanted integer suffices, which then would cause problems while181# accessing tensors by name.182_create_temporary_tf_graph_cnn(self.test_model_path)183image_edge_length = 4184image_channels = 3185# The 1st convolution layer has 48 neurons.186top_k = np.random.randint(low=1, high=48)187image = np.ones((image_edge_length, image_edge_length, image_channels))188tensor_names = {189'input': 'conv2d_input:0',190'first_layer': 'conv2d/BiasAdd:0',191'first_layer_relu': 'conv2d/Relu:0',192'logits': 'dense/BiasAdd:0',193'softmax': 'dense/Softmax:0',194'weights_layer_1': 'conv2d/Conv2D/ReadVariableOp:0',195'biases_layer_1': 'conv2d/bias/Read/ReadVariableOp:0'}196session = utils.restore_model(self.test_model_path)197cnn_predictions = session.run(198tensor_names,199feed_dict={200tensor_names['input']: image.reshape(201(1, image_edge_length, image_edge_length, image_channels))})202
203z3_mask = []204mask_id_to_var = {}205for row in range(image_edge_length):206for column in range(image_edge_length):207mask_id = image_edge_length * row + column208if mask_id in mask_id_to_var.keys():209z3_var = mask_id_to_var[mask_id]210else:211mask_name = f'mask_{mask_id}'212z3_var = z3.Int(mask_name)213mask_id_to_var[mask_name] = z3_var214z3_mask.append(z3_var)215
216first_layer_activations = masking._reorder(masking._remove_batch_axis(217cnn_predictions['first_layer'])).reshape(-1)218masked_input = masking._encode_input(image=image, z3_mask=z3_mask)219
220z3_optimizer = masking._formulate_smt_constraints_convolution_layer(221z3_optimizer=utils.ImageOptimizer(222z3_mask=z3_mask,223window_size=1,224edge_length=image_edge_length),225kernels=masking._reorder(cnn_predictions['weights_layer_1']),226biases=cnn_predictions['biases_layer_1'],227chosen_indices=first_layer_activations.argsort()[-top_k:],228conv_activations=first_layer_activations,229input_activation_maps=masked_input,230output_activation_map_shape=(image_edge_length, image_edge_length),231strides=1,232padding=(0, 1),233gamma=0.5)234mask, result = z3_optimizer.generate_mask()235
236self.assertEqual(result, 'sat')237self.assertEqual(mask.shape, (image_edge_length, image_edge_length))238session.close()239
240def test_formulate_smt_constraints_convolution_layer_text(self):241with self.test_session():242# Temporary graphs should be created inside a session. Notice multiple243# graphs are being created in this particular code. So, if each graph244# isn't created inside a separate session, the tensor names will have245# unwanted integer suffices, which then would cause problems while246# accessing tensors by name.247_create_temporary_tf_graph_text_cnn(self.test_model_path)248
249# The 1st convolution layer has 12 neurons.250image = np.ones(5)251tensor_names = {252'input': 'input_1:0',253'embedding': 'embedding/embedding_lookup/Identity:0',254'first_layer': 'conv1d/BiasAdd:0',255'first_layer_relu': 'conv1d/Relu:0',256'logits': 'dense/BiasAdd:0',257'softmax': 'dense/Sigmoid:0',258'weights_layer_1': 'conv1d/Conv1D/ExpandDims_1:0',259'biases_layer_1': 'conv1d/BiasAdd/ReadVariableOp:0'}260session = utils.restore_model(self.test_model_path)261cnn_predictions = session.run(262tensor_names, feed_dict={263tensor_names['input']: image.reshape(1, 5)})264text_embedding = masking._remove_batch_axis(cnn_predictions['embedding'])265z3_mask = [z3.Int('mask_%d' % i) for i in range(text_embedding.shape[0])]266masked_input = []267for mask_bit, embedding_row in zip(z3_mask, text_embedding):268masked_input.append([z3.ToReal(mask_bit) * i for i in embedding_row])269first_layer_activations = masking._reorder(270masking._remove_batch_axis(cnn_predictions['first_layer'])).reshape(-1)271z3_optimizer = masking._formulate_smt_constraints_convolution_layer(272z3_optimizer=utils.TextOptimizer(z3_mask=z3_mask),273kernels=masking._reshape_kernels(274kernels=cnn_predictions['weights_layer_1'],275model_type='text_cnn'),276biases=cnn_predictions['biases_layer_1'],277chosen_indices=first_layer_activations.argsort()[-5:],278conv_activations=first_layer_activations,279input_activation_maps=[masked_input],280output_activation_map_shape=masking._get_activation_map_shape(281activation_maps_shape=cnn_predictions['first_layer'].shape,282model_type='text_cnn'),283strides=1,284padding=(0, 0),285gamma=0.5)286mask, result = z3_optimizer.generate_mask()287
288self.assertEqual(result, 'sat')289self.assertEqual(mask.shape, (5,))290session.close()291
292def test_formulate_smt_constraints_fully_connected_layer(self):293# For a neural network with 4 hidden nodes in the first layer, with the294# original first layer activations = 1, and the SMT encoding of295# the first hidden nodes- [mask_0, mask_1, mask_2, mask_3]. For296# masked_activation > delta * original (k such constraints), only k mask297# bits should be set to 1 and the others to 0.298image_edge_length = 2299top_k = np.random.randint(low=1, high=image_edge_length ** 2)300z3_mask = [_get_z3_var(index=i) for i in range(image_edge_length ** 2)]301smt_first_layer = [1 * z3.ToReal(i) for i in z3_mask]302nn_first_layer = np.ones(len(smt_first_layer))303
304z3_optimizer = utils.ImageOptimizer(305z3_mask=z3_mask, window_size=1, edge_length=image_edge_length)306z3_optimizer = masking._formulate_smt_constraints_fully_connected_layer(307z3_optimizer=z3_optimizer,308smt_first_layer=smt_first_layer,309nn_first_layer=nn_first_layer,310top_k=top_k,311gamma=np.random.rand())312mask, result = z3_optimizer._optimize()313
314self.assertEqual(result, 'sat')315self.assertEqual(np.sum(mask), top_k)316
317def test_smt_constraints_final_layer(self):318# The SMT encoding of the final layer - [mask_0, mask_1, mask_2, mask_3].319# For logit_label_index > rest, the mask_bit at label_index should be set to320# 1.321image_edge_length = 2322label_index = np.random.randint(low=0, high=image_edge_length ** 2)323z3_mask = [_get_z3_var(index=i) for i in range(image_edge_length ** 2)]324smt_output = [1 * z3.ToReal(i) for i in z3_mask]325
326z3_optimizer = utils.ImageOptimizer(327z3_mask=z3_mask, window_size=1, edge_length=image_edge_length)328z3_optimizer = masking._formulate_smt_constraints_final_layer(329z3_optimizer=z3_optimizer,330smt_output=smt_output,331delta=np.random.rand(),332label_index=label_index)333mask, result = z3_optimizer._optimize()334
335self.assertEqual(result, 'sat')336self.assertEqual(mask.reshape(-1)[label_index], 1)337self.assertEqual(np.sum(mask), 1)338
339def test_find_mask_first_layer(self):340with self.test_session():341# Temporary graphs should be created inside a session. Notice multiple342# graphs are being created in this particular code. So, if each graph343# isn't created inside a separate session, the tensor names will have344# unwanted integer suffices, which then would cause problems while345# accessing tensors by name.346_create_temporary_tf_graph_fully_connected(self.test_model_path)347result = masking.find_mask_first_layer(348image=np.random.random((2, 2, 1)),349run_params=masking.RunParams(350**{351'model_path': self.test_model_path,352'tensor_names': {353'input': 'dense_input:0',354'first_layer': 'dense/BiasAdd:0',355'first_layer_relu': 'dense/Relu:0',356'softmax': 'dense_1/Softmax:0',357'logits': 'dense_1/BiasAdd:0',358'weights_layer_1': 'dense/MatMul/ReadVariableOp:0',359'biases_layer_1': 'dense/BiasAdd/ReadVariableOp:0'360},361'image_placeholder_shape': (1, 4),362'model_type': 'fully_connected',363'padding': (0, 0),364'strides': 0,365'activations': None,366'pixel_range': (0, 1),367}),368label_index=0,369score_method='activations',370window_size=1,371top_k=4,372gamma=0.5,373timeout=600,374num_unique_solutions=5)375self.assertEqual(result['image'].shape, (4,))376self.assertEqual(result['unmasked_logits'].shape, (2,))377self.assertEqual(result['unmasked_first_layer'].shape, (4,))378self.assertEqual(result['masks'][0].shape, (4,))379self.assertLen(result['masks'], 5)380self.assertLen(result['masked_first_layer'], 5)381self.assertLen(result['inv_masked_first_layer'], 5)382self.assertLen(result['masked_images'], 5)383self.assertLen(result['inv_masked_images'], 5)384self.assertLen(result['masked_logits'], 5)385self.assertLen(result['inv_masked_logits'], 5)386self.assertLen(result['solver_outputs'], 5)387
388def test_find_mask_full_encoding(self):389with self.test_session():390# Temporary graphs should be created inside a session. Notice multiple391# graphs are being created in this particular code. So, if each graph392# isn't created inside a separate session, the tensor names will have393# unwanted integer suffices, which then would cause problems while394# accessing tensors by name.395weights, biases = _create_temporary_tf_graph_fully_connected(396self.test_model_path)397result = masking.find_mask_full_encoding(398image=np.zeros((2, 2, 1)),399weights=weights,400biases=biases,401run_params=masking.RunParams(**{402'model_path': self.test_model_path,403'tensor_names': {404'input': 'dense_input:0',405'first_layer': 'dense/BiasAdd:0',406'first_layer_relu': 'dense/Relu:0',407'softmax': 'dense_1/Softmax:0',408'logits': 'dense_1/BiasAdd:0',409'weights_layer_1': 'dense/MatMul/ReadVariableOp:0',410'biases_layer_1': 'dense/BiasAdd/ReadVariableOp:0'},411'image_placeholder_shape': (1, 4),412'model_type': 'fully_connected',413'padding': (0, 0),414'strides': 0,415'activations': ['relu', 'linear'],416'pixel_range': (0, 1),417}),418window_size=1,419label_index=0,420delta=0,421timeout=600,422num_unique_solutions=5)423self.assertEqual(result['image'].shape, (4,))424self.assertEqual(result['unmasked_logits'].shape, (2,))425self.assertEqual(result['unmasked_first_layer'].shape, (4,))426self.assertEqual(result['masks'][0].shape, (4,))427self.assertLen(result['masks'], 5)428self.assertLen(result['masked_first_layer'], 5)429self.assertLen(result['inv_masked_first_layer'], 5)430self.assertLen(result['masked_images'], 5)431self.assertLen(result['inv_masked_images'], 5)432self.assertLen(result['masked_logits'], 5)433self.assertLen(result['inv_masked_logits'], 5)434self.assertLen(result['solver_outputs'], 5)435
436def test_find_mask_first_layer_text_cnn(self):437with self.test_session():438# Temporary graphs should be created inside a session. Notice multiple439# graphs are being created in this particular code. So, if each graph440# isn't created inside a separate session, the tensor names will have441# unwanted integer suffices, which then would cause problems while442# accessing tensors by name.443_create_temporary_tf_graph_text_cnn(self.test_model_path)444result = masking.find_mask_first_layer(445image=np.zeros(5),446run_params=masking.RunParams(447**{448'model_path': self.test_model_path,449'tensor_names': {450'input': 'input_1:0',451'embedding': 'embedding/embedding_lookup/Identity:0',452'first_layer': 'conv1d/BiasAdd:0',453'first_layer_relu': 'conv1d/Relu:0',454'logits': 'dense/BiasAdd:0',455'softmax': 'dense/Sigmoid:0',456'weights_layer_1': 'conv1d/Conv1D/ExpandDims_1:0',457'biases_layer_1': 'conv1d/BiasAdd/ReadVariableOp:0',458},459'image_placeholder_shape': (1, 5),460'model_type': 'text_cnn',461'padding': (0, 0),462'strides': 1,463'activations': None,464'pixel_range': (0, 1),465}),466window_size=1,467label_index=0,468score_method='activations',469top_k=4,470gamma=0.5,471timeout=600,472num_unique_solutions=5)473self.assertEqual(result['image'].shape, (5,))474self.assertEqual(result['unmasked_logits'].shape, (1,))475self.assertEqual(result['unmasked_first_layer'].shape, (12,))476self.assertEqual(result['masks'][0].shape, (5,))477self.assertLen(result['masks'], 5)478self.assertLen(result['masked_first_layer'], 5)479self.assertLen(result['inv_masked_first_layer'], 5)480self.assertLen(result['masked_images'], 5)481self.assertLen(result['inv_masked_images'], 5)482self.assertLen(result['masked_logits'], 5)483self.assertLen(result['inv_masked_logits'], 5)484self.assertLen(result['solver_outputs'], 5)485
486def test_find_mask_first_layer_cnn(self):487with self.test_session():488# Temporary graphs should be created inside a session. Notice multiple489# graphs are being created in this particular code. So, if each graph490# isn't created inside a separate session, the tensor names will have491# unwanted integer suffices, which then would cause problems while492# accessing tensors by name.493_create_temporary_tf_graph_cnn(self.test_model_path)494result = masking.find_mask_first_layer(495image=np.random.random((4, 4, 3)),496run_params=masking.RunParams(497**{498'model_path': self.test_model_path,499'tensor_names': {500'input': 'conv2d_input:0',501'first_layer': 'conv2d/BiasAdd:0',502'first_layer_relu': 'conv2d/Relu:0',503'logits': 'dense/BiasAdd:0',504'softmax': 'dense/Softmax:0',505'weights_layer_1': 'conv2d/Conv2D/ReadVariableOp:0',506'biases_layer_1': 'conv2d/bias/Read/ReadVariableOp:0'507},508'image_placeholder_shape': (1, 4, 4, 3),509'model_type': 'cnn',510'padding': (0, 1),511'strides': 1,512'activations': None,513'pixel_range': (0, 1),514}),515window_size=1,516label_index=0,517score_method='activations',518top_k=4,519gamma=0.5,520timeout=600,521num_unique_solutions=5)522self.assertEqual(result['image'].shape, (48,))523self.assertEqual(result['unmasked_logits'].shape, (4,))524self.assertEqual(result['unmasked_first_layer'].shape, (48,))525self.assertEqual(result['masks'][0].shape, (48,))526self.assertLen(result['masks'], 5)527self.assertLen(result['masked_first_layer'], 5)528self.assertLen(result['inv_masked_first_layer'], 5)529self.assertLen(result['masked_images'], 5)530self.assertLen(result['inv_masked_images'], 5)531self.assertLen(result['masked_logits'], 5)532self.assertLen(result['inv_masked_logits'], 5)533self.assertLen(result['solver_outputs'], 5)534
535@parameterized.parameters(536('get_saliency_map', 'activations'),537('get_saliency_map', 'integrated_gradients'),538('_get_gradients', 'gradients'),539('_get_gradients', 'blurred_gradients'),540)541def test_sort_indices(self, function_to_be_mocked, score_method):542# priority array is reverse engineered such that,543# the output of masking._sort_indices is [0, ..., num_hidden_nodes]544priority = np.moveaxis(np.arange(48).reshape((1, 4, 4, 3)), 1, -1)545with mock.patch.object(546masking, function_to_be_mocked,547return_value=priority), mock.patch.object(548masking, '_apply_blurring', return_value=mock.MagicMock()):549sorted_indices = masking._sort_indices(550session=mock.MagicMock(),551image=mock.MagicMock(),552label_index=0,553run_params=mock.MagicMock(),554unmasked_predictions={555'first_layer': priority,556'first_layer_relu': priority,},557score_method=score_method)558
559np.testing.assert_array_equal(sorted_indices, np.arange(48))560
561def test_get_gradients(self):562with self.test_session() as session:563_create_temporary_tf_graph_cnn(self.test_model_path)564
565gradients = masking._get_gradients(566session=session,567graph=tf.get_default_graph(),568features=np.ones((1, 4, 4, 3)),569label_index=0,570input_tensor_name='conv2d_input:0',571output_tensor_name='dense/Softmax:0')572
573self.assertEqual(gradients.shape, (1, 4, 4, 3))574
575@parameterized.parameters(576(('The input image should have 3 dimensions. Shape of the image: '577r'\(4, 4\)'), np.ones((4, 4))),578(('The input image should have height == width. '579r'Shape of the input image: \(4, 5, 1\)'), np.ones((4, 5, 1))),580(('The color channels of the input image has a value other than 1 or 3. '581r'Shape of the image: \(4, 4, 2\)'), np.ones((4, 4, 2))),582)583def test_verify_image_dimensions(self, error, image):584with self.assertRaisesRegex(ValueError, error):585masking._verify_image_dimensions(image)586
587@parameterized.parameters(588(np.ones((4, 4)), 'text_cnn',589r'Invalid mask shape: \(4, 4\). Expected a mask with 1 dimension.'),590(np.ones(4), 'cnn',591r'Invalid mask shape: \(4,\). Expected a mask with 2 equal dimensions.'),592(np.ones(4), 'fully_connected',593r'Invalid mask shape: \(4,\). Expected a mask with 2 equal dimensions.'),594)595def test_verify_mask_dimensions(self, mask, model_type, error):596with self.assertRaisesRegex(ValueError, error):597masking._verify_mask_dimensions(mask, model_type)598
599def test_reorder(self):600shape = tuple(np.random.randint(low=1, high=10, size=4))601
602self.assertEqual(masking._reorder(np.ones(shape)).shape,603(shape[3], shape[0], shape[1], shape[2]))604
605def test_remove_batch_axis(self):606# The batch size has to be 1.607shape = tuple(np.append([1], np.random.randint(low=1, high=10, size=3)))608
609self.assertEqual(masking._remove_batch_axis(np.ones(shape)).shape,610(shape[1], shape[2], shape[3]))611
612def test_remove_batch_axis_error(self):613shape = tuple(np.random.randint(low=2, high=10, size=4))614
615with self.assertRaisesRegex(616ValueError, ('The array doesn\'t have the batch dimension as 1. '617'Received an array with length along the batch '618'dimension: %d' % shape[0])):619masking._remove_batch_axis(np.ones(shape))620
621def test_process_text_error(self):622with self.assertRaisesRegex(ValueError,623('The text input should be a 1D numpy array. '624r'Shape of the received input: \(1, 500\)')):625masking._process_text(image=np.ones((1, 500)), run_params=None)626
627def test_get_hidden_node_location_image(self):628num_channels = 64629output_activation_map_size = 112630flattened_indices = np.arange(631num_channels * output_activation_map_size ** 2).reshape(632num_channels, output_activation_map_size,633output_activation_map_size)634true_row = np.random.randint(low=0, high=output_activation_map_size)635true_column = np.random.randint(low=0, high=output_activation_map_size)636true_channel = np.random.randint(low=0, high=num_channels)637
638(predicted_channel, predicted_row,639predicted_column) = masking._get_hidden_node_location(640flattened_index=flattened_indices[true_channel][true_row][true_column],641num_rows=output_activation_map_size,642num_columns=output_activation_map_size)643
644self.assertEqual(true_channel, predicted_channel)645self.assertEqual(true_row, predicted_row)646self.assertEqual(true_column, predicted_column)647
648def test_get_hidden_node_location_text(self):649num_channels = 128650output_activation_map_shape = (498, 1)651flattened_indices = np.arange(652num_channels * output_activation_map_shape[0]).reshape(653num_channels, output_activation_map_shape[0],654output_activation_map_shape[1])655true_row = np.random.randint(low=0, high=output_activation_map_shape[0])656true_channel = np.random.randint(low=0, high=num_channels)657true_column = 0658
659(predicted_channel, predicted_row,660predicted_column) = masking._get_hidden_node_location(661flattened_index=flattened_indices[true_channel][true_row][true_column],662num_rows=output_activation_map_shape[0],663num_columns=output_activation_map_shape[1])664
665self.assertEqual(true_channel, predicted_channel)666self.assertEqual(true_row, predicted_row)667self.assertEqual(true_column, predicted_column)668
669def test_get_activation_map_shape_image(self):670activation_maps_shape = tuple(np.random.randint(low=1, high=10, size=4))671
672self.assertEqual(673masking._get_activation_map_shape(674activation_maps_shape, model_type='cnn'),675(activation_maps_shape[1], activation_maps_shape[2]))676
677def test_get_activation_map_shape_text(self):678activation_maps_shape = tuple(np.random.randint(low=1, high=10, size=3))679
680self.assertEqual(681masking._get_activation_map_shape(682activation_maps_shape, model_type='text_cnn'),683(activation_maps_shape[1], 1))684
685@parameterized.parameters(686(('Invalid model_type: text. Expected one of - '687'fully_connected, cnn or text_cnn'), (1, 1, 1), 'text'),688(r'Invalid activation_maps_shape: \(1, 1, 1\).Expected length 4.',689(1, 1, 1), 'cnn'),690(r'Invalid activation_maps_shape: \(1, 1, 1, 1\).Expected length 3.',691(1, 1, 1, 1), 'text_cnn'),692)693def test_verify_activation_maps_shape(694self, activation_maps_shape, model_type, error):695with self.assertRaisesRegex(ValueError, error):696masking._verify_activation_maps_shape(activation_maps_shape, model_type)697
698@parameterized.parameters(('cnn', (3, 0, 1, 2)),699('text_cnn', (3, 1, 2, 0)))700def test_reshape_kernels(self, model_type, reshaped_dimensions):701activation_maps_shape = tuple(np.random.randint(low=1, high=10, size=4))702
703reshaped_kernel = masking._reshape_kernels(704kernels=np.ones(activation_maps_shape),705model_type=model_type)706
707self.assertEqual(reshaped_kernel.shape,708(activation_maps_shape[reshaped_dimensions[0]],709activation_maps_shape[reshaped_dimensions[1]],710activation_maps_shape[reshaped_dimensions[2]],711activation_maps_shape[reshaped_dimensions[3]]))712
713@parameterized.parameters(714('integrated_gradients', (4, 4, 3)),715('integrated_gradients_black_white_baselines', (4, 4, 3)),716('xrai', (4, 4)))717def test_get_saliency_map(self, saliency_method, saliency_map_shape):718with self.test_session():719# Temporary graphs should be created inside a session. Notice multiple720# graphs are being created in this particular code. So, if each graph721# isn't created inside a separate session, the tensor names will have722# unwanted integer suffices, which then would cause problems while723# accessing tensors by name.724_create_temporary_tf_graph_cnn(self.test_model_path)725self.assertEqual(726masking.get_saliency_map(727session=utils.restore_model(self.test_model_path),728features=np.random.rand(4, 4, 3),729saliency_method=saliency_method,730label=0,731input_tensor_name='conv2d_input:0',732output_tensor_name='dense/Softmax:0',733).shape,734saliency_map_shape)735
736def test_get_no_minimization_mask(self):737mock_session = mock.MagicMock()738mock_session.run.return_value = {739# Every hidden node has a receptive field of 2 x 2740'weights_layer_1': np.ones((1, 2, 2, 1)),741'first_layer_relu': np.ones((1, 4, 4, 1)),742'first_layer': np.ones((1, 4, 4, 1)),743}744mock_run_params = mock.MagicMock()745mock_run_params.strides = 1746mock_run_params.padding = (1, 1)747mock_run_params.image_placeholder_shape = (4, 4)748mock_run_params.model_type = 'cnn'749mock_run_params.pixel_range = (0, 1)750
751with mock.patch.object(752utils, 'restore_model',753return_value=mock_session), mock.patch.object(754masking, 'get_saliency_map',755return_value=mock.MagicMock()), mock.patch.object(756masking, '_reorder',757return_value=np.ones(32)), mock.patch.object(758masking, '_sort_indices',759return_value=[0, 21, 10]), mock.patch.object(760masking, '_remove_batch_axis',761return_value=mock.MagicMock()):762mask = masking.get_no_minimization_mask(763image=np.ones((4, 4)),764label_index=0,765top_k=4,766run_params=mock_run_params,767sum_attributions=False)768# The receptive field of hidden node indexed 0 on the padded image -769# 1 1 0 0 0 0770# 1 1 0 0 0 0771# 0 0 0 0 0 0772# 0 0 0 0 0 0773# 0 0 0 0 0 0774# 0 0 0 0 0 0775#776# The receptive field of hidden node indexed 21 on the padded image -777# 0 0 0 0 0 0778# 0 1 1 0 0 0779# 0 1 1 0 0 0780# 0 0 0 0 0 0781# 0 0 0 0 0 0782# 0 0 0 0 0 0783#784# The receptive field of hidden node indexed 10 on the padded image -785# 0 0 0 0 0 0786# 0 0 0 0 0 0787# 0 0 1 1 0 0788# 0 0 1 1 0 0789# 0 0 0 0 0 0790# 0 0 0 0 0 0791#792# Union of all these masks gives us the no minimisation mask.793# Then, the padding (1, 1) is removed from the padded image and we the794# output.795# becomes -796# 1 1 0 0797# 1 1 1 0798# 0 1 1 0799# 0 0 0 0800
801np.testing.assert_allclose(mask, [[1, 1, 0, 0],802[1, 1, 1, 0],803[0, 1, 1, 0],804[0, 0, 0, 0]])805
806def test_get_no_minimization_mask_text(self):807mock_session = mock.MagicMock()808mock_session.run.return_value = {809# Every hidden node has a receptive field of 3810'weights_layer_1': np.ones((1, 3, 10, 12)),811'first_layer_relu': np.ones((1, 11, 1)),812'first_layer': np.ones((1, 11, 1)),813}814mock_run_params = mock.MagicMock()815mock_run_params.strides = 1816mock_run_params.padding = (1, 2)817mock_run_params.image_placeholder_shape = (1, 10)818mock_run_params.model_type = 'text_cnn'819with mock.patch.object(820utils, 'restore_model',821return_value=mock_session), mock.patch.object(822masking, 'get_saliency_map',823return_value=mock.MagicMock()), mock.patch.object(824masking, '_reorder',825return_value=np.ones(20)), mock.patch.object(826masking, '_sort_indices',827return_value=[0, 16]), mock.patch.object(828masking, '_remove_batch_axis',829return_value=mock.MagicMock()):830mask = masking.get_no_minimization_mask(831image=np.ones(10),832label_index=0,833top_k=4,834run_params=mock_run_params,835sum_attributions=False)836# The receptive field of hidden node indexed 0 on the text -837# 1 1 1 0 0 0 0 0 0 0 0 0 0838# The receptive field of hidden node indexed 16839# (activation map indexed 1, 5th position) on the text -840# 0 0 0 0 0 1 1 1 0 0 0 0 0841# We get the below result after taking their union and removing the padding.842np.testing.assert_allclose(mask, [1, 1, 0, 0, 1, 1, 1, 0, 0, 0])843
844if __name__ == '__main__':845absltest.main()846