paddlenlp

Форк
0
164 строки · 5.0 Кб
1
# Copyright (c) 2023  PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License"
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import contextlib
16
import datetime
17
import functools
18
import logging
19
import threading
20
import time
21

22
import colorlog
23

24
loggers = {}
25

26
log_config = {
27
    "DEBUG": {"level": 10, "color": "purple"},
28
    "INFO": {"level": 20, "color": "green"},
29
    "TRAIN": {"level": 21, "color": "cyan"},
30
    "EVAL": {"level": 22, "color": "blue"},
31
    "WARNING": {"level": 30, "color": "yellow"},
32
    "ERROR": {"level": 40, "color": "red"},
33
    "CRITICAL": {"level": 50, "color": "bold_red"},
34
}
35

36

37
class Logger(object):
38
    """
39
    Deafult logger in PaddleFleetX
40

41
    Args:
42
        name(str) : Logger name, default is 'PaddleFleetX'
43
    """
44

45
    def __init__(self, name: str = None):
46
        name = "PaddleFleetX" if not name else name
47
        self.logger = logging.getLogger(name)
48

49
        for key, conf in log_config.items():
50
            logging.addLevelName(conf["level"], key)
51
            self.__dict__[key] = functools.partial(self.__call__, conf["level"])
52
            self.__dict__[key.lower()] = functools.partial(self.__call__, conf["level"])
53

54
        self.format = colorlog.ColoredFormatter(
55
            "%(log_color)s[%(asctime)-15s] [%(levelname)s]%(reset)s - %(message)s",
56
            log_colors={key: conf["color"] for key, conf in log_config.items()},
57
        )
58

59
        self.handler = logging.StreamHandler()
60
        self.handler.setFormatter(self.format)
61

62
        self.logger.addHandler(self.handler)
63
        self.logLevel = "DEBUG"
64
        self.logger.setLevel(logging.DEBUG)
65
        self.logger.propagate = False
66
        self._is_enable = True
67

68
    def disable(self):
69
        self._is_enable = False
70

71
    def enable(self):
72
        self._is_enable = True
73

74
    @property
75
    def is_enable(self) -> bool:
76
        return self._is_enable
77

78
    def __call__(self, log_level: str, msg: str):
79
        if not self.is_enable:
80
            return
81

82
        self.logger.log(log_level, msg)
83

84
    @contextlib.contextmanager
85
    def use_terminator(self, terminator: str):
86
        old_terminator = self.handler.terminator
87
        self.handler.terminator = terminator
88
        yield
89
        self.handler.terminator = old_terminator
90

91
    @contextlib.contextmanager
92
    def processing(self, msg: str, interval: float = 0.1):
93
        """
94
        Continuously print a progress bar with rotating special effects.
95

96
        Args:
97
            msg(str): Message to be printed.
98
            interval(float): Rotation interval. Default to 0.1.
99
        """
100
        end = False
101

102
        def _printer():
103
            index = 0
104
            flags = ["\\", "|", "/", "-"]
105
            while not end:
106
                flag = flags[index % len(flags)]
107
                with self.use_terminator("\r"):
108
                    self.info("{}: {}".format(msg, flag))
109
                time.sleep(interval)
110
                index += 1
111

112
        t = threading.Thread(target=_printer)
113
        t.start()
114
        yield
115
        end = True
116

117

118
logger = Logger()
119

120

121
def advertise():
122
    """
123
    Show the advertising message like the following:
124
    ===========================================================
125
    ==        PaddleFleetX is powered by PaddlePaddle !        ==
126
    ===========================================================
127
    ==                                                       ==
128
    ==   For more info please go to the following website.   ==
129
    ==                                                       ==
130
    ==       https://github.com/PaddlePaddle/PaddleFleetX    ==
131
    ===========================================================
132
    """
133
    copyright = "PaddleFleetX is powered by PaddlePaddle !"
134
    ad = "For more info please go to the following website."
135
    website = "https://github.com/PaddlePaddle/PaddleFleetX"
136
    AD_LEN = 6 + len(max([copyright, ad, website], key=len))
137

138
    logger.info(
139
        "\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
140
            "=" * (AD_LEN + 4),
141
            "=={}==".format(copyright.center(AD_LEN)),
142
            "=" * (AD_LEN + 4),
143
            "=={}==".format(" " * AD_LEN),
144
            "=={}==".format(ad.center(AD_LEN)),
145
            "=={}==".format(" " * AD_LEN),
146
            "=={}==".format(website.center(AD_LEN)),
147
            "=" * (AD_LEN + 4),
148
        )
149
    )
150

151

152
from .device import synchronize  # noqa: E402
153

154

155
def get_timestamp():
156
    if synchronize():
157
        return time.time()
158
    else:
159
        logger.warning("Device synchronizing failed, which may result uncorrect time")
160
    return time.time()
161

162

163
def convert_timestamp_to_data(timeStamp):
164
    return str(datetime.timedelta(seconds=int(timeStamp)))
165

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.