pytorch

Форк
0
/
_contextlib.py 
152 строки · 5.8 Кб
1
# Extra utilities for working with context managers that should have been
2
# in the standard library but are not
3

4
import functools
5
import inspect
6
import warnings
7
import sys
8
from typing import Any, Callable, TypeVar, cast
9

10
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
11
# 'no_grad' and 'enable_grad').
12
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
13
FuncType = Callable[..., Any]
14
F = TypeVar('F', bound=FuncType)
15

16

17
def _wrap_generator(ctx_factory, func):
18
    """
19
    Wrap each generator invocation with the context manager factory.
20

21
    The input should be a function that returns a context manager,
22
    not a context manager itself, to handle one-shot context managers.
23
    """
24
    @functools.wraps(func)
25
    def generator_context(*args, **kwargs):
26
        gen = func(*args, **kwargs)
27

28
        # Generators are suspended and unsuspended at `yield`, hence we
29
        # make sure the grad mode is properly set every time the execution
30
        # flow returns into the wrapped generator and restored when it
31
        # returns through our `yield` to our caller (see PR #49017).
32
        try:
33
            # Issuing `None` to a generator fires it up
34
            with ctx_factory():
35
                response = gen.send(None)
36

37
            while True:
38
                try:
39
                    # Forward the response to our caller and get its next request
40
                    request = yield response
41

42
                except GeneratorExit:
43
                    # Inform the still active generator about its imminent closure
44
                    with ctx_factory():
45
                        gen.close()
46
                    raise
47

48
                except BaseException:
49
                    # Propagate the exception thrown at us by the caller
50
                    with ctx_factory():
51
                        response = gen.throw(*sys.exc_info())
52

53
                else:
54
                    # Pass the last request to the generator and get its response
55
                    with ctx_factory():
56
                        response = gen.send(request)
57

58
        # We let the exceptions raised above by the generator's `.throw` or
59
        # `.send` methods bubble up to our caller, except for StopIteration
60
        except StopIteration as e:
61
            # The generator informed us that it is done: take whatever its
62
            # returned value (if any) was and indicate that we're done too
63
            # by returning it (see docs for python's return-statement).
64
            return e.value
65

66
    return generator_context
67

68

69
def context_decorator(ctx, func):
70
    """
71
    Like contextlib.ContextDecorator.
72

73
    But with the following differences:
74
    1. Is done by wrapping, rather than inheritance, so it works with context
75
       managers that are implemented from C and thus cannot easily inherit from
76
       Python classes
77
    2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
78
    3. Errors out if you try to wrap a class, because it is ambiguous whether
79
       or not you intended to wrap only the constructor
80

81
    The input argument can either be a context manager (in which case it must
82
    be a multi-shot context manager that can be directly invoked multiple times)
83
    or a callable that produces a context manager.
84
    """
85
    assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
86
        f"Passed in {ctx} is both callable and also a valid context manager "
87
        "(has __enter__), making it ambiguous which interface to use.  If you "
88
        "intended to pass a context manager factory, rewrite your call as "
89
        "context_decorator(lambda: ctx()); if you intended to pass a context "
90
        "manager directly, rewrite your call as context_decorator(lambda: ctx)"
91
    )
92

93
    if not callable(ctx):
94
        def ctx_factory():
95
            return ctx
96
    else:
97
        ctx_factory = ctx
98

99
    if inspect.isclass(func):
100
        raise RuntimeError(
101
            "Cannot decorate classes; it is ambiguous whether or not only the "
102
            "constructor or all methods should have the context manager applied; "
103
            "additionally, decorating a class at definition-site will prevent "
104
            "use of the identifier as a conventional type.  "
105
            "To specify which methods to decorate, decorate each of them "
106
            "individually."
107
        )
108

109
    if inspect.isgeneratorfunction(func):
110
        return _wrap_generator(ctx_factory, func)
111

112
    @functools.wraps(func)
113
    def decorate_context(*args, **kwargs):
114
        with ctx_factory():
115
            return func(*args, **kwargs)
116

117
    return decorate_context
118

119

120
class _DecoratorContextManager:
121
    """Allow a context manager to be used as a decorator."""
122

123
    def __call__(self, orig_func: F) -> F:
124
        if inspect.isclass(orig_func):
125
            warnings.warn("Decorating classes is deprecated and will be disabled in "
126
                          "future versions. You should only decorate functions or methods. "
127
                          "To preserve the current behavior of class decoration, you can "
128
                          "directly decorate the `__init__` method and nothing else.")
129
            func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
130
        else:
131
            func = orig_func
132

133
        return cast(F, context_decorator(self.clone, func))
134

135
    def __enter__(self) -> None:
136
        raise NotImplementedError
137

138
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
139
        raise NotImplementedError
140

141
    def clone(self):
142
        # override this method if your children class takes __init__ parameters
143
        return self.__class__()
144

145

146
class _NoParamDecoratorContextManager(_DecoratorContextManager):
147
    """Allow a context manager to be used as a decorator without parentheses."""
148

149
    def __new__(cls, orig_func=None):
150
        if orig_func is None:
151
            return super().__new__(cls)
152
        return cls()(orig_func)
153

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.