HairFastGAN

Форк
0
157 строк · 5.6 Кб
1
import os
2
import glob
3
import numpy as np
4
import torch
5
import torch.nn as nn
6
import torch.nn.functional as F
7
import torch.utils.data as data
8

9
from PIL import Image
10
from torchvision import transforms, utils
11

12
class MyDataSet(data.Dataset):
13
    def __init__(self, image_dir=None, label_dir=None, output_size=(256, 256), noise_in=None, training_set=True, video_data=False, train_split=0.9):
14
        self.image_dir = image_dir
15
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
16
        self.resize = transforms.Compose([
17
            transforms.Resize(output_size),
18
            transforms.ToTensor()
19
        ])
20
        self.noise_in = noise_in
21
        self.video_data = video_data
22
        self.random_rotation = transforms.Compose([
23
            transforms.Resize(output_size),
24
            transforms.RandomPerspective(distortion_scale=0.05, p=1.0),
25
            transforms.ToTensor()
26
        ])
27

28
        # load image file
29
        train_len = None
30
        self.length = 0
31
        self.image_dir = image_dir
32
        if image_dir is not None:
33
            img_list = [glob.glob1(self.image_dir, ext) for ext in ['*jpg','*png']]
34
            image_list = [item for sublist in img_list for item in sublist]
35
            image_list.sort()
36
            train_len = int(train_split*len(image_list))
37
            if training_set:
38
                self.image_list = image_list[:train_len]
39
            else:
40
                self.image_list = image_list[train_len:]
41
            self.length = len(self.image_list)
42

43
        # load label file
44
        self.label_dir = label_dir
45
        if label_dir is not None:
46
            self.seeds = np.load(label_dir)
47
            if train_len is None:
48
                train_len = int(train_split*len(self.seeds))
49
            if training_set:
50
                self.seeds = self.seeds[:train_len]
51
            else:
52
                self.seeds = self.seeds[train_len:]
53
            if self.length == 0:
54
                self.length = len(self.seeds)
55

56
    def __len__(self):
57
        return self.length
58

59
    def __getitem__(self, idx):
60
        img = None
61
        if self.image_dir is not None:
62
            img_name = os.path.join(self.image_dir, self.image_list[idx])
63
            image = Image.open(img_name)
64
            img = self.resize(image)
65
            if img.size(0) == 1:
66
                img = torch.cat((img, img, img), dim=0)
67
            img = self.normalize(img)
68

69
        # generate image 
70
        if self.label_dir is not None:
71
            torch.manual_seed(self.seeds[idx])
72
            z = torch.randn(1, 512)[0]
73
            if self.noise_in is None:
74
                n = [torch.randn(1, 1)]
75
            else:
76
                n = [torch.randn(noise.size())[0] for noise in self.noise_in]
77
            if img is None:
78
                return z, n 
79
            else:
80
                return z, img, n
81
        else:
82
            return img
83

84
class Car_DataSet(data.Dataset):
85
    def __init__(self, image_dir=None, label_dir=None, output_size=(512, 512), noise_in=None, training_set=True, video_data=False, train_split=0.9):
86
        self.image_dir = image_dir
87
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
88
        self.resize = transforms.Compose([
89
            transforms.Resize((384, 512)),
90
            transforms.Pad(padding=(0, 64, 0, 64)),
91
            transforms.ToTensor()
92
        ])
93
        self.noise_in = noise_in
94
        self.video_data = video_data
95
        self.random_rotation = transforms.Compose([
96
            transforms.Resize(output_size),
97
            transforms.RandomPerspective(distortion_scale=0.05, p=1.0),
98
            transforms.ToTensor()
99
        ])
100

101
        # load image file
102
        train_len = None
103
        self.length = 0
104
        self.image_dir = image_dir
105
        if image_dir is not None:
106
            img_list = [glob.glob1(self.image_dir, ext) for ext in ['*jpg','*png']]
107
            image_list = [item for sublist in img_list for item in sublist]
108
            image_list.sort()
109
            train_len = int(train_split*len(image_list))
110
            if training_set:
111
                self.image_list = image_list[:train_len]
112
            else:
113
                self.image_list = image_list[train_len:]
114
            self.length = len(self.image_list)
115

116
        # load label file
117
        self.label_dir = label_dir
118
        if label_dir is not None:
119
            self.seeds = np.load(label_dir)
120
            if train_len is None:
121
                train_len = int(train_split*len(self.seeds))
122
            if training_set:
123
                self.seeds = self.seeds[:train_len]
124
            else:
125
                self.seeds = self.seeds[train_len:]
126
            if self.length == 0:
127
                self.length = len(self.seeds)
128

129
    def __len__(self):
130
        return self.length
131

132
    def __getitem__(self, idx):
133
        img = None
134
        if self.image_dir is not None:
135
            img_name = os.path.join(self.image_dir, self.image_list[idx])
136
            image = Image.open(img_name)
137
            img = self.resize(image)
138
            if img.size(0) == 1:
139
                img = torch.cat((img, img, img), dim=0)
140
            img = self.normalize(img)
141
            if self.video_data:
142
                img_2 = self.random_rotation(image)
143
                img_2 = self.normalize(img_2)
144
                img_2 = torch.where(img_2 > -1, img_2, img)
145
                img = torch.cat([img, img_2], dim=0)
146

147
        # generate image 
148
        if self.label_dir is not None:
149
            torch.manual_seed(self.seeds[idx])
150
            z = torch.randn(1, 512)[0]
151
            n = [torch.randn_like(noise[0]) for noise in self.noise_in]
152
            if img is None:
153
                return z, n 
154
            else:
155
                return z, img, n
156
        else:
157
            return img
158

159

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.