otter

Форк
0
/
serving_utils.py 
131 строка · 3.9 Кб
1
import logging
2
import logging.handlers
3
import os
4
import sys
5

6
import requests
7

8
CONTROLLER_HEART_BEAT_EXPIRATION = 2 * 60
9
WORKER_HEART_BEAT_INTERVAL = 30
10

11
LOGDIR = "./logs"
12

13
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
14
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
15

16
handler = None
17

18

19
def build_logger(logger_name, logger_filename):
20
    global handler
21

22
    formatter = logging.Formatter(
23
        fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
24
        datefmt="%Y-%m-%d %H:%M:%S",
25
    )
26

27
    # Set the format of root handlers
28
    if not logging.getLogger().handlers:
29
        logging.basicConfig(level=logging.INFO)
30
    logging.getLogger().handlers[0].setFormatter(formatter)
31

32
    # Redirect stdout and stderr to loggers
33
    stdout_logger = logging.getLogger("stdout")
34
    stdout_logger.setLevel(logging.INFO)
35
    sl = StreamToLogger(stdout_logger, logging.INFO)
36
    sys.stdout = sl
37

38
    stderr_logger = logging.getLogger("stderr")
39
    stderr_logger.setLevel(logging.ERROR)
40
    sl = StreamToLogger(stderr_logger, logging.ERROR)
41
    sys.stderr = sl
42

43
    # Get logger
44
    logger = logging.getLogger(logger_name)
45
    logger.setLevel(logging.INFO)
46

47
    # Add a file handler for all loggers
48
    if handler is None:
49
        os.makedirs(LOGDIR, exist_ok=True)
50
        filename = os.path.join(LOGDIR, logger_filename)
51
        handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True)
52
        handler.setFormatter(formatter)
53

54
        for name, item in logging.root.manager.loggerDict.items():
55
            if isinstance(item, logging.Logger):
56
                item.addHandler(handler)
57

58
    return logger
59

60

61
class StreamToLogger(object):
62
    """
63
    Fake file-like stream object that redirects writes to a logger instance.
64
    """
65

66
    def __init__(self, logger, log_level=logging.INFO):
67
        self.terminal = sys.stdout
68
        self.logger = logger
69
        self.log_level = log_level
70
        self.linebuf = ""
71

72
    def __getattr__(self, attr):
73
        return getattr(self.terminal, attr)
74

75
    def write(self, buf):
76
        temp_linebuf = self.linebuf + buf
77
        self.linebuf = ""
78
        for line in temp_linebuf.splitlines(True):
79
            # From the io.TextIOWrapper docs:
80
            #   On output, if newline is None, any '\n' characters written
81
            #   are translated to the system default line separator.
82
            # By default sys.stdout.write() expects '\n' newlines and then
83
            # translates them so this is still cross platform.
84
            if line[-1] == "\n":
85
                self.logger.log(self.log_level, line.rstrip())
86
            else:
87
                self.linebuf += line
88

89
    def flush(self):
90
        if self.linebuf != "":
91
            self.logger.log(self.log_level, self.linebuf.rstrip())
92
        self.linebuf = ""
93

94

95
def disable_torch_init():
96
    """
97
    Disable the redundant torch default initialization to accelerate model creation.
98
    """
99
    import torch
100

101
    setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
102
    setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
103

104

105
def violates_moderation(text):
106
    """
107
    Check whether the text violates OpenAI moderation API.
108
    """
109
    url = "https://api.openai.com/v1/moderations"
110
    headers = {
111
        "Content-Type": "application/json",
112
        "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
113
    }
114
    text = text.replace("\n", "")
115
    data = "{" + '"input": ' + f'"{text}"' + "}"
116
    data = data.encode("utf-8")
117
    try:
118
        ret = requests.post(url, headers=headers, data=data, timeout=25)
119
        flagged = ret.json()["results"][0]["flagged"]
120
    except requests.exceptions.RequestException as e:
121
        flagged = False
122
    except KeyError as e:
123
        flagged = False
124

125
    return flagged
126

127

128
def pretty_print_semaphore(semaphore):
129
    if semaphore is None:
130
        return "None"
131
    return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.