HairFastGAN

Форк
0
/
main.py 
80 строк · 3.1 Кб
1
import argparse
2
import os
3
import sys
4
from pathlib import Path
5

6
from torchvision.utils import save_image
7
from tqdm.auto import tqdm
8

9
from hair_swap import HairFast, get_parser
10

11

12
def main(model_args, args):
13
    hair_fast = HairFast(model_args)
14

15
    experiments: list[str | tuple[str, str, str]] = []
16
    if args.file_path is not None:
17
        with open(args.file_path, 'r') as file:
18
            experiments.extend(file.readlines())
19

20
    if all(path is not None for path in (args.face_path, args.shape_path, args.color_path)):
21
        experiments.append((args.face_path, args.shape_path, args.color_path))
22

23
    for exp in tqdm(experiments):
24
        if isinstance(exp, str):
25
            file_1, file_2, file_3 = exp.split()
26
        else:
27
            file_1, file_2, file_3 = exp
28

29
        face_path = args.input_dir / file_1
30
        shape_path = args.input_dir / file_2
31
        color_path = args.input_dir / file_3
32

33
        base_name = '_'.join([path.stem for path in (face_path, shape_path, color_path)])
34
        exp_name = base_name if model_args.save_all else None
35

36
        if isinstance(exp, str) or args.result_path is None:
37
            os.makedirs(args.output_dir, exist_ok=True)
38
            output_image_path = args.output_dir / f'{base_name}.png'
39
        else:
40
            os.makedirs(args.result_path.parent, exist_ok=True)
41
            output_image_path = args.result_path
42

43
        final_image = hair_fast.swap(face_path, shape_path, color_path, benchmark=args.benchmark, exp_name=exp_name)
44
        save_image(final_image, output_image_path)
45

46

47
if __name__ == "__main__":
48
    model_parser = get_parser()
49
    parser = argparse.ArgumentParser(description='HairFast evaluate')
50
    parser.add_argument('--input_dir', type=Path, default='', help='The directory of the images to be inverted')
51
    parser.add_argument('--benchmark', action='store_true', help='Calculates the speed of the method during the session')
52

53
    # Arguments for a set of experiments
54
    parser.add_argument('--file_path', type=Path, default=None,
55
                        help='File with experiments with the format "face_path.png shape_path.png color_path.png"')
56
    parser.add_argument('--output_dir', type=Path, default=Path('output'), help='The directory for final results')
57

58
    # Arguments for single experiment
59
    parser.add_argument('--face_path', type=Path, default=None, help='Path to the face image')
60
    parser.add_argument('--shape_path', type=Path, default=None, help='Path to the shape image')
61
    parser.add_argument('--color_path', type=Path, default=None, help='Path to the color image')
62
    parser.add_argument('--result_path', type=Path, default=None, help='Path to save the result')
63

64
    args, unknown1 = parser.parse_known_args()
65
    model_args, unknown2 = model_parser.parse_known_args()
66

67
    unknown_args = set(unknown1) & set(unknown2)
68
    if unknown_args:
69
        file_ = sys.stderr
70
        print(f"Unknown arguments: {unknown_args}", file=file_)
71

72
        print("\nExpected arguments for the model:", file=file_)
73
        model_parser.print_help(file=file_)
74

75
        print("\nExpected arguments for evaluate:", file=file_)
76
        parser.print_help(file=file_)
77

78
        sys.exit(1)
79

80
    main(model_args, args)
81

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.