pytorch
248 строк · 10.1 Кб
1from collections import Counter
2from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3
4import torch
5import torch.nn as nn
6from torch import Tensor
7from torch._functorch.utils import exposed_in
8
9
10@exposed_in("torch.func")
11def functional_call(
12module: "torch.nn.Module",
13parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
14args: Union[Any, Tuple],
15kwargs: Optional[Dict[str, Any]] = None,
16*,
17tie_weights: bool = True,
18strict: bool = False,
19):
20r"""Performs a functional call on the module by replacing the module parameters
21and buffers with the provided ones.
22
23.. note:: If the module has active parametrizations, passing a value in the
24:attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter
25name will completely disable the parametrization.
26If you want to apply the parametrization function to the value passed
27please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
28
29.. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
30in the ``parameter_and_buffer_dicts`` input.
31
32
33Example::
34
35>>> a = {'foo': torch.zeros(())}
36>>> # xdoctest: +SKIP
37>>> mod = Foo() # does self.foo = self.foo + 1
38>>> print(mod.foo) # tensor(0.)
39>>> functional_call(mod, a, torch.ones(()))
40>>> print(mod.foo) # tensor(0.)
41>>> print(a['foo']) # tensor(1.)
42
43.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
44tie_weights flag.
45
46Example::
47
48>>> a = {'foo': torch.zeros(())}
49>>> # xdoctest: +SKIP
50>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
51>>> print(mod.foo) # tensor(1.)
52>>> mod(torch.zeros(())) # tensor(2.)
53>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
54>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
55>>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
56>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
57
58An example of passing multiple dictionaries
59
60.. code-block:: python
61
62a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries
63mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer
64print(mod.weight) # tensor(...)
65print(mod.buffer) # tensor(...)
66x = torch.randn((1, 1))
67print(x)
68functional_call(mod, a, x) # same as x
69print(mod.weight) # same as before functional_call
70
71
72And here is an example of applying the grad transform over the parameters
73of a model.
74
75.. code-block:: python
76
77import torch
78import torch.nn as nn
79from torch.func import functional_call, grad
80
81x = torch.randn(4, 3)
82t = torch.randn(4, 3)
83model = nn.Linear(3, 3)
84
85def compute_loss(params, x, t):
86y = functional_call(model, params, x)
87return nn.functional.mse_loss(y, t)
88
89grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
90
91.. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the
92parameters for better performance and memory usage
93
94Example::
95
96>>> detached_params = {k: v.detach() for k, v in model.named_parameters()}
97>>> grad_weights = grad(compute_loss)(detached_params, x, t)
98>>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad
99
100This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking
101outside of the transforms, this will result in less memory usage and faster speeds.
102
103Args:
104module (torch.nn.Module): the module to call
105parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
106the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can
107be used together
108args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
109kwargs (dict): keyword arguments to be passed to the module call
110tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
111tied in the reparameterized version. Therefore, if True and different values are passed for the tied
112parameters and buffers, it will error. If False, it will not respect the originally tied parameters and
113buffers unless the values passed for both weights are the same. Default: True.
114strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
115buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
116error. Default: False.
117
118Returns:
119Any: the result of calling ``module``.
120"""
121if isinstance(parameter_and_buffer_dicts, dict):
122parameters_and_buffers = parameter_and_buffer_dicts
123elif isinstance(parameter_and_buffer_dicts, Sequence):
124if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
125raise ValueError(
126"Expected all elements of parameter_and_buffer_dicts to be dictionaries"
127)
128all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
129repeated_keys = [key for key, n in Counter(all_keys).items() if n > 1]
130if len(repeated_keys) > 0:
131raise ValueError(
132f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
133)
134parameters_and_buffers = {
135k: v for d in parameter_and_buffer_dicts for k, v in d.items()
136}
137else:
138raise ValueError(
139f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
140f"but got {type(parameter_and_buffer_dicts)}"
141)
142
143return nn.utils.stateless._functional_call(
144module,
145parameters_and_buffers,
146args,
147kwargs,
148tie_weights=tie_weights,
149strict=strict,
150)
151
152
153@exposed_in("torch.func")
154def stack_module_state(
155models: List[nn.Module],
156) -> Tuple[Dict[str, Any], Dict[str, Any]]:
157"""stack_module_state(models) -> params, buffers
158
159Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
160
161Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries
162that stack all of their parameters and buffers together, indexed by name.
163The stacked parameters are optimizable (i.e. they are new leaf nodes in the
164autograd history that are unrelated to the original parameters and can be
165passed directly to an optimizer).
166
167Here's an example of how to ensemble over a very simple model:
168
169.. code-block:: python
170
171num_models = 5
172batch_size = 64
173in_features, out_features = 3, 3
174models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
175data = torch.randn(batch_size, 3)
176
177def wrapper(params, buffers, data):
178return torch.func.functional_call(model[0], (params, buffers), data)
179
180params, buffers = stack_module_state(models)
181output = vmap(wrapper, (0, 0, None))(params, buffers, data)
182
183assert output.shape == (num_models, batch_size, out_features)
184
185When there's submodules, this follows state dict naming conventions
186
187.. code-block:: python
188
189import torch.nn as nn
190class Foo(nn.Module):
191def __init__(self, in_features, out_features):
192super().__init__()
193hidden = 4
194self.l1 = nn.Linear(in_features, hidden)
195self.l2 = nn.Linear(hidden, out_features)
196
197def forward(self, x):
198return self.l2(self.l1(x))
199
200num_models = 5
201in_features, out_features = 3, 3
202models = [Foo(in_features, out_features) for i in range(num_models)]
203params, buffers = stack_module_state(models)
204print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias"
205
206.. warning::
207All of the modules being stacked together must be the same (except for
208the values of their parameters/buffers). For example, they should be in the
209same mode (training vs eval).
210"""
211if len(models) == 0:
212raise RuntimeError("stack_module_state: Expected at least one model, got 0.")
213if not (all(m.training for m in models) or all(not m.training for m in models)):
214raise RuntimeError(
215"stack_module_state: Expected all models to have the same training/eval mode."
216)
217model0_typ = type(models[0])
218if not all(type(m) == model0_typ for m in models):
219raise RuntimeError(
220"stack_module_state: Expected all models to be of the same class."
221)
222all_params = [dict(model.named_parameters()) for model in models]
223params = {
224k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
225for k in all_params[0]
226}
227all_buffers = [dict(model.named_buffers()) for model in models]
228buffers = {
229k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
230for k in all_buffers[0]
231}
232
233return params, buffers
234
235
236def construct_stacked_leaf(
237tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
238) -> Tensor:
239all_requires_grad = all(t.requires_grad for t in tensors)
240none_requires_grad = all(not t.requires_grad for t in tensors)
241if not all_requires_grad and not none_requires_grad:
242raise RuntimeError(
243f"Expected {name} from each model to have the same .requires_grad"
244)
245result = torch.stack(tensors)
246if all_requires_grad:
247result = result.detach().requires_grad_()
248return result
249