pytorch

Форк
0
/
functional_call.py 
248 строк · 10.1 Кб
1
from collections import Counter
2
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3

4
import torch
5
import torch.nn as nn
6
from torch import Tensor
7
from torch._functorch.utils import exposed_in
8

9

10
@exposed_in("torch.func")
11
def functional_call(
12
    module: "torch.nn.Module",
13
    parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
14
    args: Union[Any, Tuple],
15
    kwargs: Optional[Dict[str, Any]] = None,
16
    *,
17
    tie_weights: bool = True,
18
    strict: bool = False,
19
):
20
    r"""Performs a functional call on the module by replacing the module parameters
21
    and buffers with the provided ones.
22

23
    .. note:: If the module has active parametrizations, passing a value in the
24
        :attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter
25
        name will completely disable the parametrization.
26
        If you want to apply the parametrization function to the value passed
27
        please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
28

29
    .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
30
        in the ``parameter_and_buffer_dicts`` input.
31

32

33
         Example::
34

35
            >>> a = {'foo': torch.zeros(())}
36
            >>> # xdoctest: +SKIP
37
            >>> mod = Foo()  # does self.foo = self.foo + 1
38
            >>> print(mod.foo)  # tensor(0.)
39
            >>> functional_call(mod, a, torch.ones(()))
40
            >>> print(mod.foo)  # tensor(0.)
41
            >>> print(a['foo'])  # tensor(1.)
42

43
    .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
44
        tie_weights flag.
45

46
        Example::
47

48
            >>> a = {'foo': torch.zeros(())}
49
            >>> # xdoctest: +SKIP
50
            >>> mod = Foo()  # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
51
            >>> print(mod.foo)  # tensor(1.)
52
            >>> mod(torch.zeros(()))  # tensor(2.)
53
            >>> functional_call(mod, a, torch.zeros(()))  # tensor(0.) since it will change self.foo_tied too
54
            >>> functional_call(mod, a, torch.zeros(()), tie_weights=False)  # tensor(1.)--self.foo_tied is not updated
55
            >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
56
            >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
57

58
    An example of passing multiple dictionaries
59

60
    .. code-block:: python
61

62
            a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)})  # two separate dictionaries
63
            mod = nn.Bar(1, 1)  # return self.weight @ x + self.buffer
64
            print(mod.weight)  # tensor(...)
65
            print(mod.buffer)  # tensor(...)
66
            x = torch.randn((1, 1))
67
            print(x)
68
            functional_call(mod, a, x)  # same as x
69
            print(mod.weight)  # same as before functional_call
70

71

72
    And here is an example of applying the grad transform over the parameters
73
    of a model.
74

75
    .. code-block:: python
76

77
        import torch
78
        import torch.nn as nn
79
        from torch.func import functional_call, grad
80

81
        x = torch.randn(4, 3)
82
        t = torch.randn(4, 3)
83
        model = nn.Linear(3, 3)
84

85
        def compute_loss(params, x, t):
86
            y = functional_call(model, params, x)
87
            return nn.functional.mse_loss(y, t)
88

89
        grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
90

91
    .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the
92
        parameters for better performance and memory usage
93

94
        Example::
95

96
            >>> detached_params = {k: v.detach() for k, v in model.named_parameters()}
97
            >>> grad_weights = grad(compute_loss)(detached_params, x, t)
98
            >>> grad_weights.grad_fn  # None--it's not tracking gradients outside of grad
99

100
        This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking
101
        outside of the transforms, this will result in less memory usage and faster speeds.
102

103
    Args:
104
        module (torch.nn.Module): the module to call
105
        parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
106
            the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can
107
            be used together
108
        args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
109
        kwargs (dict): keyword arguments to be passed to the module call
110
        tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
111
            tied in the reparameterized version. Therefore, if True and different values are passed for the tied
112
            parameters and buffers, it will error. If False, it will not respect the originally tied parameters and
113
            buffers unless the values passed for both weights are the same. Default: True.
114
        strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
115
            buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
116
            error. Default: False.
117

118
    Returns:
119
        Any: the result of calling ``module``.
120
    """
121
    if isinstance(parameter_and_buffer_dicts, dict):
122
        parameters_and_buffers = parameter_and_buffer_dicts
123
    elif isinstance(parameter_and_buffer_dicts, Sequence):
124
        if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
125
            raise ValueError(
126
                "Expected all elements of parameter_and_buffer_dicts to be dictionaries"
127
            )
128
        all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
129
        repeated_keys = [key for key, n in Counter(all_keys).items() if n > 1]
130
        if len(repeated_keys) > 0:
131
            raise ValueError(
132
                f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
133
            )
134
        parameters_and_buffers = {
135
            k: v for d in parameter_and_buffer_dicts for k, v in d.items()
136
        }
137
    else:
138
        raise ValueError(
139
            f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
140
            f"but got {type(parameter_and_buffer_dicts)}"
141
        )
142

143
    return nn.utils.stateless._functional_call(
144
        module,
145
        parameters_and_buffers,
146
        args,
147
        kwargs,
148
        tie_weights=tie_weights,
149
        strict=strict,
150
    )
151

152

153
@exposed_in("torch.func")
154
def stack_module_state(
155
    models: List[nn.Module],
156
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
157
    """stack_module_state(models) -> params, buffers
158

159
    Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
160

161
    Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries
162
    that stack all of their parameters and buffers together, indexed by name.
163
    The stacked parameters are optimizable (i.e. they are new leaf nodes in the
164
    autograd history that are unrelated to the original parameters and can be
165
    passed directly to an optimizer).
166

167
    Here's an example of how to ensemble over a very simple model:
168

169
    .. code-block:: python
170

171
        num_models = 5
172
        batch_size = 64
173
        in_features, out_features = 3, 3
174
        models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
175
        data = torch.randn(batch_size, 3)
176

177
        def wrapper(params, buffers, data):
178
            return torch.func.functional_call(model[0], (params, buffers), data)
179

180
        params, buffers = stack_module_state(models)
181
        output = vmap(wrapper, (0, 0, None))(params, buffers, data)
182

183
        assert output.shape == (num_models, batch_size, out_features)
184

185
    When there's submodules, this follows state dict naming conventions
186

187
    .. code-block:: python
188

189
        import torch.nn as nn
190
        class Foo(nn.Module):
191
            def __init__(self, in_features, out_features):
192
                super().__init__()
193
                hidden = 4
194
                self.l1 = nn.Linear(in_features, hidden)
195
                self.l2 = nn.Linear(hidden, out_features)
196

197
            def forward(self, x):
198
                return self.l2(self.l1(x))
199

200
        num_models = 5
201
        in_features, out_features = 3, 3
202
        models = [Foo(in_features, out_features) for i in range(num_models)]
203
        params, buffers = stack_module_state(models)
204
        print(list(params.keys()))  # "l1.weight", "l1.bias", "l2.weight", "l2.bias"
205

206
    .. warning::
207
        All of the modules being stacked together must be the same (except for
208
        the values of their parameters/buffers). For example, they should be in the
209
        same mode (training vs eval).
210
    """
211
    if len(models) == 0:
212
        raise RuntimeError("stack_module_state: Expected at least one model, got 0.")
213
    if not (all(m.training for m in models) or all(not m.training for m in models)):
214
        raise RuntimeError(
215
            "stack_module_state: Expected all models to have the same training/eval mode."
216
        )
217
    model0_typ = type(models[0])
218
    if not all(type(m) == model0_typ for m in models):
219
        raise RuntimeError(
220
            "stack_module_state: Expected all models to be of the same class."
221
        )
222
    all_params = [dict(model.named_parameters()) for model in models]
223
    params = {
224
        k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
225
        for k in all_params[0]
226
    }
227
    all_buffers = [dict(model.named_buffers()) for model in models]
228
    buffers = {
229
        k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
230
        for k in all_buffers[0]
231
    }
232

233
    return params, buffers
234

235

236
def construct_stacked_leaf(
237
    tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
238
) -> Tensor:
239
    all_requires_grad = all(t.requires_grad for t in tensors)
240
    none_requires_grad = all(not t.requires_grad for t in tensors)
241
    if not all_requires_grad and not none_requires_grad:
242
        raise RuntimeError(
243
            f"Expected {name} from each model to have the same .requires_grad"
244
        )
245
    result = torch.stack(tensors)
246
    if all_requires_grad:
247
        result = result.detach().requires_grad_()
248
    return result
249

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.