streamlit
589 строк · 19.7 Кб
1# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""
16Global pytest fixtures for e2e tests.
17This file is automatically run by pytest before tests are executed.
18"""
19from __future__ import annotations
20
21import hashlib
22import os
23import re
24import shlex
25import shutil
26import socket
27import subprocess
28import sys
29import time
30from io import BytesIO
31from pathlib import Path
32from random import randint
33from tempfile import TemporaryFile
34from types import ModuleType
35from typing import Any, Dict, Generator, List, Literal, Protocol, Tuple
36from urllib import parse
37
38import pytest
39import requests
40from PIL import Image
41from playwright.sync_api import ElementHandle, Locator, Page
42from pytest import FixtureRequest
43
44
45def reorder_early_fixtures(metafunc: pytest.Metafunc):
46"""Put fixtures with `pytest.mark.early` first during execution
47
48This allows patch of configurations before the application is initialized
49
50Copied from: https://github.com/pytest-dev/pytest/issues/1216#issuecomment-456109892
51"""
52for fixturedef in metafunc._arg2fixturedefs.values():
53fixturedef = fixturedef[0]
54for mark in getattr(fixturedef.func, "pytestmark", []):
55if mark.name == "early":
56order = metafunc.fixturenames
57order.insert(0, order.pop(order.index(fixturedef.argname)))
58break
59
60
61def pytest_generate_tests(metafunc: pytest.Metafunc):
62reorder_early_fixtures(metafunc)
63
64
65class AsyncSubprocess:
66"""A context manager. Wraps subprocess. Popen to capture output safely."""
67
68def __init__(self, args, cwd=None, env=None):
69self.args = args
70self.cwd = cwd
71self.env = env or {}
72self._proc = None
73self._stdout_file = None
74
75def terminate(self):
76"""Terminate the process and return its stdout/stderr in a string."""
77if self._proc is not None:
78self._proc.terminate()
79self._proc.wait()
80self._proc = None
81
82# Read the stdout file and close it
83stdout = None
84if self._stdout_file is not None:
85self._stdout_file.seek(0)
86stdout = self._stdout_file.read()
87self._stdout_file.close()
88self._stdout_file = None
89
90return stdout
91
92def __enter__(self):
93self.start()
94return self
95
96def start(self):
97# Start the process and capture its stdout/stderr output to a temp
98# file. We do this instead of using subprocess.PIPE (which causes the
99# Popen object to capture the output to its own internal buffer),
100# because large amounts of output can cause it to deadlock.
101self._stdout_file = TemporaryFile("w+")
102print(f"Running: {shlex.join(self.args)}")
103self._proc = subprocess.Popen(
104self.args,
105cwd=self.cwd,
106stdout=self._stdout_file,
107stderr=subprocess.STDOUT,
108text=True,
109env={**os.environ.copy(), **self.env},
110)
111
112def __exit__(self, exc_type, exc_val, exc_tb):
113if self._proc is not None:
114self._proc.terminate()
115self._proc = None
116if self._stdout_file is not None:
117self._stdout_file.close()
118self._stdout_file = None
119
120
121def resolve_test_to_script(test_module: ModuleType) -> str:
122"""Resolve the test module to the corresponding test script filename."""
123assert test_module.__file__ is not None
124return test_module.__file__.replace("_test.py", ".py")
125
126
127def hash_to_range(
128text: str,
129min: int = 10000,
130max: int = 65535,
131) -> int:
132sha256_hash = hashlib.sha256(text.encode("utf-8")).hexdigest()
133return min + (int(sha256_hash, 16) % (max - min + 1))
134
135
136def is_port_available(port: int, host: str) -> bool:
137"""Check if a port is available on the given host."""
138with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
139return sock.connect_ex((host, port)) != 0
140
141
142def find_available_port(
143min_port: int = 10000,
144max_port: int = 65535,
145max_tries: int = 50,
146host: str = "localhost",
147) -> int:
148"""Find an available port on the given host."""
149for _ in range(max_tries):
150selected_port = randint(min_port, max_port)
151if is_port_available(selected_port, host):
152return selected_port
153raise RuntimeError("Unable to find an available port.")
154
155
156def is_app_server_running(port: int, host: str = "localhost") -> bool:
157"""Check if the app server is running."""
158try:
159return (
160requests.get(f"http://{host}:{port}/_stcore/health", timeout=1).text == "ok"
161)
162except Exception:
163return False
164
165
166def wait_for_app_server_to_start(port: int, timeout: int = 5) -> bool:
167"""Wait for the app server to start.
168
169Parameters
170----------
171port : int
172The port on which the app server is running.
173
174timeout : int
175The number of minutes to wait for the app server to start.
176
177Returns
178-------
179bool
180True if the app server is started, False otherwise.
181"""
182
183print(f"Waiting for app to start... {port}")
184start_time = time.time()
185while not is_app_server_running(port):
186time.sleep(3)
187if time.time() - start_time > 60 * timeout:
188return False
189return True
190
191
192@pytest.fixture(scope="module")
193def app_port(worker_id: str) -> int:
194"""Fixture that returns an available port on localhost."""
195if worker_id and worker_id != "master":
196# This is run with xdist, we try to get a port by hashing the worker ID
197port = hash_to_range(worker_id)
198if is_port_available(port, "localhost"):
199return port
200# Find a random available port:
201return find_available_port()
202
203
204@pytest.fixture(scope="module", autouse=True)
205def app_server(
206app_port: int, request: FixtureRequest
207) -> Generator[AsyncSubprocess, None, None]:
208"""Fixture that starts and stops the Streamlit app server."""
209streamlit_proc = AsyncSubprocess(
210[
211"streamlit",
212"run",
213resolve_test_to_script(request.module),
214"--server.headless",
215"true",
216"--global.developmentMode",
217"false",
218"--server.port",
219str(app_port),
220"--browser.gatherUsageStats",
221"false",
222"--server.fileWatcherType",
223"none",
224],
225cwd=".",
226)
227streamlit_proc.start()
228if not wait_for_app_server_to_start(app_port):
229streamlit_stdout = streamlit_proc.terminate()
230print(streamlit_stdout)
231raise RuntimeError("Unable to start Streamlit app")
232yield streamlit_proc
233streamlit_stdout = streamlit_proc.terminate()
234print(streamlit_stdout)
235
236
237@pytest.fixture(scope="function")
238def app(page: Page, app_port: int) -> Page:
239"""Fixture that opens the app."""
240page.goto(f"http://localhost:{app_port}/")
241wait_for_app_loaded(page)
242return page
243
244
245@pytest.fixture(scope="function")
246def app_with_query_params(
247page: Page, app_port: int, request: FixtureRequest
248) -> Tuple[Page, Dict]:
249"""Fixture that opens the app with additional query parameters.
250The query parameters are passed as a dictionary in the 'param' key of the request.
251"""
252query_params = request.param
253query_string = parse.urlencode(query_params, doseq=True)
254url = f"http://localhost:{app_port}/?{query_string}"
255page.goto(url)
256wait_for_app_loaded(page)
257
258return page, query_params
259
260
261@pytest.fixture(scope="session")
262def browser_type_launch_args(browser_type_launch_args: Dict, browser_name: str):
263"""Fixture that adds the fake device and ui args to the browser type launch args."""
264# The browser context fixture in pytest-playwright is defined in session scope, and
265# depends on the browser_type_launch_args fixture. This means that we can't
266# redefine the browser_type_launch_args fixture more narrow scope
267# e.g. function or module scope.
268# https://github.com/microsoft/playwright-pytest/blob/ef99541352b307411dbc15c627e50f95de30cc71/pytest_playwright/pytest_playwright.py#L128
269
270# We need to extend browser launch args to support fake video stream for
271# st.camera_input test.
272# https://github.com/microsoft/playwright/issues/4532#issuecomment-1491761713
273
274if browser_name == "chromium":
275browser_type_launch_args = {
276**browser_type_launch_args,
277"args": [
278"--use-fake-device-for-media-stream",
279"--use-fake-ui-for-media-stream",
280],
281}
282
283elif browser_name == "firefox":
284browser_type_launch_args = {
285**browser_type_launch_args,
286"firefox_user_prefs": {
287"media.navigator.streams.fake": True,
288"media.navigator.permission.disabled": True,
289"permissions.default.microphone": 1,
290"permissions.default.camera": 1,
291},
292}
293return browser_type_launch_args
294
295
296@pytest.fixture(scope="function", params=["light_theme", "dark_theme"])
297def app_theme(request) -> str:
298"""Fixture that returns the theme name."""
299return str(request.param)
300
301
302@pytest.fixture(scope="function")
303def themed_app(page: Page, app_port: int, app_theme: str) -> Page:
304"""Fixture that opens the app with the given theme."""
305page.goto(f"http://localhost:{app_port}/?embed_options={app_theme}")
306wait_for_app_loaded(page)
307return page
308
309
310class ImageCompareFunction(Protocol):
311def __call__(
312self,
313element: ElementHandle | Locator | Page,
314*,
315image_threshold: float = 0.002,
316pixel_threshold: float = 0.05,
317name: str | None = None,
318fail_fast: bool = False,
319) -> None:
320"""Compare a screenshot with screenshot from a past run.
321
322Parameters
323----------
324element : ElementHandle or Locator
325The element to take a screenshot of.
326image_threshold : float, optional
327The allowed percentage of different pixels in the image.
328pixel_threshold : float, optional
329The allowed percentage of difference for a single pixel.
330name : str | None, optional
331The name of the screenshot without an extension. If not provided, the name
332of the test function will be used.
333fail_fast : bool, optional
334If True, the comparison will stop at the first pixel mismatch.
335"""
336
337
338@pytest.fixture(scope="session")
339def output_folder(pytestconfig: Any) -> Path:
340"""Fixture that returns the directory that is used for all test failures information.
341
342This includes:
343- snapshot-tests-failures: This directory contains all the snapshots that did not
344match with the snapshots from past runs. The folder structure is based on the folder
345structure used in the main snapshots folder.
346- snapshot-updates: This directory contains all the snapshots that got updated in
347the current run based on folder structure used in the main snapshots folder.
348"""
349return Path(pytestconfig.getoption("--output")).resolve()
350
351
352@pytest.fixture(scope="function")
353def assert_snapshot(
354request: FixtureRequest, output_folder: Path
355) -> Generator[ImageCompareFunction, None, None]:
356"""Fixture that compares a screenshot with screenshot from a past run."""
357root_path = Path(os.getcwd()).resolve()
358platform = str(sys.platform)
359module_name = request.module.__name__.split(".")[-1]
360test_function_name = request.node.originalname
361
362snapshot_dir: Path = root_path / "__snapshots__" / platform / module_name
363
364module_snapshot_failures_dir: Path = (
365output_folder / "snapshot-tests-failures" / platform / module_name
366)
367module_snapshot_updates_dir: Path = (
368output_folder / "snapshot-updates" / platform / module_name
369)
370
371snapshot_file_suffix = ""
372# Extract the parameter ids if they exist
373match = re.search(r"\[(.*?)\]", request.node.name)
374if match:
375snapshot_file_suffix = f"[{match.group(1)}]"
376
377snapshot_default_file_name: str = test_function_name + snapshot_file_suffix
378
379test_failure_messages: List[str] = []
380
381def compare(
382element: ElementHandle | Locator | Page,
383*,
384image_threshold: float = 0.002,
385pixel_threshold: float = 0.05,
386name: str | None = None,
387fail_fast: bool = False,
388file_type: Literal["png", "jpg"] = "png",
389) -> None:
390"""Compare a screenshot with screenshot from a past run.
391
392Parameters
393----------
394element : ElementHandle or Locator
395The element to take a screenshot of.
396image_threshold : float, optional
397The allowed percentage of different pixels in the image.
398pixel_threshold : float, optional
399The allowed percentage of difference for a single pixel to be considered
400different.
401name : str | None, optional
402The name of the screenshot without an extension. If not provided, the name
403of the test function will be used.
404fail_fast : bool, optional
405If True, the comparison will stop at the first pixel mismatch.
406file_type: "png" or "jpg"
407The file type of the screenshot. Defaults to "png".
408"""
409nonlocal test_failure_messages
410nonlocal snapshot_default_file_name
411nonlocal module_snapshot_updates_dir
412nonlocal module_snapshot_failures_dir
413nonlocal snapshot_file_suffix
414
415if file_type == "jpg":
416file_extension = ".jpg"
417img_bytes = element.screenshot(
418type="jpeg", quality=90, animations="disabled"
419)
420
421else:
422file_extension = ".png"
423img_bytes = element.screenshot(type="png", animations="disabled")
424
425snapshot_file_name: str = snapshot_default_file_name
426if name:
427snapshot_file_name = name + snapshot_file_suffix
428
429snapshot_file_path: Path = (
430snapshot_dir / f"{snapshot_file_name}{file_extension}"
431)
432
433snapshot_updates_file_path: Path = (
434module_snapshot_updates_dir / f"{snapshot_file_name}{file_extension}"
435)
436
437snapshot_file_path.parent.mkdir(parents=True, exist_ok=True)
438
439test_failures_dir = module_snapshot_failures_dir / snapshot_file_name
440if test_failures_dir.exists():
441# Remove the past runs failure dir for this specific screenshot
442shutil.rmtree(test_failures_dir)
443
444if not snapshot_file_path.exists():
445snapshot_file_path.write_bytes(img_bytes)
446# Update this in updates folder:
447snapshot_updates_file_path.parent.mkdir(parents=True, exist_ok=True)
448snapshot_updates_file_path.write_bytes(img_bytes)
449# For missing snapshots, we don't want to directly fail in order to generate
450# all missing snapshots in one run.
451test_failure_messages.append(f"Missing snapshot for {snapshot_file_name}")
452return
453
454from pixelmatch.contrib.PIL import pixelmatch
455
456# Compare the new screenshot with the screenshot from past runs:
457img_a = Image.open(BytesIO(img_bytes))
458img_b = Image.open(snapshot_file_path)
459img_diff = Image.new("RGBA", img_a.size)
460try:
461mismatch = pixelmatch(
462img_a,
463img_b,
464img_diff,
465threshold=pixel_threshold,
466fail_fast=fail_fast,
467alpha=0,
468)
469except ValueError as ex:
470# ValueError is thrown when the images have different sizes
471# Update this in updates folder:
472snapshot_updates_file_path.parent.mkdir(parents=True, exist_ok=True)
473snapshot_updates_file_path.write_bytes(img_bytes)
474pytest.fail(f"Snapshot matching for {snapshot_file_name} failed: {ex}")
475max_diff_pixels = int(image_threshold * img_a.size[0] * img_a.size[1])
476
477if mismatch < max_diff_pixels:
478return
479
480# Update this in updates folder:
481snapshot_updates_file_path.parent.mkdir(parents=True, exist_ok=True)
482snapshot_updates_file_path.write_bytes(img_bytes)
483
484# Create new failures folder for this test:
485test_failures_dir.mkdir(parents=True, exist_ok=True)
486img_diff.save(f"{test_failures_dir}/diff_{snapshot_file_name}{file_extension}")
487img_a.save(f"{test_failures_dir}/actual_{snapshot_file_name}{file_extension}")
488img_b.save(f"{test_failures_dir}/expected_{snapshot_file_name}{file_extension}")
489
490pytest.fail(
491f"Snapshot mismatch for {snapshot_file_name} ({mismatch} pixels difference)"
492)
493
494yield compare
495
496if test_failure_messages:
497pytest.fail("Missing snapshots: \n" + "\n".join(test_failure_messages))
498
499
500# Public utility methods:
501
502
503def wait_for_app_run(page: Page, wait_delay: int = 100):
504"""Wait for the given page to finish running."""
505page.wait_for_selector(
506"[data-testid='stStatusWidget']", timeout=20000, state="detached"
507)
508
509if wait_delay > 0:
510# Give the app a little more time to render everything
511page.wait_for_timeout(wait_delay)
512
513
514def wait_for_app_loaded(page: Page, embedded: bool = False):
515"""Wait for the app to fully load."""
516# Wait for the app view container to appear:
517page.wait_for_selector(
518"[data-testid='stAppViewContainer']", timeout=30000, state="attached"
519)
520
521# Wait for the main menu to appear:
522if not embedded:
523page.wait_for_selector(
524"[data-testid='stMainMenu']", timeout=20000, state="attached"
525)
526
527wait_for_app_run(page)
528
529
530def rerun_app(page: Page):
531"""Triggers an app rerun and waits for the run to be finished."""
532# Click somewhere to clear the focus from elements:
533page.get_by_test_id("stApp").click(position={"x": 0, "y": 0})
534# Press "r" to rerun the app:
535page.keyboard.press("r")
536wait_for_app_run(page)
537
538
539def wait_until(page: Page, fn: callable, timeout: int = 5000, interval: int = 100):
540"""Run a test function in a loop until it evaluates to True
541or times out.
542
543For example:
544>>> wait_until(lambda: x.values() == ['x'], page)
545
546Parameters
547----------
548page : playwright.sync_api.Page
549Playwright page
550fn : callable
551Callback
552timeout : int, optional
553Total timeout in milliseconds, by default 5000
554interval : int, optional
555Waiting interval, by default 100
556
557Adapted from panel.
558"""
559# Hide this function traceback from the pytest output if the test fails
560__tracebackhide__ = True
561
562start = time.time()
563
564def timed_out():
565elapsed = time.time() - start
566elapsed_ms = elapsed * 1000
567return elapsed_ms > timeout
568
569timeout_msg = f"wait_until timed out in {timeout} milliseconds"
570
571while True:
572try:
573result = fn()
574except AssertionError as e:
575if timed_out():
576raise TimeoutError(timeout_msg) from e
577else:
578if result not in (None, True, False):
579raise ValueError(
580"`wait_until` callback must return None, True or "
581f"False, returned {result!r}"
582)
583# Stop is result is True or None
584# None is returned when the function has an assert
585if result is None or result:
586return
587if timed_out():
588raise TimeoutError(timeout_msg)
589page.wait_for_timeout(interval)
590