urllib3

Форк
0
/
__init__.py 
327 строк · 9.6 Кб
1
from __future__ import annotations
2

3
import errno
4
import importlib.util
5
import logging
6
import os
7
import platform
8
import socket
9
import sys
10
import typing
11
import warnings
12
from collections.abc import Sequence
13
from functools import wraps
14
from importlib.abc import Loader, MetaPathFinder
15
from importlib.machinery import ModuleSpec
16
from types import ModuleType, TracebackType
17

18
import pytest
19

20
try:
21
    try:
22
        import brotlicffi as brotli  # type: ignore[import-not-found]
23
    except ImportError:
24
        import brotli  # type: ignore[import-not-found]
25
except ImportError:
26
    brotli = None
27

28
try:
29
    import zstandard as _unused_module_zstd  # noqa: F401
30
except ImportError:
31
    HAS_ZSTD = False
32
else:
33
    HAS_ZSTD = True
34

35
from urllib3.connectionpool import ConnectionPool
36
from urllib3.exceptions import HTTPWarning
37

38
try:
39
    import urllib3.contrib.pyopenssl as pyopenssl
40
except ImportError:
41
    pyopenssl = None  # type: ignore[assignment]
42

43

44
_RT = typing.TypeVar("_RT")  # return type
45
_TestFuncT = typing.TypeVar("_TestFuncT", bound=typing.Callable[..., typing.Any])
46

47

48
# We need a host that will not immediately close the connection with a TCP
49
# Reset.
50
if platform.system() == "Windows":
51
    # Reserved loopback subnet address
52
    TARPIT_HOST = "127.0.0.0"
53
else:
54
    # Reserved internet scoped address
55
    # https://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
56
    TARPIT_HOST = "240.0.0.0"
57

58
# (Arguments for socket, is it IPv6 address?)
59
VALID_SOURCE_ADDRESSES = [(("::1", 0), True), (("127.0.0.1", 0), False)]
60
# RFC 5737: 192.0.2.0/24 is for testing only.
61
# RFC 3849: 2001:db8::/32 is for documentation only.
62
INVALID_SOURCE_ADDRESSES = [(("192.0.2.255", 0), False), (("2001:db8::1", 0), True)]
63

64
# We use timeouts in three different ways in our tests
65
#
66
# 1. To make sure that the operation timeouts, we can use a short timeout.
67
# 2. To make sure that the test does not hang even if the operation should succeed, we
68
#    want to use a long timeout, even more so on CI where tests can be really slow
69
# 3. To test our timeout logic by using two different values, eg. by using different
70
#    values at the pool level and at the request level.
71
SHORT_TIMEOUT = 0.001
72
LONG_TIMEOUT = 0.1
73
if os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS") == "true":
74
    LONG_TIMEOUT = 0.5
75

76
DUMMY_POOL = ConnectionPool("dummy")
77

78

79
def _can_resolve(host: str) -> bool:
80
    """Returns True if the system can resolve host to an address."""
81
    try:
82
        socket.getaddrinfo(host, None, socket.AF_UNSPEC)
83
        return True
84
    except socket.gaierror:
85
        return False
86

87

88
# Some systems might not resolve "localhost." correctly.
89
# See https://github.com/urllib3/urllib3/issues/1809 and
90
# https://github.com/urllib3/urllib3/pull/1475#issuecomment-440788064.
91
RESOLVES_LOCALHOST_FQDN = _can_resolve("localhost.")
92

93

94
def clear_warnings(cls: type[Warning] = HTTPWarning) -> None:
95
    new_filters = []
96
    for f in warnings.filters:
97
        if issubclass(f[2], cls):
98
            continue
99
        new_filters.append(f)
100
    warnings.filters[:] = new_filters  # type: ignore[index]
101

102

103
def setUp() -> None:
104
    clear_warnings()
105
    warnings.simplefilter("ignore", HTTPWarning)
106

107

108
def notWindows() -> typing.Callable[[_TestFuncT], _TestFuncT]:
109
    """Skips this test on Windows"""
110
    return pytest.mark.skipif(
111
        platform.system() == "Windows",
112
        reason="Test does not run on Windows",
113
    )
114

115

116
def onlyBrotli() -> typing.Callable[[_TestFuncT], _TestFuncT]:
117
    return pytest.mark.skipif(
118
        brotli is None, reason="only run if brotli library is present"
119
    )
120

121

122
def notBrotli() -> typing.Callable[[_TestFuncT], _TestFuncT]:
123
    return pytest.mark.skipif(
124
        brotli is not None, reason="only run if a brotli library is absent"
125
    )
126

127

128
def onlyZstd() -> typing.Callable[[_TestFuncT], _TestFuncT]:
129
    return pytest.mark.skipif(
130
        not HAS_ZSTD, reason="only run if a python-zstandard library is installed"
131
    )
132

133

134
def notZstd() -> typing.Callable[[_TestFuncT], _TestFuncT]:
135
    return pytest.mark.skipif(
136
        HAS_ZSTD,
137
        reason="only run if a python-zstandard library is not installed",
138
    )
139

140

141
_requires_network_has_route = None
142

143

144
def requires_network() -> typing.Callable[[_TestFuncT], _TestFuncT]:
145
    """Helps you skip tests that require the network"""
146

147
    def _is_unreachable_err(err: Exception) -> bool:
148
        return getattr(err, "errno", None) in (
149
            errno.ENETUNREACH,
150
            errno.EHOSTUNREACH,  # For OSX
151
        )
152

153
    def _has_route() -> bool:
154
        try:
155
            sock = socket.create_connection((TARPIT_HOST, 80), 0.0001)
156
            sock.close()
157
            return True
158
        except socket.timeout:
159
            return True
160
        except OSError as e:
161
            if _is_unreachable_err(e):
162
                return False
163
            else:
164
                raise
165

166
    def _skip_if_no_route(f: _TestFuncT) -> _TestFuncT:
167
        """Skip test exuction if network is unreachable"""
168

169
        @wraps(f)
170
        def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
171
            global _requires_network_has_route
172
            if _requires_network_has_route is None:
173
                _requires_network_has_route = _has_route()
174
            if not _requires_network_has_route:
175
                pytest.skip("Can't run the test because the network is unreachable")
176
            return f(*args, **kwargs)
177

178
        return typing.cast(_TestFuncT, wrapper)
179

180
    def _decorator_requires_internet(
181
        decorator: typing.Callable[[_TestFuncT], _TestFuncT]
182
    ) -> typing.Callable[[_TestFuncT], _TestFuncT]:
183
        """Mark a decorator with the "requires_internet" mark"""
184

185
        def wrapper(f: _TestFuncT) -> typing.Any:
186
            return pytest.mark.requires_network(decorator(f))
187

188
        return wrapper
189

190
    return _decorator_requires_internet(_skip_if_no_route)
191

192

193
def resolvesLocalhostFQDN() -> typing.Callable[[_TestFuncT], _TestFuncT]:
194
    """Test requires successful resolving of 'localhost.'"""
195
    return pytest.mark.skipif(
196
        not RESOLVES_LOCALHOST_FQDN,
197
        reason="Can't resolve localhost.",
198
    )
199

200

201
def withPyOpenSSL(test: typing.Callable[..., _RT]) -> typing.Callable[..., _RT]:
202
    @wraps(test)
203
    def wrapper(*args: typing.Any, **kwargs: typing.Any) -> _RT:
204
        if not pyopenssl:
205
            pytest.skip("pyopenssl not available, skipping test.")
206
            return test(*args, **kwargs)
207

208
        pyopenssl.inject_into_urllib3()
209
        result = test(*args, **kwargs)
210
        pyopenssl.extract_from_urllib3()
211
        return result
212

213
    return wrapper
214

215

216
class _ListHandler(logging.Handler):
217
    def __init__(self) -> None:
218
        super().__init__()
219
        self.records: list[logging.LogRecord] = []
220

221
    def emit(self, record: logging.LogRecord) -> None:
222
        self.records.append(record)
223

224

225
class LogRecorder:
226
    def __init__(self, target: logging.Logger = logging.root) -> None:
227
        super().__init__()
228
        self._target = target
229
        self._handler = _ListHandler()
230

231
    @property
232
    def records(self) -> list[logging.LogRecord]:
233
        return self._handler.records
234

235
    def install(self) -> None:
236
        self._target.addHandler(self._handler)
237

238
    def uninstall(self) -> None:
239
        self._target.removeHandler(self._handler)
240

241
    def __enter__(self) -> list[logging.LogRecord]:
242
        self.install()
243
        return self.records
244

245
    def __exit__(
246
        self,
247
        exc_type: type[BaseException] | None,
248
        exc_value: BaseException | None,
249
        traceback: TracebackType | None,
250
    ) -> typing.Literal[False]:
251
        self.uninstall()
252
        return False
253

254

255
class ImportBlockerLoader(Loader):
256
    def __init__(self, fullname: str) -> None:
257
        self._fullname = fullname
258

259
    def load_module(self, fullname: str) -> ModuleType:
260
        raise ImportError(f"import of {fullname} is blocked")
261

262
    def exec_module(self, module: ModuleType) -> None:
263
        raise ImportError(f"import of {self._fullname} is blocked")
264

265

266
class ImportBlocker(MetaPathFinder):
267
    """
268
    Block Imports
269

270
    To be placed on ``sys.meta_path``. This ensures that the modules
271
    specified cannot be imported, even if they are a builtin.
272
    """
273

274
    def __init__(self, *namestoblock: str) -> None:
275
        self.namestoblock = namestoblock
276

277
    def find_module(
278
        self, fullname: str, path: typing.Sequence[bytes | str] | None = None
279
    ) -> Loader | None:
280
        if fullname in self.namestoblock:
281
            return ImportBlockerLoader(fullname)
282
        return None
283

284
    def find_spec(
285
        self,
286
        fullname: str,
287
        path: Sequence[bytes | str] | None,
288
        target: ModuleType | None = None,
289
    ) -> ModuleSpec | None:
290
        loader = self.find_module(fullname, path)
291
        if loader is None:
292
            return None
293

294
        return importlib.util.spec_from_loader(fullname, loader)
295

296

297
class ModuleStash(MetaPathFinder):
298
    """
299
    Stashes away previously imported modules
300

301
    If we reimport a module the data from coverage is lost, so we reuse the old
302
    modules
303
    """
304

305
    def __init__(
306
        self, namespace: str, modules: dict[str, ModuleType] = sys.modules
307
    ) -> None:
308
        self.namespace = namespace
309
        self.modules = modules
310
        self._data: dict[str, ModuleType] = {}
311

312
    def stash(self) -> None:
313
        if self.namespace in self.modules:
314
            self._data[self.namespace] = self.modules.pop(self.namespace)
315

316
        for module in list(self.modules.keys()):
317
            if module.startswith(self.namespace + "."):
318
                self._data[module] = self.modules.pop(module)
319

320
    def pop(self) -> None:
321
        self.modules.pop(self.namespace, None)
322

323
        for module in list(self.modules.keys()):
324
            if module.startswith(self.namespace + "."):
325
                self.modules.pop(module)
326

327
        self.modules.update(self._data)
328

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.