pytorch

Форк
0
194 строки · 7.2 Кб
1
import uuid
2
from collections import OrderedDict
3
from functools import wraps
4
from typing import Callable, Dict, List, Optional, Type
5

6
import torch.nn as nn
7
from torch.distributed._composable_state import _State
8

9

10
def generate_state_key(string="__composable_api_state_key"):
11
    return f"{string}_{str(uuid.uuid4())}"
12

13

14
STATE_KEY = generate_state_key()
15
REGISTRY_KEY = generate_state_key()
16

17

18
# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
19
# we can add args and kwargs here, and then we can detect whether fully_shard
20
# is combined with reentrant activation checkpointing and error out with a clear
21
# message.
22
class RegistryItem:
23
    pass
24

25

26
def contract(state_cls: Type[_State] = _State):
27
    r"""
28
    Decorate a function as a composable distributed API, where the first
29
    argument of the function must be an :class:`nn.Module` instance. The
30
    decorator verifies that the wrapped function does not modify parameter,
31
    buffer or sub-module fully-qualified names (FQN).
32

33
    When a function ``func`` is decorated by ``@contract()``, a
34
    ``.state(module: nn.Module)`` method will be installed to the decorated
35
    function. Then you can retrieve and modify the state on a module by calling
36
    ``func.state(module)``.
37

38
    Example::
39
        >>> # xdoctest: +SKIP
40
        >>> import torch.nn as nn
41
        >>>
42
        >>> class MyModel(nn.Module):
43
        >>>     def __init__(self):
44
        >>>         super().__init__()
45
        >>>         self.l1 = nn.Linear(10, 10)
46
        >>>         self.l2 = nn.Linear(10, 10)
47
        >>>
48
        >>>     def forward(self, x):
49
        >>>         return self.l2(self.l1(x))
50
        >>>
51
        >>> @contract()
52
        >>> def my_feature(module: nn.Module) -> nn.Module:
53
        >>>     my_feature.state(module).some_state = "any value"
54
        >>>     return module
55
        >>>
56
        >>> model = MyModel()
57
        >>> my_feature(model.l1)
58
        >>> assert my_feature.state(model.l1).some_state == "any value"
59
        >>> my_feature(model.l2)
60
        >>> model(torch.randn(2, 10)).sum().backward()
61
    """
62

63
    # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
64
    @wraps(state_cls)
65
    def inner(func):
66
        @wraps(func)
67
        def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:
68
            # get existing global states
69
            default_all_state: Dict[Callable, _State] = OrderedDict()
70
            all_state: Dict[Callable, _State] = module.__dict__.setdefault(  # type: ignore[call-overload]
71
                STATE_KEY, default_all_state
72
            )
73
            assert isinstance(
74
                all_state, dict
75
            ), "Distributed composable API states corrupted"
76

77
            # get global registry
78
            default_registry: Dict[str, RegistryItem] = OrderedDict()
79
            registry: Dict[str, RegistryItem] = module.__dict__.setdefault(  # type: ignore[call-overload]
80
                REGISTRY_KEY, default_registry
81
            )
82

83
            assert isinstance(
84
                registry, dict
85
            ), "Distributed composable API registry corrupted"
86

87
            # make sure the API func has not been applied to the input module yet.
88
            assert func not in all_state and func.__name__ not in registry, (
89
                "Each distinct composable distributed API can only be applied to a "
90
                f"module once. {func.__name__} has already been applied to the "
91
                f"following module.\n{module}"
92
            )
93

94
            # install states specific to the wrapped ``func``
95
            all_state.setdefault(func, state_cls())
96
            # register ``func`` in the global registry by name
97
            registry.setdefault(func.__name__, RegistryItem())
98

99
            orig_named_params = OrderedDict(module.named_parameters())
100
            orig_named_buffers = OrderedDict(
101
                module.named_buffers(remove_duplicate=False)
102
            )
103
            orig_named_modules = OrderedDict(
104
                module.named_modules(remove_duplicate=False)
105
            )
106

107
            updated = func(module, *args, **kwargs)
108

109
            if updated is None:
110
                updated = module
111

112
            new_named_params = OrderedDict(updated.named_parameters())
113
            new_named_buffers = OrderedDict(
114
                updated.named_buffers(remove_duplicate=False)
115
            )
116
            new_named_modules = OrderedDict(
117
                updated.named_modules(remove_duplicate=False)
118
            )
119

120
            assert isinstance(updated, nn.Module), (
121
                "Output of composable distributed APIs must be either None or "
122
                f"nn.Module, but got {type(updated)}"
123
            )
124

125
            def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
126
                if orig_fqns == new_fqns:
127
                    return
128

129
                orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
130
                orig_only = orig_fqn_set - new_fqn_set
131
                new_only = new_fqn_set - orig_fqn_set
132
                if len(orig_only) or len(new_only):
133
                    raise RuntimeError(
134
                        f"{check_key}"
135
                        "Composable distributed API implementations cannot modify "
136
                        "FQNs.\n"
137
                        f"Only in original FQNs: {orig_only},\n"
138
                        f"Only in new FQNs: {new_only}"
139
                    )
140
                else:
141
                    raise RuntimeError(
142
                        f"{check_key}"
143
                        "Composable distributed API implementations cannot modify "
144
                        "the order of FQNs.\n"
145
                        f"Original FQNs: {orig_only}\n"
146
                        f"New FQNs: {new_only}"
147
                    )
148

149
            check_fqn(
150
                list(orig_named_params.keys()),
151
                list(new_named_params.keys()),
152
                "Check parameters, ",
153
            )
154
            check_fqn(
155
                list(orig_named_buffers.keys()),
156
                list(new_named_buffers.keys()),
157
                "Check buffer, ",
158
            )
159
            check_fqn(
160
                list(orig_named_modules.keys()),
161
                list(new_named_modules.keys()),
162
                "Check modules, ",
163
            )
164

165
            # TODO: a stricter verification should also reject changing module
166
            # types and monkey-patching forward() method implementations.
167

168
            # TODO: verify that installed distributed paradigms are compatible with
169
            # each other.
170

171
            return updated
172

173
        def get_state(module: nn.Module) -> Optional[_State]:
174
            return module.__dict__.setdefault(  # type: ignore[call-overload]
175
                STATE_KEY,
176
                {},  # TODO(@yhcharles): this is a temporary fix, need a better way
177
            ).get(
178
                func
179
            )  # type: ignore[call-overload]
180

181
        wrapper.state = get_state  # type: ignore[attr-defined]
182

183
        return wrapper
184

185
    return inner
186

187

188
def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
189
    r"""
190
    Get an ``OrderedDict`` of composable APIs that have been applied to the
191
    ``module``, indexed by the API name. If no API has been applied, then this
192
    returns ``None``.
193
    """
194
    return getattr(module, REGISTRY_KEY, None)
195

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.