google-research

Форк
0
/
float_part201_inference.ipynb 
2459 строк · 102.4 Кб
1
{
2
  "cells": [
3
    {
4
      "cell_type": "code",
5
      "execution_count": null,
6
      "metadata": {
7
        "id": "ZMj2kRWO6S6s"
8
      },
9
      "outputs": [],
10
      "source": [
11
        "#@title Copyright 2022 Google LLC, licensed under the Apache License, Version 2.0 (the \"License\")\n",
12
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
13
        "# you may not use this file except in compliance with the License.\n",
14
        "# You may obtain a copy of the License at\n",
15
        "#\n",
16
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
17
        "#\n",
18
        "# Unless required by applicable law or agreed to in writing, software\n",
19
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
20
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
21
        "# See the License for the specific language governing permissions and\n",
22
        "# limitations under the License."
23
      ]
24
    },
25
    {
26
      "cell_type": "code",
27
      "execution_count": null,
28
      "metadata": {
29
        "id": "6zx9fUOP6UcS"
30
      },
31
      "outputs": [],
32
      "source": [
33
        "import glob \n",
34
        "import math\n",
35
        "import sys\n",
36
        "import os\n",
37
        "import cv2\n",
38
        "import glob\n",
39
        "import numpy as np\n",
40
        "import pickle\n",
41
        "import matplotlib.pylab as plt\n",
42
        "import time\n",
43
        "import random\n",
44
        "import math\n",
45
        "import collections\n",
46
        "import queue\n",
47
        "import collections\n",
48
        "import threading\n",
49
        "import functools\n",
50
        "from tqdm.notebook import tqdm\n",
51
        "from typing import Dict, Type, Any, Callable, Union, List, Optional\n",
52
        "import matplotlib.pyplot as plt\n",
53
        "import seaborn as sns\n",
54
        "\n",
55
        "import torch\n",
56
        "from torch import nn\n",
57
        "import torch.optim as optim\n",
58
        "from torch.nn import functional as F\n",
59
        "from torch.autograd import Variable\n",
60
        "from torch.utils.data import Dataset, DataLoader\n",
61
        "import torch.backends.cudnn as cudnn\n",
62
        "from torchinfo import summary\n",
63
        "import torch.utils.model_zoo as model_zoo\n",
64
        "from torch.nn.parallel.data_parallel import DataParallel\n",
65
        "from torch.nn.modules.batchnorm import _BatchNorm\n",
66
        "from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast"
67
      ]
68
    },
69
    {
70
      "cell_type": "markdown",
71
      "metadata": {
72
        "id": "r_ek4Nti6E6t"
73
      },
74
      "source": [
75
        "**Synchorised Batch Norm**\n",
76
        "Citation : https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n",
77
        "Implements a Synchronised Batch Norm for distributed learning (using DataParallel here)."
78
      ]
79
    },
80
    {
81
      "cell_type": "code",
82
      "execution_count": null,
83
      "metadata": {
84
        "id": "mQhAV42V6E6v"
85
      },
86
      "outputs": [],
87
      "source": [
88
        "class FutureResult(object):\n",
89
        "    \"\"\"A thread-safe future implementation. Used only as one-to-one pipe.\"\"\"\n",
90
        "\n",
91
        "    def __init__(self):\n",
92
        "        self._result = None\n",
93
        "        self._lock = threading.Lock()\n",
94
        "        self._cond = threading.Condition(self._lock)\n",
95
        "\n",
96
        "    def put(self, result):\n",
97
        "        with self._lock:\n",
98
        "            assert self._result is None, 'Previous result has\\'t been fetched.'\n",
99
        "            self._result = result\n",
100
        "            self._cond.notify()\n",
101
        "\n",
102
        "    def get(self):\n",
103
        "        with self._lock:\n",
104
        "            if self._result is None:\n",
105
        "                self._cond.wait()\n",
106
        "\n",
107
        "            res = self._result\n",
108
        "            self._result = None\n",
109
        "            return res\n",
110
        "\n",
111
        "\n",
112
        "_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])\n",
113
        "_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])\n",
114
        "\n",
115
        "\n",
116
        "class SlavePipe(_SlavePipeBase):\n",
117
        "    \"\"\"Pipe for master-slave communication.\"\"\"\n",
118
        "\n",
119
        "    def run_slave(self, msg):\n",
120
        "        self.queue.put((self.identifier, msg))\n",
121
        "        ret = self.result.get()\n",
122
        "        self.queue.put(True)\n",
123
        "        return ret\n",
124
        "\n",
125
        "\n",
126
        "class SyncMaster(object):\n",
127
        "    \"\"\"An abstract `SyncMaster` object.\n",
128
        "    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should\n",
129
        "    call `register(id)` and obtain an `SlavePipe` to communicate with the master.\n",
130
        "    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,\n",
131
        "    and passed to a registered callback.\n",
132
        "    - After receiving the messages, the master device should gather the information and determine to message passed\n",
133
        "    back to each slave devices.\n",
134
        "    \"\"\"\n",
135
        "\n",
136
        "    def __init__(self, master_callback):\n",
137
        "        \"\"\"\n",
138
        "        Args:\n",
139
        "            master_callback: a callback to be invoked after having collected messages from slave devices.\n",
140
        "        \"\"\"\n",
141
        "        self._master_callback = master_callback\n",
142
        "        self._queue = queue.Queue()\n",
143
        "        self._registry = collections.OrderedDict()\n",
144
        "        self._activated = False\n",
145
        "\n",
146
        "    def __getstate__(self):\n",
147
        "        return {'master_callback': self._master_callback}\n",
148
        "\n",
149
        "    def __setstate__(self, state):\n",
150
        "        self.__init__(state['master_callback'])\n",
151
        "\n",
152
        "    def register_slave(self, identifier):\n",
153
        "        \"\"\"\n",
154
        "        Register an slave device.\n",
155
        "        Args:\n",
156
        "            identifier: an identifier, usually is the device id.\n",
157
        "        Returns: a `SlavePipe` object which can be used to communicate with the master device.\n",
158
        "        \"\"\"\n",
159
        "        if self._activated:\n",
160
        "            assert self._queue.empty(), 'Queue is not clean before next initialization.'\n",
161
        "            self._activated = False\n",
162
        "            self._registry.clear()\n",
163
        "        future = FutureResult()\n",
164
        "        self._registry[identifier] = _MasterRegistry(future)\n",
165
        "        return SlavePipe(identifier, self._queue, future)\n",
166
        "\n",
167
        "    def run_master(self, master_msg):\n",
168
        "        \"\"\"\n",
169
        "        Main entry for the master device in each forward pass.\n",
170
        "        The messages were first collected from each devices (including the master device), and then\n",
171
        "        an callback will be invoked to compute the message to be sent back to each devices\n",
172
        "        (including the master device).\n",
173
        "        Args:\n",
174
        "            master_msg: the message that the master want to send to itself. This will be placed as the first\n",
175
        "            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.\n",
176
        "        Returns: the message to be sent back to the master device.\n",
177
        "        \"\"\"\n",
178
        "        self._activated = True\n",
179
        "\n",
180
        "        intermediates = [(0, master_msg)]\n",
181
        "        for i in range(self.nr_slaves):\n",
182
        "            intermediates.append(self._queue.get())\n",
183
        "\n",
184
        "        results = self._master_callback(intermediates)\n",
185
        "        assert results[0][0] == 0, 'The first result should belongs to the master.'\n",
186
        "\n",
187
        "        for i, res in results:\n",
188
        "            if i == 0:\n",
189
        "                continue\n",
190
        "            self._registry[i].result.put(res)\n",
191
        "\n",
192
        "        for i in range(self.nr_slaves):\n",
193
        "            assert self._queue.get() is True\n",
194
        "\n",
195
        "        return results[0][1]\n",
196
        "\n",
197
        "    @property\n",
198
        "    def nr_slaves(self):\n",
199
        "        return len(self._registry)"
200
      ]
201
    },
202
    {
203
      "cell_type": "code",
204
      "execution_count": null,
205
      "metadata": {
206
        "id": "9g7iu4p06E6y"
207
      },
208
      "outputs": [],
209
      "source": [
210
        "def _sum_ft(tensor):\n",
211
        "    \"\"\"sum over the first and last dimention\"\"\"\n",
212
        "    return tensor.sum(dim=0).sum(dim=-1)\n",
213
        "\n",
214
        "\n",
215
        "def _unsqueeze_ft(tensor):\n",
216
        "    \"\"\"add new dementions at the front and the tail\"\"\"\n",
217
        "    return tensor.unsqueeze(0).unsqueeze(-1)\n",
218
        "\n",
219
        "\n",
220
        "_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])\n",
221
        "_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])\n",
222
        "\n",
223
        "\n",
224
        "class _SynchronizedBatchNorm(_BatchNorm):\n",
225
        "    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):\n",
226
        "        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)\n",
227
        "\n",
228
        "        self._sync_master = SyncMaster(self._data_parallel_master)\n",
229
        "\n",
230
        "        self._is_parallel = False\n",
231
        "        self._parallel_id = None\n",
232
        "        self._slave_pipe = None\n",
233
        "\n",
234
        "    def forward(self, input):\n",
235
        "        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.\n",
236
        "        if not (self._is_parallel and self.training):\n",
237
        "            return F.batch_norm(\n",
238
        "                input, self.running_mean, self.running_var, self.weight, self.bias,\n",
239
        "                self.training, self.momentum, self.eps)\n",
240
        "\n",
241
        "        # Resize the input to (B, C, -1).\n",
242
        "        input_shape = input.size()\n",
243
        "        input = input.view(input.size(0), self.num_features, -1)\n",
244
        "\n",
245
        "        # Compute the sum and square-sum.\n",
246
        "        sum_size = input.size(0) * input.size(2)\n",
247
        "        input_sum = _sum_ft(input)\n",
248
        "        input_ssum = _sum_ft(input ** 2)\n",
249
        "\n",
250
        "        # Reduce-and-broadcast the statistics.\n",
251
        "        if self._parallel_id == 0:\n",
252
        "            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))\n",
253
        "        else:\n",
254
        "            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))\n",
255
        "\n",
256
        "        # Compute the output.\n",
257
        "        if self.affine:\n",
258
        "            # MJY:: Fuse the multiplication for speed.\n",
259
        "            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)\n",
260
        "        else:\n",
261
        "            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)\n",
262
        "\n",
263
        "        # Reshape it.\n",
264
        "        return output.view(input_shape)\n",
265
        "\n",
266
        "    def __data_parallel_replicate__(self, ctx, copy_id):\n",
267
        "        self._is_parallel = True\n",
268
        "        self._parallel_id = copy_id\n",
269
        "\n",
270
        "        # parallel_id == 0 means master device.\n",
271
        "        if self._parallel_id == 0:\n",
272
        "            ctx.sync_master = self._sync_master\n",
273
        "        else:\n",
274
        "            self._slave_pipe = ctx.sync_master.register_slave(copy_id)\n",
275
        "\n",
276
        "    def _data_parallel_master(self, intermediates):\n",
277
        "        \"\"\"Reduce the sum and square-sum, compute the statistics, and broadcast it.\"\"\"\n",
278
        "\n",
279
        "        # Always using same \"device order\" makes the ReduceAdd operation faster.\n",
280
        "        # Thanks to:: Tete Xiao (http://tetexiao.com/)\n",
281
        "        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())\n",
282
        "\n",
283
        "        to_reduce = [i[1][:2] for i in intermediates]\n",
284
        "        to_reduce = [j for i in to_reduce for j in i]  # flatten\n",
285
        "        target_gpus = [i[1].sum.get_device() for i in intermediates]\n",
286
        "\n",
287
        "        sum_size = sum([i[1].sum_size for i in intermediates])\n",
288
        "        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)\n",
289
        "        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)\n",
290
        "\n",
291
        "        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)\n",
292
        "\n",
293
        "        outputs = []\n",
294
        "        for i, rec in enumerate(intermediates):\n",
295
        "            outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))\n",
296
        "\n",
297
        "        return outputs\n",
298
        "\n",
299
        "    def _compute_mean_std(self, sum_, ssum, size):\n",
300
        "        \"\"\"Compute the mean and standard-deviation with sum and square-sum. This method\n",
301
        "        also maintains the moving average on the master device.\"\"\"\n",
302
        "        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'\n",
303
        "        mean = sum_ / size\n",
304
        "        sumvar = ssum - sum_ * mean\n",
305
        "        unbias_var = sumvar / (size - 1)\n",
306
        "        bias_var = sumvar / size\n",
307
        "\n",
308
        "        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data\n",
309
        "        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data\n",
310
        "\n",
311
        "        return mean, bias_var.clamp(self.eps) ** -0.5"
312
      ]
313
    },
314
    {
315
      "cell_type": "code",
316
      "execution_count": null,
317
      "metadata": {
318
        "id": "AqukuKRu6E6y"
319
      },
320
      "outputs": [],
321
      "source": [
322
        "class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):\n",
323
        "    r\"\"\"Applies Batch Normalization over a 4d input that is seen as a mini-batch\n",
324
        "    of 3d inputs\n",
325
        "    .. math::\n",
326
        "        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n",
327
        "    This module differs from the built-in PyTorch BatchNorm2d as the mean and\n",
328
        "    standard-deviation are reduced across all devices during training.\n",
329
        "    For example, when one uses `nn.DataParallel` to wrap the network during\n",
330
        "    training, PyTorch's implementation normalize the tensor on each device using\n",
331
        "    the statistics only on that device, which accelerated the computation and\n",
332
        "    is also easy to implement, but the statistics might be inaccurate.\n",
333
        "    Instead, in this synchronized version, the statistics will be computed\n",
334
        "    over all training samples distributed on multiple devices.\n",
335
        "    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n",
336
        "    as the built-in PyTorch implementation.\n",
337
        "    The mean and standard-deviation are calculated per-dimension over\n",
338
        "    the mini-batches and gamma and beta are learnable parameter vectors\n",
339
        "    of size C (where C is the input size).\n",
340
        "    During training, this layer keeps a running estimate of its computed mean\n",
341
        "    and variance. The running sum is kept with a default momentum of 0.1.\n",
342
        "    During evaluation, this running mean/variance is used for normalization.\n",
343
        "    Because the BatchNorm is done over the `C` dimension, computing statistics\n",
344
        "    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm\n",
345
        "    Args:\n",
346
        "        num_features: num_features from an expected input of\n",
347
        "            size batch_size x num_features x height x width\n",
348
        "        eps: a value added to the denominator for numerical stability.\n",
349
        "            Default: 1e-5\n",
350
        "        momentum: the value used for the running_mean and running_var\n",
351
        "            computation. Default: 0.1\n",
352
        "        affine: a boolean value that when set to ``True``, gives the layer learnable\n",
353
        "            affine parameters. Default: ``True``\n",
354
        "    Shape:\n",
355
        "        - Input: :math:`(N, C, H, W)`\n",
356
        "        - Output: :math:`(N, C, H, W)` (same shape as input)\n",
357
        "    Examples:\n",
358
        "        >>> # With Learnable Parameters\n",
359
        "        >>> m = SynchronizedBatchNorm2d(100)\n",
360
        "        >>> # Without Learnable Parameters\n",
361
        "        >>> m = SynchronizedBatchNorm2d(100, affine=False)\n",
362
        "        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))\n",
363
        "        >>> output = m(input)\n",
364
        "    \"\"\"\n",
365
        "\n",
366
        "    def _check_input_dim(self, input):\n",
367
        "        if input.dim() != 4:\n",
368
        "            raise ValueError('expected 4D input (got {}D input)'\n",
369
        "                             .format(input.dim()))\n",
370
        "        super(SynchronizedBatchNorm2d, self)._check_input_dim(input)"
371
      ]
372
    },
373
    {
374
      "cell_type": "code",
375
      "execution_count": null,
376
      "metadata": {
377
        "id": "3g7Gclzw6um6"
378
      },
379
      "outputs": [],
380
      "source": [
381
        "\"\"\" \n",
382
        "For handling DataParallel with Synchorisned Batch Norm.\n",
383
        "\"\"\"\n",
384
        "class CallbackContext(object):\n",
385
        "    pass\n",
386
        "\n",
387
        "class DataParallelWithCallback(DataParallel):\n",
388
        "    \"\"\"\n",
389
        "    Data Parallel with a replication callback.\n",
390
        "    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by\n",
391
        "    original `replicate` function.\n",
392
        "    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n",
393
        "    Examples:\n",
394
        "        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n",
395
        "        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n",
396
        "        # sync_bn.__data_parallel_replicate__ will be invoked.\n",
397
        "    \"\"\"\n",
398
        "\n",
399
        "    def replicate(self, module, device_ids):\n",
400
        "        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)\n",
401
        "        execute_replication_callbacks(modules)\n",
402
        "        return modules\n",
403
        "\n",
404
        "\n",
405
        "def execute_replication_callbacks(modules):\n",
406
        "    \"\"\"\n",
407
        "    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.\n",
408
        "    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n",
409
        "    Note that, as all modules are isomorphism, we assign each sub-module with a context\n",
410
        "    (shared among multiple copies of this module on different devices).\n",
411
        "    Through this context, different copies can share some information.\n",
412
        "    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback\n",
413
        "    of any slave copies.\n",
414
        "    \"\"\"\n",
415
        "    master_copy = modules[0]\n",
416
        "    nr_modules = len(list(master_copy.modules()))\n",
417
        "    ctxs = [CallbackContext() for _ in range(nr_modules)]\n",
418
        "\n",
419
        "    for i, module in enumerate(modules):\n",
420
        "        for j, m in enumerate(module.modules()):\n",
421
        "            if hasattr(m, '__data_parallel_replicate__'):\n",
422
        "                m.__data_parallel_replicate__(ctxs[j], i)\n",
423
        "\n",
424
        "\n",
425
        "def patch_replication_callback(data_parallel):\n",
426
        "    \"\"\"\n",
427
        "    Monkey-patch an existing `DataParallel` object. Add the replication callback.\n",
428
        "    Useful when you have customized `DataParallel` implementation.\n",
429
        "    Examples:\n",
430
        "        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n",
431
        "        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])\n",
432
        "        > patch_replication_callback(sync_bn)\n",
433
        "        # this is equivalent to\n",
434
        "        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n",
435
        "        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n",
436
        "    \"\"\"\n",
437
        "\n",
438
        "    assert isinstance(data_parallel, DataParallel)\n",
439
        "\n",
440
        "    old_replicate = data_parallel.replicate\n",
441
        "\n",
442
        "    @functools.wraps(old_replicate)\n",
443
        "    def new_replicate(module, device_ids):\n",
444
        "        modules = old_replicate(module, device_ids)\n",
445
        "        execute_replication_callbacks(modules)\n",
446
        "        return modules\n",
447
        "\n",
448
        "    data_parallel.replicate = new_replicate"
449
      ]
450
    },
451
    {
452
      "cell_type": "markdown",
453
      "metadata": {
454
        "id": "8ULz8sId6E60"
455
      },
456
      "source": [
457
        "**ResNet**\n",
458
        "Basic Implementation of ResNet models."
459
      ]
460
    },
461
    {
462
      "cell_type": "code",
463
      "execution_count": null,
464
      "metadata": {
465
        "id": "X8X-o58z6E61"
466
      },
467
      "outputs": [],
468
      "source": [
469
        "class Bottleneck(nn.Module):\n",
470
        "    expansion = 4\n",
471
        "\n",
472
        "    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):\n",
473
        "        super(Bottleneck, self).__init__()\n",
474
        "        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n",
475
        "        self.bn1 = BatchNorm(planes)\n",
476
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n",
477
        "                               dilation=dilation, padding=dilation, bias=False)\n",
478
        "        self.bn2 = BatchNorm(planes)\n",
479
        "        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n",
480
        "        self.bn3 = BatchNorm(planes * 4)\n",
481
        "        self.relu = nn.ReLU(inplace=True)\n",
482
        "        self.downsample = downsample\n",
483
        "        self.stride = stride\n",
484
        "        self.dilation = dilation\n",
485
        "\n",
486
        "    def forward(self, x):\n",
487
        "        residual = x\n",
488
        "\n",
489
        "        out = self.conv1(x)\n",
490
        "        out = self.bn1(out)\n",
491
        "        out = self.relu(out)\n",
492
        "\n",
493
        "        out = self.conv2(out)\n",
494
        "        out = self.bn2(out)\n",
495
        "        out = self.relu(out)\n",
496
        "\n",
497
        "        out = self.conv3(out)\n",
498
        "        out = self.bn3(out)\n",
499
        "\n",
500
        "        if self.downsample is not None:\n",
501
        "            residual = self.downsample(x)\n",
502
        "\n",
503
        "        out += residual\n",
504
        "        out = self.relu(out)\n",
505
        "\n",
506
        "        return out\n",
507
        "\n",
508
        "class ResNet(nn.Module):\n",
509
        "\n",
510
        "    def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):\n",
511
        "        self.inplanes = 64\n",
512
        "        super(ResNet, self).__init__()\n",
513
        "        blocks = [1, 2, 4]\n",
514
        "        if output_stride == 16:\n",
515
        "            strides = [1, 2, 2, 1]\n",
516
        "            dilations = [1, 1, 1, 2]\n",
517
        "        elif output_stride == 8:\n",
518
        "            strides = [1, 2, 1, 1]\n",
519
        "            dilations = [1, 1, 2, 4]\n",
520
        "        else:\n",
521
        "            raise NotImplementedError\n",
522
        "\n",
523
        "        # Modules\n",
524
        "        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,\n",
525
        "                                bias=False)\n",
526
        "        self.bn1 = BatchNorm(64)\n",
527
        "        self.relu = nn.ReLU(inplace=True)\n",
528
        "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
529
        "\n",
530
        "        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)\n",
531
        "        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)\n",
532
        "        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)\n",
533
        "        self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)\n",
534
        "        self._init_weight()\n",
535
        "\n",
536
        "        if pretrained:\n",
537
        "            self._load_pretrained_model()\n",
538
        "\n",
539
        "    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):\n",
540
        "        downsample = None\n",
541
        "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
542
        "            downsample = nn.Sequential(\n",
543
        "                nn.Conv2d(self.inplanes, planes * block.expansion,\n",
544
        "                          kernel_size=1, stride=stride, bias=False),\n",
545
        "                BatchNorm(planes * block.expansion),\n",
546
        "            )\n",
547
        "\n",
548
        "        layers = []\n",
549
        "        layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))\n",
550
        "        self.inplanes = planes * block.expansion\n",
551
        "        for i in range(1, blocks):\n",
552
        "            layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))\n",
553
        "\n",
554
        "        return nn.Sequential(*layers)\n",
555
        "\n",
556
        "    def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):\n",
557
        "        downsample = None\n",
558
        "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
559
        "            downsample = nn.Sequential(\n",
560
        "                nn.Conv2d(self.inplanes, planes * block.expansion,\n",
561
        "                          kernel_size=1, stride=stride, bias=False),\n",
562
        "                BatchNorm(planes * block.expansion),\n",
563
        "            )\n",
564
        "\n",
565
        "        layers = []\n",
566
        "        layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,\n",
567
        "                            downsample=downsample, BatchNorm=BatchNorm))\n",
568
        "        self.inplanes = planes * block.expansion\n",
569
        "        for i in range(1, len(blocks)):\n",
570
        "            layers.append(block(self.inplanes, planes, stride=1,\n",
571
        "                                dilation=blocks[i]*dilation, BatchNorm=BatchNorm))\n",
572
        "\n",
573
        "        return nn.Sequential(*layers)\n",
574
        "\n",
575
        "    def forward(self, input):\n",
576
        "        x = self.conv1(input)\n",
577
        "        x = self.bn1(x)\n",
578
        "        x = self.relu(x)\n",
579
        "        x = self.maxpool(x)\n",
580
        "\n",
581
        "        x = self.layer1(x)\n",
582
        "        low_level_feat = x\n",
583
        "        x = self.layer2(x)\n",
584
        "        x = self.layer3(x)\n",
585
        "        x = self.layer4(x)\n",
586
        "        return x, low_level_feat\n",
587
        "\n",
588
        "    def _init_weight(self):\n",
589
        "        for m in self.modules():\n",
590
        "            if isinstance(m, nn.Conv2d):\n",
591
        "                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
592
        "                m.weight.data.normal_(0, math.sqrt(2. / n))\n",
593
        "            elif isinstance(m, SynchronizedBatchNorm2d):\n",
594
        "                m.weight.data.fill_(1)\n",
595
        "                m.bias.data.zero_()\n",
596
        "            elif isinstance(m, nn.BatchNorm2d):\n",
597
        "                m.weight.data.fill_(1)\n",
598
        "                m.bias.data.zero_()\n",
599
        "\n",
600
        "    def _load_pretrained_model(self):\n",
601
        "        pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')\n",
602
        "        model_dict = {}\n",
603
        "        state_dict = self.state_dict()\n",
604
        "        for k, v in pretrain_dict.items():\n",
605
        "            if k in state_dict:\n",
606
        "                model_dict[k] = v\n",
607
        "        state_dict.update(model_dict)\n",
608
        "        self.load_state_dict(state_dict)\n",
609
        "\n",
610
        "def ResNet101(output_stride, BatchNorm, pretrained=True):\n",
611
        "    \"\"\"Constructs a ResNet-101 model.\n",
612
        "    Args:\n",
613
        "        pretrained (bool): If True, returns a model pre-trained on ImageNet\n",
614
        "    \"\"\"\n",
615
        "    model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)\n",
616
        "    return model\n",
617
        "\n",
618
        "def ResNet50(output_stride, BatchNorm, pretrained=True):\n",
619
        "    \"\"\"Constructs a ResNet-101 model.\n",
620
        "    Args:\n",
621
        "        pretrained (bool): If True, returns a model pre-trained on ImageNet\n",
622
        "    \"\"\"\n",
623
        "    model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, pretrained=pretrained)\n",
624
        "    return model"
625
      ]
626
    },
627
    {
628
      "cell_type": "markdown",
629
      "metadata": {
630
        "id": "H4o1Qb0O66Gu"
631
      },
632
      "source": [
633
        "**FLOAT**\n",
634
        "Inspired from DeepLabV3+"
635
      ]
636
    },
637
    {
638
      "cell_type": "markdown",
639
      "metadata": {
640
        "id": "SwO2XzsH6E61"
641
      },
642
      "source": [
643
        "**Backbone**\n",
644
        "Function returning ResNet backbone. Modifiable to use other backbones."
645
      ]
646
    },
647
    {
648
      "cell_type": "code",
649
      "execution_count": null,
650
      "metadata": {
651
        "id": "QGGl0W2u6E62"
652
      },
653
      "outputs": [],
654
      "source": [
655
        "def build_backbone(backbone, output_stride, BatchNorm):\n",
656
        "    if backbone == 'resnet101':\n",
657
        "        return ResNet101(output_stride, BatchNorm)\n",
658
        "    elif backbone == 'resnet50':\n",
659
        "        return ResNet50(output_stride, BatchNorm)\n",
660
        "    else:\n",
661
        "        raise NotImplementedError"
662
      ]
663
    },
664
    {
665
      "cell_type": "markdown",
666
      "metadata": {
667
        "id": "vxL-7au16E63"
668
      },
669
      "source": [
670
        "**ASPP**\n",
671
        "Atrous Spatial Pyramid Pooling module from DeepLabV3+"
672
      ]
673
    },
674
    {
675
      "cell_type": "code",
676
      "execution_count": null,
677
      "metadata": {
678
        "id": "OlaBlUYB6E63"
679
      },
680
      "outputs": [],
681
      "source": [
682
        "class _ASPPModule(nn.Module):\n",
683
        "    def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):\n",
684
        "        super(_ASPPModule, self).__init__()\n",
685
        "        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,\n",
686
        "                                            stride=1, padding=padding, dilation=dilation, bias=False)\n",
687
        "        self.bn = BatchNorm(planes)\n",
688
        "        self.relu = nn.ReLU()\n",
689
        "\n",
690
        "        self._init_weight()\n",
691
        "\n",
692
        "    def forward(self, x):\n",
693
        "        x = self.atrous_conv(x)\n",
694
        "        x = self.bn(x)\n",
695
        "\n",
696
        "        return self.relu(x)\n",
697
        "\n",
698
        "    def _init_weight(self):\n",
699
        "        for m in self.modules():\n",
700
        "            if isinstance(m, nn.Conv2d):\n",
701
        "                torch.nn.init.kaiming_normal_(m.weight)\n",
702
        "            elif isinstance(m, SynchronizedBatchNorm2d):\n",
703
        "                m.weight.data.fill_(1)\n",
704
        "                m.bias.data.zero_()\n",
705
        "            elif isinstance(m, nn.BatchNorm2d):\n",
706
        "                m.weight.data.fill_(1)\n",
707
        "                m.bias.data.zero_()\n",
708
        "\n",
709
        "class ASPP(nn.Module):\n",
710
        "    def __init__(self, backbone, output_stride, BatchNorm):\n",
711
        "        super(ASPP, self).__init__()\n",
712
        "        if backbone == 'drn':\n",
713
        "            inplanes = 512\n",
714
        "        elif backbone == 'mobilenet':\n",
715
        "            inplanes = 320\n",
716
        "        else:\n",
717
        "            inplanes = 2048\n",
718
        "        if output_stride == 16:\n",
719
        "            dilations = [1, 6, 12, 18]\n",
720
        "        elif output_stride == 8:\n",
721
        "            dilations = [1, 12, 24, 36]\n",
722
        "        else:\n",
723
        "            raise NotImplementedError\n",
724
        "\n",
725
        "        self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)\n",
726
        "        self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)\n",
727
        "        self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)\n",
728
        "        self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)\n",
729
        "\n",
730
        "        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),\n",
731
        "                                             nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),\n",
732
        "                                             BatchNorm(256),\n",
733
        "                                             nn.ReLU())\n",
734
        "        self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)\n",
735
        "        self.bn1 = BatchNorm(256)\n",
736
        "        self.relu = nn.ReLU()\n",
737
        "        self.dropout = nn.Dropout(0.5)\n",
738
        "        self._init_weight()\n",
739
        "\n",
740
        "    def forward(self, x):\n",
741
        "        x1 = self.aspp1(x)\n",
742
        "        x2 = self.aspp2(x)\n",
743
        "        x3 = self.aspp3(x)\n",
744
        "        x4 = self.aspp4(x)\n",
745
        "        x5 = self.global_avg_pool(x)\n",
746
        "        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)\n",
747
        "        x = torch.cat((x1, x2, x3, x4, x5), dim=1)\n",
748
        "\n",
749
        "        x = self.conv1(x)\n",
750
        "        x = self.bn1(x)\n",
751
        "        x = self.relu(x)\n",
752
        "\n",
753
        "        return self.dropout(x)\n",
754
        "\n",
755
        "    def _init_weight(self):\n",
756
        "        for m in self.modules():\n",
757
        "            if isinstance(m, nn.Conv2d):\n",
758
        "                torch.nn.init.kaiming_normal_(m.weight)\n",
759
        "            elif isinstance(m, SynchronizedBatchNorm2d):\n",
760
        "                m.weight.data.fill_(1)\n",
761
        "                m.bias.data.zero_()\n",
762
        "            elif isinstance(m, nn.BatchNorm2d):\n",
763
        "                m.weight.data.fill_(1)\n",
764
        "                m.bias.data.zero_()\n",
765
        "\n",
766
        "\n",
767
        "def build_aspp(backbone, output_stride, BatchNorm):\n",
768
        "    return ASPP(backbone, output_stride, BatchNorm)"
769
      ]
770
    },
771
    {
772
      "cell_type": "markdown",
773
      "metadata": {
774
        "id": "pcwtCAK96E64"
775
      },
776
      "source": [
777
        "**Decoder**\n",
778
        "Decoder from DeepLabV3+"
779
      ]
780
    },
781
    {
782
      "cell_type": "code",
783
      "execution_count": null,
784
      "metadata": {
785
        "id": "UPblrkMC6E65"
786
      },
787
      "outputs": [],
788
      "source": [
789
        "class Decoder(nn.Module):\n",
790
        "    def __init__(self, num_classes, backbone, BatchNorm):\n",
791
        "        super(Decoder, self).__init__()\n",
792
        "        if backbone == 'resnet101' or backbone == 'resnet50' or backbone == 'drn':\n",
793
        "            low_level_inplanes = 256\n",
794
        "        elif backbone == 'xception':\n",
795
        "            low_level_inplanes = 128\n",
796
        "        elif backbone == 'mobilenet':\n",
797
        "            low_level_inplanes = 24\n",
798
        "        else:\n",
799
        "            raise NotImplementedError\n",
800
        "\n",
801
        "        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)\n",
802
        "        self.bn1 = BatchNorm(48)\n",
803
        "        self.relu = nn.ReLU()\n",
804
        "        self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),\n",
805
        "                                       BatchNorm(256),\n",
806
        "                                       nn.ReLU(),\n",
807
        "                                       nn.Dropout(0.5),\n",
808
        "                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),\n",
809
        "                                       BatchNorm(256),\n",
810
        "                                       nn.ReLU(),\n",
811
        "                                       nn.Dropout(0.1),\n",
812
        "                                       nn.Conv2d(256, num_classes, kernel_size=1, stride=1))\n",
813
        "        self._init_weight()\n",
814
        "\n",
815
        "\n",
816
        "    def forward(self, x, low_level_feat):\n",
817
        "        low_level_feat = self.conv1(low_level_feat)\n",
818
        "        low_level_feat = self.bn1(low_level_feat)\n",
819
        "        low_level_feat = self.relu(low_level_feat)\n",
820
        "\n",
821
        "        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)\n",
822
        "        x = torch.cat((x, low_level_feat), dim=1)\n",
823
        "        x = self.last_conv(x)\n",
824
        "\n",
825
        "        return x\n",
826
        "\n",
827
        "    def _init_weight(self):\n",
828
        "        for m in self.modules():\n",
829
        "            if isinstance(m, nn.Conv2d):\n",
830
        "                torch.nn.init.kaiming_normal_(m.weight)\n",
831
        "            elif isinstance(m, SynchronizedBatchNorm2d):\n",
832
        "                m.weight.data.fill_(1)\n",
833
        "                m.bias.data.zero_()\n",
834
        "            elif isinstance(m, nn.BatchNorm2d):\n",
835
        "                m.weight.data.fill_(1)\n",
836
        "                m.bias.data.zero_()\n",
837
        "\n",
838
        "def build_decoder(num_classes, backbone, BatchNorm):\n",
839
        "    return Decoder(num_classes, backbone, BatchNorm)\n"
840
      ]
841
    },
842
    {
843
      "cell_type": "markdown",
844
      "metadata": {},
845
      "source": [
846
        "\n",
847
        "Standard DeepLabV3 : for object level semantic segmentation."
848
      ]
849
    },
850
    {
851
      "cell_type": "code",
852
      "execution_count": null,
853
      "metadata": {
854
        "id": "bJT1YAmZ6E66"
855
      },
856
      "outputs": [],
857
      "source": [
858
        "class DeepLab(nn.Module):\n",
859
        "    def __init__(self, backbone='resnet50', output_stride=16, num_classes=21,\n",
860
        "                 sync_bn=True, freeze_bn=False):\n",
861
        "        super(DeepLab, self).__init__()\n",
862
        "        if backbone == 'drn':\n",
863
        "            output_stride = 8\n",
864
        "\n",
865
        "        if sync_bn == True:\n",
866
        "            BatchNorm = SynchronizedBatchNorm2d\n",
867
        "        else:\n",
868
        "            BatchNorm = nn.BatchNorm2d\n",
869
        "\n",
870
        "        self.backbone = build_backbone(backbone, output_stride, BatchNorm)\n",
871
        "        self.aspp = build_aspp(backbone, output_stride, BatchNorm)\n",
872
        "        self.decoder = build_decoder(num_classes, backbone, BatchNorm)\n",
873
        "\n",
874
        "        self.freeze_bn = freeze_bn\n",
875
        "\n",
876
        "    def forward(self, input):\n",
877
        "        x, low_level_feat = self.backbone(input)\n",
878
        "        x = self.aspp(x)\n",
879
        "        x = self.decoder(x, low_level_feat)\n",
880
        "        x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)\n",
881
        "\n",
882
        "        return x\n",
883
        "\n",
884
        "    def freeze_bn(self):\n",
885
        "        for m in self.modules():\n",
886
        "            if isinstance(m, SynchronizedBatchNorm2d):\n",
887
        "                m.eval()\n",
888
        "            elif isinstance(m, nn.BatchNorm2d):\n",
889
        "                m.eval()\n",
890
        "\n",
891
        "    def get_1x_lr_params(self):\n",
892
        "        modules = [self.backbone]\n",
893
        "        for i in range(len(modules)):\n",
894
        "            for m in modules[i].named_modules():\n",
895
        "                if self.freeze_bn:\n",
896
        "                    if isinstance(m[1], nn.Conv2d):\n",
897
        "                        for p in m[1].parameters():\n",
898
        "                            if p.requires_grad:\n",
899
        "                                yield p\n",
900
        "                else:\n",
901
        "                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \\\n",
902
        "                            or isinstance(m[1], nn.BatchNorm2d):\n",
903
        "                        for p in m[1].parameters():\n",
904
        "                            if p.requires_grad:\n",
905
        "                                yield p\n",
906
        "\n",
907
        "    def get_10x_lr_params(self):\n",
908
        "        modules = [self.aspp, self.decoder]\n",
909
        "        for i in range(len(modules)):\n",
910
        "            for m in modules[i].named_modules():\n",
911
        "                if self.freeze_bn:\n",
912
        "                    if isinstance(m[1], nn.Conv2d):\n",
913
        "                        for p in m[1].parameters():\n",
914
        "                            if p.requires_grad:\n",
915
        "                                yield p\n",
916
        "                else:\n",
917
        "                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \\\n",
918
        "                            or isinstance(m[1], nn.BatchNorm2d):\n",
919
        "                        for p in m[1].parameters():\n",
920
        "                            if p.requires_grad:\n",
921
        "                                yield p"
922
      ]
923
    },
924
    {
925
      "cell_type": "markdown",
926
      "metadata": {
927
        "id": "RKXPoK917HBw"
928
      },
929
      "source": [
930
        "**Float**\n",
931
        "Using DeepLabV3+'s modules."
932
      ]
933
    },
934
    {
935
      "cell_type": "code",
936
      "execution_count": null,
937
      "metadata": {
938
        "id": "96qXQ-wu6E67"
939
      },
940
      "outputs": [],
941
      "source": [
942
        "class Float(nn.Module):\n",
943
        "    def __init__(self, num_anim_classes, num_inanim_classes, backbone='resnet101', output_stride=16,\n",
944
        "                 sync_bn=True, freeze_bn=False):\n",
945
        "        super(Float, self).__init__()\n",
946
        "        if backbone == 'drn':\n",
947
        "            output_stride = 8\n",
948
        "\n",
949
        "        if sync_bn == True:\n",
950
        "            BatchNorm = SynchronizedBatchNorm2d\n",
951
        "        else:\n",
952
        "            BatchNorm = nn.BatchNorm2d\n",
953
        "\n",
954
        "        self.backbone = build_backbone(backbone, output_stride, BatchNorm)\n",
955
        "\n",
956
        "        self.anim_aspp = build_aspp(backbone, output_stride, BatchNorm)\n",
957
        "        self.anim_decoder = build_decoder(num_anim_classes, backbone, BatchNorm)\n",
958
        "\n",
959
        "        self.inanim_aspp = build_aspp(backbone, output_stride, BatchNorm)\n",
960
        "        self.inanim_decoder = build_decoder(num_inanim_classes, backbone, BatchNorm)\n",
961
        "\n",
962
        "        self.lrfb_aspp = build_aspp(backbone, output_stride, BatchNorm)\n",
963
        "        self.lr_decoder = build_decoder(3, backbone, BatchNorm)\n",
964
        "        self.fb_decoder = build_decoder(3, backbone, BatchNorm)\n",
965
        "\n",
966
        "        self.freeze_bn = freeze_bn\n",
967
        "\n",
968
        "    def forward(self, input):\n",
969
        "        x, low_level_feat = self.backbone(input)\n",
970
        "        anim_x = self.anim_aspp(x)\n",
971
        "        anim_x = self.anim_decoder(anim_x, low_level_feat)\n",
972
        "        anim_x = F.interpolate(anim_x, size=input.size()[2:], mode='bilinear', align_corners=True)\n",
973
        "\n",
974
        "        inanim_x = self.inanim_aspp(x)\n",
975
        "        inanim_x = self.inanim_decoder(inanim_x, low_level_feat)\n",
976
        "        inanim_x = F.interpolate(inanim_x, size=input.size()[2:], mode='bilinear', align_corners=True)\n",
977
        "\n",
978
        "        lrfb_x = self.lrfb_aspp(x)\n",
979
        "        lr = self.lr_decoder(lrfb_x, low_level_feat)\n",
980
        "        lr = F.interpolate(lr, size=input.size()[2:], mode='bilinear', align_corners=True)\n",
981
        "\n",
982
        "        fb = self.fb_decoder(lrfb_x, low_level_feat)\n",
983
        "        fb = F.interpolate(fb, size=input.size()[2:], mode='bilinear', align_corners=True)            \n",
984
        "        return anim_x, inanim_x, lr, fb\n",
985
        "\n",
986
        "    def freeze_bn(self):\n",
987
        "        for m in self.modules():\n",
988
        "            if isinstance(m, SynchronizedBatchNorm2d):\n",
989
        "                m.eval()\n",
990
        "            elif isinstance(m, nn.BatchNorm2d):\n",
991
        "                m.eval()\n",
992
        "\n",
993
        "    def get_1x_lr_params(self):\n",
994
        "        modules = [self.backbone]\n",
995
        "        for i in range(len(modules)):\n",
996
        "            for m in modules[i].named_modules():\n",
997
        "                if self.freeze_bn:\n",
998
        "                    if isinstance(m[1], nn.Conv2d):\n",
999
        "                        for p in m[1].parameters():\n",
1000
        "                            if p.requires_grad:\n",
1001
        "                                yield p\n",
1002
        "                else:\n",
1003
        "                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \\\n",
1004
        "                            or isinstance(m[1], nn.BatchNorm2d):\n",
1005
        "                        for p in m[1].parameters():\n",
1006
        "                            if p.requires_grad:\n",
1007
        "                                yield p\n",
1008
        "\n",
1009
        "    def get_10x_lr_params(self):\n",
1010
        "        modules = [self.anim_aspp, self.inanim_aspp, self.lrfb_aspp, self.anim_decoder,\n",
1011
        "                   self.inanim_decoder, self.lr_decoder, self.fb_decoder]\n",
1012
        "        for i in range(len(modules)):\n",
1013
        "            for m in modules[i].named_modules():\n",
1014
        "                if self.freeze_bn:\n",
1015
        "                    if isinstance(m[1], nn.Conv2d):\n",
1016
        "                        for p in m[1].parameters():\n",
1017
        "                            if p.requires_grad:\n",
1018
        "                                yield p\n",
1019
        "                else:\n",
1020
        "                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \\\n",
1021
        "                            or isinstance(m[1], nn.BatchNorm2d):\n",
1022
        "                        for p in m[1].parameters():\n",
1023
        "                            if p.requires_grad:\n",
1024
        "                                yield p"
1025
      ]
1026
    },
1027
    {
1028
      "cell_type": "markdown",
1029
      "metadata": {
1030
        "id": "uGMosACT7Nm3"
1031
      },
1032
      "source": [
1033
        "**Metric calculator functions**"
1034
      ]
1035
    },
1036
    {
1037
      "cell_type": "code",
1038
      "execution_count": null,
1039
      "metadata": {
1040
        "id": "lTVL3Vb56E69"
1041
      },
1042
      "outputs": [],
1043
      "source": [
1044
        "class AverageMeter(object):\n",
1045
        "    \"\"\"Used for updatable average loss computation.\"\"\"\n",
1046
        "    def __init__(self):\n",
1047
        "        self.reset()\n",
1048
        "\n",
1049
        "    def reset(self):\n",
1050
        "        self.val = 0\n",
1051
        "        self.avg = 0\n",
1052
        "        self.sum = 0\n",
1053
        "        self.count = 0\n",
1054
        "\n",
1055
        "    def update(self, val, n=1):\n",
1056
        "        self.val = val\n",
1057
        "        self.sum += val * n\n",
1058
        "        self.count += n\n",
1059
        "        self.avg = self.sum / self.count"
1060
      ]
1061
    },
1062
    {
1063
      "cell_type": "code",
1064
      "execution_count": null,
1065
      "metadata": {
1066
        "id": "9dOLf_bI6E6-"
1067
      },
1068
      "outputs": [],
1069
      "source": [
1070
        "class sqIOUMeter(object):\n",
1071
        "    \"\"\"Used for updatable sqIOU style average (per class) calcuation.\"\"\"\n",
1072
        "    def __init__(self, n_classes):\n",
1073
        "        self.n_classes = n_classes\n",
1074
        "        self.vals = {}\n",
1075
        "        self.counts = {}\n",
1076
        "        for i in range(self.n_classes):\n",
1077
        "            self.vals[i] = 0\n",
1078
        "            self.counts[i] = 0\n",
1079
        "\n",
1080
        "    def update(self, val_d, count_d):\n",
1081
        "        sqiou = []\n",
1082
        "        for i in range(self.n_classes):\n",
1083
        "            self.vals[i] += val_d[i]\n",
1084
        "            self.counts[i] += count_d[i]\n",
1085
        "            if self.counts[i] > 0:\n",
1086
        "                sqiou.append(self.vals[i] / self.counts[i])\n",
1087
        "\n",
1088
        "        self.avg = np.mean(sqiou)"
1089
      ]
1090
    },
1091
    {
1092
      "cell_type": "code",
1093
      "execution_count": null,
1094
      "metadata": {
1095
        "id": "XSAwo2FG6E6-"
1096
      },
1097
      "outputs": [],
1098
      "source": [
1099
        "\"\"\"Used for updatated mIOU calcuation among other metrics.\"\"\"\n",
1100
        "class Evaluator(object):\n",
1101
        "    def __init__(self, num_class):\n",
1102
        "        self.num_class = num_class\n",
1103
        "        self.confusion_matrix = np.zeros((self.num_class,)*2)\n",
1104
        "        \n",
1105
        "    def set_confusion_matrix(self, conf_mat):\n",
1106
        "        self.confusion_matrix = np.copy(conf_mat)\n",
1107
        "\n",
1108
        "    def Pixel_Accuracy(self):\n",
1109
        "        Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()\n",
1110
        "        return Acc\n",
1111
        "\n",
1112
        "    def Pixel_Accuracy_Class(self):\n",
1113
        "        Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)\n",
1114
        "        Acc = np.nanmean(Acc)\n",
1115
        "        return Acc\n",
1116
        "\n",
1117
        "    def Mean_Intersection_over_Union(self):\n",
1118
        "        MIoU = np.diag(self.confusion_matrix) / (\n",
1119
        "                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -\n",
1120
        "                    np.diag(self.confusion_matrix))\n",
1121
        "        MIoU = np.nanmean(MIoU)\n",
1122
        "        return MIoU\n",
1123
        "    \n",
1124
        "    def Mean_Intersection_over_Union_PerClass(self):\n",
1125
        "        MIoU = np.diag(self.confusion_matrix) / (\n",
1126
        "                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -\n",
1127
        "                    np.diag(self.confusion_matrix))\n",
1128
        "        return MIoU\n",
1129
        "\n",
1130
        "    def Frequency_Weighted_Intersection_over_Union(self):\n",
1131
        "        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)\n",
1132
        "        iu = np.diag(self.confusion_matrix) / (\n",
1133
        "                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -\n",
1134
        "                    np.diag(self.confusion_matrix))\n",
1135
        "\n",
1136
        "        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()\n",
1137
        "        return FWIoU\n",
1138
        "\n",
1139
        "    def _generate_matrix(self, gt_image, pre_image):\n",
1140
        "        mask = (gt_image >= 0) & (gt_image < self.num_class)\n",
1141
        "        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]\n",
1142
        "        count = np.bincount(label, minlength=self.num_class**2)\n",
1143
        "        confusion_matrix = count.reshape(self.num_class, self.num_class)\n",
1144
        "        return confusion_matrix\n",
1145
        "\n",
1146
        "    def add_batch(self, gt_image, pre_image):\n",
1147
        "        assert gt_image.shape == pre_image.shape\n",
1148
        "        self.confusion_matrix += self._generate_matrix(gt_image, pre_image)\n",
1149
        "\n",
1150
        "    def reset(self):\n",
1151
        "        self.confusion_matrix = np.zeros((self.num_class,) * 2)"
1152
      ]
1153
    },
1154
    {
1155
      "cell_type": "code",
1156
      "execution_count": null,
1157
      "metadata": {
1158
        "id": "EReFAghK6E6_"
1159
      },
1160
      "outputs": [],
1161
      "source": [
1162
        "\"\"\"\n",
1163
        "Mapping from object category label and factored part label to original dataset label category.\n",
1164
        "\"\"\"\n",
1165
        "def part_obj_to_datasetclass(animate, obj_classes=21):\n",
1166
        "    \n",
1167
        "    # Animate parts : Head(1), Torso(2), Leg(3), Tail(4), Wing(5), Arm(6), Neck(7)\n",
1168
        "    # Animate parts : Eye(8), Ear(9), Nose(10), Muzzle(11), Horn(12), Mouth(13), Hair(14), Foot(15),\n",
1169
        "    #                 Hand(16), Paw(17), Hoof(18), Beak(19)\n",
1170
        "    map_pc = {}\n",
1171
        "    for i in range(obj_classes):\n",
1172
        "        map_pc[i] = {}\n",
1173
        "        \n",
1174
        "    # Animate objects\n",
1175
        "\n",
1176
        "    for idx in [8, 15, 3, 5]:\n",
1177
        "        map_pc[3][idx] = {}\n",
1178
        "    map_pc[3][19]    = 15  # Bird\n",
1179
        "    map_pc[3][1]     = 16\n",
1180
        "    map_pc[3][8][1]  = 17\n",
1181
        "    map_pc[3][15][1] = 18\n",
1182
        "    map_pc[3][3][1]  = 19\n",
1183
        "    map_pc[3][5][1]  = 20\n",
1184
        "    map_pc[3][7]     = 21\n",
1185
        "    map_pc[3][8][2]  = 22\n",
1186
        "    map_pc[3][15][2] = 23\n",
1187
        "    map_pc[3][3][2]  = 24\n",
1188
        "    map_pc[3][5][2]  = 25\n",
1189
        "    map_pc[3][4]     = 26\n",
1190
        "    map_pc[3][2]     = 27\n",
1191
        "\n",
1192
        "    for idx in [8, 9, 3, 17]:\n",
1193
        "        map_pc[8][idx] = {}\n",
1194
        "    for idx in [3, 17]:\n",
1195
        "        map_pc[8][idx][1] = {}\n",
1196
        "        map_pc[8][idx][2] = {}\n",
1197
        "    map_pc[8][1]        = 57  # Cat\n",
1198
        "    map_pc[8][3][1][2]  = 58\n",
1199
        "    map_pc[8][17][1][2] = 59\n",
1200
        "    map_pc[8][9][1]     = 60\n",
1201
        "    map_pc[8][8][1]     = 61\n",
1202
        "    map_pc[8][3][1][1]  = 62\n",
1203
        "    map_pc[8][17][1][1] = 63\n",
1204
        "    map_pc[8][7]        = 64\n",
1205
        "    map_pc[8][10]       = 65\n",
1206
        "    map_pc[8][3][2][2]  = 66\n",
1207
        "    map_pc[8][17][2][2] = 67\n",
1208
        "    map_pc[8][9][2]     = 68\n",
1209
        "    map_pc[8][8][2]     = 69\n",
1210
        "    map_pc[8][3][2][1]  = 70\n",
1211
        "    map_pc[8][17][2][1] = 71\n",
1212
        "    map_pc[8][4]        = 72\n",
1213
        "    map_pc[8][2]        = 73\n",
1214
        "\n",
1215
        "    for idx in [8, 9, 12, 21, 22]:\n",
1216
        "        map_pc[10][idx] = {}\n",
1217
        "    for idx in [21, 22]:\n",
1218
        "        map_pc[10][idx][1] = {}\n",
1219
        "        map_pc[10][idx][2] = {}\n",
1220
        "    map_pc[10][1]        = 75  # Cow\n",
1221
        "    map_pc[10][21][1][2] = 76\n",
1222
        "    map_pc[10][22][1][2] = 77\n",
1223
        "    map_pc[10][9][1]     = 78\n",
1224
        "    map_pc[10][8][1]     = 79\n",
1225
        "    map_pc[10][21][1][1] = 80\n",
1226
        "    map_pc[10][22][1][1] = 81\n",
1227
        "    map_pc[10][12][1]    = 82\n",
1228
        "    map_pc[10][11]       = 83\n",
1229
        "    map_pc[10][7]        = 84\n",
1230
        "    map_pc[10][21][2][2] = 85\n",
1231
        "    map_pc[10][22][2][2] = 86\n",
1232
        "    map_pc[10][9][2]     = 87\n",
1233
        "    map_pc[10][8][2]     = 88\n",
1234
        "    map_pc[10][21][2][1] = 89\n",
1235
        "    map_pc[10][22][2][1] = 90\n",
1236
        "    map_pc[10][12][2]    = 91\n",
1237
        "    map_pc[10][4]        = 92\n",
1238
        "    map_pc[10][2]        = 93\n",
1239
        "\n",
1240
        "    for idx in [8, 9, 3, 17]:\n",
1241
        "        map_pc[12][idx] = {}\n",
1242
        "    for idx in [3, 17]:\n",
1243
        "        map_pc[12][idx][1] = {}\n",
1244
        "        map_pc[12][idx][2] = {}\n",
1245
        "    map_pc[12][1]        = 95  # Dog\n",
1246
        "    map_pc[12][3][1][2]  = 96\n",
1247
        "    map_pc[12][17][1][2] = 97\n",
1248
        "    map_pc[12][9][1]     = 98\n",
1249
        "    map_pc[12][8][1]     = 99\n",
1250
        "    map_pc[12][3][1][1]  = 100\n",
1251
        "    map_pc[12][17][1][1] = 101\n",
1252
        "    map_pc[12][11]       = 102\n",
1253
        "    map_pc[12][7]        = 103\n",
1254
        "    map_pc[12][10]       = 104\n",
1255
        "    map_pc[12][3][2][2]  = 105\n",
1256
        "    map_pc[12][17][2][2] = 106\n",
1257
        "    map_pc[12][9][2]     = 107\n",
1258
        "    map_pc[12][8][2]     = 108\n",
1259
        "    map_pc[12][3][2][1]  = 109\n",
1260
        "    map_pc[12][17][2][1] = 110\n",
1261
        "    map_pc[12][4]        = 111\n",
1262
        "    map_pc[12][2]        = 112\n",
1263
        "\n",
1264
        "    for idx in [8, 9, 18, 21, 22]:\n",
1265
        "        map_pc[13][idx] = {}\n",
1266
        "    for idx in [18, 21, 22]:\n",
1267
        "        map_pc[13][idx][1] = {}\n",
1268
        "        map_pc[13][idx][2] = {}\n",
1269
        "    map_pc[13][1]        = 113  # Horse\n",
1270
        "    map_pc[13][18][1][2] = 114\n",
1271
        "    map_pc[13][21][1][2] = 115\n",
1272
        "    map_pc[13][22][1][2] = 116\n",
1273
        "    map_pc[13][9][1]     = 117\n",
1274
        "    map_pc[13][8][1]     = 118\n",
1275
        "    map_pc[13][18][1][1] = 119\n",
1276
        "    map_pc[13][21][1][1] = 120\n",
1277
        "    map_pc[13][22][1][1] = 121\n",
1278
        "    map_pc[13][11]       = 122\n",
1279
        "    map_pc[13][7]        = 123\n",
1280
        "    map_pc[13][18][2][2] = 124\n",
1281
        "    map_pc[13][21][2][2] = 125\n",
1282
        "    map_pc[13][22][2][2] = 126\n",
1283
        "    map_pc[13][9][2]     = 127\n",
1284
        "    map_pc[13][8][2]     = 128\n",
1285
        "    map_pc[13][18][2][1] = 129\n",
1286
        "    map_pc[13][21][2][1] = 130\n",
1287
        "    map_pc[13][22][2][1] = 131\n",
1288
        "    map_pc[13][4]        = 132\n",
1289
        "    map_pc[13][2]        = 133\n",
1290
        "\n",
1291
        "    for idx in [8, 9, 20, 15, 16, 6, 21, 23, 22]:\n",
1292
        "        map_pc[15][idx] = {}\n",
1293
        "    map_pc[15][14]       = 140  # Person\n",
1294
        "    map_pc[15][1]        = 141\n",
1295
        "    map_pc[15][9][1]     = 142\n",
1296
        "    map_pc[15][8][1]     = 143\n",
1297
        "    map_pc[15][20][1]    = 144\n",
1298
        "    map_pc[15][15][1]    = 145\n",
1299
        "    map_pc[15][16][1]    = 146\n",
1300
        "    map_pc[15][6][1]     = 147\n",
1301
        "    map_pc[15][21][1]    = 148\n",
1302
        "    map_pc[15][23][1]    = 149\n",
1303
        "    map_pc[15][22][1]    = 150\n",
1304
        "    map_pc[15][13]       = 151\n",
1305
        "    map_pc[15][7]        = 152\n",
1306
        "    map_pc[15][10]       = 153\n",
1307
        "    map_pc[15][9][2]     = 154\n",
1308
        "    map_pc[15][8][2]     = 155\n",
1309
        "    map_pc[15][20][2]    = 156\n",
1310
        "    map_pc[15][15][2]    = 157\n",
1311
        "    map_pc[15][16][2]    = 158\n",
1312
        "    map_pc[15][6][2]     = 159\n",
1313
        "    map_pc[15][21][2]    = 160\n",
1314
        "    map_pc[15][23][2]    = 161\n",
1315
        "    map_pc[15][22][2]    = 162\n",
1316
        "    map_pc[15][2]        = 163\n",
1317
        "\n",
1318
        "    for idx in [8, 9, 12, 21, 22]:\n",
1319
        "        map_pc[17][idx] = {}\n",
1320
        "    for idx in [21, 22]:\n",
1321
        "        map_pc[17][idx][1] = {}\n",
1322
        "        map_pc[17][idx][2] = {}\n",
1323
        "    map_pc[17][1]        = 166  # Sheep\n",
1324
        "    map_pc[17][21][1][2] = 167\n",
1325
        "    map_pc[17][22][1][2] = 168\n",
1326
        "    map_pc[17][9][1]     = 169\n",
1327
        "    map_pc[17][8][1]     = 170\n",
1328
        "    map_pc[17][21][1][1] = 171\n",
1329
        "    map_pc[17][22][1][1] = 172\n",
1330
        "    map_pc[17][12][1]    = 173\n",
1331
        "    map_pc[17][11]       = 174\n",
1332
        "    map_pc[17][7]        = 175\n",
1333
        "    map_pc[17][21][2][2] = 176\n",
1334
        "    map_pc[17][22][2][2] = 177\n",
1335
        "    map_pc[17][9][2]     = 178\n",
1336
        "    map_pc[17][8][2]     = 179\n",
1337
        "    map_pc[17][21][2][1] = 180\n",
1338
        "    map_pc[17][22][2][1] = 181\n",
1339
        "    map_pc[17][12][2]    = 182\n",
1340
        "    map_pc[17][4]        = 183\n",
1341
        "    map_pc[17][2]        = 184\n",
1342
        "    \n",
1343
        "    map_pc[1][3] = {}\n",
1344
        "    map_pc[1][1]    = 1  # Aeroplane\n",
1345
        "    map_pc[1][5]    = 2\n",
1346
        "    map_pc[1][3][1] = 3\n",
1347
        "    map_pc[1][3][2] = 4\n",
1348
        "    map_pc[1][4]    = 5\n",
1349
        "    map_pc[1][28]   = 6\n",
1350
        "    map_pc[1][2]    = 7\n",
1351
        "\n",
1352
        "    map_pc[2][2]       = {}\n",
1353
        "    map_pc[2][2][0]    = {}\n",
1354
        "    map_pc[2][2][0][2] = 8  # Bicycle\n",
1355
        "    map_pc[2][16]      = 9\n",
1356
        "    map_pc[2][1]       = 10\n",
1357
        "    map_pc[2][2][0][1] = 11\n",
1358
        "    map_pc[2][15]      = 12\n",
1359
        "    map_pc[2][6]       = 13\n",
1360
        "    map_pc[2][14]      = 14\n",
1361
        "    \n",
1362
        "    map_pc[4][0] = 28  # Boat\n",
1363
        "\n",
1364
        "    map_pc[5][12] = 29 # Bottle\n",
1365
        "    map_pc[5][13] = 30\n",
1366
        "\n",
1367
        "    for idx in [19, 7, 17]:\n",
1368
        "        map_pc[6][idx] = {}\n",
1369
        "    map_pc[6][7][0] = {}\n",
1370
        "    map_pc[6][17][-1] = {}\n",
1371
        "\n",
1372
        "    map_pc[6][7][0][2]   = 31  # Bus\n",
1373
        "    map_pc[6][17][-1][4] = 32\n",
1374
        "    map_pc[6][20]        = 33\n",
1375
        "    map_pc[6][7][0][1]   = 34\n",
1376
        "    map_pc[6][17][-1][3] = 35\n",
1377
        "    map_pc[6][6]         = 36\n",
1378
        "    map_pc[6][19][1]     = 37\n",
1379
        "    map_pc[6][17][-1][1] = 38\n",
1380
        "    map_pc[6][19][2]     = 39\n",
1381
        "    map_pc[6][17][-1][2] = 40\n",
1382
        "    map_pc[6][18]        = 41\n",
1383
        "    map_pc[6][2]         = 42\n",
1384
        "    map_pc[6][11]        = 43\n",
1385
        "\n",
1386
        "    for idx in [19, 7, 17]:\n",
1387
        "        map_pc[7][idx] = {}\n",
1388
        "    map_pc[7][7][0] = {}\n",
1389
        "    map_pc[7][17][-1] = {}\n",
1390
        "\n",
1391
        "    map_pc[7][7][0][2]   = 44  # Car\n",
1392
        "    map_pc[7][17][-1][4] = 45\n",
1393
        "    map_pc[7][20]        = 46\n",
1394
        "    map_pc[7][7][0][1]   = 47\n",
1395
        "    map_pc[7][17][-1][3] = 48\n",
1396
        "    map_pc[7][6]         = 49\n",
1397
        "    map_pc[7][19][1]     = 50\n",
1398
        "    map_pc[7][17][-1][1] = 51\n",
1399
        "    map_pc[7][19][2]     = 52\n",
1400
        "    map_pc[7][17][-1][2] = 53\n",
1401
        "    map_pc[7][18]        = 54\n",
1402
        "    map_pc[7][2]         = 55\n",
1403
        "    map_pc[7][11]        = 56\n",
1404
        "    \n",
1405
        "    map_pc[9][0] = 74  # Chair\n",
1406
        "    \n",
1407
        "    map_pc[11][0] = 94  # Dining Table\n",
1408
        "\n",
1409
        "    map_pc[14][2]       = {}\n",
1410
        "    map_pc[14][2][0]    = {}\n",
1411
        "    map_pc[14][2][0][2] = 134  # Motorbike\n",
1412
        "    map_pc[14][1]       = 135\n",
1413
        "    map_pc[14][2][0][1] = 136\n",
1414
        "    map_pc[14][15]      = 137\n",
1415
        "    map_pc[14][6]       = 138\n",
1416
        "    map_pc[14][14]      = 149\n",
1417
        "\n",
1418
        "    map_pc[16][10] = 164  # Potted plant\n",
1419
        "    map_pc[16][9]  = 165\n",
1420
        "    \n",
1421
        "    map_pc[18][0] = 185  # Sofa\n",
1422
        "\n",
1423
        "    for idx in [25, 22]:\n",
1424
        "        map_pc[19][idx] = {}\n",
1425
        "        map_pc[19][idx][-1] = {}\n",
1426
        "\n",
1427
        "    map_pc[19][25][-1][4] = 186 # Train\n",
1428
        "    map_pc[19][25][-1][3] = 187\n",
1429
        "    map_pc[19][25][-1][1] = 188\n",
1430
        "    map_pc[19][25][-1][2] = 189\n",
1431
        "    map_pc[19][26] = 190\n",
1432
        "    map_pc[19][24] = 191\n",
1433
        "    map_pc[19][21] = 192\n",
1434
        "    map_pc[19][22][-1][4] = 193\n",
1435
        "    map_pc[19][22][-1][3] = 194\n",
1436
        "    map_pc[19][22][-1][1] = 195\n",
1437
        "    map_pc[19][22][-1][2] = 196\n",
1438
        "    map_pc[19][23] = 197\n",
1439
        "    map_pc[19][6] = 198\n",
1440
        "\n",
1441
        "    map_pc[20][27] = 199 # Tv monitor\n",
1442
        "    map_pc[20][8] = 200\n",
1443
        "    \n",
1444
        "    if animate is None:\n",
1445
        "        classes = list(range(1, 21))\n",
1446
        "    elif animate:\n",
1447
        "        classes = [3, 8, 10, 12, 13, 15, 17]\n",
1448
        "    else:\n",
1449
        "        classes = [1, 2, 4, 5, 6, 7, 9, 11, 14, 16, 18, 19, 20]\n",
1450
        "    return map_pc, classes"
1451
      ]
1452
    },
1453
    {
1454
      "cell_type": "code",
1455
      "execution_count": null,
1456
      "metadata": {
1457
        "id": "5G9GKdqo6E7B"
1458
      },
1459
      "outputs": [],
1460
      "source": [
1461
        "\"\"\"\n",
1462
        "Basic image and label transforms for the dataset.\n",
1463
        "\"\"\"\n",
1464
        "from PIL import Image, ImageOps, ImageFilter\n",
1465
        "from torchvision import transforms\n",
1466
        "\n",
1467
        "class Normalize(object):\n",
1468
        "    \"\"\"Normalize a tensor image with mean and standard deviation.\n",
1469
        "    Args:\n",
1470
        "        mean (tuple): means for each channel.\n",
1471
        "        std (tuple): standard deviations for each channel.\n",
1472
        "    \"\"\"\n",
1473
        "    def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):\n",
1474
        "        self.mean = mean\n",
1475
        "        self.std = std\n",
1476
        "\n",
1477
        "    def __call__(self, sample):\n",
1478
        "        img = sample['image']\n",
1479
        "        obj = sample['obj']\n",
1480
        "        part = sample['part']\n",
1481
        "        img = np.array(img).astype(np.float32)\n",
1482
        "        obj = np.array(obj).astype(np.float32)\n",
1483
        "        part = np.array(part).astype(np.float32)\n",
1484
        "        img /= 255.0\n",
1485
        "        img -= self.mean\n",
1486
        "        img /= self.std\n",
1487
        "\n",
1488
        "        return {'image': img,\n",
1489
        "                'obj': obj,\n",
1490
        "                'part': part}\n",
1491
        "\n",
1492
        "\n",
1493
        "class ToTensor(object):\n",
1494
        "    \"\"\"Convert ndarrays in sample to Tensors.\"\"\"\n",
1495
        "\n",
1496
        "    def __call__(self, sample):\n",
1497
        "        # swap color axis because\n",
1498
        "        # numpy image: H x W x C\n",
1499
        "        # torch image: C X H X W\n",
1500
        "        img = sample['image']\n",
1501
        "        obj = sample['obj']\n",
1502
        "        part = sample['part']\n",
1503
        "        img = np.array(img).astype(np.float32).transpose((2, 0, 1))\n",
1504
        "        obj = np.array(obj).astype(np.float32)\n",
1505
        "        part = np.array(part).astype(np.float32)\n",
1506
        "\n",
1507
        "        img = torch.from_numpy(img).float()\n",
1508
        "        obj = torch.from_numpy(obj).float()\n",
1509
        "        part = torch.from_numpy(part).float()\n",
1510
        "\n",
1511
        "        return {'image': img,\n",
1512
        "                'obj': obj,\n",
1513
        "                'part': part}\n",
1514
        "\n",
1515
        "\n",
1516
        "class RandomHorizontalFlip(object):\n",
1517
        "    def __call__(self, sample):\n",
1518
        "        img = sample['image']\n",
1519
        "        obj = sample['obj']\n",
1520
        "        part = sample['part']\n",
1521
        "        if random.random() < 0.5:\n",
1522
        "            img = img.transpose(Image.FLIP_LEFT_RIGHT)\n",
1523
        "            obj = obj.transpose(Image.FLIP_LEFT_RIGHT)\n",
1524
        "            part = part.transpose(Image.FLIP_LEFT_RIGHT)\n",
1525
        "\n",
1526
        "        return {'image': img,\n",
1527
        "                'obj': obj,\n",
1528
        "                'part': part}\n",
1529
        "\n",
1530
        "\n",
1531
        "class RandomRotate(object):\n",
1532
        "    def __init__(self, degree):\n",
1533
        "        self.degree = degree\n",
1534
        "\n",
1535
        "    def __call__(self, sample):\n",
1536
        "        img = sample['image']\n",
1537
        "        obj = sample['obj']\n",
1538
        "        part = sample['part']\n",
1539
        "        rotate_degree = random.uniform(-1*self.degree, self.degree)\n",
1540
        "        img = img.rotate(rotate_degree, Image.BILINEAR)\n",
1541
        "        obj = obj.rotate(rotate_degree, Image.NEAREST)\n",
1542
        "        part = part.rotate(rotate_degree, Image.NEAREST)\n",
1543
        "\n",
1544
        "        return {'image': img,\n",
1545
        "                'obj': obj,\n",
1546
        "                'part': part}\n",
1547
        "\n",
1548
        "\n",
1549
        "class RandomGaussianBlur(object):\n",
1550
        "    def __call__(self, sample):\n",
1551
        "        img = sample['image']\n",
1552
        "        obj = sample['obj']\n",
1553
        "        part = sample['part']\n",
1554
        "        if random.random() < 0.5:\n",
1555
        "            img = img.filter(ImageFilter.GaussianBlur(\n",
1556
        "                radius=random.random()))\n",
1557
        "\n",
1558
        "        return {'image': img,\n",
1559
        "                'obj': obj,\n",
1560
        "                'part': part}\n",
1561
        "\n",
1562
        "\n",
1563
        "class RandomScaleCrop(object):\n",
1564
        "    def __init__(self, base_size, crop_size, fill=0):\n",
1565
        "        self.base_size = base_size\n",
1566
        "        self.crop_size = crop_size\n",
1567
        "        self.fill = fill\n",
1568
        "\n",
1569
        "    def __call__(self, sample):\n",
1570
        "        img = sample['image']\n",
1571
        "        obj = sample['obj']\n",
1572
        "        part = sample['part']\n",
1573
        "        # random scale (short edge)\n",
1574
        "        short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))\n",
1575
        "        w, h = img.size\n",
1576
        "        if h > w:\n",
1577
        "            ow = short_size\n",
1578
        "            oh = int(1.0 * h * ow / w)\n",
1579
        "        else:\n",
1580
        "            oh = short_size\n",
1581
        "            ow = int(1.0 * w * oh / h)\n",
1582
        "        img = img.resize((ow, oh), Image.BILINEAR)\n",
1583
        "        obj = obj.resize((ow, oh), Image.NEAREST)\n",
1584
        "        part = part.resize((ow, oh), Image.NEAREST)\n",
1585
        "        # pad crop\n",
1586
        "        if short_size < self.crop_size:\n",
1587
        "            padh = self.crop_size - oh if oh < self.crop_size else 0\n",
1588
        "            padw = self.crop_size - ow if ow < self.crop_size else 0\n",
1589
        "            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)\n",
1590
        "            obj = ImageOps.expand(obj, border=(0, 0, padw, padh), fill=self.fill)\n",
1591
        "            part = ImageOps.expand(part, border=(0, 0, padw, padh), fill=self.fill)\n",
1592
        "        # random crop crop_size\n",
1593
        "        w, h = img.size\n",
1594
        "        x1 = random.randint(0, w - self.crop_size)\n",
1595
        "        y1 = random.randint(0, h - self.crop_size)\n",
1596
        "        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n",
1597
        "        obj = obj.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n",
1598
        "        part = part.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n",
1599
        "\n",
1600
        "        return {'image': img,\n",
1601
        "                'obj': obj,\n",
1602
        "                'part': part}\n",
1603
        "\n",
1604
        "\n",
1605
        "class FixScaleCrop(object):\n",
1606
        "    def __init__(self, crop_size):\n",
1607
        "        self.crop_size = crop_size\n",
1608
        "\n",
1609
        "    def __call__(self, sample):\n",
1610
        "        img = sample['image']\n",
1611
        "        obj = sample['obj']\n",
1612
        "        part = sample['part']\n",
1613
        "        w, h = img.size\n",
1614
        "        if w > h:\n",
1615
        "            oh = self.crop_size\n",
1616
        "            ow = int(1.0 * w * oh / h)\n",
1617
        "        else:\n",
1618
        "            ow = self.crop_size\n",
1619
        "            oh = int(1.0 * h * ow / w)\n",
1620
        "        img = img.resize((ow, oh), Image.BILINEAR)\n",
1621
        "        obj = obj.resize((ow, oh), Image.NEAREST)\n",
1622
        "        part = part.resize((ow, oh), Image.NEAREST)\n",
1623
        "        # center crop\n",
1624
        "        w, h = img.size\n",
1625
        "        x1 = int(round((w - self.crop_size) / 2.))\n",
1626
        "        y1 = int(round((h - self.crop_size) / 2.))\n",
1627
        "        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n",
1628
        "        obj = obj.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n",
1629
        "        part = part.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n",
1630
        "\n",
1631
        "        return {'image': img,\n",
1632
        "                'obj': obj,\n",
1633
        "                'part': part}\n",
1634
        "\n",
1635
        "class FixedResize(object):\n",
1636
        "    def __init__(self, size):\n",
1637
        "        self.size = (size, size)  # size: (h, w)\n",
1638
        "\n",
1639
        "    def __call__(self, sample):\n",
1640
        "        img = sample['image']\n",
1641
        "        obj = sample['obj']\n",
1642
        "        part = sample['part']\n",
1643
        "\n",
1644
        "        # assert img.size == mask.size\n",
1645
        "\n",
1646
        "        img = img.resize(self.size, Image.BILINEAR)\n",
1647
        "        obj = obj.resize(self.size, Image.NEAREST)\n",
1648
        "        part = part.resize(self.size, Image.NEAREST)\n",
1649
        "\n",
1650
        "        return {'image': img,\n",
1651
        "                'obj': obj,\n",
1652
        "                'part': part}\n",
1653
        "    \n",
1654
        "class ResizeMasks(object):\n",
1655
        "    def __init__(self, crop_size):\n",
1656
        "        self.crop_size = crop_size\n",
1657
        "\n",
1658
        "    def __call__(self, sample):\n",
1659
        "        img = sample['image']\n",
1660
        "        obj = sample['obj']\n",
1661
        "        part = sample['part']\n",
1662
        "        w, h = img.size\n",
1663
        "        short_size = 0\n",
1664
        "        if w > h:\n",
1665
        "            ow = self.crop_size\n",
1666
        "            oh = int(1.0 * h * ow / w)\n",
1667
        "            short_size = oh\n",
1668
        "        else:\n",
1669
        "            oh = self.crop_size\n",
1670
        "            ow = int(1.0 * w * oh / h)\n",
1671
        "            short_size = ow\n",
1672
        "            \n",
1673
        "        img = img.resize((ow, oh), Image.BILINEAR)\n",
1674
        "        obj = obj.resize((ow, oh), Image.NEAREST)\n",
1675
        "        part = part.resize((ow, oh), Image.NEAREST)\n",
1676
        "        \n",
1677
        "        if short_size < self.crop_size:\n",
1678
        "            padh = self.crop_size - oh if oh < self.crop_size else 0\n",
1679
        "            padw = self.crop_size - ow if ow < self.crop_size else 0\n",
1680
        "            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)\n",
1681
        "            obj = ImageOps.expand(obj, border=(0, 0, padw, padh), fill=0)\n",
1682
        "            part = ImageOps.expand(part, border=(0, 0, padw, padh), fill=0)\n",
1683
        "        \n",
1684
        "        return {'image': img,\n",
1685
        "                'obj': obj,\n",
1686
        "                'part': part}"
1687
      ]
1688
    },
1689
    {
1690
      "cell_type": "code",
1691
      "execution_count": null,
1692
      "metadata": {
1693
        "id": "QiYUlNnI6E7C"
1694
      },
1695
      "outputs": [],
1696
      "source": [
1697
        "\"\"\"\n",
1698
        "Dataset class : manages animate/inanimate separation for images and labels.\n",
1699
        "\"\"\"\n",
1700
        "class SegmentationDataset(Dataset):\n",
1701
        "    def __init__(self, folder, mode='train'):\n",
1702
        "\n",
1703
        "        self.folder = folder\n",
1704
        "        with open(folder + mode + '.txt') as f:\n",
1705
        "            self.image_path_list = f.read().splitlines()\n",
1706
        "\n",
1707
        "    def __len__(self):\n",
1708
        "        return len(self.image_path_list)\n",
1709
        "\n",
1710
        "    def __getitem__(self, i):\n",
1711
        "\n",
1712
        "        image_path = self.folder + 'images/' + self.image_path_list[i] + '.png'\n",
1713
        "        part_label_path = self.folder + 'parts201/' + self.image_path_list[i] + '.png'\n",
1714
        "        obj_label_path = self.folder + 'objs21/' + self.image_path_list[i] + '.png'\n",
1715
        "\n",
1716
        "        sample = {}\n",
1717
        "        sample['image'] = Image.open(image_path)\n",
1718
        "        org_img = sample['image'].copy()\n",
1719
        "        org_size = sample['image'].size\n",
1720
        "\n",
1721
        "        sample['part'] = Image.open(part_label_path)\n",
1722
        "        sample['obj'] = Image.open(obj_label_path)\n",
1723
        "\n",
1724
        "        sample = self.transform_val(sample)\n",
1725
        "\n",
1726
        "        sample['path'] = self.image_path_list[i]\n",
1727
        "        sample['orgsize'] = org_size\n",
1728
        "        sample['org_img'] = np.array(org_img)\n",
1729
        "\n",
1730
        "        return sample\n",
1731
        "\n",
1732
        "    def transform_tr(self, sample):\n",
1733
        "        composed_transforms = transforms.Compose([\n",
1734
        "            RandomHorizontalFlip(),\n",
1735
        "            RandomScaleCrop(base_size=513, crop_size=513),\n",
1736
        "            RandomGaussianBlur(),\n",
1737
        "            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
1738
        "            ToTensor()])\n",
1739
        "\n",
1740
        "        return composed_transforms(sample)\n",
1741
        "\n",
1742
        "    def transform_val(self, sample):\n",
1743
        "        composed_transforms = transforms.Compose([\n",
1744
        "            ResizeMasks(crop_size=770),\n",
1745
        "            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
1746
        "            ToTensor()])\n",
1747
        "\n",
1748
        "        return composed_transforms(sample)"
1749
      ]
1750
    },
1751
    {
1752
      "cell_type": "code",
1753
      "execution_count": null,
1754
      "metadata": {
1755
        "id": "uDxla8HS6E7D"
1756
      },
1757
      "outputs": [],
1758
      "source": [
1759
        "PATH = '/path/to/dataset/'\n",
1760
        "batch_size = 1\n",
1761
        "\n",
1762
        "train_dataset = SegmentationDataset(PATH)\n",
1763
        "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=2)\n",
1764
        "valid_dataset = SegmentationDataset(PATH, mode='val')\n",
1765
        "valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)"
1766
      ]
1767
    },
1768
    {
1769
      "cell_type": "code",
1770
      "execution_count": null,
1771
      "metadata": {
1772
        "id": "39h1M34A6E7F"
1773
      },
1774
      "outputs": [],
1775
      "source": [
1776
        "obj_model = DeepLab(backbone='resnet101', num_classes=21)\n",
1777
        "model = Float(backbone='resnet101', num_anim_classes=24, num_inanim_classes=29)"
1778
      ]
1779
    },
1780
    {
1781
      "cell_type": "code",
1782
      "execution_count": null,
1783
      "metadata": {
1784
        "id": "4we8TLUb6E7F",
1785
        "outputId": "09a7bcd9-924a-407d-f490-a796b18bb957"
1786
      },
1787
      "outputs": [],
1788
      "source": [
1789
        "obj_model.load_state_dict(torch.load('/path/to/obj/model'))\n",
1790
        "model.load_state_dict(torch.load('/path/to/part/model'))"
1791
      ]
1792
    },
1793
    {
1794
      "cell_type": "code",
1795
      "execution_count": null,
1796
      "metadata": {
1797
        "id": "NjoEhLez6E7F"
1798
      },
1799
      "outputs": [],
1800
      "source": [
1801
        "gpu_ids = [0,1,2,3,4,5,6,7]"
1802
      ]
1803
    },
1804
    {
1805
      "cell_type": "code",
1806
      "execution_count": null,
1807
      "metadata": {
1808
        "id": "56aRgPwP6E7G"
1809
      },
1810
      "outputs": [],
1811
      "source": [
1812
        "\"\"\"\n",
1813
        "Parallezing the model deploying simultaneous training on multiple GPUs.\n",
1814
        "\"\"\"\n",
1815
        "if torch.cuda.device_count() > 1:\n",
1816
        "    obj_model = torch.nn.DataParallel(obj_model, device_ids=gpu_ids)\n",
1817
        "    patch_replication_callback(obj_model)\n",
1818
        "    obj_model.cuda()\n",
1819
        "\n",
1820
        "    model = torch.nn.DataParallel(model, device_ids=gpu_ids)\n",
1821
        "    patch_replication_callback(model)\n",
1822
        "    model.cuda()"
1823
      ]
1824
    },
1825
    {
1826
      "cell_type": "code",
1827
      "execution_count": null,
1828
      "metadata": {
1829
        "id": "dXywaBGy6E7T"
1830
      },
1831
      "outputs": [],
1832
      "source": [
1833
        "\"\"\"\n",
1834
        "Returns the per class mIOU and per class count in a batch avgd per img.\n",
1835
        "\"\"\"\n",
1836
        "def jaccard_perpart_perimg(y_pred, y_true, num_classes):\n",
1837
        "    y_pred = torch.Tensor(y_pred).type(torch.LongTensor)\n",
1838
        "    y_true = torch.Tensor(y_true).type(torch.LongTensor)\n",
1839
        "    y_pred = F.one_hot(y_pred, num_classes=num_classes)\n",
1840
        "    y_true = F.one_hot(y_true, num_classes=num_classes)\n",
1841
        "    ious = {}\n",
1842
        "    counts = {}\n",
1843
        "    for i in range(num_classes):\n",
1844
        "        pred = y_pred[:,:,i]\n",
1845
        "        gt = y_true[:,:,i]\n",
1846
        "        inter = torch.logical_and(pred, gt)\n",
1847
        "        union = torch.logical_or(pred, gt)\n",
1848
        "        iou = torch.sum(inter, [0,1]) / torch.sum(union, [0,1])\n",
1849
        "        legal = torch.sum(gt, [0,1]) > 0\n",
1850
        "        ious[i] = torch.sum(iou[legal])\n",
1851
        "        counts[i] = torch.sum(legal)\n",
1852
        "\n",
1853
        "    return ious, counts"
1854
      ]
1855
    },
1856
    {
1857
      "cell_type": "code",
1858
      "execution_count": null,
1859
      "metadata": {
1860
        "id": "HFOY-vO16E7T"
1861
      },
1862
      "outputs": [],
1863
      "source": [
1864
        "\"\"\"\n",
1865
        "Pad and sqaure the bounding box received in input.\n",
1866
        "Padding for context and accoutning for errors.\n",
1867
        "Square to support resizing without distorting.\n",
1868
        "\"\"\"\n",
1869
        "def pad_and_square(x_min, y_min, x_max, y_max, pad, orgsize):\n",
1870
        "    x_min = max(x_min - pad, 0)\n",
1871
        "    y_min = max(y_min - pad, 0)\n",
1872
        "    x_max = min(x_max + pad, orgsize)\n",
1873
        "    y_max = min(y_max + pad, orgsize)\n",
1874
        "    \n",
1875
        "    y_dis = y_max - y_min\n",
1876
        "    x_dis = x_max - x_min\n",
1877
        "    \n",
1878
        "    if y_dis > x_dis:\n",
1879
        "        diff = y_dis - x_dis\n",
1880
        "        dsub = diff // 2\n",
1881
        "        dadd = diff - dsub\n",
1882
        "\n",
1883
        "        if dsub > x_min:\n",
1884
        "            x_max = min(x_max + dadd + (dsub - x_min), orgsize)\n",
1885
        "            x_min = 0\n",
1886
        "        elif x_max + dadd > orgsize:\n",
1887
        "            x_min = max(x_min - dsub - (x_max + dadd - orgsize), 0)\n",
1888
        "            x_max = orgsize\n",
1889
        "        else:\n",
1890
        "            x_min = x_min - dsub\n",
1891
        "            x_max = x_max + dadd\n",
1892
        "\n",
1893
        "    elif x_dis > y_dis:\n",
1894
        "        diff = x_dis - y_dis\n",
1895
        "        dsub = diff // 2\n",
1896
        "        dadd = diff - dsub\n",
1897
        "        \n",
1898
        "        if dsub > y_min:\n",
1899
        "            y_max = min(y_max + dadd + (dsub - y_min), orgsize)\n",
1900
        "            y_min = 0\n",
1901
        "        elif y_max + dadd > orgsize:\n",
1902
        "            y_min = max(y_min - dsub - (y_max + dadd - orgsize), 0)\n",
1903
        "            y_max = orgsize\n",
1904
        "        else:\n",
1905
        "            y_min = y_min - dsub\n",
1906
        "            y_max = y_max + dadd\n",
1907
        "\n",
1908
        "    return x_min, y_min, x_max, y_max"
1909
      ]
1910
    },
1911
    {
1912
      "cell_type": "code",
1913
      "execution_count": null,
1914
      "metadata": {
1915
        "id": "blNAMaA06E7U"
1916
      },
1917
      "outputs": [],
1918
      "source": [
1919
        "\"\"\"\n",
1920
        "Get bounding boxes from mask prediction.\n",
1921
        "\"\"\"\n",
1922
        "def bbox(img):\n",
1923
        "    y_min = 0\n",
1924
        "    y_max = 0\n",
1925
        "    x_min = 0\n",
1926
        "    x_max = 0\n",
1927
        "\n",
1928
        "    for i in img:\n",
1929
        "        if np.count_nonzero(i) is not 0:\n",
1930
        "            break\n",
1931
        "        y_min+=1\n",
1932
        "\n",
1933
        "    for i in img.T:\n",
1934
        "        if np.count_nonzero(i) is not 0:\n",
1935
        "            break\n",
1936
        "        x_min+=1\n",
1937
        "\n",
1938
        "    for i in img[::-1]:\n",
1939
        "        if np.count_nonzero(i) is not 0:\n",
1940
        "            break\n",
1941
        "        y_max+=1\n",
1942
        "    y_max = img.shape[0] - y_max - 1\n",
1943
        "\n",
1944
        "    for i in img.T[::-1]:\n",
1945
        "        if np.count_nonzero(i) is not 0:\n",
1946
        "            break\n",
1947
        "        x_max+=1\n",
1948
        "    x_max = img.shape[1] - x_max - 1\n",
1949
        "\n",
1950
        "    return x_min, y_min, x_max, y_max"
1951
      ]
1952
    },
1953
    {
1954
      "cell_type": "code",
1955
      "execution_count": null,
1956
      "metadata": {
1957
        "id": "oYMb-lds6E7U"
1958
      },
1959
      "outputs": [],
1960
      "source": [
1961
        "\"\"\"\n",
1962
        "Converting predicted mask (larger) to the original size of the image for final metric calculation.\n",
1963
        "\"\"\"\n",
1964
        "def pred_to_orgsize(part, org_size):\n",
1965
        "    ow, oh = org_size\n",
1966
        "    h, w = part.shape\n",
1967
        "    short_size = 0\n",
1968
        "    if ow > oh:\n",
1969
        "        s = int(1.0 * oh * w / ow)\n",
1970
        "        part = part[:s, :]\n",
1971
        "    else:\n",
1972
        "        s = int(1.0 * ow * h / oh)\n",
1973
        "        part = part[:, :s]\n",
1974
        "\n",
1975
        "    part = part.astype(np.uint8)\n",
1976
        "    part = Image.fromarray(part)\n",
1977
        "    part = part.resize((ow, oh), Image.NEAREST)\n",
1978
        "    part = np.array(part).astype(np.uint8)\n",
1979
        "\n",
1980
        "    return part"
1981
      ]
1982
    },
1983
    {
1984
      "cell_type": "code",
1985
      "execution_count": null,
1986
      "metadata": {
1987
        "id": "yyVQHjGX6E7V"
1988
      },
1989
      "outputs": [],
1990
      "source": [
1991
        "\"\"\"\n",
1992
        "Combinging object, animate part and left/right and front/back predictions.\n",
1993
        "\"\"\"\n",
1994
        "def combine_obj_animpart_pred(objs, parts, lr_pred, fb_pred):\n",
1995
        "    map_pc, classes = part_obj_to_datasetclass(animate=True)\n",
1996
        "    preds = np.zeros(objs.shape)\n",
1997
        "\n",
1998
        "    for objkey in classes:\n",
1999
        "        for partkey in map_pc[objkey]:\n",
2000
        "            obj = (objs == objkey)\n",
2001
        "            obj = obj.astype(int)\n",
2002
        "            part = (parts == partkey)\n",
2003
        "            part = part.astype(int)\n",
2004
        "\n",
2005
        "            if type(map_pc[objkey][partkey]) is dict:\n",
2006
        "                for lr_idx in [1, 2]:\n",
2007
        "                    lr = (lr_pred == lr_idx)\n",
2008
        "                    lr = lr.astype(int)\n",
2009
        "                    if type(map_pc[objkey][partkey][lr_idx]) is dict:\n",
2010
        "                        for fb_idx in [1, 2]:\n",
2011
        "                            fb = (fb_pred == fb_idx)\n",
2012
        "                            fb = fb.astype(int)\n",
2013
        "                            finalkey = map_pc[objkey][partkey][lr_idx][fb_idx]\n",
2014
        "                            preds += finalkey * obj * part * lr * fb\n",
2015
        "                    else:\n",
2016
        "                        finalkey = map_pc[objkey][partkey][lr_idx]\n",
2017
        "                        preds += finalkey * obj * part * lr\n",
2018
        "                            \n",
2019
        "            else:\n",
2020
        "                finalkey = map_pc[objkey][partkey]\n",
2021
        "                preds += finalkey * obj * part\n",
2022
        "\n",
2023
        "    return preds"
2024
      ]
2025
    },
2026
    {
2027
      "cell_type": "code",
2028
      "execution_count": null,
2029
      "metadata": {
2030
        "id": "hNd37u-k6E7V"
2031
      },
2032
      "outputs": [],
2033
      "source": [
2034
        "\"\"\"\n",
2035
        "Combinging object, inanimate part and left/right and front/back predictions.\n",
2036
        "\"\"\"\n",
2037
        "def combine_obj_inanimpart_pred(objs, parts, lr_pred, fb_pred, lrfb_pred):\n",
2038
        "    map_pc, classes = part_obj_to_datasetclass(animate=False)\n",
2039
        "    preds = np.zeros(objs.shape)\n",
2040
        "\n",
2041
        "    for objkey in classes:\n",
2042
        "        for partkey in map_pc[objkey]:\n",
2043
        "            obj = (objs == objkey)\n",
2044
        "            obj = obj.astype(int)\n",
2045
        "            part = (parts == partkey)\n",
2046
        "            part = part.astype(int)\n",
2047
        "\n",
2048
        "            if type(map_pc[objkey][partkey]) is dict:\n",
2049
        "                if 0 in map_pc[objkey][partkey]:\n",
2050
        "                    for fb_idx in [1, 2]:\n",
2051
        "                        fb = (fb_pred == fb_idx)\n",
2052
        "                        fb = fb.astype(int)\n",
2053
        "                        finalkey = map_pc[objkey][partkey][0][fb_idx]\n",
2054
        "                        preds += finalkey * obj * part * fb\n",
2055
        "                elif -1 in map_pc[objkey][partkey]:\n",
2056
        "                    for lrfb_idx in [1, 2, 3, 4]:\n",
2057
        "                        lrfb = (lrfb_pred == lrfb_idx)\n",
2058
        "                        lrfb = lrfb.astype(int)\n",
2059
        "                        finalkey = map_pc[objkey][partkey][-1][lrfb_idx]\n",
2060
        "                        preds += finalkey * obj * part * lrfb\n",
2061
        "                else:\n",
2062
        "                    for lr_idx in [1, 2]:\n",
2063
        "                        lr = (lr_pred == lr_idx)\n",
2064
        "                        lr = lr.astype(int)\n",
2065
        "                        finalkey = map_pc[objkey][partkey][lr_idx]\n",
2066
        "                        preds += finalkey * obj * part * lr\n",
2067
        "\n",
2068
        "            else:\n",
2069
        "                finalkey = map_pc[objkey][partkey]\n",
2070
        "                preds += finalkey * obj * part\n",
2071
        "\n",
2072
        "    return preds"
2073
      ]
2074
    },
2075
    {
2076
      "cell_type": "code",
2077
      "execution_count": null,
2078
      "metadata": {
2079
        "id": "b5Lz5wF56E7V"
2080
      },
2081
      "outputs": [],
2082
      "source": [
2083
        "\"\"\"\n",
2084
        "Creating and combining animate and inanimate prediction maps.\n",
2085
        "\"\"\"\n",
2086
        "def combine_obj_all_parts(objs, anim_parts, inanim_parts, lr_pred, fb_pred, lrfb_pred):\n",
2087
        "    anim_final = combine_obj_animpart_pred(objs, anim_parts, lr_pred, fb_pred)\n",
2088
        "    inanim_final = combine_obj_inanimpart_pred(objs, inanim_parts, lr_pred, fb_pred, lrfb_pred)\n",
2089
        "\n",
2090
        "    anim_loc = anim_final > 0\n",
2091
        "    anim_loc = anim_loc.astype(float)\n",
2092
        "    inanim_loc = inanim_final > 0\n",
2093
        "    inanim_loc = inanim_loc.astype(float)\n",
2094
        "\n",
2095
        "    invalid_loc = anim_loc * inanim_loc\n",
2096
        "    valid_loc = 1 - invalid_loc\n",
2097
        "\n",
2098
        "    anim_final = anim_final * valid_loc\n",
2099
        "    inanim_final = inanim_final * valid_loc\n",
2100
        "\n",
2101
        "    return anim_final + inanim_final"
2102
      ]
2103
    },
2104
    {
2105
      "cell_type": "code",
2106
      "execution_count": null,
2107
      "metadata": {
2108
        "id": "llaIUl2P6E7V"
2109
      },
2110
      "outputs": [],
2111
      "source": [
2112
        "\"\"\"\n",
2113
        "Creates the left-right front-back prediction maps.\n",
2114
        "\"\"\"\n",
2115
        "def get_lrfb_pred(lr_pred, fb_pred):\n",
2116
        "    lrfb_pred = np.concatenate((lr_pred, fb_pred), axis=1)\n",
2117
        "    lrfb_pred = np.argmax(lrfb_pred, 1) + 1\n",
2118
        "    return lrfb_pred"
2119
      ]
2120
    },
2121
    {
2122
      "cell_type": "code",
2123
      "execution_count": null,
2124
      "metadata": {
2125
        "id": "CH1lMPrT6E7W"
2126
      },
2127
      "outputs": [],
2128
      "source": [
2129
        "# FLOAT (with IZR)\n",
2130
        "pad = 50\n",
2131
        "inpsize = 770\n",
2132
        "size = (513, 513)\n",
2133
        "resize_tr = transforms.Compose([transforms.Resize(size)])\n",
2134
        "\n",
2135
        "num_classes = 201\n",
2136
        "obj_model.eval()\n",
2137
        "model.eval()\n",
2138
        "valid_miou_avg = Evaluator(num_classes)\n",
2139
        "valid_sqiou_avg = sqIOUMeter(num_classes)\n",
2140
        "\n",
2141
        "i = 0\n",
2142
        "for sample in valid_dataloader:\n",
2143
        "    i += 1\n",
2144
        "    images = sample['image'].float()\n",
2145
        "    parts = sample['part'].type(torch.LongTensor)\n",
2146
        "    orgsizes = sample['orgsize']\n",
2147
        "    num_batches = images.shape[0]\n",
2148
        "    images = images.cuda()\n",
2149
        "    parts = parts.cuda()\n",
2150
        "\n",
2151
        "    objpred = obj_model(images)\n",
2152
        "    animpred, inanimpred, lr_pred, fb_pred = model(images)\n",
2153
        "\n",
2154
        "    parts = parts.cpu().detach().numpy()\n",
2155
        "\n",
2156
        "    lr_pred = lr_pred.cpu().detach().numpy()\n",
2157
        "    fb_pred = fb_pred.cpu().detach().numpy()\n",
2158
        "    animpred = animpred.cpu().detach().numpy()\n",
2159
        "    inanimpred = inanimpred.cpu().detach().numpy()\n",
2160
        "    objpred = objpred.cpu().detach().numpy()\n",
2161
        "    objpred_lbls = np.argmax(objpred, 1)\n",
2162
        "\n",
2163
        "    # --------------------------------------ZOOM---------------------------------------\n",
2164
        "\n",
2165
        "    # for every sample in the batch\n",
2166
        "    for nb in range(num_batches):\n",
2167
        "        objpred_classes = np.unique(objpred_lbls[nb])\n",
2168
        "        zoom_info = []\n",
2169
        "        zoomed_inp = []\n",
2170
        "\n",
2171
        "        # for every unique object classes in the sample\n",
2172
        "        for obj_cls in objpred_classes:\n",
2173
        "            if obj_cls == 0:\n",
2174
        "                continue\n",
2175
        "            num_labels, labels = cv2.connectedComponents((objpred_lbls[nb] == obj_cls).astype(np.uint8))\n",
2176
        "\n",
2177
        "            # for every unique component of the object class\n",
2178
        "            for ncomp in range(1, num_labels):\n",
2179
        "                # ignore if component is too small\n",
2180
        "                if np.sum(labels == ncomp) < 25:\n",
2181
        "                    continue\n",
2182
        "\n",
2183
        "                x_min, y_min, x_max, y_max = bbox(labels == ncomp)\n",
2184
        "                x_min, y_min, x_max, y_max = pad_and_square(x_min, y_min, x_max, y_max, pad, inpsize)\n",
2185
        "                # print(nb, obj_cls, ncomp, x_min, y_min, x_max, y_max)\n",
2186
        "                if (y_max-y_min) * (x_max-x_min) > 400*400:\n",
2187
        "                     continue\n",
2188
        "\n",
2189
        "                if obj_cls in [3, 8, 10, 12, 13, 15, 17]:\n",
2190
        "                    zoom_info.append((y_min, y_max, x_min, x_max, True, obj_cls))\n",
2191
        "\n",
2192
        "                elif obj_cls in [1, 2, 4, 5, 6, 7, 9, 11, 14, 16, 18, 19, 20]:\n",
2193
        "                    zoom_info.append((y_min, y_max, x_min, x_max, False, obj_cls))\n",
2194
        "\n",
2195
        "                else:\n",
2196
        "                    assert(False)\n",
2197
        "\n",
2198
        "                cropimg = resize_tr(images[nb, :, y_min:y_max, x_min:x_max])\n",
2199
        "                zoomed_inp.append(cropimg)\n",
2200
        "\n",
2201
        "        if len(zoomed_inp) == 0:\n",
2202
        "            continue\n",
2203
        "\n",
2204
        "        num_objs = len(zoom_info)\n",
2205
        "        zoomed_inp = torch.stack(zoomed_inp)\n",
2206
        "\n",
2207
        "        for iobj in range(num_objs):\n",
2208
        "            y_min, y_max, x_min, x_max, animate, obj_cls = zoom_info[iobj]\n",
2209
        "            objpred_zoom = obj_model(zoomed_inp[iobj:iobj+1])\n",
2210
        "            animpred_zoom, inanimpred_zoom, lrpred_zoom, fbpred_zoom = model(zoomed_inp[iobj:iobj+1])\n",
2211
        "\n",
2212
        "            resize_zoomed = transforms.Compose([transforms.Resize((y_max-y_min, x_max-x_min))])\n",
2213
        "            obj_zoom = resize_zoomed(objpred_zoom[0])\n",
2214
        "            obj_zoom = obj_zoom.cpu().detach().numpy()\n",
2215
        "            obj_zoom_pred = (np.argmax(obj_zoom, 0) == obj_cls).astype(int)\n",
2216
        "            \n",
2217
        "            lr_zoom = resize_zoomed(lrpred_zoom[0])\n",
2218
        "            lr_zoom = lr_zoom.cpu().detach().numpy()\n",
2219
        "\n",
2220
        "            fb_zoom = resize_zoomed(fbpred_zoom[0])\n",
2221
        "            fb_zoom = fb_zoom.cpu().detach().numpy()\n",
2222
        "\n",
2223
        "            objpred[nb, :, y_min:y_max, x_min:x_max] = (obj_zoom_pred * obj_zoom) + ((1 - obj_zoom_pred) * objpred[nb, :, y_min:y_max, x_min:x_max])\n",
2224
        "            lr_pred[nb, :, y_min:y_max, x_min:x_max] = (obj_zoom_pred * lr_zoom) + ((1 - obj_zoom_pred) * lr_pred[nb, :, y_min:y_max, x_min:x_max])\n",
2225
        "            fb_pred[nb, :, y_min:y_max, x_min:x_max] = (obj_zoom_pred * fb_zoom) + ((1 - obj_zoom_pred) * fb_pred[nb, :, y_min:y_max, x_min:x_max])\n",
2226
        "\n",
2227
        "            if animate:\n",
2228
        "                anim_zoom = resize_zoomed(animpred_zoom[0])\n",
2229
        "                anim_zoom = anim_zoom.cpu().detach().numpy()\n",
2230
        "                animpred[nb, :, y_min:y_max, x_min:x_max] = (obj_zoom_pred * anim_zoom) + ((1 - obj_zoom_pred) * animpred[nb, :, y_min:y_max, x_min:x_max])\n",
2231
        "            else:\n",
2232
        "                inanim_zoom = resize_zoomed(inanimpred_zoom[0])\n",
2233
        "                inanim_zoom = inanim_zoom.cpu().detach().numpy()\n",
2234
        "                inanimpred[nb, :, y_min:y_max, x_min:x_max] = (obj_zoom_pred * inanim_zoom) + ((1 - obj_zoom_pred) * inanimpred[nb, :, y_min:y_max, x_min:x_max])\n",
2235
        "\n",
2236
        "    # -------------------------------------COMBINE--------------------------------------\n",
2237
        "    \n",
2238
        "    objpred = np.argmax(objpred, 1)\n",
2239
        "    animpred = np.argmax(animpred, 1)\n",
2240
        "    inanimpred = np.argmax(inanimpred, 1)\n",
2241
        "\n",
2242
        "    lr_pred = lr_pred[:,1:,:,:]\n",
2243
        "    fb_pred = fb_pred[:,1:,:,:]\n",
2244
        "\n",
2245
        "    lrfb_pred = get_lrfb_pred(lr_pred, fb_pred)\n",
2246
        "    lr_pred = np.argmax(lr_pred, 1) + 1\n",
2247
        "    fb_pred = np.argmax(fb_pred, 1) + 1\n",
2248
        "\n",
2249
        "    preds = combine_obj_all_parts(objpred, animpred, inanimpred, lr_pred, fb_pred, lrfb_pred)\n",
2250
        "    preds = preds.astype(int)\n",
2251
        "\n",
2252
        "    for j in range(num_batches):\n",
2253
        "        pred = pred_to_orgsize(preds[j], (orgsizes[0][j], orgsizes[1][j]))\n",
2254
        "        gt = pred_to_orgsize(parts[j], (orgsizes[0][j], orgsizes[1][j]))\n",
2255
        "        valid_miou_avg.add_batch(gt, pred)\n",
2256
        "\n",
2257
        "        ious, counts = jaccard_perpart_perimg(pred, gt, num_classes)\n",
2258
        "        for cl in range(num_classes):\n",
2259
        "            ious[cl] = ious[cl].item()\n",
2260
        "            counts[cl] = counts[cl].item()\n",
2261
        "        valid_sqiou_avg.update(ious, counts)\n",
2262
        "\n",
2263
        "    print(i, valid_sqiou_avg.avg, valid_miou_avg.Mean_Intersection_over_Union())"
2264
      ]
2265
    },
2266
    {
2267
      "cell_type": "code",
2268
      "execution_count": null,
2269
      "metadata": {
2270
        "id": "DJynF12u6E7W",
2271
        "outputId": "b00654fd-8632-45c1-dcc3-cf017e3f0542",
2272
        "scrolled": true
2273
      },
2274
      "outputs": [],
2275
      "source": [
2276
        "# FLOAT (without IZR)\n",
2277
        "num_classes = 201\n",
2278
        "obj_model.eval()\n",
2279
        "model.eval()\n",
2280
        "valid_miou_avg = Evaluator(num_classes)\n",
2281
        "valid_sqiou_avg = sqIOUMeter(num_classes)\n",
2282
        "\n",
2283
        "i = 0\n",
2284
        "for sample in valid_dataloader:\n",
2285
        "    i += 1\n",
2286
        "    images = sample['image'].float()\n",
2287
        "    parts = sample['part'].type(torch.LongTensor)\n",
2288
        "    orgsizes = sample['orgsize']\n",
2289
        "    nb = images.shape[0]\n",
2290
        "    images = images.cuda()\n",
2291
        "    parts = parts.cuda()\n",
2292
        "    \n",
2293
        "    objpred = obj_model(images)\n",
2294
        "    anim_pred, inanim_pred, lr_pred, fb_pred = model(images)\n",
2295
        "\n",
2296
        "    parts = parts.cpu().detach().numpy()\n",
2297
        "\n",
2298
        "    anim_pred = anim_pred.cpu().detach().numpy()\n",
2299
        "    anim_pred = np.argmax(anim_pred, 1)\n",
2300
        "    \n",
2301
        "    inanim_pred = inanim_pred.cpu().detach().numpy()\n",
2302
        "    inanim_pred = np.argmax(inanim_pred, 1)\n",
2303
        "\n",
2304
        "    objpred = objpred.cpu().detach().numpy()\n",
2305
        "    objpred = np.argmax(objpred, 1)\n",
2306
        "\n",
2307
        "    lr_pred = lr_pred.cpu().detach().numpy()\n",
2308
        "    fb_pred = fb_pred.cpu().detach().numpy()\n",
2309
        "    lr_pred = lr_pred[:,1:,:,:]\n",
2310
        "    fb_pred = fb_pred[:,1:,:,:]\n",
2311
        "\n",
2312
        "    lrfb_pred = get_lrfb_pred(lr_pred, fb_pred)\n",
2313
        "    lr_pred = np.argmax(lr_pred, 1) + 1\n",
2314
        "    fb_pred = np.argmax(fb_pred, 1) + 1\n",
2315
        "\n",
2316
        "    preds = combine_obj_all_parts(objpred, anim_pred, inanim_pred, lr_pred, fb_pred, lrfb_pred)\n",
2317
        "    preds = preds.astype(int)\n",
2318
        "\n",
2319
        "    for j in range(nb):\n",
2320
        "        pred = pred_to_orgsize(preds[j], (orgsizes[0][j], orgsizes[1][j]))\n",
2321
        "        gt = pred_to_orgsize(parts[j], (orgsizes[0][j], orgsizes[1][j]))\n",
2322
        "\n",
2323
        "        ious, counts = jaccard_perpart_perimg(pred, gt, num_classes)\n",
2324
        "        for cl in range(num_classes):\n",
2325
        "            ious[cl] = ious[cl].item()\n",
2326
        "            counts[cl] = counts[cl].item()\n",
2327
        "        valid_sqiou_avg.update(ious, counts)\n",
2328
        "\n",
2329
        "        valid_miou_avg.add_batch(gt, pred)\n",
2330
        "\n",
2331
        "    print(i, valid_sqiou_avg.avg, valid_miou_avg.Mean_Intersection_over_Union())\n",
2332
        "\n",
2333
        "print(valid_miou_avg.Mean_Intersection_over_Union())"
2334
      ]
2335
    },
2336
    {
2337
      "cell_type": "code",
2338
      "execution_count": null,
2339
      "metadata": {
2340
        "id": "L4I0j6bb6E7X"
2341
      },
2342
      "outputs": [],
2343
      "source": [
2344
        "sqiou_pc = [valid_sqiou_avg.vals[i] / valid_sqiou_avg.counts[i]  if valid_sqiou_avg.counts[i] > 0 else -1 for i in range(201)]"
2345
      ]
2346
    },
2347
    {
2348
      "cell_type": "code",
2349
      "execution_count": null,
2350
      "metadata": {
2351
        "id": "ohiUeEEG6E7X"
2352
      },
2353
      "outputs": [],
2354
      "source": [
2355
        "miou_pc = valid_miou_avg.Mean_Intersection_over_Union_PerClass()"
2356
      ]
2357
    },
2358
    {
2359
      "cell_type": "code",
2360
      "execution_count": null,
2361
      "metadata": {
2362
        "id": "U-rh1edu6E7X",
2363
        "outputId": "0b3f9cf2-7469-488d-b7e3-78259cc66160",
2364
        "scrolled": true
2365
      },
2366
      "outputs": [],
2367
      "source": [
2368
        "miou_ = []\n",
2369
        "for i in range(201):\n",
2370
        "    if np.isnan(miou_pc[i]) or miou_pc[i] == -1 or i == 13:\n",
2371
        "        continue\n",
2372
        "    print(np.round(miou_pc[i]*1000)/10)\n",
2373
        "    miou_.append(miou_pc[i])"
2374
      ]
2375
    },
2376
    {
2377
      "cell_type": "code",
2378
      "execution_count": null,
2379
      "metadata": {
2380
        "id": "Ee__8jyN6E7Y",
2381
        "outputId": "bc4f5e6d-ac10-48e5-ffd2-870bc8f222e5"
2382
      },
2383
      "outputs": [],
2384
      "source": [
2385
        "np.mean(miou_)"
2386
      ]
2387
    },
2388
    {
2389
      "cell_type": "code",
2390
      "execution_count": null,
2391
      "metadata": {
2392
        "id": "pujJS5gp6E7Y"
2393
      },
2394
      "outputs": [],
2395
      "source": [
2396
        "all_split_map = [1, 8, 14, 27, 28, 30, 43, 56, 73, 74, 93, 94, 112, 133, 139, 163, 165, 184, 185, 198, 200]"
2397
      ]
2398
    },
2399
    {
2400
      "cell_type": "code",
2401
      "execution_count": null,
2402
      "metadata": {
2403
        "id": "vpiGWY5Y6E7Y"
2404
      },
2405
      "outputs": [],
2406
      "source": [
2407
        "objiou = np.split(miou_, all_split_map)"
2408
      ]
2409
    },
2410
    {
2411
      "cell_type": "code",
2412
      "execution_count": null,
2413
      "metadata": {
2414
        "id": "inLoY-2H6E7Y",
2415
        "outputId": "c8ae4488-57ff-4b73-a350-1d213fcd4eeb"
2416
      },
2417
      "outputs": [],
2418
      "source": [
2419
        "objavg = []\n",
2420
        "for ious in objiou[:-1]:\n",
2421
        "    print(np.mean(ious))\n",
2422
        "    objavg.append(np.mean(ious))"
2423
      ]
2424
    },
2425
    {
2426
      "cell_type": "code",
2427
      "execution_count": null,
2428
      "metadata": {
2429
        "id": "2ehNTi5s6E7Z",
2430
        "outputId": "00f104a2-d7ec-485e-b80d-2682afa53233"
2431
      },
2432
      "outputs": [],
2433
      "source": [
2434
        "np.mean(objavg)"
2435
      ]
2436
    }
2437
  ],
2438
  "metadata": {
2439
    "kernelspec": {
2440
     "display_name": "Python 3",
2441
     "language": "python",
2442
     "name": "python3"
2443
    },
2444
    "language_info": {
2445
     "codemirror_mode": {
2446
      "name": "ipython",
2447
      "version": 3
2448
     },
2449
     "file_extension": ".py",
2450
     "mimetype": "text/x-python",
2451
     "name": "python",
2452
     "nbconvert_exporter": "python",
2453
     "pygments_lexer": "ipython3",
2454
     "version": "3.7.6"
2455
    }
2456
   },
2457
   "nbformat": 4,
2458
   "nbformat_minor": 4
2459
  }

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.