demo-ml-pennfudanped

Форк
0
/
masks_for_mask_r_cnn_dataset.py 
78 строк · 2.6 Кб
1
import os
2
import numpy as np
3
import torch
4
from PIL import Image
5

6

7
class MasksForMaskRCNNDataset(torch.utils.data.Dataset):
8
    def __init__(self, images_root, masks_root, transforms=None):
9
        self.images_root = images_root
10
        self.masks_root = masks_root
11
        self.transforms = transforms
12
        # load all image files, sorting them to
13
        # ensure that they are aligned
14
        self.imgs = list(sorted(os.listdir(images_root)))
15
        self.masks = list(sorted(os.listdir(masks_root)))
16

17
    def __getitem__(self, idx):
18
        # load images and masks
19
        img_path = os.path.join(self.images_root, self.imgs[idx])
20
        mask_path = os.path.join(self.masks_root, self.masks[idx])
21
        img = Image.open(img_path).convert("RGB")
22
        # note that we haven't converted the mask to RGB,
23
        # because each color corresponds to a different instance
24
        # with 0 being background
25
        mask = Image.open(mask_path)
26
        # convert the PIL Image into a numpy array
27
        mask = np.array(mask)
28
        # instances are encoded as different colors
29
        obj_ids = np.unique(mask)
30
        # first id is the background, so remove it
31
        obj_ids = obj_ids[1:]
32

33
        # split the color-encoded mask into a set
34
        # of binary masks
35
        masks = mask == obj_ids[:, None, None]
36

37
        # get bounding box coordinates for each mask
38
        num_objs = len(obj_ids)
39
        boxes = []
40
        for i in range(num_objs):
41
            pos = np.where(masks[i])
42
            xmin = np.min(pos[1])
43
            xmax = np.max(pos[1])
44
            ymin = np.min(pos[0])
45
            ymax = np.max(pos[0])
46
            boxes.append([xmin, ymin, xmax, ymax])
47

48
        # convert everything into a torch.Tensor
49
        boxes = torch.as_tensor(boxes, dtype=torch.int64)
50
        # there is only one class
51
        labels = torch.ones((num_objs,), dtype=torch.int64)
52
        masks = torch.as_tensor(masks, dtype=torch.uint8)
53

54
        image_id = torch.tensor([idx])
55
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
56
        # suppose all instances are not crowd
57
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
58

59
        target = {}
60
        target["boxes"] = boxes
61
        target["labels"] = labels
62
        target["masks"] = masks
63
        target["image_id"] = image_id
64
        target["area"] = area
65
        target["iscrowd"] = iscrowd
66

67
        if self.transforms is not None:
68
            img, target = self.transforms(img, target)
69

70
        return img, target
71

72
    def get_categories(self):
73
        return {
74
            1: {'id': 1, 'name': 'human'}
75
        }
76

77
    def __len__(self):
78
        return len(self.imgs)
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.