pytorch
194 строки · 7.2 Кб
1import uuid2from collections import OrderedDict3from functools import wraps4from typing import Callable, Dict, List, Optional, Type5
6import torch.nn as nn7from torch.distributed._composable_state import _State8
9
10def generate_state_key(string="__composable_api_state_key"):11return f"{string}_{str(uuid.uuid4())}"12
13
14STATE_KEY = generate_state_key()15REGISTRY_KEY = generate_state_key()16
17
18# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
19# we can add args and kwargs here, and then we can detect whether fully_shard
20# is combined with reentrant activation checkpointing and error out with a clear
21# message.
22class RegistryItem:23pass24
25
26def contract(state_cls: Type[_State] = _State):27r"""28Decorate a function as a composable distributed API, where the first
29argument of the function must be an :class:`nn.Module` instance. The
30decorator verifies that the wrapped function does not modify parameter,
31buffer or sub-module fully-qualified names (FQN).
32
33When a function ``func`` is decorated by ``@contract()``, a
34``.state(module: nn.Module)`` method will be installed to the decorated
35function. Then you can retrieve and modify the state on a module by calling
36``func.state(module)``.
37
38Example::
39>>> # xdoctest: +SKIP
40>>> import torch.nn as nn
41>>>
42>>> class MyModel(nn.Module):
43>>> def __init__(self):
44>>> super().__init__()
45>>> self.l1 = nn.Linear(10, 10)
46>>> self.l2 = nn.Linear(10, 10)
47>>>
48>>> def forward(self, x):
49>>> return self.l2(self.l1(x))
50>>>
51>>> @contract()
52>>> def my_feature(module: nn.Module) -> nn.Module:
53>>> my_feature.state(module).some_state = "any value"
54>>> return module
55>>>
56>>> model = MyModel()
57>>> my_feature(model.l1)
58>>> assert my_feature.state(model.l1).some_state == "any value"
59>>> my_feature(model.l2)
60>>> model(torch.randn(2, 10)).sum().backward()
61"""
62
63# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package64@wraps(state_cls)65def inner(func):66@wraps(func)67def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:68# get existing global states69default_all_state: Dict[Callable, _State] = OrderedDict()70all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]71STATE_KEY, default_all_state72)73assert isinstance(74all_state, dict75), "Distributed composable API states corrupted"76
77# get global registry78default_registry: Dict[str, RegistryItem] = OrderedDict()79registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]80REGISTRY_KEY, default_registry81)82
83assert isinstance(84registry, dict85), "Distributed composable API registry corrupted"86
87# make sure the API func has not been applied to the input module yet.88assert func not in all_state and func.__name__ not in registry, (89"Each distinct composable distributed API can only be applied to a "90f"module once. {func.__name__} has already been applied to the "91f"following module.\n{module}"92)93
94# install states specific to the wrapped ``func``95all_state.setdefault(func, state_cls())96# register ``func`` in the global registry by name97registry.setdefault(func.__name__, RegistryItem())98
99orig_named_params = OrderedDict(module.named_parameters())100orig_named_buffers = OrderedDict(101module.named_buffers(remove_duplicate=False)102)103orig_named_modules = OrderedDict(104module.named_modules(remove_duplicate=False)105)106
107updated = func(module, *args, **kwargs)108
109if updated is None:110updated = module111
112new_named_params = OrderedDict(updated.named_parameters())113new_named_buffers = OrderedDict(114updated.named_buffers(remove_duplicate=False)115)116new_named_modules = OrderedDict(117updated.named_modules(remove_duplicate=False)118)119
120assert isinstance(updated, nn.Module), (121"Output of composable distributed APIs must be either None or "122f"nn.Module, but got {type(updated)}"123)124
125def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):126if orig_fqns == new_fqns:127return128
129orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)130orig_only = orig_fqn_set - new_fqn_set131new_only = new_fqn_set - orig_fqn_set132if len(orig_only) or len(new_only):133raise RuntimeError(134f"{check_key}"135"Composable distributed API implementations cannot modify "136"FQNs.\n"137f"Only in original FQNs: {orig_only},\n"138f"Only in new FQNs: {new_only}"139)140else:141raise RuntimeError(142f"{check_key}"143"Composable distributed API implementations cannot modify "144"the order of FQNs.\n"145f"Original FQNs: {orig_only}\n"146f"New FQNs: {new_only}"147)148
149check_fqn(150list(orig_named_params.keys()),151list(new_named_params.keys()),152"Check parameters, ",153)154check_fqn(155list(orig_named_buffers.keys()),156list(new_named_buffers.keys()),157"Check buffer, ",158)159check_fqn(160list(orig_named_modules.keys()),161list(new_named_modules.keys()),162"Check modules, ",163)164
165# TODO: a stricter verification should also reject changing module166# types and monkey-patching forward() method implementations.167
168# TODO: verify that installed distributed paradigms are compatible with169# each other.170
171return updated172
173def get_state(module: nn.Module) -> Optional[_State]:174return module.__dict__.setdefault( # type: ignore[call-overload]175STATE_KEY,176{}, # TODO(@yhcharles): this is a temporary fix, need a better way177).get(178func
179) # type: ignore[call-overload]180
181wrapper.state = get_state # type: ignore[attr-defined]182
183return wrapper184
185return inner186
187
188def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:189r"""190Get an ``OrderedDict`` of composable APIs that have been applied to the
191``module``, indexed by the API name. If no API has been applied, then this
192returns ``None``.
193"""
194return getattr(module, REGISTRY_KEY, None)195