1
from __future__ import annotations
12
from collections.abc import Sequence
13
from functools import wraps
14
from importlib.abc import Loader, MetaPathFinder
15
from importlib.machinery import ModuleSpec
16
from types import ModuleType, TracebackType
22
import brotlicffi as brotli # type: ignore[import-not-found]
24
import brotli # type: ignore[import-not-found]
29
import zstandard as _unused_module_zstd # noqa: F401
35
from urllib3.connectionpool import ConnectionPool
36
from urllib3.exceptions import HTTPWarning
39
import urllib3.contrib.pyopenssl as pyopenssl
41
pyopenssl = None # type: ignore[assignment]
44
_RT = typing.TypeVar("_RT") # return type
45
_TestFuncT = typing.TypeVar("_TestFuncT", bound=typing.Callable[..., typing.Any])
48
# We need a host that will not immediately close the connection with a TCP
50
if platform.system() == "Windows":
51
# Reserved loopback subnet address
52
TARPIT_HOST = "127.0.0.0"
54
# Reserved internet scoped address
55
# https://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
56
TARPIT_HOST = "240.0.0.0"
58
# (Arguments for socket, is it IPv6 address?)
59
VALID_SOURCE_ADDRESSES = [(("::1", 0), True), (("127.0.0.1", 0), False)]
60
# RFC 5737: 192.0.2.0/24 is for testing only.
61
# RFC 3849: 2001:db8::/32 is for documentation only.
62
INVALID_SOURCE_ADDRESSES = [(("192.0.2.255", 0), False), (("2001:db8::1", 0), True)]
64
# We use timeouts in three different ways in our tests
66
# 1. To make sure that the operation timeouts, we can use a short timeout.
67
# 2. To make sure that the test does not hang even if the operation should succeed, we
68
# want to use a long timeout, even more so on CI where tests can be really slow
69
# 3. To test our timeout logic by using two different values, eg. by using different
70
# values at the pool level and at the request level.
73
if os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS") == "true":
76
DUMMY_POOL = ConnectionPool("dummy")
79
def _can_resolve(host: str) -> bool:
80
"""Returns True if the system can resolve host to an address."""
82
socket.getaddrinfo(host, None, socket.AF_UNSPEC)
84
except socket.gaierror:
88
# Some systems might not resolve "localhost." correctly.
89
# See https://github.com/urllib3/urllib3/issues/1809 and
90
# https://github.com/urllib3/urllib3/pull/1475#issuecomment-440788064.
91
RESOLVES_LOCALHOST_FQDN = _can_resolve("localhost.")
94
def clear_warnings(cls: type[Warning] = HTTPWarning) -> None:
96
for f in warnings.filters:
97
if issubclass(f[2], cls):
100
warnings.filters[:] = new_filters # type: ignore[index]
105
warnings.simplefilter("ignore", HTTPWarning)
108
def notWindows() -> typing.Callable[[_TestFuncT], _TestFuncT]:
109
"""Skips this test on Windows"""
110
return pytest.mark.skipif(
111
platform.system() == "Windows",
112
reason="Test does not run on Windows",
116
def onlyBrotli() -> typing.Callable[[_TestFuncT], _TestFuncT]:
117
return pytest.mark.skipif(
118
brotli is None, reason="only run if brotli library is present"
122
def notBrotli() -> typing.Callable[[_TestFuncT], _TestFuncT]:
123
return pytest.mark.skipif(
124
brotli is not None, reason="only run if a brotli library is absent"
128
def onlyZstd() -> typing.Callable[[_TestFuncT], _TestFuncT]:
129
return pytest.mark.skipif(
130
not HAS_ZSTD, reason="only run if a python-zstandard library is installed"
134
def notZstd() -> typing.Callable[[_TestFuncT], _TestFuncT]:
135
return pytest.mark.skipif(
137
reason="only run if a python-zstandard library is not installed",
141
_requires_network_has_route = None
144
def requires_network() -> typing.Callable[[_TestFuncT], _TestFuncT]:
145
"""Helps you skip tests that require the network"""
147
def _is_unreachable_err(err: Exception) -> bool:
148
return getattr(err, "errno", None) in (
150
errno.EHOSTUNREACH, # For OSX
153
def _has_route() -> bool:
155
sock = socket.create_connection((TARPIT_HOST, 80), 0.0001)
158
except socket.timeout:
161
if _is_unreachable_err(e):
166
def _skip_if_no_route(f: _TestFuncT) -> _TestFuncT:
167
"""Skip test exuction if network is unreachable"""
170
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
171
global _requires_network_has_route
172
if _requires_network_has_route is None:
173
_requires_network_has_route = _has_route()
174
if not _requires_network_has_route:
175
pytest.skip("Can't run the test because the network is unreachable")
176
return f(*args, **kwargs)
178
return typing.cast(_TestFuncT, wrapper)
180
def _decorator_requires_internet(
181
decorator: typing.Callable[[_TestFuncT], _TestFuncT]
182
) -> typing.Callable[[_TestFuncT], _TestFuncT]:
183
"""Mark a decorator with the "requires_internet" mark"""
185
def wrapper(f: _TestFuncT) -> typing.Any:
186
return pytest.mark.requires_network(decorator(f))
190
return _decorator_requires_internet(_skip_if_no_route)
193
def resolvesLocalhostFQDN() -> typing.Callable[[_TestFuncT], _TestFuncT]:
194
"""Test requires successful resolving of 'localhost.'"""
195
return pytest.mark.skipif(
196
not RESOLVES_LOCALHOST_FQDN,
197
reason="Can't resolve localhost.",
201
def withPyOpenSSL(test: typing.Callable[..., _RT]) -> typing.Callable[..., _RT]:
203
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> _RT:
205
pytest.skip("pyopenssl not available, skipping test.")
206
return test(*args, **kwargs)
208
pyopenssl.inject_into_urllib3()
209
result = test(*args, **kwargs)
210
pyopenssl.extract_from_urllib3()
216
class _ListHandler(logging.Handler):
217
def __init__(self) -> None:
219
self.records: list[logging.LogRecord] = []
221
def emit(self, record: logging.LogRecord) -> None:
222
self.records.append(record)
226
def __init__(self, target: logging.Logger = logging.root) -> None:
228
self._target = target
229
self._handler = _ListHandler()
232
def records(self) -> list[logging.LogRecord]:
233
return self._handler.records
235
def install(self) -> None:
236
self._target.addHandler(self._handler)
238
def uninstall(self) -> None:
239
self._target.removeHandler(self._handler)
241
def __enter__(self) -> list[logging.LogRecord]:
247
exc_type: type[BaseException] | None,
248
exc_value: BaseException | None,
249
traceback: TracebackType | None,
250
) -> typing.Literal[False]:
255
class ImportBlockerLoader(Loader):
256
def __init__(self, fullname: str) -> None:
257
self._fullname = fullname
259
def load_module(self, fullname: str) -> ModuleType:
260
raise ImportError(f"import of {fullname} is blocked")
262
def exec_module(self, module: ModuleType) -> None:
263
raise ImportError(f"import of {self._fullname} is blocked")
266
class ImportBlocker(MetaPathFinder):
270
To be placed on ``sys.meta_path``. This ensures that the modules
271
specified cannot be imported, even if they are a builtin.
274
def __init__(self, *namestoblock: str) -> None:
275
self.namestoblock = namestoblock
278
self, fullname: str, path: typing.Sequence[bytes | str] | None = None
280
if fullname in self.namestoblock:
281
return ImportBlockerLoader(fullname)
287
path: Sequence[bytes | str] | None,
288
target: ModuleType | None = None,
289
) -> ModuleSpec | None:
290
loader = self.find_module(fullname, path)
294
return importlib.util.spec_from_loader(fullname, loader)
297
class ModuleStash(MetaPathFinder):
299
Stashes away previously imported modules
301
If we reimport a module the data from coverage is lost, so we reuse the old
306
self, namespace: str, modules: dict[str, ModuleType] = sys.modules
308
self.namespace = namespace
309
self.modules = modules
310
self._data: dict[str, ModuleType] = {}
312
def stash(self) -> None:
313
if self.namespace in self.modules:
314
self._data[self.namespace] = self.modules.pop(self.namespace)
316
for module in list(self.modules.keys()):
317
if module.startswith(self.namespace + "."):
318
self._data[module] = self.modules.pop(module)
320
def pop(self) -> None:
321
self.modules.pop(self.namespace, None)
323
for module in list(self.modules.keys()):
324
if module.startswith(self.namespace + "."):
325
self.modules.pop(module)
327
self.modules.update(self._data)