HairFastGAN
/
main.py
80 строк · 3.1 Кб
1import argparse2import os3import sys4from pathlib import Path5
6from torchvision.utils import save_image7from tqdm.auto import tqdm8
9from hair_swap import HairFast, get_parser10
11
12def main(model_args, args):13hair_fast = HairFast(model_args)14
15experiments: list[str | tuple[str, str, str]] = []16if args.file_path is not None:17with open(args.file_path, 'r') as file:18experiments.extend(file.readlines())19
20if all(path is not None for path in (args.face_path, args.shape_path, args.color_path)):21experiments.append((args.face_path, args.shape_path, args.color_path))22
23for exp in tqdm(experiments):24if isinstance(exp, str):25file_1, file_2, file_3 = exp.split()26else:27file_1, file_2, file_3 = exp28
29face_path = args.input_dir / file_130shape_path = args.input_dir / file_231color_path = args.input_dir / file_332
33base_name = '_'.join([path.stem for path in (face_path, shape_path, color_path)])34exp_name = base_name if model_args.save_all else None35
36if isinstance(exp, str) or args.result_path is None:37os.makedirs(args.output_dir, exist_ok=True)38output_image_path = args.output_dir / f'{base_name}.png'39else:40os.makedirs(args.result_path.parent, exist_ok=True)41output_image_path = args.result_path42
43final_image = hair_fast.swap(face_path, shape_path, color_path, benchmark=args.benchmark, exp_name=exp_name)44save_image(final_image, output_image_path)45
46
47if __name__ == "__main__":48model_parser = get_parser()49parser = argparse.ArgumentParser(description='HairFast evaluate')50parser.add_argument('--input_dir', type=Path, default='', help='The directory of the images to be inverted')51parser.add_argument('--benchmark', action='store_true', help='Calculates the speed of the method during the session')52
53# Arguments for a set of experiments54parser.add_argument('--file_path', type=Path, default=None,55help='File with experiments with the format "face_path.png shape_path.png color_path.png"')56parser.add_argument('--output_dir', type=Path, default=Path('output'), help='The directory for final results')57
58# Arguments for single experiment59parser.add_argument('--face_path', type=Path, default=None, help='Path to the face image')60parser.add_argument('--shape_path', type=Path, default=None, help='Path to the shape image')61parser.add_argument('--color_path', type=Path, default=None, help='Path to the color image')62parser.add_argument('--result_path', type=Path, default=None, help='Path to save the result')63
64args, unknown1 = parser.parse_known_args()65model_args, unknown2 = model_parser.parse_known_args()66
67unknown_args = set(unknown1) & set(unknown2)68if unknown_args:69file_ = sys.stderr70print(f"Unknown arguments: {unknown_args}", file=file_)71
72print("\nExpected arguments for the model:", file=file_)73model_parser.print_help(file=file_)74
75print("\nExpected arguments for evaluate:", file=file_)76parser.print_help(file=file_)77
78sys.exit(1)79
80main(model_args, args)81