15
"""Base class for DeltaGenerator-related unit tests."""
19
from typing import List
20
from unittest.mock import MagicMock
22
from streamlit.proto.Delta_pb2 import Delta
23
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
24
from streamlit.runtime import Runtime
25
from streamlit.runtime.caching.storage.dummy_cache_storage import (
26
MemoryCacheStorageManager,
28
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
29
from streamlit.runtime.media_file_manager import MediaFileManager
30
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
31
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
32
from streamlit.runtime.scriptrunner import (
37
from streamlit.runtime.scriptrunner.script_requests import ScriptRequests
38
from streamlit.runtime.state import SafeSessionState, SessionState
39
from streamlit.web.server.server import MEDIA_ENDPOINT, UPLOAD_FILE_ENDPOINT
42
class DeltaGeneratorTestCase(unittest.TestCase):
44
self.forward_msg_queue = ForwardMsgQueue()
47
self.orig_report_ctx = get_script_run_ctx()
50
self.script_run_ctx = ScriptRunContext(
51
session_id="test session id",
52
_enqueue=self.forward_msg_queue.enqueue,
54
session_state=SafeSessionState(SessionState(), lambda: None),
55
uploaded_file_mgr=MemoryUploadedFileManager(UPLOAD_FILE_ENDPOINT),
58
user_info={"email": "test@test.com"},
59
script_requests=ScriptRequests(),
61
add_script_run_ctx(threading.current_thread(), self.script_run_ctx)
65
self.media_file_storage = MemoryMediaFileStorage(MEDIA_ENDPOINT)
67
mock_runtime = MagicMock(spec=Runtime)
68
mock_runtime.cache_storage_manager = MemoryCacheStorageManager()
69
mock_runtime.media_file_mgr = MediaFileManager(self.media_file_storage)
70
mock_runtime.uploaded_file_mgr = self.script_run_ctx.uploaded_file_mgr
71
Runtime._instance = mock_runtime
75
add_script_run_ctx(threading.current_thread(), self.orig_report_ctx)
76
Runtime._instance = None
78
def get_message_from_queue(self, index=-1) -> ForwardMsg:
79
"""Get a ForwardMsg proto from the queue, by index."""
80
return self.forward_msg_queue._queue[index]
82
def get_delta_from_queue(self, index=-1) -> Delta:
83
"""Get a Delta proto from the queue, by index."""
84
deltas = self.get_all_deltas_from_queue()
87
def get_all_deltas_from_queue(self) -> List[Delta]:
88
"""Return all the delta messages in our ForwardMsgQueue"""
90
msg.delta for msg in self.forward_msg_queue._queue if msg.HasField("delta")
93
def clear_queue(self) -> None:
94
self.forward_msg_queue.clear()