h2o-llmstudio

Форк
0
184 строки · 5.0 Кб
1
import dataclasses
2
import logging
3
import os
4
from typing import Any, Dict, List, Optional
5

6
import numpy as np
7
from sqlitedict import SqliteDict
8

9
__all__ = ["Loggers"]
10

11
from llm_studio.src.utils.plot_utils import PLOT_ENCODINGS
12

13
logger = logging.getLogger(__name__)
14

15

16
def get_cfg(cfg: Any) -> Dict:
17
    """Returns simplified config elements
18

19
    Args:
20
        cfg: configuration
21

22
    Returns:
23
        Dict of config elements
24
    """
25

26
    items: Dict = {}
27
    type_annotations = cfg.get_annotations()
28

29
    cfg_dict = cfg.__dict__
30

31
    cfg_dict = {key: cfg_dict[key] for key in cfg._get_order(warn_if_unset=False)}
32

33
    for k, v in cfg_dict.items():
34
        if k.startswith("_") or cfg._get_visibility(k) < 0:
35
            continue
36

37
        if any([x in k for x in ["api"]]):
38
            continue
39

40
        if dataclasses.is_dataclass(v):
41
            elements_group = get_cfg(cfg=v)
42
            t = elements_group
43
            items = {**items, **t}
44
        else:
45
            type_annotation = type_annotations[k]
46
            if type_annotation == float:
47
                items[k] = float(v)
48
            else:
49
                items[k] = v
50

51
    return items
52

53

54
class NeptuneLogger:
55
    def __init__(self, cfg: Any):
56
        import neptune as neptune
57
        from neptune.utils import stringify_unsupported
58

59
        if cfg.logging._neptune_debug:
60
            mode = "debug"
61
        else:
62
            mode = "async"
63

64
        self.logger = neptune.init_run(
65
            project=cfg.logging.neptune_project,
66
            api_token=os.getenv("NEPTUNE_API_TOKEN", ""),
67
            name=cfg.experiment_name,
68
            mode=mode,
69
            capture_stdout=False,
70
            capture_stderr=False,
71
            source_files=[],
72
        )
73

74
        self.logger["cfg"] = stringify_unsupported(get_cfg(cfg))
75

76
    def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):
77
        name = f"{subset}/{name}"
78
        self.logger[name].append(value, step=step)
79

80

81
class LocalLogger:
82
    def __init__(self, cfg: Any):
83
        logging.getLogger("sqlitedict").setLevel(logging.ERROR)
84

85
        self.logs = f"{cfg.output_directory}/charts.db"
86

87
        params = get_cfg(cfg)
88

89
        with SqliteDict(self.logs) as logs:
90
            logs["cfg"] = params
91
            logs.commit()
92

93
    def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):
94
        if subset in PLOT_ENCODINGS:
95
            with SqliteDict(self.logs) as logs:
96
                if subset not in logs:
97
                    subset_dict = dict()
98
                else:
99
                    subset_dict = logs[subset]
100
                subset_dict[name] = value
101
                logs[subset] = subset_dict
102
                logs.commit()
103
            return
104

105
        # https://github.com/h2oai/wave/issues/447
106
        if np.isnan(value):
107
            value = None
108
        else:
109
            value = float(value)
110
        with SqliteDict(self.logs) as logs:
111
            if subset not in logs:
112
                subset_dict = dict()
113
            else:
114
                subset_dict = logs[subset]
115
            if name not in subset_dict:
116
                subset_dict[name] = {"steps": [], "values": []}
117

118
            subset_dict[name]["steps"].append(step)
119
            subset_dict[name]["values"].append(value)
120

121
            logs[subset] = subset_dict
122
            logs.commit()
123

124

125
class DummyLogger:
126
    def __init__(self, cfg: Optional[Any] = None):
127
        return
128

129
    def log(self, subset: str, name: str, value: Any, step: Optional[int] = None):
130
        return
131

132

133
class MainLogger:
134
    """Main logger"""
135

136
    def __init__(self, cfg: Any):
137
        self.loggers = {
138
            "local": LocalLogger(cfg),
139
            "external": Loggers.get(cfg.logging.logger),
140
        }
141

142
        try:
143
            self.loggers["external"] = self.loggers["external"](cfg)
144
        except Exception as e:
145
            logger.warning(
146
                f"Error when initializing logger. "
147
                f"Disabling custom logging functionality. "
148
                f"Please ensure logger configuration is correct and "
149
                f"you have a stable Internet connection: {e}"
150
            )
151
            self.loggers["external"] = DummyLogger(cfg)
152

153
    def reset_external(self):
154
        self.loggers["external"] = DummyLogger()
155

156
    def log(self, subset: str, name: str, value: str | float, step: float = None):
157
        for k, logger in self.loggers.items():
158
            if "validation_predictions" in name and k == "external":
159
                continue
160
            if subset == "internal" and not isinstance(logger, LocalLogger):
161
                continue
162
            logger.log(subset=subset, name=name, value=value, step=step)
163

164

165
class Loggers:
166
    """Loggers factory."""
167

168
    _loggers = {"None": DummyLogger, "Neptune": NeptuneLogger}
169

170
    @classmethod
171
    def names(cls) -> List[str]:
172
        return sorted(cls._loggers.keys())
173

174
    @classmethod
175
    def get(cls, name: str) -> Any:
176
        """Access to Loggers.
177

178
        Args:
179
            name: loggers name
180
        Returns:
181
            A class to build the Loggers
182
        """
183

184
        return cls._loggers.get(name, DummyLogger)
185

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.