in-context-impersonation

Форк
0
103 строки · 3.6 Кб
1
from typing import List, Tuple
2

3
import hydra
4
import lightning as L
5
import pyrootutils
6
from lightning import LightningDataModule, LightningModule, Trainer
7
from lightning.pytorch.loggers import Logger
8
from omegaconf import DictConfig, OmegaConf
9

10
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
11
# ------------------------------------------------------------------------------------ #
12
# the setup_root above is equivalent to:
13
# - adding project root dir to PYTHONPATH
14
#       (so you don't need to force user to install project as a package)
15
#       (necessary before importing any local modules e.g. `from src import utils`)
16
# - setting up PROJECT_ROOT environment variable
17
#       (which is used as a base for paths in "configs/paths/default.yaml")
18
#       (this way all filepaths are the same no matter where you run the code)
19
# - loading environment variables from ".env" in root dir
20
#
21
# you can remove it if you:
22
# 1. either install project as a package or move entry files to project root dir
23
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
24
#
25
# more info: https://github.com/ashleve/pyrootutils
26
# ------------------------------------------------------------------------------------ #
27

28
from src import utils
29
from src.utils.configure_torch import configure_torch
30

31
log = utils.get_pylogger(__name__)
32

33

34
@utils.task_wrapper
35
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
36
    """Evaluates given checkpoint on a datamodule testset.
37

38
    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
39
    failure. Useful for multiruns, saving info about the crash, etc.
40

41
    Args:
42
        cfg (DictConfig): Configuration composed by Hydra.
43

44
    Returns:
45
        Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
46
    """
47

48
    # set seed for random number generators in pytorch, numpy and python.random
49
    if cfg.get("seed"):
50
        L.seed_everything(cfg.seed, workers=True)
51

52
    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
53
    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
54

55
    log.info(f"Instantiating model <{cfg.model._target_}>")
56
    model: LightningModule = hydra.utils.instantiate(cfg.model)
57

58
    log.info("Instantiating loggers...")
59
    logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
60

61
    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
62
    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
63

64
    converted_cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
65
    object_dict = {
66
        "cfg": converted_cfg,
67
        "datamodule": datamodule,
68
        "model": model,
69
        "logger": logger,
70
        "trainer": trainer,
71
    }
72

73
    if logger:
74
        log.info("Logging hyperparameters!")
75
        utils.log_hyperparameters(object_dict)
76

77
    log.info("Starting testing!")
78
    trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
79

80
    # for predictions use trainer.predict(...)
81
    # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
82

83
    metric_dict = trainer.callback_metrics
84

85
    return metric_dict, object_dict
86

87

88
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
89
def main(cfg: DictConfig) -> None:
90
    # apply extra utilities
91
    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
92
    utils.extras(cfg)
93

94
    # configure torch
95
    configure_torch()
96

97
    # TODO: Do something about hydra eating the errors:
98
    # https://github.com/facebookresearch/hydra/issues/2664
99
    evaluate(cfg)
100

101

102
if __name__ == "__main__":
103
    main()
104

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.