demo-ml-pennfudanped
/
masks_for_unet.py
36 строк · 985.0 Байт
1import os2import numpy as np3import torch4from PIL import Image5
6
7class MasksForUnet(torch.utils.data.Dataset):8def __init__(self, images_root, masks_root, transforms=None):9self.images_root = images_root10self.masks_root = masks_root11self.transforms = transforms12
13self.imgs = list(sorted(os.listdir(images_root)))14self.masks = list(sorted(os.listdir(masks_root)))15
16def __getitem__(self, idx):17img_path = os.path.join(self.images_root, self.imgs[idx])18mask_path = os.path.join(self.masks_root, self.masks[idx])19
20img = Image.open(img_path).convert("RGB")21img = np.array(img)22
23target = Image.open(mask_path)24target = np.array(target)25
26if self.transforms is not None:27img = self.transforms(img)28target = self.transforms(target)29
30return img, target31
32def get_categories(self):33return None34
35def __len__(self):36return len(self.imgs)37