demo-ml-pennfudanped

Форк
0
/
masks_for_unet.py 
36 строк · 985.0 Байт
1
import os
2
import numpy as np
3
import torch
4
from PIL import Image
5

6

7
class MasksForUnet(torch.utils.data.Dataset):
8
    def __init__(self, images_root, masks_root, transforms=None):
9
        self.images_root = images_root
10
        self.masks_root = masks_root
11
        self.transforms = transforms
12

13
        self.imgs = list(sorted(os.listdir(images_root)))
14
        self.masks = list(sorted(os.listdir(masks_root)))
15

16
    def __getitem__(self, idx):
17
        img_path = os.path.join(self.images_root, self.imgs[idx])
18
        mask_path = os.path.join(self.masks_root, self.masks[idx])
19

20
        img = Image.open(img_path).convert("RGB")
21
        img = np.array(img)
22

23
        target = Image.open(mask_path)
24
        target = np.array(target)
25

26
        if self.transforms is not None:
27
            img = self.transforms(img)
28
            target = self.transforms(target)
29

30
        return img, target
31

32
    def get_categories(self):
33
        return None
34

35
    def __len__(self):
36
        return len(self.imgs)
37

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.