HairFastGAN

Форк
0
149 строк · 5.9 Кб
1
import argparse
2
import os
3
import numpy as np
4
import torch
5
import torch.nn as nn
6
import torch.nn.functional as F
7
import torch.utils.data as data
8
import yaml
9

10
from PIL import Image
11
from tqdm import tqdm
12
from torchvision import transforms, utils
13
from tensorboard_logger import Logger
14

15
from utils.datasets import *
16
from utils.functions import *
17
from trainer import *
18

19
torch.backends.cudnn.enabled = True
20
torch.backends.cudnn.deterministic = True
21
torch.backends.cudnn.benchmark = True
22
torch.autograd.set_detect_anomaly(True)
23
Image.MAX_IMAGE_PIXELS = None
24
device = torch.device('cuda')
25

26
parser = argparse.ArgumentParser()
27
parser.add_argument('--config', type=str, default='001', help='Path to the config file.')
28
parser.add_argument('--real_dataset_path', type=str, default='./data/ffhq-dataset/images/', help='dataset path')
29
parser.add_argument('--dataset_path', type=str, default='./data/stylegan2-generate-images/ims/', help='dataset path')
30
parser.add_argument('--label_path', type=str, default='./data/stylegan2-generate-images/seeds_pytorch_1.8.1.npy', help='laebl path')
31
parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='pretrained stylegan2 model')
32
parser.add_argument('--arcface_model_path', type=str, default='./pretrained_models/backbone.pth', help='pretrained ArcFace model')
33
parser.add_argument('--parsing_model_path', type=str, default='./pretrained_models/79999_iter.pth', help='pretrained parsing model')
34
parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')
35
parser.add_argument('--resume', action='store_true', help='resume from checkpoint')
36
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path')
37
opts = parser.parse_args()
38

39
log_dir = os.path.join(opts.log_path, opts.config) + '/'
40
os.makedirs(log_dir, exist_ok=True)
41
logger = Logger(log_dir)
42

43
config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'), Loader=yaml.FullLoader)
44

45
batch_size = config['batch_size']
46
epochs = config['epochs']
47
iter_per_epoch = config['iter_per_epoch']
48
img_size = (config['resolution'], config['resolution'])
49
video_data_input = False
50

51

52
img_to_tensor = transforms.Compose([
53
    transforms.ToTensor(),
54
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
55
])
56
img_to_tensor_car = transforms.Compose([
57
    transforms.Resize((384, 512)),
58
    transforms.Pad(padding=(0, 64, 0, 64)),
59
    transforms.ToTensor(),
60
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
61
])
62

63
# Initialize trainer
64
trainer = Trainer(config, opts)
65
trainer.initialize(opts.stylegan_model_path, opts.arcface_model_path, opts.parsing_model_path)   
66
trainer.to(device)
67

68
noise_exemple = trainer.noise_inputs
69
train_data_split = 0.9 if 'train_split' not in config else config['train_split']
70

71
# Load synthetic dataset
72
dataset_A = MyDataSet(image_dir=opts.dataset_path, label_dir=opts.label_path, output_size=img_size, noise_in=noise_exemple, training_set=True, train_split=train_data_split)
73
loader_A = data.DataLoader(dataset_A, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
74
# Load real dataset
75
dataset_B = MyDataSet(image_dir=opts.real_dataset_path, label_dir=None, output_size=img_size, noise_in=noise_exemple, training_set=True, train_split=train_data_split)
76
loader_B = data.DataLoader(dataset_B, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
77

78
# Start Training
79
epoch_0 = 0
80

81
# check if checkpoint exist
82
if 'checkpoint.pth' in os.listdir(log_dir):
83
    epoch_0 = trainer.load_checkpoint(os.path.join(log_dir, 'checkpoint.pth'))
84

85
if opts.resume:
86
    epoch_0 = trainer.load_checkpoint(os.path.join(opts.log_path, opts.checkpoint))
87

88
torch.manual_seed(0)
89
os.makedirs(log_dir + 'validation/', exist_ok=True)
90

91
print("Start!")
92

93
for n_epoch in tqdm(range(epoch_0, epochs)):
94

95
    iter_A = iter(loader_A)
96
    iter_B = iter(loader_B)
97
    iter_0 = n_epoch*iter_per_epoch
98

99
    trainer.enc_opt.zero_grad()
100

101
    for n_iter in range(iter_0, iter_0 + iter_per_epoch):
102
        
103
        if opts.dataset_path is None:
104
            z, noise = next(iter_A)
105
            img_A = None
106
        else:
107
            z, img_A, noise = next(iter_A)
108
            img_A = img_A.to(device)
109

110
        z = z.to(device)
111
        noise = [ee.to(device) for ee in noise]
112
        w = trainer.mapping(z)
113
        if 'fixed_noise' in config and config['fixed_noise']:
114
            img_A, noise = None, None
115

116
        img_B = None
117
        if 'use_realimg' in config and config['use_realimg']:
118
            try:
119
                img_B = next(iter_B)
120
                if img_B.size(0) != batch_size:
121
                    iter_B = iter(loader_B)
122
                    img_B = next(iter_B)
123
            except StopIteration:
124
                iter_B = iter(loader_B)
125
                img_B = next(iter_B)
126
            img_B = img_B.to(device)
127
            
128
        trainer.update(w=w, img=img_A, noise=noise, real_img=img_B, n_iter=n_iter)
129
        if (n_iter+1) % config['log_iter'] == 0:
130
            trainer.log_loss(logger, n_iter, prefix='scripts')
131
        if (n_iter+1) % config['image_save_iter'] == 0:
132
            trainer.save_image(log_dir, n_epoch, n_iter, prefix='/scripts/', w=w, img=img_A, noise=noise)
133
            trainer.save_image(log_dir, n_epoch, n_iter+1, prefix='/scripts/', w=w, img=img_B, noise=noise, training_mode=False)
134
            
135
    trainer.enc_scheduler.step()
136
    trainer.save_checkpoint(n_epoch, log_dir)
137
    
138
    # Test the model on celeba hq dataset
139
    with torch.no_grad():
140
        trainer.enc.eval()
141
        for i in range(10):
142
            image_A = img_to_tensor(Image.open('./data/celeba_hq/%d.jpg' % i)).unsqueeze(0).to(device)
143
            output = trainer.test(img=image_A)
144
            out_img = torch.cat(output, 3)
145
            utils.save_image(clip_img(out_img[:1]), log_dir + 'validation/' + 'epoch_' +str(n_epoch+1) + '_' + str(i) + '.jpg')
146
        trainer.compute_loss(w=w, img=img_A, noise=noise, real_img=img_B)
147
        trainer.log_loss(logger, n_iter, prefix='validation')
148

149
trainer.save_model(log_dir)

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.