google-research
531 строка · 20.4 Кб
1{
2"cells": [
3{
4"cell_type": "code",
5"execution_count": null,
6"metadata": {
7"id": "TwOwlnLuRZXu"
8},
9"outputs": [],
10"source": [
11"# Copyright 2021 The Google Research Authors.\n",
12"#\n",
13"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
14"# you may not use this file except in compliance with the License.\n",
15"# You may obtain a copy of the License at\n",
16"#\n",
17"# http://www.apache.org/licenses/LICENSE-2.0\n",
18"#\n",
19"# Unless required by applicable law or agreed to in writing, software\n",
20"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
21"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
22"# See the License for the specific language governing permissions and\n",
23"# limitations under the License."
24]
25},
26{
27"cell_type": "markdown",
28"metadata": {
29"id": "JaCqYOye9DLX"
30},
31"source": [
32"# NOTE\n",
33"#### Make sure that this notebook is using `smug` kernel\n",
34"#### we use Inception_v1 model from TF slim here. Our paper used a slightly different variant of this model without batch norm, so visualizations and results may differ. At the bottom, we also show examples for Inception_v3."
35]
36},
37{
38"cell_type": "code",
39"execution_count": null,
40"metadata": {
41"id": "iuaZMP90GCRd"
42},
43"outputs": [],
44"source": [
45"import os\n",
46"import matplotlib.pyplot as plt\n",
47"import numpy as np\n",
48"from PIL import Image\n",
49"import saliency\n",
50"import tensorflow.compat.v1 as tf\n",
51"import tensorflow_hub as hub\n",
52"import tf_slim as slim\n",
53"tf.disable_eager_execution()\n",
54"\n",
55"if not os.path.exists('models/research/slim'):\n",
56" !git clone https://github.com/tensorflow/models/\n",
57"\n",
58"if not os.path.exists('inception_v1_2016_08_28.tar.gz'):\n",
59" !wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz\n",
60" !tar -xvzf inception_v1_2016_08_28.tar.gz\n",
61"\n",
62"old_cwd = os.getcwd()\n",
63"os.chdir('models/research/slim')\n",
64"from nets import inception_v1\n",
65"os.chdir(old_cwd)\n",
66"\n",
67"os.chdir('../')\n",
68"from smug_saliency import masking\n",
69"from smug_saliency import utils\n",
70"os.chdir('smug_saliency/')"
71]
72},
73{
74"cell_type": "code",
75"execution_count": null,
76"metadata": {
77"id": "10LcAJ_2VjmN"
78},
79"outputs": [],
80"source": [
81"run_params_inception_v1 = masking.RunParams(**{\n",
82" 'model_type': 'cnn',\n",
83" \n",
84" # The following parameters pertain to the pre-trained model.\n",
85" \n",
86" # model_path is the path to the frozen tensorflow graph. It usually\n",
87" # has a '.pb' extension. To load such a graph utils.restore_model\n",
88" # function can be used. If a frozen model is unavailable then the\n",
89" # model_path is set to '' and a custom load_model function should\n",
90" # used for example restore_inception_v1 (below).\n",
91" 'model_path': '',\n",
92" 'image_placeholder_shape': (1, 224, 224, 3),\n",
93" 'padding': (2, 3),\n",
94" 'strides': 2,\n",
95" 'activations': None,\n",
96" # range of input pixel values expected by the model.\n",
97" 'pixel_range': (0, 1),\n",
98" # Find the appropriate tensornames by printing the tf ops using\n",
99" # restore_inception_v1.\n",
100" 'tensor_names': {\n",
101" 'input': 'Placeholder:0',\n",
102" 'first_layer': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Conv2D:0',\n",
103" 'first_layer_relu': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Relu:0',\n",
104" 'logits': 'InceptionV1/Logits/SpatialSqueeze:0',\n",
105" 'softmax': 'InceptionV1/Logits/Predictions/Softmax:0',\n",
106" 'weights_layer_1': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Conv2D/ReadVariableOp:0',\n",
107" }\n",
108"})\n",
109"\n",
110"def restore_inception_v1(model_path='./inception_v1.ckpt',\n",
111" print_ops=False):\n",
112" \"\"\"Restores a tensorflow model from a checkpoint and returns it.\n",
113"\n",
114" Args:\n",
115" model_path: string, path to a tensorflow frozen graph.\n",
116" print_ops: bool, prints operations in a tensorflow graph if true.\n",
117"\n",
118" Returns:\n",
119" session: tf.Session, tensorflow session with the loaded neural network.\n",
120" graph: tensorflow graph corresponding to the tensorflow session.\n",
121" \"\"\"\n",
122" graph = tf.Graph()\n",
123" with graph.as_default():\n",
124" images = tf.placeholder(tf.float32, shape=(None, 224, 224, 3))\n",
125" with slim.arg_scope(inception_v1.inception_v1_arg_scope()):\n",
126" _, end_points = inception_v1.inception_v1(images, is_training=False, num_classes=1001)\n",
127"\n",
128" # Restore the checkpoint\n",
129" session = tf.Session(graph=graph)\n",
130" saver = tf.train.Saver()\n",
131" saver.restore(session, model_path)\n",
132"\n",
133" # Find the appropriate tensornames by printing the tf ops.\n",
134" # These tensornames are required to construct run_params.\n",
135" if print_ops:\n",
136" for op in graph.get_operations():\n",
137" print(\"name:\", op.name)\n",
138" print('inputs:')\n",
139" for ip in op.inputs:\n",
140" print(ip)\n",
141" print('outputs:', op.outputs)\n",
142" print('----\\n')\n",
143" return session, graph"
144]
145},
146{
147"cell_type": "code",
148"execution_count": null,
149"metadata": {
150"id": "-EKQE-hrFUmV"
151},
152"outputs": [],
153"source": [
154"# Print the name of the tensors so as to construct\n",
155"# run_params_inception_v1.tensor_names\n",
156"restore_inception_v1(print_ops=True)"
157]
158},
159{
160"cell_type": "markdown",
161"metadata": {
162"id": "KCm2bFUG7mDr"
163},
164"source": [
165"#### Ensure that the first layer weights and biases are indeed correct"
166]
167},
168{
169"cell_type": "code",
170"execution_count": null,
171"metadata": {
172"id": "nhtrhm1OFUmj"
173},
174"outputs": [],
175"source": [
176"def verify_first_layer_conv_weights(run_params, restore_model):\n",
177" \"\"\"Performs convolution for the first layer using nested for loop\n",
178" and checks that this is equal to the first layer conv weights.\"\"\"\n",
179" image = utils.process_model_input(\n",
180" np.random.random(run_params.image_placeholder_shape[1:]),\n",
181" run_params.pixel_range)\n",
182" session, _ = restore_model()\n",
183" output_tensors = session.run(\n",
184" run_params.tensor_names,\n",
185" feed_dict={run_params.tensor_names['input']: [image]})\n",
186" session.close()\n",
187" if 'biases_layer_1' in run_params.tensor_names:\n",
188" kernel_biases = output_tensors['biases_layer_1']\n",
189" else:\n",
190" kernel_biases = np.zeros(output_tensors['weights_layer_1'].shape[-1])\n",
191" \n",
192" # Computes the convoluion using nested for loop.\n",
193" for_loop_convolution = utils.smt_convolution(\n",
194" input_activation_maps=np.moveaxis(image, -1, 0),\n",
195" kernels=output_tensors['weights_layer_1'],\n",
196" kernel_biases=kernel_biases,\n",
197" padding=run_params.padding,\n",
198" strides=run_params.strides)\n",
199" for_loop_convolution = np.moveaxis(\n",
200" np.array(for_loop_convolution), 0, -1)\n",
201" if np.mean(np.abs(output_tensors['first_layer'][0]\n",
202" - for_loop_convolution)) \u003e 1e-6:\n",
203" print('The supplied names of the tensors is wrong.')\n",
204" assert False\n",
205" else:\n",
206" print('Tensor names in run_params is consistent.')"
207]
208},
209{
210"cell_type": "code",
211"execution_count": null,
212"metadata": {
213"id": "mV60QtFXFUmj",
214"outputId": "96a0d8a8-91ad-461a-ece1-0f3702c50e66"
215},
216"outputs": [
217{
218"name": "stdout",
219"output_type": "stream",
220"text": [
221"Tensor names in run_params is consistent.\n"
222]
223}
224],
225"source": [
226"verify_first_layer_conv_weights(run_params_inception_v1,\n",
227" restore_inception_v1)"
228]
229},
230{
231"cell_type": "code",
232"execution_count": null,
233"metadata": {
234"id": "FqgvHJgRFUmk"
235},
236"outputs": [],
237"source": [
238"def _get_saliency_maps(image, run_params, restore_model,\n",
239" top_k=3000, window_size=3):\n",
240" tf.reset_default_graph()\n",
241" image = utils.process_model_input(image, run_params.pixel_range)\n",
242" restored_sess, restored_graph = restore_model()\n",
243" input_tensor = restored_graph.get_tensor_by_name(\n",
244" run_params.tensor_names['input'])\n",
245" label_index = np.argmax(restored_sess.run(\n",
246" run_params.tensor_names['softmax'],\n",
247" feed_dict={input_tensor: [image]}))\n",
248" ig_saliency_map = saliency.core.VisualizeImageGrayscale(\n",
249" masking.get_saliency_map(\n",
250" session=restored_sess,\n",
251" features=image,\n",
252" saliency_method='integrated_gradients',\n",
253" label=label_index,\n",
254" input_tensor_name=run_params.tensor_names['input'],\n",
255" output_tensor_name=run_params.tensor_names['softmax'],\n",
256" graph=restored_graph))\n",
257" restored_sess, restored_graph = restore_model()\n",
258" no_minimization_mask = utils.scale_saliency_map(\n",
259" masking.get_no_minimization_mask(\n",
260" image=image,\n",
261" label_index=label_index,\n",
262" run_params=run_params,\n",
263" top_k=top_k,\n",
264" session=restored_sess,\n",
265" graph=restored_graph),\n",
266" method='smug')\n",
267" restored_sess, restored_graph = restore_model()\n",
268" result = masking.find_mask_first_layer(\n",
269" image=image,\n",
270" label_index=label_index,\n",
271" run_params=run_params,\n",
272" window_size=window_size,\n",
273" score_method='integrated_gradients',\n",
274" top_k=top_k,\n",
275" gamma=0.0,\n",
276" timeout=3600,\n",
277" session=restored_sess,\n",
278" graph=restored_graph)\n",
279" smug_mask = result['masks'][0].reshape(\n",
280" run_params.image_placeholder_shape)[0, :, :, 0]\n",
281" return (smug_mask * no_minimization_mask, no_minimization_mask,\n",
282" ig_saliency_map)"
283]
284},
285{
286"cell_type": "code",
287"execution_count": null,
288"metadata": {
289"id": "OyfzQUFyFUmk"
290},
291"outputs": [],
292"source": [
293"def _get_saliency_params(image, saliency_map, run_params, restore_model):\n",
294" tf.reset_default_graph()\n",
295" session, _ = restore_model()\n",
296" saliency_score = utils.calculate_saliency_score(\n",
297" run_params=run_params,\n",
298" image=image,\n",
299" saliency_map=saliency_map,\n",
300" session=session)\n",
301" if saliency_score is None:\n",
302" return None, None\n",
303" return (saliency_score['saliency_score'],\n",
304" saliency_score['crop_mask'])"
305]
306},
307{
308"cell_type": "code",
309"execution_count": null,
310"metadata": {
311"id": "aJL5fQycFUmk"
312},
313"outputs": [],
314"source": [
315"def plot_saliency_maps(image, run_params, restore_model, window_size,\n",
316" show_bounding_box=False):\n",
317" smug_saliency, no_minimization_saliency, ig_saliency = _get_saliency_maps(\n",
318" image=image, restore_model=restore_model, run_params=run_params,\n",
319" window_size=window_size)\n",
320" smug_saliency_score, smug_crop_mask = _get_saliency_params(\n",
321" image, smug_saliency, run_params, restore_model)\n",
322" (no_minimization_saliency_score,\n",
323" no_minimization_crop_mask) = _get_saliency_params(\n",
324" image, no_minimization_saliency, run_params, restore_model)\n",
325" ig_saliency_score, ig_crop_mask = _get_saliency_params(\n",
326" image, ig_saliency, run_params, restore_model)\n",
327" if smug_saliency_score is None or no_minimization_saliency_score is None:\n",
328" return\n",
329" fig=plt.figure(figsize=(10, 10))\n",
330" fig.add_subplot(2, 2, 1)\n",
331" plt.imshow(image)\n",
332" plt.title('image')\n",
333" utils.remove_ticks()\n",
334"\n",
335" fig.add_subplot(2, 2, 2)\n",
336" plt.imshow(smug_saliency, cmap='RdBu_r') \n",
337" plt.title(f'SMUG score:{smug_saliency_score:.2f}')\n",
338" if show_bounding_box:\n",
339" utils.show_bounding_box(smug_crop_mask)\n",
340" utils.remove_ticks()\n",
341"\n",
342" fig.add_subplot(2, 2, 3)\n",
343" plt.imshow(no_minimization_saliency, cmap='RdBu_r') \n",
344" plt.title(f'SMUG_BASE score:{no_minimization_saliency_score:.2f}')\n",
345" if show_bounding_box:\n",
346" utils.show_bounding_box(no_minimization_crop_mask)\n",
347" utils.remove_ticks()\n",
348"\n",
349" fig.add_subplot(2, 2, 4)\n",
350" plt.imshow(ig_saliency, cmap='RdBu_r')\n",
351" plt.title(f'IG {ig_saliency_score:.2f}')\n",
352" if show_bounding_box:\n",
353" utils.show_bounding_box(ig_crop_mask)\n",
354" utils.remove_ticks()"
355]
356},
357{
358"cell_type": "code",
359"execution_count": null,
360"metadata": {
361"id": "jGje_qFhFUml"
362},
363"outputs": [],
364"source": [
365"image = np.array(Image.open(open('tabby.jpg', 'rb')))\n",
366"tabby = (255 * np.ones((299, 299, 3))).astype(int)\n",
367"tabby[:224, :224, :3] = image\n",
368"print(tabby.shape)\n",
369"plt.imshow(tabby)"
370]
371},
372{
373"cell_type": "code",
374"execution_count": null,
375"metadata": {
376"id": "p7zImFbMFUml"
377},
378"outputs": [],
379"source": [
380"plot_saliency_maps(tabby[:224, :224, :],\n",
381" run_params_inception_v1,\n",
382" restore_inception_v1,\n",
383" window_size=4)"
384]
385},
386{
387"cell_type": "markdown",
388"metadata": {
389"id": "2qtVR1bC9ndp"
390},
391"source": [
392"### Inception v3"
393]
394},
395{
396"cell_type": "code",
397"execution_count": null,
398"metadata": {
399"id": "42Momsw9FUmm",
400"scrolled": false
401},
402"outputs": [],
403"source": [
404"# Note that most of the IG attributions lie at the edge of the cat.\n",
405"# While SMUG and SMUG_BASE highlight the facial features of the cat.\n",
406"# This observation has been explained in sec. 5.2\n",
407"# of https://arxiv.org/pdf/2006.16322.pdf\n",
408"\n",
409"run_params_inception_v3 = masking.RunParams(**{\n",
410" 'model_path': '',\n",
411" 'image_placeholder_shape': (1, 299, 299, 3),\n",
412" 'model_type': 'cnn',\n",
413" 'padding': (0, 0),\n",
414" 'strides': 2,\n",
415" 'activations': None,\n",
416" 'pixel_range': (-1, 1),\n",
417" # Find the appropriate tensornames by printing the tf ops in\n",
418" # restore_inception_v3.\n",
419" 'tensor_names': {\n",
420" # Ideally the input tensor to inception v3 is 'module/hub_input/images:0'\n",
421" # Instead we choose the tensor 'module/hub_input/Sub:0' because\n",
422" # the input to the model has pixel values between (0, 1) and it\n",
423" # is scaled between (-1, 1) and fed to the subsequent network.\n",
424" # The scaled version of the image is denoted by the tensor\n",
425" # 'module/hub_input/Sub:0'. Because we utils.find_mask_first_layer\n",
426" # assumes that the convolution is performed directly on the input image\n",
427" # withoout any rescaling, we feed input to the network via \n",
428" # 'module/hub_input/Sub:0' tensor.\n",
429" 'input': 'module/hub_input/Sub:0',\n",
430" 'first_layer': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D:0',\n",
431" 'first_layer_relu': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Relu:0',\n",
432" 'logits': 'module/InceptionV3/Logits/SpatialSqueeze:0',\n",
433" 'softmax': 'module/InceptionV3/Predictions/Softmax:0',\n",
434" 'weights_layer_1': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D/ReadVariableOp:0',\n",
435" }\n",
436"})\n",
437"\n",
438"def restore_inception_v3(model_path=('https://tfhub.dev/google/imagenet/'\n",
439" 'inception_v3/classification/1'),\n",
440" print_ops=False):\n",
441" \"\"\"Restores a tensorflow model from a checkpoint and returns it.\n",
442"\n",
443" Args:\n",
444" model_path: string, path to a tensorflow frozen graph.\n",
445" print_ops: bool, prints operations in a tensorflow graph if true.\n",
446"\n",
447" Returns:\n",
448" session: tf.Session, tensorflow session with the loaded neural network.\n",
449" graph: tensorflow graph corresponding to the tensorflow session.\n",
450" \"\"\"\n",
451" graph = tf.Graph()\n",
452" session = tf.Session(graph=graph)\n",
453" with graph.as_default():\n",
454" hub.Module(model_path)\n",
455" session.run(tf.global_variables_initializer())\n",
456" session.run(tf.tables_initializer())\n",
457"\n",
458" # Find the appropriate tensornames by printing the tf ops.\n",
459" # These tensornames are required to construct run_params.\n",
460" if print_ops:\n",
461" for op in graph.get_operations():\n",
462" print(\"name:\", op.name)\n",
463" print('inputs:')\n",
464" for ip in op.inputs:\n",
465" print(ip)\n",
466" print('outputs:', op.outputs)\n",
467" print('----\\n')\n",
468" return session, graph\n",
469"\n",
470"restore_inception_v3(print_ops=True)"
471]
472},
473{
474"cell_type": "code",
475"execution_count": null,
476"metadata": {
477"id": "v7y3T_-hFUmm"
478},
479"outputs": [],
480"source": [
481"verify_first_layer_conv_weights(run_params_inception_v3,\n",
482" restore_inception_v3)"
483]
484},
485{
486"cell_type": "code",
487"execution_count": null,
488"metadata": {
489"id": "Uz77-0w29uPF"
490},
491"outputs": [],
492"source": [
493"plot_saliency_maps(tabby,\n",
494" run_params_inception_v3,\n",
495" restore_inception_v3,\n",
496" window_size=3)"
497]
498}
499],
500"metadata": {
501"colab": {
502"collapsed_sections": [],
503"name": "image_saliency.ipynb",
504"provenance": [
505{
506"file_id": "1AhdqlDFsFrs3ctHx-Mz2N03KPEKx5pVW",
507"timestamp": 1616167425955
508}
509]
510},
511"kernelspec": {
512"display_name": "Python 3",
513"language": "python",
514"name": "python3"
515},
516"language_info": {
517"codemirror_mode": {
518"name": "ipython",
519"version": 3
520},
521"file_extension": ".py",
522"mimetype": "text/x-python",
523"name": "python",
524"nbconvert_exporter": "python",
525"pygments_lexer": "ipython3",
526"version": "3.7.4"
527}
528},
529"nbformat": 4,
530"nbformat_minor": 0
531}
532