StyleFeatureEditor
80 строк · 1.9 Кб
1import os2
3from pathlib import Path4from typing import Optional, List, Tuple, Dict5from dataclasses import dataclass, field6from omegaconf import OmegaConf, MISSING7from utils.class_registry import ClassRegistry8from models.methods import methods_registry9from metrics.metrics import metrics_registry10
11
12
13args = ClassRegistry()14
15
16@args.add_to_registry("exp")17@dataclass
18class ExperimentArgs:19config_dir: str = str(Path(__file__).resolve().parent / "configs")20config: str = MISSING21output_dir: str = "results_dir"22seed: int = 123root: str = os.getenv("EXP_ROOT", ".")24domain: str = "human_faces"25wandb: bool = False26
27
28@args.add_to_registry("data")29@dataclass
30class DataArgs:31inference_dir: str = ""32transform: str = "face_1024"33
34
35@args.add_to_registry("inference")36@dataclass
37class InferenceArgs:38inference_runner: str = "base_inference_runner"39editings_data: Dict = field(default_factory=lambda: {})40
41
42@args.add_to_registry("model")43@dataclass
44class ModelArgs:45method: str = "fse_full"46device: str = "0"47batch_size: int = 448workers: int = 449checkpoint_path: str = ""50
51
52
53MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs")54args.add_to_registry("methods_args")(MethodsArgs)55
56MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs")57args.add_to_registry("metrics")(MetricsArgs)58
59
60
61Args = args.make_dataclass_from_classes("Args")62
63
64def load_config():65config = OmegaConf.structured(Args)66
67conf_cli = OmegaConf.from_cli()68config.exp.config = conf_cli.exp.config69config.exp.config_dir = conf_cli.exp.config_dir70
71config_path = os.path.join(config.exp.config_dir, config.exp.config)72conf_file = OmegaConf.load(config_path)73config = OmegaConf.merge(config, conf_file)74for method in list(config.methods_args.keys()):75if method != config.model.method:76config.methods_args.__delattr__(method)77
78config = OmegaConf.merge(config, conf_cli)79
80return config81