pytorch-lightning
137 строк · 4.5 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import json
16import queue
17from dataclasses import dataclass
18from threading import Thread
19from typing import Callable, Iterator, List, Optional
20
21import dateutil.parser
22from websocket import WebSocketApp
23
24from lightning.app.utilities.log_helpers import _error_callback, _OrderedLogEntry
25from lightning.app.utilities.logs_socket_api import _LightningLogsSocketAPI
26
27
28@dataclass
29class _LogEventLabels:
30app: Optional[str] = None
31container: Optional[str] = None
32filename: Optional[str] = None
33job: Optional[str] = None
34namespace: Optional[str] = None
35node_name: Optional[str] = None
36pod: Optional[str] = None
37component: Optional[str] = None
38projectID: Optional[str] = None
39stream: Optional[str] = None
40
41
42@dataclass
43class _LogEvent(_OrderedLogEntry):
44component_name: str
45labels: _LogEventLabels
46
47
48def _push_log_events_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue):
49"""Pushes _LogEvents from websocket to read_queue.
50
51Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
52
53"""
54
55def callback(ws_app: WebSocketApp, msg: str):
56# We strongly trust that the contract on API will hold atm :D
57event_dict = json.loads(msg)
58labels = _LogEventLabels(**event_dict.get("labels", {}))
59
60if "message" in event_dict:
61message = event_dict["message"]
62timestamp = dateutil.parser.isoparse(event_dict["timestamp"])
63event = _LogEvent(
64message=message,
65timestamp=timestamp,
66component_name=component_name,
67labels=labels,
68)
69read_queue.put(event)
70
71return callback
72
73
74def _app_logs_reader(
75logs_api_client: _LightningLogsSocketAPI,
76project_id: str,
77app_id: str,
78component_names: List[str],
79follow: bool,
80on_error_callback: Optional[Callable] = None,
81) -> Iterator[_LogEvent]:
82read_queue = queue.PriorityQueue()
83
84# We will use a socket per component
85log_sockets = [
86logs_api_client.create_lightning_logs_socket(
87project_id=project_id,
88app_id=app_id,
89component=component_name,
90on_message_callback=_push_log_events_to_read_queue_callback(component_name, read_queue),
91on_error_callback=on_error_callback or _error_callback,
92)
93for component_name in component_names
94]
95
96# And each socket on separate thread pushing log event to print queue
97# run_forever() will run until we close() the connection from outside
98log_threads = [Thread(target=work.run_forever, daemon=True) for work in log_sockets]
99
100# Establish connection and begin pushing logs to the print queue
101for th in log_threads:
102th.start()
103
104# Print logs from queue when log event is available
105flow = "Your app has started."
106work = "USER_RUN_WORK"
107start_timestamps = {}
108
109# Print logs from queue when log event is available
110try:
111while True:
112log_event: _LogEvent = read_queue.get(timeout=None if follow else 1.0)
113
114token = flow if log_event.component_name == "flow" else work
115if token in log_event.message:
116start_timestamps[log_event.component_name] = log_event.timestamp
117
118timestamp = start_timestamps.get(log_event.component_name, None)
119if timestamp and log_event.timestamp >= timestamp and "launcher" not in log_event.message:
120yield log_event
121
122except queue.Empty:
123# Empty is raised by queue.get if timeout is reached. Follow = False case.
124pass
125
126except KeyboardInterrupt:
127# User pressed CTRL+C to exit, we should respect that
128pass
129
130finally:
131# Close connections - it will cause run_forever() to finish -> thread as finishes aswell
132for socket in log_sockets:
133socket.close()
134
135# Because all socket were closed, we can just wait for threads to finish.
136for th in log_threads:
137th.join()
138