StyleFeatureEditor

Форк
0
/
inference_arguments.py 
80 строк · 1.9 Кб
1
import os
2

3
from pathlib import Path
4
from typing import Optional, List, Tuple, Dict
5
from dataclasses import dataclass, field
6
from omegaconf import OmegaConf, MISSING
7
from utils.class_registry import ClassRegistry
8
from models.methods import methods_registry
9
from metrics.metrics import metrics_registry
10

11

12

13
args = ClassRegistry()
14

15

16
@args.add_to_registry("exp")
17
@dataclass
18
class ExperimentArgs:
19
    config_dir: str = str(Path(__file__).resolve().parent / "configs")
20
    config: str = MISSING
21
    output_dir: str = "results_dir"
22
    seed: int = 1
23
    root: str = os.getenv("EXP_ROOT", ".")
24
    domain: str = "human_faces"
25
    wandb: bool = False
26

27

28
@args.add_to_registry("data")
29
@dataclass
30
class DataArgs:
31
    inference_dir: str = ""
32
    transform: str = "face_1024"
33

34

35
@args.add_to_registry("inference")
36
@dataclass
37
class InferenceArgs:
38
    inference_runner: str = "base_inference_runner"
39
    editings_data: Dict = field(default_factory=lambda: {})
40

41

42
@args.add_to_registry("model")
43
@dataclass
44
class ModelArgs:
45
    method: str = "fse_full"
46
    device: str = "0"
47
    batch_size: int = 4
48
    workers: int = 4
49
    checkpoint_path: str = ""
50

51

52

53
MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs")
54
args.add_to_registry("methods_args")(MethodsArgs)
55

56
MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs")
57
args.add_to_registry("metrics")(MetricsArgs)
58

59

60

61
Args = args.make_dataclass_from_classes("Args")
62

63

64
def load_config():
65
    config = OmegaConf.structured(Args)
66

67
    conf_cli = OmegaConf.from_cli()
68
    config.exp.config = conf_cli.exp.config
69
    config.exp.config_dir = conf_cli.exp.config_dir
70

71
    config_path = os.path.join(config.exp.config_dir, config.exp.config)
72
    conf_file = OmegaConf.load(config_path)
73
    config = OmegaConf.merge(config, conf_file)
74
    for method in list(config.methods_args.keys()):
75
        if method != config.model.method:
76
            config.methods_args.__delattr__(method)
77

78
    config = OmegaConf.merge(config, conf_cli)
79

80
    return config
81

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.