HairFastGAN
157 строк · 5.6 Кб
1import os2import glob3import numpy as np4import torch5import torch.nn as nn6import torch.nn.functional as F7import torch.utils.data as data8
9from PIL import Image10from torchvision import transforms, utils11
12class MyDataSet(data.Dataset):13def __init__(self, image_dir=None, label_dir=None, output_size=(256, 256), noise_in=None, training_set=True, video_data=False, train_split=0.9):14self.image_dir = image_dir15self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])16self.resize = transforms.Compose([17transforms.Resize(output_size),18transforms.ToTensor()19])20self.noise_in = noise_in21self.video_data = video_data22self.random_rotation = transforms.Compose([23transforms.Resize(output_size),24transforms.RandomPerspective(distortion_scale=0.05, p=1.0),25transforms.ToTensor()26])27
28# load image file29train_len = None30self.length = 031self.image_dir = image_dir32if image_dir is not None:33img_list = [glob.glob1(self.image_dir, ext) for ext in ['*jpg','*png']]34image_list = [item for sublist in img_list for item in sublist]35image_list.sort()36train_len = int(train_split*len(image_list))37if training_set:38self.image_list = image_list[:train_len]39else:40self.image_list = image_list[train_len:]41self.length = len(self.image_list)42
43# load label file44self.label_dir = label_dir45if label_dir is not None:46self.seeds = np.load(label_dir)47if train_len is None:48train_len = int(train_split*len(self.seeds))49if training_set:50self.seeds = self.seeds[:train_len]51else:52self.seeds = self.seeds[train_len:]53if self.length == 0:54self.length = len(self.seeds)55
56def __len__(self):57return self.length58
59def __getitem__(self, idx):60img = None61if self.image_dir is not None:62img_name = os.path.join(self.image_dir, self.image_list[idx])63image = Image.open(img_name)64img = self.resize(image)65if img.size(0) == 1:66img = torch.cat((img, img, img), dim=0)67img = self.normalize(img)68
69# generate image70if self.label_dir is not None:71torch.manual_seed(self.seeds[idx])72z = torch.randn(1, 512)[0]73if self.noise_in is None:74n = [torch.randn(1, 1)]75else:76n = [torch.randn(noise.size())[0] for noise in self.noise_in]77if img is None:78return z, n79else:80return z, img, n81else:82return img83
84class Car_DataSet(data.Dataset):85def __init__(self, image_dir=None, label_dir=None, output_size=(512, 512), noise_in=None, training_set=True, video_data=False, train_split=0.9):86self.image_dir = image_dir87self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])88self.resize = transforms.Compose([89transforms.Resize((384, 512)),90transforms.Pad(padding=(0, 64, 0, 64)),91transforms.ToTensor()92])93self.noise_in = noise_in94self.video_data = video_data95self.random_rotation = transforms.Compose([96transforms.Resize(output_size),97transforms.RandomPerspective(distortion_scale=0.05, p=1.0),98transforms.ToTensor()99])100
101# load image file102train_len = None103self.length = 0104self.image_dir = image_dir105if image_dir is not None:106img_list = [glob.glob1(self.image_dir, ext) for ext in ['*jpg','*png']]107image_list = [item for sublist in img_list for item in sublist]108image_list.sort()109train_len = int(train_split*len(image_list))110if training_set:111self.image_list = image_list[:train_len]112else:113self.image_list = image_list[train_len:]114self.length = len(self.image_list)115
116# load label file117self.label_dir = label_dir118if label_dir is not None:119self.seeds = np.load(label_dir)120if train_len is None:121train_len = int(train_split*len(self.seeds))122if training_set:123self.seeds = self.seeds[:train_len]124else:125self.seeds = self.seeds[train_len:]126if self.length == 0:127self.length = len(self.seeds)128
129def __len__(self):130return self.length131
132def __getitem__(self, idx):133img = None134if self.image_dir is not None:135img_name = os.path.join(self.image_dir, self.image_list[idx])136image = Image.open(img_name)137img = self.resize(image)138if img.size(0) == 1:139img = torch.cat((img, img, img), dim=0)140img = self.normalize(img)141if self.video_data:142img_2 = self.random_rotation(image)143img_2 = self.normalize(img_2)144img_2 = torch.where(img_2 > -1, img_2, img)145img = torch.cat([img, img_2], dim=0)146
147# generate image148if self.label_dir is not None:149torch.manual_seed(self.seeds[idx])150z = torch.randn(1, 512)[0]151n = [torch.randn_like(noise[0]) for noise in self.noise_in]152if img is None:153return z, n154else:155return z, img, n156else:157return img158
159