google-research

Форк
0
/
masking_test.py 
845 строк · 34.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for third_party.google_research.google_research.smug_saliency.masking."""
17
import os
18
import shutil
19
import tempfile
20
from unittest import mock
21

22
from absl import flags
23
from absl.testing import absltest
24
from absl.testing import parameterized
25
import numpy as np
26
import tensorflow.compat.v1 as tf
27
from tensorflow.compat.v1 import gfile
28
import z3
29

30
from smug_saliency import masking
31
from smug_saliency import utils
32

33
FLAGS = flags.FLAGS
34

35
tf.disable_eager_execution()
36

37

38
def _get_z3_var(index):
39
  return z3.Int('z3var_' + str(index))
40

41

42
def _create_temporary_tf_graph_fully_connected(test_model_path):
43
  # Create a test model with 1 hiddenlayer with 4 nodes, and an output layer
44
  # with softmax activation function and 2 nodes.
45
  test_model = tf.keras.Sequential([
46
      tf.keras.layers.Dense(units=4, input_shape=(4,), activation='relu'),
47
      tf.keras.layers.Dense(units=2, activation='softmax')])
48
  test_model.compile(
49
      optimizer=tf.keras.optimizers.Adam(1e-3),
50
      loss=tf.keras.losses.CategoricalCrossentropy(),
51
      metrics=['accuracy'])
52
  # Name of the tensors in the graph:
53
  # input- dense_input:0
54
  # weights_first_layer - dense/MatMul/ReadVariableOp:0
55
  # biases_first_layer - dense/BiasAdd/ReadVariableOp:0
56
  # first_layer_input- dense/BiasAdd:0
57
  # first_layer_relu_output - dense/Relu:0
58
  # final_layer_input - dense_1/BiasAdd:0
59
  # final_layer_softmax_output - dense_1/Softmax:0
60
  tf.saved_model.simple_save(
61
      session=tf.keras.backend.get_session(),
62
      export_dir=test_model_path,
63
      inputs={'input': test_model.inputs[0]},
64
      outputs={'output': test_model.outputs[0]})
65

66
  weights = []
67
  biases = []
68
  weights_and_biases = test_model.get_weights()
69
  for i in range(len(weights_and_biases) // 2):
70
    weights.append(weights_and_biases[2 * i])
71
    biases.append(weights_and_biases[2 * i + 1])
72
  return weights, biases
73

74

75
def _create_temporary_tf_graph_cnn(test_model_path):
76
  # Create a test model with 1 conv layer with 3 kernels shaped (2, 2), and an
77
  # output layer with 4 nodes and softmax activation function.
78

79
  test_model = tf.keras.Sequential([
80
      tf.keras.layers.Conv2D(
81
          3,
82
          kernel_size=(2, 2),
83
          strides=(1, 1),
84
          padding='same',
85
          activation='relu',
86
          input_shape=(4, 4, 3)),
87
      tf.keras.layers.Flatten(),
88
      tf.keras.layers.Dense(4, activation='softmax')])
89
  test_model.compile(
90
      optimizer=tf.keras.optimizers.Adam(1e-3),
91
      loss=tf.keras.losses.CategoricalCrossentropy(),
92
      metrics=['accuracy'])
93
  # Name of the tensors in the graph:
94
  # input - conv2d_input:0
95
  # weights_first_layer - conv2d/Conv2D/ReadVariableOp:0
96
  # biases_first_layer - conv2d/bias/Read/ReadVariableOp:0
97
  # first_layer_input - conv2d/BiasAdd:0
98
  # first_layer_relu_output - conv2d/Relu:0
99
  # final_layer_input - dense/BiasAdd:0
100
  # final_layer_softmax_output - dense/Softmax:0
101
  tf.saved_model.simple_save(
102
      session=tf.keras.backend.get_session(),
103
      export_dir=test_model_path,
104
      inputs={'input': test_model.inputs[0]},
105
      outputs={'output': test_model.outputs[0]})
106

107

108
def _create_temporary_tf_graph_text_cnn(test_model_path):
109
  # Create a test model with 1 conv layer with 4 kernels shaped (3, 10),
110
  # and an output layer with 1 nodes and sigmoid activation function.
111
  test_model = tf.keras.Sequential([
112
      tf.keras.Input(shape=(5,)),  # max words = 5
113
      tf.keras.layers.Embedding(10, 10),  # num top words = 10
114
      tf.keras.layers.Conv1D(
115
          filters=4, strides=1, kernel_size=3, activation='relu'),
116
      tf.keras.layers.GlobalMaxPooling1D(),
117
      tf.keras.layers.Dense(1, activation='sigmoid')])
118
  test_model.compile(
119
      optimizer=tf.keras.optimizers.Adam(1e-3),
120
      loss=tf.keras.losses.CategoricalCrossentropy(),
121
      metrics=['accuracy'])
122

123
  # Name of the tensors in the graph:
124
  # input - input_1:0
125
  # embedding - embedding/embedding_lookup/Identity:0
126
  # weights_first_layer - conv1d/Conv1D/ExpandDims_1:0
127
  # biases_first_layer - conv1d/BiasAdd/ReadVariableOp:0
128
  # first_layer_input - conv1d/BiasAdd:0
129
  # first_layer_relu_output - conv1d/Relu:0
130
  # final_layer_input - dense/BiasAdd:0
131
  # final_layer_softmax_output - dense/Sigmoid:0
132
  tf.saved_model.simple_save(
133
      session=tf.keras.backend.get_session(),
134
      export_dir=test_model_path,
135
      inputs={'input': test_model.inputs[0]},
136
      outputs={'output': test_model.outputs[0]})
137

138

139
class MaskingLibTest(parameterized.TestCase, tf.test.TestCase):
140

141
  def setUp(self):
142
    super().setUp()
143
    self.test_model_path = os.path.join(
144
        tempfile.mkdtemp(dir=FLAGS.test_tmpdir), 'checkpoint')
145
    gfile.MakeDirs(self.test_model_path)
146

147
  def tearDown(self):
148
    shutil.rmtree(self.test_model_path)
149
    super().tearDown()
150

151
  @parameterized.parameters(1, 3)
152
  def test_encode_input(self, image_channels):
153
    # Creates a random image and checks if the encoded image, after multiplying
154
    # the mask bits (set to 1) is the same as the original image.
155
    image_edge_length = 2
156
    image = np.random.rand(image_edge_length, image_edge_length, image_channels)
157
    z3_var = _get_z3_var(index=0)
158

159
    # encoded_image has dimensions
160
    # (image_channels, image_edge_length, image_edge_length)
161
    encoded_image = masking._encode_input(
162
        image=image,
163
        z3_mask=[z3_var for _ in range(image_edge_length ** 2)])
164
    solver = z3.Solver()
165
    solver.add(z3_var == 1)
166
    # Swap the axes of the image so that it has the same dimensions as the
167
    # encoded image.
168
    image = masking._reorder(image).reshape(-1)
169
    encoded_image = utils.flatten_nested_lists(encoded_image)
170
    for i in range(image_channels * image_edge_length ** 2):
171
      solver.add(encoded_image[i] == image[i])
172

173
    self.assertEqual(str(solver.check()), 'sat')
174

175
  def test_formulate_smt_constraints_convolution_layer(self):
176
    with self.test_session():
177
      # Temporary graphs should be created inside a session. Notice multiple
178
      # graphs are being created in this particular code. So, if each graph
179
      # isn't created inside a separate session, the tensor names will have
180
      # unwanted integer suffices, which then would cause problems while
181
      # accessing tensors by name.
182
      _create_temporary_tf_graph_cnn(self.test_model_path)
183
    image_edge_length = 4
184
    image_channels = 3
185
    # The 1st convolution layer has 48 neurons.
186
    top_k = np.random.randint(low=1, high=48)
187
    image = np.ones((image_edge_length, image_edge_length, image_channels))
188
    tensor_names = {
189
        'input': 'conv2d_input:0',
190
        'first_layer': 'conv2d/BiasAdd:0',
191
        'first_layer_relu': 'conv2d/Relu:0',
192
        'logits': 'dense/BiasAdd:0',
193
        'softmax': 'dense/Softmax:0',
194
        'weights_layer_1': 'conv2d/Conv2D/ReadVariableOp:0',
195
        'biases_layer_1': 'conv2d/bias/Read/ReadVariableOp:0'}
196
    session = utils.restore_model(self.test_model_path)
197
    cnn_predictions = session.run(
198
        tensor_names,
199
        feed_dict={
200
            tensor_names['input']: image.reshape(
201
                (1, image_edge_length, image_edge_length, image_channels))})
202

203
    z3_mask = []
204
    mask_id_to_var = {}
205
    for row in range(image_edge_length):
206
      for column in range(image_edge_length):
207
        mask_id = image_edge_length * row + column
208
        if mask_id in mask_id_to_var.keys():
209
          z3_var = mask_id_to_var[mask_id]
210
        else:
211
          mask_name = f'mask_{mask_id}'
212
          z3_var = z3.Int(mask_name)
213
          mask_id_to_var[mask_name] = z3_var
214
        z3_mask.append(z3_var)
215

216
    first_layer_activations = masking._reorder(masking._remove_batch_axis(
217
        cnn_predictions['first_layer'])).reshape(-1)
218
    masked_input = masking._encode_input(image=image, z3_mask=z3_mask)
219

220
    z3_optimizer = masking._formulate_smt_constraints_convolution_layer(
221
        z3_optimizer=utils.ImageOptimizer(
222
            z3_mask=z3_mask,
223
            window_size=1,
224
            edge_length=image_edge_length),
225
        kernels=masking._reorder(cnn_predictions['weights_layer_1']),
226
        biases=cnn_predictions['biases_layer_1'],
227
        chosen_indices=first_layer_activations.argsort()[-top_k:],
228
        conv_activations=first_layer_activations,
229
        input_activation_maps=masked_input,
230
        output_activation_map_shape=(image_edge_length, image_edge_length),
231
        strides=1,
232
        padding=(0, 1),
233
        gamma=0.5)
234
    mask, result = z3_optimizer.generate_mask()
235

236
    self.assertEqual(result, 'sat')
237
    self.assertEqual(mask.shape, (image_edge_length, image_edge_length))
238
    session.close()
239

240
  def test_formulate_smt_constraints_convolution_layer_text(self):
241
    with self.test_session():
242
      # Temporary graphs should be created inside a session. Notice multiple
243
      # graphs are being created in this particular code. So, if each graph
244
      # isn't created inside a separate session, the tensor names will have
245
      # unwanted integer suffices, which then would cause problems while
246
      # accessing tensors by name.
247
      _create_temporary_tf_graph_text_cnn(self.test_model_path)
248

249
    # The 1st convolution layer has 12 neurons.
250
    image = np.ones(5)
251
    tensor_names = {
252
        'input': 'input_1:0',
253
        'embedding': 'embedding/embedding_lookup/Identity:0',
254
        'first_layer': 'conv1d/BiasAdd:0',
255
        'first_layer_relu': 'conv1d/Relu:0',
256
        'logits': 'dense/BiasAdd:0',
257
        'softmax': 'dense/Sigmoid:0',
258
        'weights_layer_1': 'conv1d/Conv1D/ExpandDims_1:0',
259
        'biases_layer_1': 'conv1d/BiasAdd/ReadVariableOp:0'}
260
    session = utils.restore_model(self.test_model_path)
261
    cnn_predictions = session.run(
262
        tensor_names, feed_dict={
263
            tensor_names['input']: image.reshape(1, 5)})
264
    text_embedding = masking._remove_batch_axis(cnn_predictions['embedding'])
265
    z3_mask = [z3.Int('mask_%d' % i) for i in range(text_embedding.shape[0])]
266
    masked_input = []
267
    for mask_bit, embedding_row in zip(z3_mask, text_embedding):
268
      masked_input.append([z3.ToReal(mask_bit) * i for i in embedding_row])
269
    first_layer_activations = masking._reorder(
270
        masking._remove_batch_axis(cnn_predictions['first_layer'])).reshape(-1)
271
    z3_optimizer = masking._formulate_smt_constraints_convolution_layer(
272
        z3_optimizer=utils.TextOptimizer(z3_mask=z3_mask),
273
        kernels=masking._reshape_kernels(
274
            kernels=cnn_predictions['weights_layer_1'],
275
            model_type='text_cnn'),
276
        biases=cnn_predictions['biases_layer_1'],
277
        chosen_indices=first_layer_activations.argsort()[-5:],
278
        conv_activations=first_layer_activations,
279
        input_activation_maps=[masked_input],
280
        output_activation_map_shape=masking._get_activation_map_shape(
281
            activation_maps_shape=cnn_predictions['first_layer'].shape,
282
            model_type='text_cnn'),
283
        strides=1,
284
        padding=(0, 0),
285
        gamma=0.5)
286
    mask, result = z3_optimizer.generate_mask()
287

288
    self.assertEqual(result, 'sat')
289
    self.assertEqual(mask.shape, (5,))
290
    session.close()
291

292
  def test_formulate_smt_constraints_fully_connected_layer(self):
293
    # For a neural network with 4 hidden nodes in the first layer, with the
294
    # original first layer activations = 1, and the SMT encoding of
295
    # the first hidden nodes- [mask_0, mask_1, mask_2, mask_3]. For
296
    # masked_activation > delta * original (k such constraints), only k mask
297
    # bits should be set to 1 and the others to 0.
298
    image_edge_length = 2
299
    top_k = np.random.randint(low=1, high=image_edge_length ** 2)
300
    z3_mask = [_get_z3_var(index=i) for i in range(image_edge_length ** 2)]
301
    smt_first_layer = [1 * z3.ToReal(i) for i in z3_mask]
302
    nn_first_layer = np.ones(len(smt_first_layer))
303

304
    z3_optimizer = utils.ImageOptimizer(
305
        z3_mask=z3_mask, window_size=1, edge_length=image_edge_length)
306
    z3_optimizer = masking._formulate_smt_constraints_fully_connected_layer(
307
        z3_optimizer=z3_optimizer,
308
        smt_first_layer=smt_first_layer,
309
        nn_first_layer=nn_first_layer,
310
        top_k=top_k,
311
        gamma=np.random.rand())
312
    mask, result = z3_optimizer._optimize()
313

314
    self.assertEqual(result, 'sat')
315
    self.assertEqual(np.sum(mask), top_k)
316

317
  def test_smt_constraints_final_layer(self):
318
    # The SMT encoding of the final layer - [mask_0, mask_1, mask_2, mask_3].
319
    # For logit_label_index > rest, the mask_bit at label_index should be set to
320
    # 1.
321
    image_edge_length = 2
322
    label_index = np.random.randint(low=0, high=image_edge_length ** 2)
323
    z3_mask = [_get_z3_var(index=i) for i in range(image_edge_length ** 2)]
324
    smt_output = [1 * z3.ToReal(i) for i in z3_mask]
325

326
    z3_optimizer = utils.ImageOptimizer(
327
        z3_mask=z3_mask, window_size=1, edge_length=image_edge_length)
328
    z3_optimizer = masking._formulate_smt_constraints_final_layer(
329
        z3_optimizer=z3_optimizer,
330
        smt_output=smt_output,
331
        delta=np.random.rand(),
332
        label_index=label_index)
333
    mask, result = z3_optimizer._optimize()
334

335
    self.assertEqual(result, 'sat')
336
    self.assertEqual(mask.reshape(-1)[label_index], 1)
337
    self.assertEqual(np.sum(mask), 1)
338

339
  def test_find_mask_first_layer(self):
340
    with self.test_session():
341
      # Temporary graphs should be created inside a session. Notice multiple
342
      # graphs are being created in this particular code. So, if each graph
343
      # isn't created inside a separate session, the tensor names will have
344
      # unwanted integer suffices, which then would cause problems while
345
      # accessing tensors by name.
346
      _create_temporary_tf_graph_fully_connected(self.test_model_path)
347
    result = masking.find_mask_first_layer(
348
        image=np.random.random((2, 2, 1)),
349
        run_params=masking.RunParams(
350
            **{
351
                'model_path': self.test_model_path,
352
                'tensor_names': {
353
                    'input': 'dense_input:0',
354
                    'first_layer': 'dense/BiasAdd:0',
355
                    'first_layer_relu': 'dense/Relu:0',
356
                    'softmax': 'dense_1/Softmax:0',
357
                    'logits': 'dense_1/BiasAdd:0',
358
                    'weights_layer_1': 'dense/MatMul/ReadVariableOp:0',
359
                    'biases_layer_1': 'dense/BiasAdd/ReadVariableOp:0'
360
                },
361
                'image_placeholder_shape': (1, 4),
362
                'model_type': 'fully_connected',
363
                'padding': (0, 0),
364
                'strides': 0,
365
                'activations': None,
366
                'pixel_range': (0, 1),
367
            }),
368
        label_index=0,
369
        score_method='activations',
370
        window_size=1,
371
        top_k=4,
372
        gamma=0.5,
373
        timeout=600,
374
        num_unique_solutions=5)
375
    self.assertEqual(result['image'].shape, (4,))
376
    self.assertEqual(result['unmasked_logits'].shape, (2,))
377
    self.assertEqual(result['unmasked_first_layer'].shape, (4,))
378
    self.assertEqual(result['masks'][0].shape, (4,))
379
    self.assertLen(result['masks'], 5)
380
    self.assertLen(result['masked_first_layer'], 5)
381
    self.assertLen(result['inv_masked_first_layer'], 5)
382
    self.assertLen(result['masked_images'], 5)
383
    self.assertLen(result['inv_masked_images'], 5)
384
    self.assertLen(result['masked_logits'], 5)
385
    self.assertLen(result['inv_masked_logits'], 5)
386
    self.assertLen(result['solver_outputs'], 5)
387

388
  def test_find_mask_full_encoding(self):
389
    with self.test_session():
390
      # Temporary graphs should be created inside a session. Notice multiple
391
      # graphs are being created in this particular code. So, if each graph
392
      # isn't created inside a separate session, the tensor names will have
393
      # unwanted integer suffices, which then would cause problems while
394
      # accessing tensors by name.
395
      weights, biases = _create_temporary_tf_graph_fully_connected(
396
          self.test_model_path)
397
    result = masking.find_mask_full_encoding(
398
        image=np.zeros((2, 2, 1)),
399
        weights=weights,
400
        biases=biases,
401
        run_params=masking.RunParams(**{
402
            'model_path': self.test_model_path,
403
            'tensor_names': {
404
                'input': 'dense_input:0',
405
                'first_layer': 'dense/BiasAdd:0',
406
                'first_layer_relu': 'dense/Relu:0',
407
                'softmax': 'dense_1/Softmax:0',
408
                'logits': 'dense_1/BiasAdd:0',
409
                'weights_layer_1': 'dense/MatMul/ReadVariableOp:0',
410
                'biases_layer_1': 'dense/BiasAdd/ReadVariableOp:0'},
411
            'image_placeholder_shape': (1, 4),
412
            'model_type': 'fully_connected',
413
            'padding': (0, 0),
414
            'strides': 0,
415
            'activations': ['relu', 'linear'],
416
            'pixel_range': (0, 1),
417
        }),
418
        window_size=1,
419
        label_index=0,
420
        delta=0,
421
        timeout=600,
422
        num_unique_solutions=5)
423
    self.assertEqual(result['image'].shape, (4,))
424
    self.assertEqual(result['unmasked_logits'].shape, (2,))
425
    self.assertEqual(result['unmasked_first_layer'].shape, (4,))
426
    self.assertEqual(result['masks'][0].shape, (4,))
427
    self.assertLen(result['masks'], 5)
428
    self.assertLen(result['masked_first_layer'], 5)
429
    self.assertLen(result['inv_masked_first_layer'], 5)
430
    self.assertLen(result['masked_images'], 5)
431
    self.assertLen(result['inv_masked_images'], 5)
432
    self.assertLen(result['masked_logits'], 5)
433
    self.assertLen(result['inv_masked_logits'], 5)
434
    self.assertLen(result['solver_outputs'], 5)
435

436
  def test_find_mask_first_layer_text_cnn(self):
437
    with self.test_session():
438
      # Temporary graphs should be created inside a session. Notice multiple
439
      # graphs are being created in this particular code. So, if each graph
440
      # isn't created inside a separate session, the tensor names will have
441
      # unwanted integer suffices, which then would cause problems while
442
      # accessing tensors by name.
443
      _create_temporary_tf_graph_text_cnn(self.test_model_path)
444
    result = masking.find_mask_first_layer(
445
        image=np.zeros(5),
446
        run_params=masking.RunParams(
447
            **{
448
                'model_path': self.test_model_path,
449
                'tensor_names': {
450
                    'input': 'input_1:0',
451
                    'embedding': 'embedding/embedding_lookup/Identity:0',
452
                    'first_layer': 'conv1d/BiasAdd:0',
453
                    'first_layer_relu': 'conv1d/Relu:0',
454
                    'logits': 'dense/BiasAdd:0',
455
                    'softmax': 'dense/Sigmoid:0',
456
                    'weights_layer_1': 'conv1d/Conv1D/ExpandDims_1:0',
457
                    'biases_layer_1': 'conv1d/BiasAdd/ReadVariableOp:0',
458
                },
459
                'image_placeholder_shape': (1, 5),
460
                'model_type': 'text_cnn',
461
                'padding': (0, 0),
462
                'strides': 1,
463
                'activations': None,
464
                'pixel_range': (0, 1),
465
            }),
466
        window_size=1,
467
        label_index=0,
468
        score_method='activations',
469
        top_k=4,
470
        gamma=0.5,
471
        timeout=600,
472
        num_unique_solutions=5)
473
    self.assertEqual(result['image'].shape, (5,))
474
    self.assertEqual(result['unmasked_logits'].shape, (1,))
475
    self.assertEqual(result['unmasked_first_layer'].shape, (12,))
476
    self.assertEqual(result['masks'][0].shape, (5,))
477
    self.assertLen(result['masks'], 5)
478
    self.assertLen(result['masked_first_layer'], 5)
479
    self.assertLen(result['inv_masked_first_layer'], 5)
480
    self.assertLen(result['masked_images'], 5)
481
    self.assertLen(result['inv_masked_images'], 5)
482
    self.assertLen(result['masked_logits'], 5)
483
    self.assertLen(result['inv_masked_logits'], 5)
484
    self.assertLen(result['solver_outputs'], 5)
485

486
  def test_find_mask_first_layer_cnn(self):
487
    with self.test_session():
488
      # Temporary graphs should be created inside a session. Notice multiple
489
      # graphs are being created in this particular code. So, if each graph
490
      # isn't created inside a separate session, the tensor names will have
491
      # unwanted integer suffices, which then would cause problems while
492
      # accessing tensors by name.
493
      _create_temporary_tf_graph_cnn(self.test_model_path)
494
    result = masking.find_mask_first_layer(
495
        image=np.random.random((4, 4, 3)),
496
        run_params=masking.RunParams(
497
            **{
498
                'model_path': self.test_model_path,
499
                'tensor_names': {
500
                    'input': 'conv2d_input:0',
501
                    'first_layer': 'conv2d/BiasAdd:0',
502
                    'first_layer_relu': 'conv2d/Relu:0',
503
                    'logits': 'dense/BiasAdd:0',
504
                    'softmax': 'dense/Softmax:0',
505
                    'weights_layer_1': 'conv2d/Conv2D/ReadVariableOp:0',
506
                    'biases_layer_1': 'conv2d/bias/Read/ReadVariableOp:0'
507
                },
508
                'image_placeholder_shape': (1, 4, 4, 3),
509
                'model_type': 'cnn',
510
                'padding': (0, 1),
511
                'strides': 1,
512
                'activations': None,
513
                'pixel_range': (0, 1),
514
            }),
515
        window_size=1,
516
        label_index=0,
517
        score_method='activations',
518
        top_k=4,
519
        gamma=0.5,
520
        timeout=600,
521
        num_unique_solutions=5)
522
    self.assertEqual(result['image'].shape, (48,))
523
    self.assertEqual(result['unmasked_logits'].shape, (4,))
524
    self.assertEqual(result['unmasked_first_layer'].shape, (48,))
525
    self.assertEqual(result['masks'][0].shape, (48,))
526
    self.assertLen(result['masks'], 5)
527
    self.assertLen(result['masked_first_layer'], 5)
528
    self.assertLen(result['inv_masked_first_layer'], 5)
529
    self.assertLen(result['masked_images'], 5)
530
    self.assertLen(result['inv_masked_images'], 5)
531
    self.assertLen(result['masked_logits'], 5)
532
    self.assertLen(result['inv_masked_logits'], 5)
533
    self.assertLen(result['solver_outputs'], 5)
534

535
  @parameterized.parameters(
536
      ('get_saliency_map', 'activations'),
537
      ('get_saliency_map', 'integrated_gradients'),
538
      ('_get_gradients', 'gradients'),
539
      ('_get_gradients', 'blurred_gradients'),
540
  )
541
  def test_sort_indices(self, function_to_be_mocked, score_method):
542
    # priority array is reverse engineered such that,
543
    # the output of masking._sort_indices is [0, ..., num_hidden_nodes]
544
    priority = np.moveaxis(np.arange(48).reshape((1, 4, 4, 3)), 1, -1)
545
    with mock.patch.object(
546
        masking, function_to_be_mocked,
547
        return_value=priority), mock.patch.object(
548
            masking, '_apply_blurring', return_value=mock.MagicMock()):
549
      sorted_indices = masking._sort_indices(
550
          session=mock.MagicMock(),
551
          image=mock.MagicMock(),
552
          label_index=0,
553
          run_params=mock.MagicMock(),
554
          unmasked_predictions={
555
              'first_layer': priority,
556
              'first_layer_relu': priority,},
557
          score_method=score_method)
558

559
    np.testing.assert_array_equal(sorted_indices, np.arange(48))
560

561
  def test_get_gradients(self):
562
    with self.test_session() as session:
563
      _create_temporary_tf_graph_cnn(self.test_model_path)
564

565
      gradients = masking._get_gradients(
566
          session=session,
567
          graph=tf.get_default_graph(),
568
          features=np.ones((1, 4, 4, 3)),
569
          label_index=0,
570
          input_tensor_name='conv2d_input:0',
571
          output_tensor_name='dense/Softmax:0')
572

573
    self.assertEqual(gradients.shape, (1, 4, 4, 3))
574

575
  @parameterized.parameters(
576
      (('The input image should have 3 dimensions. Shape of the image: '
577
        r'\(4, 4\)'), np.ones((4, 4))),
578
      (('The input image should have height == width. '
579
        r'Shape of the input image: \(4, 5, 1\)'), np.ones((4, 5, 1))),
580
      (('The color channels of the input image has a value other than 1 or 3. '
581
        r'Shape of the image: \(4, 4, 2\)'), np.ones((4, 4, 2))),
582
  )
583
  def test_verify_image_dimensions(self, error, image):
584
    with self.assertRaisesRegex(ValueError, error):
585
      masking._verify_image_dimensions(image)
586

587
  @parameterized.parameters(
588
      (np.ones((4, 4)), 'text_cnn',
589
       r'Invalid mask shape: \(4, 4\). Expected a mask with 1 dimension.'),
590
      (np.ones(4), 'cnn',
591
       r'Invalid mask shape: \(4,\). Expected a mask with 2 equal dimensions.'),
592
      (np.ones(4), 'fully_connected',
593
       r'Invalid mask shape: \(4,\). Expected a mask with 2 equal dimensions.'),
594
  )
595
  def test_verify_mask_dimensions(self, mask, model_type, error):
596
    with self.assertRaisesRegex(ValueError, error):
597
      masking._verify_mask_dimensions(mask, model_type)
598

599
  def test_reorder(self):
600
    shape = tuple(np.random.randint(low=1, high=10, size=4))
601

602
    self.assertEqual(masking._reorder(np.ones(shape)).shape,
603
                     (shape[3], shape[0], shape[1], shape[2]))
604

605
  def test_remove_batch_axis(self):
606
    # The batch size has to be 1.
607
    shape = tuple(np.append([1], np.random.randint(low=1, high=10, size=3)))
608

609
    self.assertEqual(masking._remove_batch_axis(np.ones(shape)).shape,
610
                     (shape[1], shape[2], shape[3]))
611

612
  def test_remove_batch_axis_error(self):
613
    shape = tuple(np.random.randint(low=2, high=10, size=4))
614

615
    with self.assertRaisesRegex(
616
        ValueError, ('The array doesn\'t have the batch dimension as 1. '
617
                     'Received an array with length along the batch '
618
                     'dimension: %d' % shape[0])):
619
      masking._remove_batch_axis(np.ones(shape))
620

621
  def test_process_text_error(self):
622
    with self.assertRaisesRegex(ValueError,
623
                                ('The text input should be a 1D numpy array. '
624
                                 r'Shape of the received input: \(1, 500\)')):
625
      masking._process_text(image=np.ones((1, 500)), run_params=None)
626

627
  def test_get_hidden_node_location_image(self):
628
    num_channels = 64
629
    output_activation_map_size = 112
630
    flattened_indices = np.arange(
631
        num_channels * output_activation_map_size ** 2).reshape(
632
            num_channels, output_activation_map_size,
633
            output_activation_map_size)
634
    true_row = np.random.randint(low=0, high=output_activation_map_size)
635
    true_column = np.random.randint(low=0, high=output_activation_map_size)
636
    true_channel = np.random.randint(low=0, high=num_channels)
637

638
    (predicted_channel, predicted_row,
639
     predicted_column) = masking._get_hidden_node_location(
640
         flattened_index=flattened_indices[true_channel][true_row][true_column],
641
         num_rows=output_activation_map_size,
642
         num_columns=output_activation_map_size)
643

644
    self.assertEqual(true_channel, predicted_channel)
645
    self.assertEqual(true_row, predicted_row)
646
    self.assertEqual(true_column, predicted_column)
647

648
  def test_get_hidden_node_location_text(self):
649
    num_channels = 128
650
    output_activation_map_shape = (498, 1)
651
    flattened_indices = np.arange(
652
        num_channels * output_activation_map_shape[0]).reshape(
653
            num_channels, output_activation_map_shape[0],
654
            output_activation_map_shape[1])
655
    true_row = np.random.randint(low=0, high=output_activation_map_shape[0])
656
    true_channel = np.random.randint(low=0, high=num_channels)
657
    true_column = 0
658

659
    (predicted_channel, predicted_row,
660
     predicted_column) = masking._get_hidden_node_location(
661
         flattened_index=flattened_indices[true_channel][true_row][true_column],
662
         num_rows=output_activation_map_shape[0],
663
         num_columns=output_activation_map_shape[1])
664

665
    self.assertEqual(true_channel, predicted_channel)
666
    self.assertEqual(true_row, predicted_row)
667
    self.assertEqual(true_column, predicted_column)
668

669
  def test_get_activation_map_shape_image(self):
670
    activation_maps_shape = tuple(np.random.randint(low=1, high=10, size=4))
671

672
    self.assertEqual(
673
        masking._get_activation_map_shape(
674
            activation_maps_shape, model_type='cnn'),
675
        (activation_maps_shape[1], activation_maps_shape[2]))
676

677
  def test_get_activation_map_shape_text(self):
678
    activation_maps_shape = tuple(np.random.randint(low=1, high=10, size=3))
679

680
    self.assertEqual(
681
        masking._get_activation_map_shape(
682
            activation_maps_shape, model_type='text_cnn'),
683
        (activation_maps_shape[1], 1))
684

685
  @parameterized.parameters(
686
      (('Invalid model_type: text. Expected one of - '
687
        'fully_connected, cnn or text_cnn'), (1, 1, 1), 'text'),
688
      (r'Invalid activation_maps_shape: \(1, 1, 1\).Expected length 4.',
689
       (1, 1, 1), 'cnn'),
690
      (r'Invalid activation_maps_shape: \(1, 1, 1, 1\).Expected length 3.',
691
       (1, 1, 1, 1), 'text_cnn'),
692
      )
693
  def test_verify_activation_maps_shape(
694
      self, activation_maps_shape, model_type, error):
695
    with self.assertRaisesRegex(ValueError, error):
696
      masking._verify_activation_maps_shape(activation_maps_shape, model_type)
697

698
  @parameterized.parameters(('cnn', (3, 0, 1, 2)),
699
                            ('text_cnn', (3, 1, 2, 0)))
700
  def test_reshape_kernels(self, model_type, reshaped_dimensions):
701
    activation_maps_shape = tuple(np.random.randint(low=1, high=10, size=4))
702

703
    reshaped_kernel = masking._reshape_kernels(
704
        kernels=np.ones(activation_maps_shape),
705
        model_type=model_type)
706

707
    self.assertEqual(reshaped_kernel.shape,
708
                     (activation_maps_shape[reshaped_dimensions[0]],
709
                      activation_maps_shape[reshaped_dimensions[1]],
710
                      activation_maps_shape[reshaped_dimensions[2]],
711
                      activation_maps_shape[reshaped_dimensions[3]]))
712

713
  @parameterized.parameters(
714
      ('integrated_gradients', (4, 4, 3)),
715
      ('integrated_gradients_black_white_baselines', (4, 4, 3)),
716
      ('xrai', (4, 4)))
717
  def test_get_saliency_map(self, saliency_method, saliency_map_shape):
718
    with self.test_session():
719
      # Temporary graphs should be created inside a session. Notice multiple
720
      # graphs are being created in this particular code. So, if each graph
721
      # isn't created inside a separate session, the tensor names will have
722
      # unwanted integer suffices, which then would cause problems while
723
      # accessing tensors by name.
724
      _create_temporary_tf_graph_cnn(self.test_model_path)
725
    self.assertEqual(
726
        masking.get_saliency_map(
727
            session=utils.restore_model(self.test_model_path),
728
            features=np.random.rand(4, 4, 3),
729
            saliency_method=saliency_method,
730
            label=0,
731
            input_tensor_name='conv2d_input:0',
732
            output_tensor_name='dense/Softmax:0',
733
            ).shape,
734
        saliency_map_shape)
735

736
  def test_get_no_minimization_mask(self):
737
    mock_session = mock.MagicMock()
738
    mock_session.run.return_value = {
739
        # Every hidden node has a receptive field of 2 x 2
740
        'weights_layer_1': np.ones((1, 2, 2, 1)),
741
        'first_layer_relu': np.ones((1, 4, 4, 1)),
742
        'first_layer': np.ones((1, 4, 4, 1)),
743
    }
744
    mock_run_params = mock.MagicMock()
745
    mock_run_params.strides = 1
746
    mock_run_params.padding = (1, 1)
747
    mock_run_params.image_placeholder_shape = (4, 4)
748
    mock_run_params.model_type = 'cnn'
749
    mock_run_params.pixel_range = (0, 1)
750

751
    with mock.patch.object(
752
        utils, 'restore_model',
753
        return_value=mock_session), mock.patch.object(
754
            masking, 'get_saliency_map',
755
            return_value=mock.MagicMock()), mock.patch.object(
756
                masking, '_reorder',
757
                return_value=np.ones(32)), mock.patch.object(
758
                    masking, '_sort_indices',
759
                    return_value=[0, 21, 10]), mock.patch.object(
760
                        masking, '_remove_batch_axis',
761
                        return_value=mock.MagicMock()):
762
      mask = masking.get_no_minimization_mask(
763
          image=np.ones((4, 4)),
764
          label_index=0,
765
          top_k=4,
766
          run_params=mock_run_params,
767
          sum_attributions=False)
768
    # The receptive field of hidden node indexed 0 on the padded image -
769
    # 1 1 0 0 0 0
770
    # 1 1 0 0 0 0
771
    # 0 0 0 0 0 0
772
    # 0 0 0 0 0 0
773
    # 0 0 0 0 0 0
774
    # 0 0 0 0 0 0
775
    #
776
    # The receptive field of hidden node indexed 21 on the padded image -
777
    # 0 0 0 0 0 0
778
    # 0 1 1 0 0 0
779
    # 0 1 1 0 0 0
780
    # 0 0 0 0 0 0
781
    # 0 0 0 0 0 0
782
    # 0 0 0 0 0 0
783
    #
784
    # The receptive field of hidden node indexed 10 on the padded image -
785
    # 0 0 0 0 0 0
786
    # 0 0 0 0 0 0
787
    # 0 0 1 1 0 0
788
    # 0 0 1 1 0 0
789
    # 0 0 0 0 0 0
790
    # 0 0 0 0 0 0
791
    #
792
    # Union of all these masks gives us the no minimisation mask.
793
    # Then, the padding (1, 1) is removed from the padded image and we the
794
    # output.
795
    # becomes -
796
    # 1 1 0 0
797
    # 1 1 1 0
798
    # 0 1 1 0
799
    # 0 0 0 0
800

801
    np.testing.assert_allclose(mask, [[1, 1, 0, 0],
802
                                      [1, 1, 1, 0],
803
                                      [0, 1, 1, 0],
804
                                      [0, 0, 0, 0]])
805

806
  def test_get_no_minimization_mask_text(self):
807
    mock_session = mock.MagicMock()
808
    mock_session.run.return_value = {
809
        # Every hidden node has a receptive field of 3
810
        'weights_layer_1': np.ones((1, 3, 10, 12)),
811
        'first_layer_relu': np.ones((1, 11, 1)),
812
        'first_layer': np.ones((1, 11, 1)),
813
    }
814
    mock_run_params = mock.MagicMock()
815
    mock_run_params.strides = 1
816
    mock_run_params.padding = (1, 2)
817
    mock_run_params.image_placeholder_shape = (1, 10)
818
    mock_run_params.model_type = 'text_cnn'
819
    with mock.patch.object(
820
        utils, 'restore_model',
821
        return_value=mock_session), mock.patch.object(
822
            masking, 'get_saliency_map',
823
            return_value=mock.MagicMock()), mock.patch.object(
824
                masking, '_reorder',
825
                return_value=np.ones(20)), mock.patch.object(
826
                    masking, '_sort_indices',
827
                    return_value=[0, 16]), mock.patch.object(
828
                        masking, '_remove_batch_axis',
829
                        return_value=mock.MagicMock()):
830
      mask = masking.get_no_minimization_mask(
831
          image=np.ones(10),
832
          label_index=0,
833
          top_k=4,
834
          run_params=mock_run_params,
835
          sum_attributions=False)
836
    # The receptive field of hidden node indexed 0 on the text -
837
    # 1 1 1 0 0 0 0 0 0 0 0 0 0
838
    # The receptive field of hidden node indexed 16
839
    # (activation map indexed 1, 5th position) on the text -
840
    # 0 0 0 0 0 1 1 1 0 0 0 0 0
841
    # We get the below result after taking their union and removing the padding.
842
    np.testing.assert_allclose(mask, [1, 1, 0, 0, 1, 1, 1, 0, 0, 0])
843

844
if __name__ == '__main__':
845
  absltest.main()
846

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.