google-research

Форк
0
/
image_saliency.ipynb 
531 строка · 20.4 Кб
1
{
2
  "cells": [
3
    {
4
      "cell_type": "code",
5
      "execution_count": null,
6
      "metadata": {
7
        "id": "TwOwlnLuRZXu"
8
      },
9
      "outputs": [],
10
      "source": [
11
        "# Copyright 2021 The Google Research Authors.\n",
12
        "#\n",
13
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
14
        "# you may not use this file except in compliance with the License.\n",
15
        "# You may obtain a copy of the License at\n",
16
        "#\n",
17
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
18
        "#\n",
19
        "# Unless required by applicable law or agreed to in writing, software\n",
20
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
21
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
22
        "# See the License for the specific language governing permissions and\n",
23
        "# limitations under the License."
24
      ]
25
    },
26
    {
27
      "cell_type": "markdown",
28
      "metadata": {
29
        "id": "JaCqYOye9DLX"
30
      },
31
      "source": [
32
        "# NOTE\n",
33
        "#### Make sure that this notebook is using `smug` kernel\n",
34
        "#### we use Inception_v1 model from TF slim here. Our paper used a slightly different variant of this model without batch norm, so visualizations and results may differ. At the bottom, we also show examples for Inception_v3."
35
      ]
36
    },
37
    {
38
      "cell_type": "code",
39
      "execution_count": null,
40
      "metadata": {
41
        "id": "iuaZMP90GCRd"
42
      },
43
      "outputs": [],
44
      "source": [
45
        "import os\n",
46
        "import matplotlib.pyplot as plt\n",
47
        "import numpy as np\n",
48
        "from PIL import Image\n",
49
        "import saliency\n",
50
        "import tensorflow.compat.v1 as tf\n",
51
        "import tensorflow_hub as hub\n",
52
        "import tf_slim as slim\n",
53
        "tf.disable_eager_execution()\n",
54
        "\n",
55
        "if not os.path.exists('models/research/slim'):\n",
56
        "  !git clone https://github.com/tensorflow/models/\n",
57
        "\n",
58
        "if not os.path.exists('inception_v1_2016_08_28.tar.gz'):\n",
59
        "  !wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz\n",
60
        "  !tar -xvzf inception_v1_2016_08_28.tar.gz\n",
61
        "\n",
62
        "old_cwd = os.getcwd()\n",
63
        "os.chdir('models/research/slim')\n",
64
        "from nets import inception_v1\n",
65
        "os.chdir(old_cwd)\n",
66
        "\n",
67
        "os.chdir('../')\n",
68
        "from smug_saliency import masking\n",
69
        "from smug_saliency import utils\n",
70
        "os.chdir('smug_saliency/')"
71
      ]
72
    },
73
    {
74
      "cell_type": "code",
75
      "execution_count": null,
76
      "metadata": {
77
        "id": "10LcAJ_2VjmN"
78
      },
79
      "outputs": [],
80
      "source": [
81
        "run_params_inception_v1 = masking.RunParams(**{\n",
82
        "  'model_type': 'cnn',\n",
83
        "  \n",
84
        "  # The following parameters pertain to the pre-trained model.\n",
85
        "  \n",
86
        "  # model_path is the path to the frozen tensorflow graph. It usually\n",
87
        "  # has a '.pb' extension. To load such a graph utils.restore_model\n",
88
        "  # function can be used. If a frozen model is unavailable then the\n",
89
        "  # model_path is set to '' and a custom load_model function should\n",
90
        "  # used for example restore_inception_v1 (below).\n",
91
        "  'model_path': '',\n",
92
        "  'image_placeholder_shape': (1, 224, 224, 3),\n",
93
        "  'padding': (2, 3),\n",
94
        "  'strides': 2,\n",
95
        "  'activations': None,\n",
96
        "  # range of input pixel values expected by the model.\n",
97
        "  'pixel_range': (0, 1),\n",
98
        "  # Find the appropriate tensornames by printing the tf ops using\n",
99
        "  # restore_inception_v1.\n",
100
        "  'tensor_names': {\n",
101
        "    'input': 'Placeholder:0',\n",
102
        "    'first_layer': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Conv2D:0',\n",
103
        "    'first_layer_relu': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Relu:0',\n",
104
        "    'logits': 'InceptionV1/Logits/SpatialSqueeze:0',\n",
105
        "    'softmax': 'InceptionV1/Logits/Predictions/Softmax:0',\n",
106
        "    'weights_layer_1': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Conv2D/ReadVariableOp:0',\n",
107
        "  }\n",
108
        "})\n",
109
        "\n",
110
        "def restore_inception_v1(model_path='./inception_v1.ckpt',\n",
111
        "                         print_ops=False):\n",
112
        "  \"\"\"Restores a tensorflow model from a checkpoint and returns it.\n",
113
        "\n",
114
        "  Args:\n",
115
        "    model_path: string, path to a tensorflow frozen graph.\n",
116
        "    print_ops: bool, prints operations in a tensorflow graph if true.\n",
117
        "\n",
118
        "  Returns:\n",
119
        "    session: tf.Session, tensorflow session with the loaded neural network.\n",
120
        "    graph: tensorflow graph corresponding to the tensorflow session.\n",
121
        "  \"\"\"\n",
122
        "  graph = tf.Graph()\n",
123
        "  with graph.as_default():\n",
124
        "    images = tf.placeholder(tf.float32, shape=(None, 224, 224, 3))\n",
125
        "    with slim.arg_scope(inception_v1.inception_v1_arg_scope()):\n",
126
        "      _, end_points = inception_v1.inception_v1(images, is_training=False, num_classes=1001)\n",
127
        "\n",
128
        "      # Restore the checkpoint\n",
129
        "      session = tf.Session(graph=graph)\n",
130
        "      saver = tf.train.Saver()\n",
131
        "      saver.restore(session, model_path)\n",
132
        "\n",
133
        "  # Find the appropriate tensornames by printing the tf ops.\n",
134
        "  # These tensornames are required to construct run_params.\n",
135
        "  if print_ops:\n",
136
        "    for op in graph.get_operations():\n",
137
        "      print(\"name:\", op.name)\n",
138
        "      print('inputs:')\n",
139
        "      for ip in op.inputs:\n",
140
        "        print(ip)\n",
141
        "      print('outputs:', op.outputs)\n",
142
        "      print('----\\n')\n",
143
        "  return session, graph"
144
      ]
145
    },
146
    {
147
      "cell_type": "code",
148
      "execution_count": null,
149
      "metadata": {
150
        "id": "-EKQE-hrFUmV"
151
      },
152
      "outputs": [],
153
      "source": [
154
        "# Print the name of the tensors so as to construct\n",
155
        "# run_params_inception_v1.tensor_names\n",
156
        "restore_inception_v1(print_ops=True)"
157
      ]
158
    },
159
    {
160
      "cell_type": "markdown",
161
      "metadata": {
162
        "id": "KCm2bFUG7mDr"
163
      },
164
      "source": [
165
        "#### Ensure that the first layer weights and biases are indeed correct"
166
      ]
167
    },
168
    {
169
      "cell_type": "code",
170
      "execution_count": null,
171
      "metadata": {
172
        "id": "nhtrhm1OFUmj"
173
      },
174
      "outputs": [],
175
      "source": [
176
        "def verify_first_layer_conv_weights(run_params, restore_model):\n",
177
        "  \"\"\"Performs convolution for the first layer using nested for loop\n",
178
        "  and checks that this is equal to the first layer conv weights.\"\"\"\n",
179
        "  image = utils.process_model_input(\n",
180
        "    np.random.random(run_params.image_placeholder_shape[1:]),\n",
181
        "    run_params.pixel_range)\n",
182
        "  session, _ = restore_model()\n",
183
        "  output_tensors = session.run(\n",
184
        "    run_params.tensor_names,\n",
185
        "    feed_dict={run_params.tensor_names['input']: [image]})\n",
186
        "  session.close()\n",
187
        "  if 'biases_layer_1' in run_params.tensor_names:\n",
188
        "    kernel_biases = output_tensors['biases_layer_1']\n",
189
        "  else:\n",
190
        "    kernel_biases = np.zeros(output_tensors['weights_layer_1'].shape[-1])\n",
191
        "  \n",
192
        "  # Computes the convoluion using nested for loop.\n",
193
        "  for_loop_convolution = utils.smt_convolution(\n",
194
        "      input_activation_maps=np.moveaxis(image, -1, 0),\n",
195
        "      kernels=output_tensors['weights_layer_1'],\n",
196
        "      kernel_biases=kernel_biases,\n",
197
        "      padding=run_params.padding,\n",
198
        "      strides=run_params.strides)\n",
199
        "  for_loop_convolution = np.moveaxis(\n",
200
        "      np.array(for_loop_convolution), 0, -1)\n",
201
        "  if np.mean(np.abs(output_tensors['first_layer'][0]\n",
202
        "                    - for_loop_convolution)) \u003e 1e-6:\n",
203
        "    print('The supplied names of the tensors is wrong.')\n",
204
        "    assert False\n",
205
        "  else:\n",
206
        "    print('Tensor names in run_params is consistent.')"
207
      ]
208
    },
209
    {
210
      "cell_type": "code",
211
      "execution_count": null,
212
      "metadata": {
213
        "id": "mV60QtFXFUmj",
214
        "outputId": "96a0d8a8-91ad-461a-ece1-0f3702c50e66"
215
      },
216
      "outputs": [
217
        {
218
          "name": "stdout",
219
          "output_type": "stream",
220
          "text": [
221
            "Tensor names in run_params is consistent.\n"
222
          ]
223
        }
224
      ],
225
      "source": [
226
        "verify_first_layer_conv_weights(run_params_inception_v1,\n",
227
        "                                restore_inception_v1)"
228
      ]
229
    },
230
    {
231
      "cell_type": "code",
232
      "execution_count": null,
233
      "metadata": {
234
        "id": "FqgvHJgRFUmk"
235
      },
236
      "outputs": [],
237
      "source": [
238
        "def _get_saliency_maps(image, run_params, restore_model,\n",
239
        "                       top_k=3000, window_size=3):\n",
240
        "    tf.reset_default_graph()\n",
241
        "    image = utils.process_model_input(image, run_params.pixel_range)\n",
242
        "    restored_sess, restored_graph = restore_model()\n",
243
        "    input_tensor = restored_graph.get_tensor_by_name(\n",
244
        "        run_params.tensor_names['input'])\n",
245
        "    label_index = np.argmax(restored_sess.run(\n",
246
        "        run_params.tensor_names['softmax'],\n",
247
        "        feed_dict={input_tensor: [image]}))\n",
248
        "    ig_saliency_map = saliency.core.VisualizeImageGrayscale(\n",
249
        "        masking.get_saliency_map(\n",
250
        "            session=restored_sess,\n",
251
        "            features=image,\n",
252
        "            saliency_method='integrated_gradients',\n",
253
        "            label=label_index,\n",
254
        "            input_tensor_name=run_params.tensor_names['input'],\n",
255
        "            output_tensor_name=run_params.tensor_names['softmax'],\n",
256
        "            graph=restored_graph))\n",
257
        "    restored_sess, restored_graph = restore_model()\n",
258
        "    no_minimization_mask = utils.scale_saliency_map(\n",
259
        "        masking.get_no_minimization_mask(\n",
260
        "            image=image,\n",
261
        "            label_index=label_index,\n",
262
        "            run_params=run_params,\n",
263
        "            top_k=top_k,\n",
264
        "            session=restored_sess,\n",
265
        "            graph=restored_graph),\n",
266
        "        method='smug')\n",
267
        "    restored_sess, restored_graph = restore_model()\n",
268
        "    result = masking.find_mask_first_layer(\n",
269
        "        image=image,\n",
270
        "        label_index=label_index,\n",
271
        "        run_params=run_params,\n",
272
        "        window_size=window_size,\n",
273
        "        score_method='integrated_gradients',\n",
274
        "        top_k=top_k,\n",
275
        "        gamma=0.0,\n",
276
        "        timeout=3600,\n",
277
        "        session=restored_sess,\n",
278
        "        graph=restored_graph)\n",
279
        "    smug_mask = result['masks'][0].reshape(\n",
280
        "      run_params.image_placeholder_shape)[0, :, :, 0]\n",
281
        "    return (smug_mask * no_minimization_mask, no_minimization_mask,\n",
282
        "            ig_saliency_map)"
283
      ]
284
    },
285
    {
286
      "cell_type": "code",
287
      "execution_count": null,
288
      "metadata": {
289
        "id": "OyfzQUFyFUmk"
290
      },
291
      "outputs": [],
292
      "source": [
293
        "def _get_saliency_params(image, saliency_map, run_params, restore_model):\n",
294
        "  tf.reset_default_graph()\n",
295
        "  session, _ = restore_model()\n",
296
        "  saliency_score = utils.calculate_saliency_score(\n",
297
        "    run_params=run_params,\n",
298
        "    image=image,\n",
299
        "    saliency_map=saliency_map,\n",
300
        "    session=session)\n",
301
        "  if saliency_score is None:\n",
302
        "    return None, None\n",
303
        "  return (saliency_score['saliency_score'],\n",
304
        "          saliency_score['crop_mask'])"
305
      ]
306
    },
307
    {
308
      "cell_type": "code",
309
      "execution_count": null,
310
      "metadata": {
311
        "id": "aJL5fQycFUmk"
312
      },
313
      "outputs": [],
314
      "source": [
315
        "def plot_saliency_maps(image, run_params, restore_model, window_size,\n",
316
        "                       show_bounding_box=False):\n",
317
        "    smug_saliency, no_minimization_saliency, ig_saliency = _get_saliency_maps(\n",
318
        "      image=image, restore_model=restore_model, run_params=run_params,\n",
319
        "      window_size=window_size)\n",
320
        "    smug_saliency_score, smug_crop_mask = _get_saliency_params(\n",
321
        "        image, smug_saliency, run_params, restore_model)\n",
322
        "    (no_minimization_saliency_score,\n",
323
        "     no_minimization_crop_mask) = _get_saliency_params(\n",
324
        "        image, no_minimization_saliency, run_params, restore_model)\n",
325
        "    ig_saliency_score, ig_crop_mask = _get_saliency_params(\n",
326
        "        image, ig_saliency, run_params, restore_model)\n",
327
        "    if smug_saliency_score is None or no_minimization_saliency_score is None:\n",
328
        "        return\n",
329
        "    fig=plt.figure(figsize=(10, 10))\n",
330
        "    fig.add_subplot(2, 2, 1)\n",
331
        "    plt.imshow(image)\n",
332
        "    plt.title('image')\n",
333
        "    utils.remove_ticks()\n",
334
        "\n",
335
        "    fig.add_subplot(2, 2, 2)\n",
336
        "    plt.imshow(smug_saliency, cmap='RdBu_r')    \n",
337
        "    plt.title(f'SMUG score:{smug_saliency_score:.2f}')\n",
338
        "    if show_bounding_box:\n",
339
        "      utils.show_bounding_box(smug_crop_mask)\n",
340
        "    utils.remove_ticks()\n",
341
        "\n",
342
        "    fig.add_subplot(2, 2, 3)\n",
343
        "    plt.imshow(no_minimization_saliency, cmap='RdBu_r')    \n",
344
        "    plt.title(f'SMUG_BASE score:{no_minimization_saliency_score:.2f}')\n",
345
        "    if show_bounding_box:\n",
346
        "      utils.show_bounding_box(no_minimization_crop_mask)\n",
347
        "    utils.remove_ticks()\n",
348
        "\n",
349
        "    fig.add_subplot(2, 2, 4)\n",
350
        "    plt.imshow(ig_saliency, cmap='RdBu_r')\n",
351
        "    plt.title(f'IG {ig_saliency_score:.2f}')\n",
352
        "    if show_bounding_box:\n",
353
        "      utils.show_bounding_box(ig_crop_mask)\n",
354
        "    utils.remove_ticks()"
355
      ]
356
    },
357
    {
358
      "cell_type": "code",
359
      "execution_count": null,
360
      "metadata": {
361
        "id": "jGje_qFhFUml"
362
      },
363
      "outputs": [],
364
      "source": [
365
        "image = np.array(Image.open(open('tabby.jpg', 'rb')))\n",
366
        "tabby = (255 * np.ones((299, 299, 3))).astype(int)\n",
367
        "tabby[:224, :224, :3] = image\n",
368
        "print(tabby.shape)\n",
369
        "plt.imshow(tabby)"
370
      ]
371
    },
372
    {
373
      "cell_type": "code",
374
      "execution_count": null,
375
      "metadata": {
376
        "id": "p7zImFbMFUml"
377
      },
378
      "outputs": [],
379
      "source": [
380
        "plot_saliency_maps(tabby[:224, :224, :],\n",
381
        "                   run_params_inception_v1,\n",
382
        "                   restore_inception_v1,\n",
383
        "                   window_size=4)"
384
      ]
385
    },
386
    {
387
      "cell_type": "markdown",
388
      "metadata": {
389
        "id": "2qtVR1bC9ndp"
390
      },
391
      "source": [
392
        "### Inception v3"
393
      ]
394
    },
395
    {
396
      "cell_type": "code",
397
      "execution_count": null,
398
      "metadata": {
399
        "id": "42Momsw9FUmm",
400
        "scrolled": false
401
      },
402
      "outputs": [],
403
      "source": [
404
        "# Note that most of the IG attributions lie at the edge of the cat.\n",
405
        "# While SMUG and SMUG_BASE highlight the facial features of the cat.\n",
406
        "# This observation has been explained in sec. 5.2\n",
407
        "# of https://arxiv.org/pdf/2006.16322.pdf\n",
408
        "\n",
409
        "run_params_inception_v3 = masking.RunParams(**{\n",
410
        "  'model_path': '',\n",
411
        "  'image_placeholder_shape': (1, 299, 299, 3),\n",
412
        "  'model_type': 'cnn',\n",
413
        "  'padding': (0, 0),\n",
414
        "  'strides': 2,\n",
415
        "  'activations': None,\n",
416
        "  'pixel_range': (-1, 1),\n",
417
        "  # Find the appropriate tensornames by printing the tf ops in\n",
418
        "  # restore_inception_v3.\n",
419
        "  'tensor_names': {\n",
420
        "    # Ideally the input tensor to inception v3 is 'module/hub_input/images:0'\n",
421
        "    # Instead we choose the tensor 'module/hub_input/Sub:0' because\n",
422
        "    # the input to the model has pixel values between (0, 1) and it\n",
423
        "    # is scaled between (-1, 1) and fed to the subsequent network.\n",
424
        "    # The scaled version of the image is denoted by the tensor\n",
425
        "    # 'module/hub_input/Sub:0'. Because we utils.find_mask_first_layer\n",
426
        "    # assumes that the convolution is performed directly on the input image\n",
427
        "    # withoout any rescaling, we feed input to the network via \n",
428
        "    # 'module/hub_input/Sub:0' tensor.\n",
429
        "    'input': 'module/hub_input/Sub:0',\n",
430
        "    'first_layer': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D:0',\n",
431
        "    'first_layer_relu': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Relu:0',\n",
432
        "    'logits': 'module/InceptionV3/Logits/SpatialSqueeze:0',\n",
433
        "    'softmax': 'module/InceptionV3/Predictions/Softmax:0',\n",
434
        "    'weights_layer_1': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D/ReadVariableOp:0',\n",
435
        "  }\n",
436
        "})\n",
437
        "\n",
438
        "def restore_inception_v3(model_path=('https://tfhub.dev/google/imagenet/'\n",
439
        "                                     'inception_v3/classification/1'),\n",
440
        "                         print_ops=False):\n",
441
        "  \"\"\"Restores a tensorflow model from a checkpoint and returns it.\n",
442
        "\n",
443
        "  Args:\n",
444
        "    model_path: string, path to a tensorflow frozen graph.\n",
445
        "    print_ops: bool, prints operations in a tensorflow graph if true.\n",
446
        "\n",
447
        "  Returns:\n",
448
        "    session: tf.Session, tensorflow session with the loaded neural network.\n",
449
        "    graph: tensorflow graph corresponding to the tensorflow session.\n",
450
        "  \"\"\"\n",
451
        "  graph = tf.Graph()\n",
452
        "  session = tf.Session(graph=graph)\n",
453
        "  with graph.as_default():\n",
454
        "    hub.Module(model_path)\n",
455
        "    session.run(tf.global_variables_initializer())\n",
456
        "    session.run(tf.tables_initializer())\n",
457
        "\n",
458
        "  # Find the appropriate tensornames by printing the tf ops.\n",
459
        "  # These tensornames are required to construct run_params.\n",
460
        "  if print_ops:\n",
461
        "    for op in graph.get_operations():\n",
462
        "      print(\"name:\", op.name)\n",
463
        "      print('inputs:')\n",
464
        "      for ip in op.inputs:\n",
465
        "        print(ip)\n",
466
        "      print('outputs:', op.outputs)\n",
467
        "      print('----\\n')\n",
468
        "  return session, graph\n",
469
        "\n",
470
        "restore_inception_v3(print_ops=True)"
471
      ]
472
    },
473
    {
474
      "cell_type": "code",
475
      "execution_count": null,
476
      "metadata": {
477
        "id": "v7y3T_-hFUmm"
478
      },
479
      "outputs": [],
480
      "source": [
481
        "verify_first_layer_conv_weights(run_params_inception_v3,\n",
482
        "                                restore_inception_v3)"
483
      ]
484
    },
485
    {
486
      "cell_type": "code",
487
      "execution_count": null,
488
      "metadata": {
489
        "id": "Uz77-0w29uPF"
490
      },
491
      "outputs": [],
492
      "source": [
493
        "plot_saliency_maps(tabby,\n",
494
        "                   run_params_inception_v3,\n",
495
        "                   restore_inception_v3,\n",
496
        "                   window_size=3)"
497
      ]
498
    }
499
  ],
500
  "metadata": {
501
    "colab": {
502
      "collapsed_sections": [],
503
      "name": "image_saliency.ipynb",
504
      "provenance": [
505
        {
506
          "file_id": "1AhdqlDFsFrs3ctHx-Mz2N03KPEKx5pVW",
507
          "timestamp": 1616167425955
508
        }
509
      ]
510
    },
511
    "kernelspec": {
512
      "display_name": "Python 3",
513
      "language": "python",
514
      "name": "python3"
515
    },
516
    "language_info": {
517
      "codemirror_mode": {
518
        "name": "ipython",
519
        "version": 3
520
      },
521
      "file_extension": ".py",
522
      "mimetype": "text/x-python",
523
      "name": "python",
524
      "nbconvert_exporter": "python",
525
      "pygments_lexer": "ipython3",
526
      "version": "3.7.4"
527
    }
528
  },
529
  "nbformat": 4,
530
  "nbformat_minor": 0
531
}
532

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.