in-context-impersonation
103 строки · 3.6 Кб
1from typing import List, Tuple2
3import hydra4import lightning as L5import pyrootutils6from lightning import LightningDataModule, LightningModule, Trainer7from lightning.pytorch.loggers import Logger8from omegaconf import DictConfig, OmegaConf9
10pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)11# ------------------------------------------------------------------------------------ #
12# the setup_root above is equivalent to:
13# - adding project root dir to PYTHONPATH
14# (so you don't need to force user to install project as a package)
15# (necessary before importing any local modules e.g. `from src import utils`)
16# - setting up PROJECT_ROOT environment variable
17# (which is used as a base for paths in "configs/paths/default.yaml")
18# (this way all filepaths are the same no matter where you run the code)
19# - loading environment variables from ".env" in root dir
20#
21# you can remove it if you:
22# 1. either install project as a package or move entry files to project root dir
23# 2. set `root_dir` to "." in "configs/paths/default.yaml"
24#
25# more info: https://github.com/ashleve/pyrootutils
26# ------------------------------------------------------------------------------------ #
27
28from src import utils29from src.utils.configure_torch import configure_torch30
31log = utils.get_pylogger(__name__)32
33
34@utils.task_wrapper35def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:36"""Evaluates given checkpoint on a datamodule testset.37
38This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
39failure. Useful for multiruns, saving info about the crash, etc.
40
41Args:
42cfg (DictConfig): Configuration composed by Hydra.
43
44Returns:
45Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
46"""
47
48# set seed for random number generators in pytorch, numpy and python.random49if cfg.get("seed"):50L.seed_everything(cfg.seed, workers=True)51
52log.info(f"Instantiating datamodule <{cfg.data._target_}>")53datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)54
55log.info(f"Instantiating model <{cfg.model._target_}>")56model: LightningModule = hydra.utils.instantiate(cfg.model)57
58log.info("Instantiating loggers...")59logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))60
61log.info(f"Instantiating trainer <{cfg.trainer._target_}>")62trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)63
64converted_cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)65object_dict = {66"cfg": converted_cfg,67"datamodule": datamodule,68"model": model,69"logger": logger,70"trainer": trainer,71}72
73if logger:74log.info("Logging hyperparameters!")75utils.log_hyperparameters(object_dict)76
77log.info("Starting testing!")78trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)79
80# for predictions use trainer.predict(...)81# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)82
83metric_dict = trainer.callback_metrics84
85return metric_dict, object_dict86
87
88@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")89def main(cfg: DictConfig) -> None:90# apply extra utilities91# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)92utils.extras(cfg)93
94# configure torch95configure_torch()96
97# TODO: Do something about hydra eating the errors:98# https://github.com/facebookresearch/hydra/issues/266499evaluate(cfg)100
101
102if __name__ == "__main__":103main()104