h2o-llmstudio
184 строки · 5.0 Кб
1import dataclasses2import logging3import os4from typing import Any, Dict, List, Optional5
6import numpy as np7from sqlitedict import SqliteDict8
9__all__ = ["Loggers"]10
11from llm_studio.src.utils.plot_utils import PLOT_ENCODINGS12
13logger = logging.getLogger(__name__)14
15
16def get_cfg(cfg: Any) -> Dict:17"""Returns simplified config elements18
19Args:
20cfg: configuration
21
22Returns:
23Dict of config elements
24"""
25
26items: Dict = {}27type_annotations = cfg.get_annotations()28
29cfg_dict = cfg.__dict__30
31cfg_dict = {key: cfg_dict[key] for key in cfg._get_order(warn_if_unset=False)}32
33for k, v in cfg_dict.items():34if k.startswith("_") or cfg._get_visibility(k) < 0:35continue36
37if any([x in k for x in ["api"]]):38continue39
40if dataclasses.is_dataclass(v):41elements_group = get_cfg(cfg=v)42t = elements_group43items = {**items, **t}44else:45type_annotation = type_annotations[k]46if type_annotation == float:47items[k] = float(v)48else:49items[k] = v50
51return items52
53
54class NeptuneLogger:55def __init__(self, cfg: Any):56import neptune as neptune57from neptune.utils import stringify_unsupported58
59if cfg.logging._neptune_debug:60mode = "debug"61else:62mode = "async"63
64self.logger = neptune.init_run(65project=cfg.logging.neptune_project,66api_token=os.getenv("NEPTUNE_API_TOKEN", ""),67name=cfg.experiment_name,68mode=mode,69capture_stdout=False,70capture_stderr=False,71source_files=[],72)73
74self.logger["cfg"] = stringify_unsupported(get_cfg(cfg))75
76def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):77name = f"{subset}/{name}"78self.logger[name].append(value, step=step)79
80
81class LocalLogger:82def __init__(self, cfg: Any):83logging.getLogger("sqlitedict").setLevel(logging.ERROR)84
85self.logs = f"{cfg.output_directory}/charts.db"86
87params = get_cfg(cfg)88
89with SqliteDict(self.logs) as logs:90logs["cfg"] = params91logs.commit()92
93def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):94if subset in PLOT_ENCODINGS:95with SqliteDict(self.logs) as logs:96if subset not in logs:97subset_dict = dict()98else:99subset_dict = logs[subset]100subset_dict[name] = value101logs[subset] = subset_dict102logs.commit()103return104
105# https://github.com/h2oai/wave/issues/447106if np.isnan(value):107value = None108else:109value = float(value)110with SqliteDict(self.logs) as logs:111if subset not in logs:112subset_dict = dict()113else:114subset_dict = logs[subset]115if name not in subset_dict:116subset_dict[name] = {"steps": [], "values": []}117
118subset_dict[name]["steps"].append(step)119subset_dict[name]["values"].append(value)120
121logs[subset] = subset_dict122logs.commit()123
124
125class DummyLogger:126def __init__(self, cfg: Optional[Any] = None):127return128
129def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):130return131
132
133class MainLogger:134"""Main logger"""135
136def __init__(self, cfg: Any):137self.loggers = {138"local": LocalLogger(cfg),139"external": Loggers.get(cfg.logging.logger),140}141
142try:143self.loggers["external"] = self.loggers["external"](cfg)144except Exception as e:145logger.warning(146f"Error when initializing logger. "147f"Disabling custom logging functionality. "148f"Please ensure logger configuration is correct and "149f"you have a stable Internet connection: {e}"150)151self.loggers["external"] = DummyLogger(cfg)152
153def reset_external(self):154self.loggers["external"] = DummyLogger()155
156def log(self, subset: str, name: str, value: str | float, step: float = None):157for k, logger in self.loggers.items():158if "validation_predictions" in name and k == "external":159continue160if subset == "internal" and not isinstance(logger, LocalLogger):161continue162logger.log(subset=subset, name=name, value=value, step=step)163
164
165class Loggers:166"""Loggers factory."""167
168_loggers = {"None": DummyLogger, "Neptune": NeptuneLogger}169
170@classmethod171def names(cls) -> List[str]:172return sorted(cls._loggers.keys())173
174@classmethod175def get(cls, name: str) -> Any:176"""Access to Loggers.177
178Args:
179name: loggers name
180Returns:
181A class to build the Loggers
182"""
183
184return cls._loggers.get(name, DummyLogger)185