demo-ml-pennfudanped

Форк
0
/
train_model.ipynb 
474 строки · 29.3 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 76,
6
   "id": "0df75687",
7
   "metadata": {
8
    "cellId": "jziuwa87tkxdnpvgjqd3q",
9
    "collapsed": true,
10
    "jupyter": {
11
     "outputs_hidden": true
12
    }
13
   },
14
   "outputs": [
15
    {
16
     "name": "stdout",
17
     "output_type": "stream",
18
     "text": [
19
      "Defaulting to user installation because normal site-packages is not writeable\n",
20
      "Requirement already satisfied: pycocotools in /home/jupyter/.local/lib/python3.8/site-packages (2.0.5)\n",
21
      "Requirement already satisfied: matplotlib>=2.1.0 in /kernel/lib/python3.8/site-packages (from pycocotools) (3.3.3)\n",
22
      "Requirement already satisfied: numpy in /kernel/fallback/lib/python3.8/site-packages (from pycocotools) (1.19.4)\n",
23
      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (2.4.7)\n",
24
      "Requirement already satisfied: pillow>=6.2.0 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (9.2.0)\n",
25
      "Requirement already satisfied: kiwisolver>=1.0.1 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (1.4.4)\n",
26
      "Requirement already satisfied: python-dateutil>=2.1 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (2.8.2)\n",
27
      "Requirement already satisfied: cycler>=0.10 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (0.11.0)\n",
28
      "Requirement already satisfied: six>=1.5 in /kernel/lib/python3.8/site-packages (from python-dateutil>=2.1->matplotlib>=2.1.0->pycocotools) (1.16.0)\n",
29
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.3 is available.\n",
30
      "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
31
      "Defaulting to user installation because normal site-packages is not writeable\n",
32
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (4.50.0)\n",
33
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.3 is available.\n",
34
      "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
35
      "Defaulting to user installation because normal site-packages is not writeable\n",
36
      "Requirement already satisfied: torchvision in /usr/local/lib/python3.8/dist-packages (0.10.1+cu111)\n",
37
      "Collecting torchvision\n",
38
      "  Downloading torchvision-0.13.1-cp38-cp38-manylinux1_x86_64.whl (19.1 MB)\n",
39
      "     |████████████████████████████████| 19.1 MB 1.7 MB/s            \n",
40
      "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torchvision) (3.7.4.3)\n",
41
      "Collecting torch==1.12.1\n",
42
      "  Downloading torch-1.12.1-cp38-cp38-manylinux1_x86_64.whl (776.3 MB)\n",
43
      "     |████████████████████████████████| 776.3 MB 372 bytes/s           \n",
44
      "\u001b[?25hRequirement already satisfied: requests in /kernel/lib/python3.8/site-packages (from torchvision) (2.25.1)\n",
45
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /kernel/lib/python3.8/site-packages (from torchvision) (9.2.0)\n",
46
      "Requirement already satisfied: numpy in /kernel/fallback/lib/python3.8/site-packages (from torchvision) (1.19.4)\n",
47
      "Requirement already satisfied: idna<3,>=2.5 in /kernel/lib/python3.8/site-packages (from requests->torchvision) (2.10)\n",
48
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /kernel/lib/python3.8/site-packages (from requests->torchvision) (1.26.12)\n",
49
      "Requirement already satisfied: chardet<5,>=3.0.2 in /kernel/lib/python3.8/site-packages (from requests->torchvision) (4.0.0)\n",
50
      "Requirement already satisfied: certifi>=2017.4.17 in /kernel/lib/python3.8/site-packages (from requests->torchvision) (2022.9.24)\n",
51
      "Installing collected packages: torch, torchvision\n",
52
      "\u001b[33m  WARNING: The scripts convert-caffe2-to-onnx, convert-onnx-to-caffe2 and torchrun are installed in '/home/jupyter/.local/bin' which is not on PATH.\n",
53
      "  Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\n",
54
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
55
      "torchaudio 0.9.1 requires torch==1.9.1, but you have torch 1.12.1 which is incompatible.\u001b[0m\n",
56
      "Successfully installed torch-1.12.1 torchvision-0.13.1\n",
57
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.3 is available.\n",
58
      "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
59
      "Defaulting to user installation because normal site-packages is not writeable\n",
60
      "Requirement already satisfied: torch in /home/jupyter/.local/lib/python3.8/site-packages (1.12.1)\n",
61
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (3.7.4.3)\n",
62
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.3 is available.\n",
63
      "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n"
64
     ]
65
    }
66
   ],
67
   "source": [
68
    "# %pip install pycocotools\n",
69
    "# %pip install tqdm\n",
70
    "# %pip install torchvision -U\n",
71
    "# %pip install torch -U"
72
   ]
73
  },
74
  {
75
   "cell_type": "code",
76
   "execution_count": 13,
77
   "id": "e8527c6e",
78
   "metadata": {
79
    "cellId": "ech9jthm9zrjuq7fn2bs"
80
   },
81
   "outputs": [],
82
   "source": [
83
    "from tqdm import tqdm\n",
84
    "import torch\n",
85
    "import torchvision\n",
86
    "from torch.utils.data import DataLoader\n",
87
    "from masks_for_mask_r_cnn_dataset import MasksForMaskRCNNDataset\n",
88
    "from custom_segmentation_transforms import Compose, ToTensor, RandomHorizontalFlip\n",
89
    "from image_utils import show_image, build_box_masks, merge_image_and_masks_boxes, merge_masks_with_colors\n",
90
    "from torch.utils.data import random_split\n",
91
    "from maskrcnn_model import build_maskrsnn_model\n",
92
    "from metrics import accumulate_metrics, compute_metrics, coco_metric_names"
93
   ]
94
  },
95
  {
96
   "cell_type": "code",
97
   "execution_count": 14,
98
   "id": "57636714",
99
   "metadata": {
100
    "cellId": "zamp0lyvjrjlqyekmhu2am"
101
   },
102
   "outputs": [],
103
   "source": [
104
    "ROOT = '/home/jupyter/mnt/s3'\n",
105
    "DS_ROOT = f'{ROOT}/pennfudanped'\n",
106
    "DS_MASKS = f'{DS_ROOT}/PedMasks'\n",
107
    "DS_IMAGES = f'{DS_ROOT}/PNGImages'\n",
108
    "\n",
109
    "PARAMS = {\n",
110
    "    'batch_size': 1,\n",
111
    "    'epochs': 1,\n",
112
    "    'lr': 0.001,\n",
113
    "    'momentum': 0.9,\n",
114
    "    'weight_decay': 0.0005,\n",
115
    "    'step_size': 3,\n",
116
    "    'gamma': 0.1,\n",
117
    "    'pretrained': torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT\n",
118
    "}"
119
   ]
120
  },
121
  {
122
   "cell_type": "code",
123
   "execution_count": 15,
124
   "id": "02cfedc8",
125
   "metadata": {
126
    "cellId": "0opprkt2dlxsvv4kuogazah"
127
   },
128
   "outputs": [
129
    {
130
     "name": "stdout",
131
     "output_type": "stream",
132
     "text": [
133
      "num_of_targets = 1\n"
134
     ]
135
    }
136
   ],
137
   "source": [
138
    "full_dataset = MasksForMaskRCNNDataset(\n",
139
    "    images_root=DS_IMAGES,\n",
140
    "    masks_root=DS_MASKS,\n",
141
    "    transforms=Compose([\n",
142
    "        ToTensor(),\n",
143
    "        RandomHorizontalFlip(0.5),\n",
144
    "    ])\n",
145
    ")\n",
146
    "num_of_targets = 1\n",
147
    "print(f'num_of_targets = {num_of_targets}')"
148
   ]
149
  },
150
  {
151
   "cell_type": "code",
152
   "execution_count": 16,
153
   "id": "142cf22a",
154
   "metadata": {
155
    "cellId": "eam1tkh4we33jiixj8t46"
156
   },
157
   "outputs": [],
158
   "source": [
159
    "train_size = int(0.8 * len(full_dataset))\n",
160
    "val_size = len(full_dataset) - train_size\n",
161
    "\n",
162
    "dataset_train, dataset_val = random_split(full_dataset,\n",
163
    "                                          [train_size, val_size],\n",
164
    "                                          generator=torch.Generator().manual_seed(0))"
165
   ]
166
  },
167
  {
168
   "cell_type": "code",
169
   "execution_count": 17,
170
   "id": "839f0590",
171
   "metadata": {
172
    "cellId": "47b0jkr0hu7uznuewj20p"
173
   },
174
   "outputs": [],
175
   "source": [
176
    "train_loader = DataLoader(\n",
177
    "    dataset_train,\n",
178
    "    batch_size=PARAMS['batch_size'],\n",
179
    "    shuffle=True,\n",
180
    "    pin_memory=True,\n",
181
    "    drop_last=True\n",
182
    ")\n",
183
    "val_loader = DataLoader(\n",
184
    "    dataset_val,\n",
185
    "    batch_size=PARAMS['batch_size'],\n",
186
    "    shuffle=True,\n",
187
    "    pin_memory=True,\n",
188
    "    drop_last=True\n",
189
    ")"
190
   ]
191
  },
192
  {
193
   "cell_type": "code",
194
   "execution_count": 18,
195
   "id": "c46c38ac",
196
   "metadata": {
197
    "cellId": "tii3mrngw7gi0heqnjzir"
198
   },
199
   "outputs": [
200
    {
201
     "name": "stdout",
202
     "output_type": "stream",
203
     "text": [
204
      "torch.Size([1, 3, 383, 456])\n",
205
      "torch.Size([1, 1, 4])\n",
206
      "torch.Size([1, 1])\n",
207
      "torch.Size([1, 1])\n",
208
      "torch.Size([1, 1])\n",
209
      "torch.Size([1, 1])\n"
210
     ]
211
    }
212
   ],
213
   "source": [
214
    "image, targets = next(iter(train_loader))\n",
215
    "\n",
216
    "print( image.shape )\n",
217
    "print( targets['boxes'].shape )\n",
218
    "print( targets['labels'].shape )\n",
219
    "print( targets['image_id'].shape )\n",
220
    "print( targets['area'].shape )\n",
221
    "print( targets['iscrowd'].shape )"
222
   ]
223
  },
224
  {
225
   "cell_type": "code",
226
   "execution_count": 23,
227
   "id": "efa59496",
228
   "metadata": {
229
    "cellId": "3nww1emr9a6n4r82fa53vj"
230
   },
231
   "outputs": [],
232
   "source": [
233
    "model = build_maskrsnn_model(num_of_targets, PARAMS['pretrained'])\n",
234
    "\n",
235
    "params = [p for p in model.parameters() if p.requires_grad]\n",
236
    "optimizer = torch.optim.SGD(params, lr=PARAMS['lr'], momentum=PARAMS['momentum'], weight_decay=PARAMS['weight_decay'])\n",
237
    "lr_scheduler_global = torch.optim.lr_scheduler.StepLR(optimizer, step_size=PARAMS['step_size'], gamma=PARAMS['gamma'])\n",
238
    "\n",
239
    "epoch_start = 0\n",
240
    "best_accuracy = 0\n",
241
    "best_state = None"
242
   ]
243
  },
244
  {
245
   "cell_type": "code",
246
   "execution_count": 20,
247
   "id": "83f5867f",
248
   "metadata": {
249
    "cellId": "hmz9wi2duwu11ihuuu8m9u"
250
   },
251
   "outputs": [
252
    {
253
     "name": "stdout",
254
     "output_type": "stream",
255
     "text": [
256
      "device = cuda\n"
257
     ]
258
    },
259
    {
260
     "name": "stderr",
261
     "output_type": "stream",
262
     "text": [
263
      "Downloading: \"https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth\" to /tmp/xdg_cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth\n"
264
     ]
265
    },
266
    {
267
     "data": {
268
      "application/vnd.jupyter.widget-view+json": {
269
       "model_id": "979c05f7b1334fad87c241ff5f61e8c8",
270
       "version_major": 2,
271
       "version_minor": 0
272
      },
273
      "text/plain": [
274
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=178090079.0), HTML(value='')))"
275
      ]
276
     },
277
     "metadata": {},
278
     "output_type": "display_data"
279
    },
280
    {
281
     "name": "stdout",
282
     "output_type": "stream",
283
     "text": [
284
      "\n"
285
     ]
286
    },
287
    {
288
     "name": "stderr",
289
     "output_type": "stream",
290
     "text": [
291
      "  1%|          | 1/136 [00:05<12:16,  5.46s/it]../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
292
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
293
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
294
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [9,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
295
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [11,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
296
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [14,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
297
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [18,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
298
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [20,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
299
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [22,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
300
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [23,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
301
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [25,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
302
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [26,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
303
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.\n"
304
     ]
305
    },
306
    {
307
     "ename": "RuntimeError",
308
     "evalue": "CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.",
309
     "output_type": "error",
310
     "traceback": [
311
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
312
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
313
      "\u001b[0;32m<ipython-input-1-ab096cc10900>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     51\u001b[0m         \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscaler\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m             \u001b[0mloss_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     54\u001b[0m             \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mloss_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
314
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1128\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1131\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1132\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
315
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, images, targets)\u001b[0m\n\u001b[1;32m    103\u001b[0m             \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"0\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m         \u001b[0mproposals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproposal_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrpn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m         \u001b[0mdetections\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdetector_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroi_heads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproposals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    106\u001b[0m         \u001b[0mdetections\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpostprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdetections\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moriginal_image_sizes\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[operator]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
316
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1128\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1131\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1132\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
317
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, features, proposals, image_shapes, targets)\u001b[0m\n\u001b[1;32m    770\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mregression_targets\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    771\u001b[0m                 \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"regression_targets cannot be None\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 772\u001b[0;31m             \u001b[0mloss_classifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_box_reg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfastrcnn_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclass_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbox_regression\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mregression_targets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    773\u001b[0m             \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"loss_classifier\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss_classifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"loss_box_reg\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss_box_reg\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    774\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
318
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py\u001b[0m in \u001b[0;36mfastrcnn_loss\u001b[0;34m(class_logits, box_regression, labels, regression_targets)\u001b[0m\n\u001b[1;32m     34\u001b[0m     \u001b[0;31m# the corresponding ground truth labels, to be used with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     35\u001b[0m     \u001b[0;31m# advanced indexing\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m     \u001b[0msampled_pos_inds_subset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabels\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     37\u001b[0m     \u001b[0mlabels_pos\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0msampled_pos_inds_subset\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     38\u001b[0m     \u001b[0mN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_classes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclass_logits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
319
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1."
320
     ]
321
    }
322
   ],
323
   "source": [
324
    "#!g2.mig\n",
325
    "\n",
326
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
327
    "print( f'device = {device}' )\n",
328
    "\n",
329
    "model.to(device)\n",
330
    "model.eval()\n",
331
    "\n",
332
    "for epoch in range(PARAMS['epochs']):\n",
333
    "    pbar = tqdm(total=len(train_loader.dataset))\n",
334
    "\n",
335
    "    acc_loss_value = 0\n",
336
    "    acc_loss_classifier = 0\n",
337
    "    acc_loss_box_reg = 0\n",
338
    "    acc_loss_mask = 0\n",
339
    "    acc_loss_objectness = 0\n",
340
    "    acc_loss_rpn_box_reg = 0\n",
341
    "\n",
342
    "    model.train()\n",
343
    "    scaler = None\n",
344
    "    lr_scheduler = None\n",
345
    "    if epoch == 0:\n",
346
    "        warmup_factor = 1.0 / 1000\n",
347
    "        warmup_iters = min(1000, len(train_loader) - 1)\n",
348
    "        lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
349
    "            optimizer, start_factor=warmup_factor, total_iters=warmup_iters\n",
350
    "        )\n",
351
    "    for images, targets in train_loader:\n",
352
    "        pbar.update(len(images))\n",
353
    "\n",
354
    "        images = list(image.to(device) for image in images)\n",
355
    "        # targets = [{k: v.to(device) for k, v in t.items()} for t in targets]\n",
356
    "        device_target = {}\n",
357
    "        for k in ( 'boxes', 'labels', 'masks', 'image_id', 'area', 'iscrowd' ):\n",
358
    "            device_target[k] = targets[k].to(device)\n",
359
    "        targets = [{\n",
360
    "            'boxes': device_target['boxes'][0,:,:],\n",
361
    "            'labels': device_target['labels'][0,:],\n",
362
    "            'masks': device_target['masks'][0,:],\n",
363
    "            'image_id': device_target['image_id'][0,:],\n",
364
    "            'area': device_target['area'][0,:],\n",
365
    "            'iscrowd': device_target['iscrowd'][0,:],\n",
366
    "        }]\n",
367
    "        #\n",
368
    "        with torch.cuda.amp.autocast(enabled=scaler is not None):\n",
369
    "            loss_dict = model(images, targets)\n",
370
    "            losses = sum(loss for loss in loss_dict.values())\n",
371
    "\n",
372
    "        loss_value = losses.item()\n",
373
    "        acc_loss_value += loss_value\n",
374
    "        acc_loss_classifier += loss_dict['loss_classifier'].item()\n",
375
    "        acc_loss_box_reg += loss_dict['loss_box_reg'].item()\n",
376
    "        acc_loss_mask += loss_dict['loss_mask'].item()\n",
377
    "        acc_loss_objectness += loss_dict['loss_objectness'].item()\n",
378
    "        acc_loss_rpn_box_reg += loss_dict['loss_rpn_box_reg'].item()\n",
379
    "\n",
380
    "        if not math.isfinite(loss_value):\n",
381
    "            print(f\"Loss is {loss_value}, stopping training\")\n",
382
    "            sys.exit(1)\n",
383
    "\n",
384
    "        optimizer.zero_grad()\n",
385
    "        if scaler is not None:\n",
386
    "            scaler.scale(losses).backward()\n",
387
    "            scaler.step(optimizer)\n",
388
    "            scaler.update()\n",
389
    "        else:\n",
390
    "            losses.backward()\n",
391
    "            optimizer.step()\n",
392
    "        if lr_scheduler is not None:\n",
393
    "            lr_scheduler.step()\n",
394
    "\n",
395
    "    pbar.close()\n",
396
    "\n",
397
    "    current_lr = optimizer.param_groups[0][\"lr\"]\n",
398
    "    lr_scheduler_global.step()\n",
399
    "\n",
400
    "    model.eval()\n",
401
    "    pbar = tqdm(total=len(val_loader.dataset))\n",
402
    "    # iou_types = \"segm\"  # \"segm\", \"bbox\", \"keypoints\"\n",
403
    "    segmPredicted = []\n",
404
    "    bboxPredicted = []\n",
405
    "    for images, targets in val_loader:\n",
406
    "        pbar.update(len(images))\n",
407
    "        images = list(img.to(device) for img in images)\n",
408
    "        if torch.cuda.is_available():\n",
409
    "            torch.cuda.synchronize()\n",
410
    "        predictions = model(images)\n",
411
    "        accumulate_metrics(segmPredicted, bboxPredicted, targets, predictions)\n",
412
    "\n",
413
    "    metrics = compute_metrics(val_loader.dataset.coco, segmPredicted, bboxPredicted)\n",
414
    "    pbar.close()\n",
415
    "\n",
416
    "    print(\"Epoch: {0}; lr={1:.4f}; loss={2:.4f}; loss_mask={3:.4f}; cls={4:.4f}; box={5:.4f}; segm.mAP={6:.3f}; bbox.mAP={6:.3f}\"\n",
417
    "          .format(epoch, current_lr, acc_loss_value, acc_loss_mask, acc_loss_classifier, acc_loss_box_reg, metrics['segm']['mAP'], metrics['bbox']['mAP']))\n",
418
    "\n",
419
    "    metrics_summary = {\n",
420
    "        'current_lr': current_lr,\n",
421
    "        'acc_loss_value': acc_loss_value,\n",
422
    "        'acc_loss_mask': acc_loss_mask,\n",
423
    "        'acc_loss_classifier': acc_loss_classifier,\n",
424
    "        'acc_loss_box_reg': acc_loss_box_reg,\n",
425
    "        'acc_loss_objectness': acc_loss_objectness,\n",
426
    "        'acc_loss_rpn_box_reg': acc_loss_rpn_box_reg,\n",
427
    "    }\n",
428
    "    for mtype in ('segm', 'bbox'):\n",
429
    "        for metric_key in coco_metric_names.keys():\n",
430
    "            metrics_summary[mtype+'.'+metric_key] = metrics[mtype][metric_key]\n",
431
    "    # mlflow.log_metrics(metrics_summary)\n",
432
    "\n",
433
    "#     mlflow.pytorch.log_state_dict(artifact_path=\"checkpoint\", state_dict={\n",
434
    "#         \"model\": model.state_dict(),\n",
435
    "#         \"optimizer\": optimizer.state_dict(),\n",
436
    "#         \"epoch\": epoch,\n",
437
    "#         \"best_accuracy\": metrics['segm']['mAP']\n",
438
    "#     })\n",
439
    "    if metrics['segm']['mAP'] > best_accuracy:\n",
440
    "        best_accuracy = metrics['segm']['mAP']\n",
441
    "        best_state = copy.deepcopy(model.state_dict())\n",
442
    "        "
443
   ]
444
  },
445
  {
446
   "cell_type": "code",
447
   "execution_count": null,
448
   "id": "a926cacf",
449
   "metadata": {
450
    "cellId": "74qvm7e2emcpjd0k860d6"
451
   },
452
   "outputs": [],
453
   "source": []
454
  }
455
 ],
456
 "metadata": {
457
  "language_info": {
458
   "codemirror_mode": {
459
    "name": "ipython",
460
    "version": 3
461
   },
462
   "file_extension": ".py",
463
   "mimetype": "text/x-python",
464
   "name": "python",
465
   "nbconvert_exporter": "python",
466
   "pygments_lexer": "ipython3",
467
   "version": "3.7.7"
468
  },
469
  "notebookId": "f1223451-1595-404f-8bde-8aa2570af0fe",
470
  "notebookPath": "demo-ml-pennfudanped/train_model.ipynb"
471
 },
472
 "nbformat": 4,
473
 "nbformat_minor": 5
474
}
475

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.