16
from typing import Any, Callable, Optional
18
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
19
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
20
from streamlit.runtime.scriptrunner import ScriptRunContext, add_script_run_ctx
21
from streamlit.runtime.state import SafeSessionState, SessionState
25
func: Callable[[int], Any],
27
timeout: Optional[float] = 0.25,
28
attach_script_run_ctx: bool = True,
30
"""Call a function on multiple threads simultaneously and assert that no
31
thread raises an unhandled exception.
33
The function must take single `int` param, which will be the index of
34
the thread it's being called on.
36
Note that a passing multi-threaded test does not generally guarantee that
37
the tested code is thread safe! Because threading issues tend to be
38
non-deterministic, a flaky test that fails only occasionally is a good
39
indicator of an underlying issue.
44
The function to call on each thread.
46
The number of threads to create.
48
If the thread runs for longer than this amount of time, raise an
51
If True, attach a mock ScriptRunContext to each thread before
55
ExceptionCapturingThread(name=f"Thread {ii}", target=func, args=[ii])
56
for ii in range(num_threads)
59
if attach_script_run_ctx:
60
for ii in range(num_threads):
61
ctx = ScriptRunContext(
62
session_id=f"Thread{ii}_Session",
63
_enqueue=ForwardMsgQueue().enqueue,
65
session_state=SafeSessionState(SessionState(), lambda: None),
66
uploaded_file_mgr=MemoryUploadedFileManager("/mock/upload"),
69
user_info={"email": "test@test.com"},
72
add_script_run_ctx(thread, ctx)
74
for thread in threads:
77
for thread in threads:
78
thread.join(timeout=timeout)
79
thread.assert_no_unhandled_exception()
82
class ExceptionCapturingThread(threading.Thread):
83
"""Thread subclass that captures unhandled exceptions."""
86
self, group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None
96
self._unhandled_exception: Optional[BaseException] = None
99
def unhandled_exception(self) -> Optional[BaseException]:
100
"""The unhandled exception raised by the thread's target, if it raised one."""
101
return self._unhandled_exception
103
def assert_no_unhandled_exception(self) -> None:
104
"""If the thread target raised an unhandled exception, re-raise it.
107
if self._unhandled_exception is not None:
109
f"Unhandled exception in thread '{self.name}'"
110
) from self._unhandled_exception
112
def run(self) -> None:
115
except Exception as e:
116
self._unhandled_exception = e