pytorch-lightning

Форк
0
137 строк · 4.5 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import json
16
import queue
17
from dataclasses import dataclass
18
from threading import Thread
19
from typing import Callable, Iterator, List, Optional
20

21
import dateutil.parser
22
from websocket import WebSocketApp
23

24
from lightning.app.utilities.log_helpers import _error_callback, _OrderedLogEntry
25
from lightning.app.utilities.logs_socket_api import _LightningLogsSocketAPI
26

27

28
@dataclass
29
class _LogEventLabels:
30
    app: Optional[str] = None
31
    container: Optional[str] = None
32
    filename: Optional[str] = None
33
    job: Optional[str] = None
34
    namespace: Optional[str] = None
35
    node_name: Optional[str] = None
36
    pod: Optional[str] = None
37
    component: Optional[str] = None
38
    projectID: Optional[str] = None
39
    stream: Optional[str] = None
40

41

42
@dataclass
43
class _LogEvent(_OrderedLogEntry):
44
    component_name: str
45
    labels: _LogEventLabels
46

47

48
def _push_log_events_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue):
49
    """Pushes _LogEvents from websocket to read_queue.
50

51
    Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
52

53
    """
54

55
    def callback(ws_app: WebSocketApp, msg: str):
56
        # We strongly trust that the contract on API will hold atm :D
57
        event_dict = json.loads(msg)
58
        labels = _LogEventLabels(**event_dict.get("labels", {}))
59

60
        if "message" in event_dict:
61
            message = event_dict["message"]
62
            timestamp = dateutil.parser.isoparse(event_dict["timestamp"])
63
            event = _LogEvent(
64
                message=message,
65
                timestamp=timestamp,
66
                component_name=component_name,
67
                labels=labels,
68
            )
69
            read_queue.put(event)
70

71
    return callback
72

73

74
def _app_logs_reader(
75
    logs_api_client: _LightningLogsSocketAPI,
76
    project_id: str,
77
    app_id: str,
78
    component_names: List[str],
79
    follow: bool,
80
    on_error_callback: Optional[Callable] = None,
81
) -> Iterator[_LogEvent]:
82
    read_queue = queue.PriorityQueue()
83

84
    # We will use a socket per component
85
    log_sockets = [
86
        logs_api_client.create_lightning_logs_socket(
87
            project_id=project_id,
88
            app_id=app_id,
89
            component=component_name,
90
            on_message_callback=_push_log_events_to_read_queue_callback(component_name, read_queue),
91
            on_error_callback=on_error_callback or _error_callback,
92
        )
93
        for component_name in component_names
94
    ]
95

96
    # And each socket on separate thread pushing log event to print queue
97
    #   run_forever() will run until we close() the connection from outside
98
    log_threads = [Thread(target=work.run_forever, daemon=True) for work in log_sockets]
99

100
    # Establish connection and begin pushing logs to the print queue
101
    for th in log_threads:
102
        th.start()
103

104
    # Print logs from queue when log event is available
105
    flow = "Your app has started."
106
    work = "USER_RUN_WORK"
107
    start_timestamps = {}
108

109
    # Print logs from queue when log event is available
110
    try:
111
        while True:
112
            log_event: _LogEvent = read_queue.get(timeout=None if follow else 1.0)
113

114
            token = flow if log_event.component_name == "flow" else work
115
            if token in log_event.message:
116
                start_timestamps[log_event.component_name] = log_event.timestamp
117

118
            timestamp = start_timestamps.get(log_event.component_name, None)
119
            if timestamp and log_event.timestamp >= timestamp and "launcher" not in log_event.message:
120
                yield log_event
121

122
    except queue.Empty:
123
        # Empty is raised by queue.get if timeout is reached. Follow = False case.
124
        pass
125

126
    except KeyboardInterrupt:
127
        # User pressed CTRL+C to exit, we should respect that
128
        pass
129

130
    finally:
131
        # Close connections - it will cause run_forever() to finish -> thread as finishes aswell
132
        for socket in log_sockets:
133
            socket.close()
134

135
        # Because all socket were closed, we can just wait for threads to finish.
136
        for th in log_threads:
137
            th.join()
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.