10
def __init__(self, cls, allow_default):
12
self.allow_default = allow_default
13
self._local_stack = threading.local()
17
if not hasattr(self._local_stack, 'obj'):
18
self._local_stack.obj = []
19
return self._local_stack.obj
21
def enter(self, value):
22
self._stack.append(value)
24
def exit(self, value):
25
assert len(self._stack) > 0, 'Context %s is empty.' % self.cls
26
assert self._stack.pop() == value
28
def get_active(self, required=True):
29
if len(self._stack) == 0:
32
assert self.allow_default, (
33
'Context %s is required but none is active.' % self.cls)
34
self.enter(self.cls())
35
return self._stack[-1]
38
class _ContextRegistry:
43
if cls not in self._ctxs:
44
assert issubclass(cls, Managed), "must be a context managed class, got {}".format(cls)
45
self._ctxs[cls] = _ContextInfo(cls, allow_default=issubclass(cls, DefaultManaged))
46
return self._ctxs[cls]
49
_CONTEXT_REGISTRY = _ContextRegistry()
52
def _context_registry():
53
global _CONTEXT_REGISTRY
54
return _CONTEXT_REGISTRY
57
def _get_managed_classes(obj):
59
cls for cls in inspect.getmro(obj.__class__)
60
if issubclass(cls, Managed) and cls != Managed and cls != DefaultManaged
67
Managed makes the inheritted class a context managed class.
69
class Foo(Managed): ...
72
assert f == Foo.current()
76
def current(cls, value=None, required=True):
77
ctx_info = _context_registry().get(cls)
79
assert isinstance(value, cls), (
80
'Wrong context type. Expected: %s, got %s.' % (cls, type(value)))
82
return ctx_info.get_active(required=required)
85
for cls in _get_managed_classes(self):
86
_context_registry().get(cls).enter(self)
89
def __exit__(self, *args):
90
for cls in _get_managed_classes(self):
91
_context_registry().get(cls).exit(self)
93
def __call__(self, func):
94
@functools.wraps(func)
95
def wrapper(*args, **kwargs):
97
return func(*args, **kwargs)
101
class DefaultManaged(Managed):
103
DefaultManaged is similar to Managed but if there is no parent when
104
current() is called it makes a new one.