demo-ml-pennfudanped
/
masks_for_mask_r_cnn_dataset.py
78 строк · 2.6 Кб
1import os2import numpy as np3import torch4from PIL import Image5
6
7class MasksForMaskRCNNDataset(torch.utils.data.Dataset):8def __init__(self, images_root, masks_root, transforms=None):9self.images_root = images_root10self.masks_root = masks_root11self.transforms = transforms12# load all image files, sorting them to13# ensure that they are aligned14self.imgs = list(sorted(os.listdir(images_root)))15self.masks = list(sorted(os.listdir(masks_root)))16
17def __getitem__(self, idx):18# load images and masks19img_path = os.path.join(self.images_root, self.imgs[idx])20mask_path = os.path.join(self.masks_root, self.masks[idx])21img = Image.open(img_path).convert("RGB")22# note that we haven't converted the mask to RGB,23# because each color corresponds to a different instance24# with 0 being background25mask = Image.open(mask_path)26# convert the PIL Image into a numpy array27mask = np.array(mask)28# instances are encoded as different colors29obj_ids = np.unique(mask)30# first id is the background, so remove it31obj_ids = obj_ids[1:]32
33# split the color-encoded mask into a set34# of binary masks35masks = mask == obj_ids[:, None, None]36
37# get bounding box coordinates for each mask38num_objs = len(obj_ids)39boxes = []40for i in range(num_objs):41pos = np.where(masks[i])42xmin = np.min(pos[1])43xmax = np.max(pos[1])44ymin = np.min(pos[0])45ymax = np.max(pos[0])46boxes.append([xmin, ymin, xmax, ymax])47
48# convert everything into a torch.Tensor49boxes = torch.as_tensor(boxes, dtype=torch.int64)50# there is only one class51labels = torch.ones((num_objs,), dtype=torch.int64)52masks = torch.as_tensor(masks, dtype=torch.uint8)53
54image_id = torch.tensor([idx])55area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])56# suppose all instances are not crowd57iscrowd = torch.zeros((num_objs,), dtype=torch.int64)58
59target = {}60target["boxes"] = boxes61target["labels"] = labels62target["masks"] = masks63target["image_id"] = image_id64target["area"] = area65target["iscrowd"] = iscrowd66
67if self.transforms is not None:68img, target = self.transforms(img, target)69
70return img, target71
72def get_categories(self):73return {741: {'id': 1, 'name': 'human'}75}76
77def __len__(self):78return len(self.imgs)79