1
# This source code is licensed under the BSD-style license found in the
2
# LICENSE file in the root directory of this source tree.
4
from typing import Dict, Iterable, List, Tuple
9
_MISSING: torch.Tensor = object() # type: ignore[assignment]
12
def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None:
13
if not isinstance(module, torch.nn.Module):
14
raise TypeError(f"{module} is not an instance of torch.nn.Module")
15
if not isinstance(tensor, torch.Tensor) and tensor is not None:
16
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
18
raise KeyError('tensor name can\'t contain "."')
20
raise KeyError('tensor name can\'t be empty string ""')
21
if name in module._parameters:
22
module._parameters[name] = tensor # type: ignore[assignment]
23
elif name in module._buffers:
24
module._buffers[name] = tensor
26
setattr(module, name, tensor)
30
module: "torch.nn.Module",
33
allow_missing: bool = False,
35
if not isinstance(module, torch.nn.Module):
36
raise TypeError(f"{module} is not an instance of torch.nn.Module")
38
tensor is not _MISSING
39
and not isinstance(tensor, torch.Tensor)
40
and tensor is not None
42
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
44
raise KeyError('tensor name can\'t contain "."')
46
raise KeyError('tensor name can\'t be empty string ""')
48
orig_tensor: torch.Tensor
49
if name in module._parameters:
50
orig_tensor = module._parameters[name] # type: ignore[assignment]
51
if tensor is not _MISSING:
52
module._parameters[name] = tensor # type: ignore[assignment]
54
del module._parameters[name]
55
elif name in module._buffers:
56
orig_tensor = module._buffers[name] # type: ignore[assignment]
57
if tensor is not _MISSING:
58
module._buffers[name] = tensor
60
del module._buffers[name]
63
orig_tensor = getattr(module, name)
64
except AttributeError as ex:
67
f"{module._get_name()} has no attribute `{name}`"
69
orig_tensor = _MISSING
71
orig_tensor is not _MISSING
72
and not isinstance(orig_tensor, torch.Tensor)
73
and orig_tensor is not None
76
f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor"
78
if tensor is not _MISSING:
79
setattr(module, name, tensor)
80
elif hasattr(module, name):
86
module: "torch.nn.Module",
88
submodule: "torch.nn.Module",
89
) -> "torch.nn.Module":
90
if not isinstance(module, torch.nn.Module):
91
raise TypeError(f"{module} is not an instance of torch.nn.Module")
92
if not isinstance(submodule, torch.nn.Module):
93
raise TypeError(f"{submodule} is not an instance of torch.nn.Module")
95
raise KeyError('submodule name can\'t contain "."')
97
raise KeyError('submodule name can\'t be empty string ""')
98
if name not in module._modules:
99
raise KeyError(f"submodule {name} does not exist")
101
orig_submodule = module._modules[name]
102
if not isinstance(orig_submodule, torch.nn.Module):
103
raise TypeError(f"{name} attribute is not an instance of torch.nn.Module")
104
module._modules[name] = submodule
105
return orig_submodule
108
class NamedMemberAccessor:
110
A class that provides a way to access the submodules and parameters/buffers of a module.
112
It provides caching mechanism to speed up submodule lookups.
113
This is useful for functional programming to manipulate the module state.
116
def __init__(self, module: "torch.nn.Module") -> None:
118
self.memo: Dict[str, torch.nn.Module] = {}
120
# Nested attribute access
122
def get_submodule(self, name: str) -> "torch.nn.Module":
124
Return the submodule specified by the given path.
126
For example, to get the submodule mod.layer1.conv1,
127
use accessor.get_submodule("layer1.conv1")
129
Compare to mod.get_submodule("layer1.conv1"), this method will cache the
130
intermediate submodule access to speed up future lookups.
136
return self.memo[name]
138
prefix, dot, attr = name.rpartition(".")
140
module = self.get_submodule(prefix)
144
submodule = getattr(module, attr)
145
except AttributeError as ex:
146
raise AttributeError(
147
f"{module._get_name()} has no attribute `{attr}`"
149
if not isinstance(submodule, torch.nn.Module):
150
raise TypeError( # noqa: TRY200
151
f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module"
153
self.memo[name] = submodule
156
def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module":
158
Swap the submodule specified by the given ``path`` to ``value``.
160
For example, to swap the attribute mod.layer1.conv1 use
161
``accessor.swap_submodule("layer1.conv1", conv2)``.
163
prefix, _, attr = path.rpartition(".")
164
return swap_submodule(self.get_submodule(prefix), attr, value)
166
def get_tensor(self, name: str) -> torch.Tensor:
168
Get the tensor specified by the given path to value.
170
For example, to get the attribute mod.layer1.conv1.weight,
171
use accessor.get_tensor('layer1.conv1.weight')
173
Compare to mod.get_parameter("layer1.conv1.weight"), this method will
174
cache the intermediate submodule access to speed up future lookups.
176
prefix, _, attr = name.rpartition(".")
177
submodule = self.get_submodule(prefix)
179
tensor = getattr(submodule, attr)
180
except AttributeError as ex:
181
raise AttributeError(
182
f"{submodule._get_name()} has no attribute `{name}`"
184
if not isinstance(tensor, torch.Tensor) and tensor is not None:
185
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
186
return tensor # type: ignore[return-value]
188
def set_tensor(self, name: str, value: torch.Tensor) -> None:
190
Set the attribute specified by the given path to value.
192
For example, to set the attribute mod.layer1.conv1.weight,
193
use accessor.set_tensor("layer1.conv1.weight", value)
195
prefix, _, attr = name.rpartition(".")
196
set_tensor(self.get_submodule(prefix), attr, value)
198
def del_tensor(self, name: str) -> None:
200
Delete the attribute specified by the given path.
202
For example, to delete the attribute mod.layer1.conv1.weight,
203
use accessor.del_tensor("layer1.conv1.weight")
205
prefix, _, attr = name.rpartition(".")
206
submodule = self.get_submodule(prefix)
208
delattr(submodule, attr)
209
except AttributeError as ex:
210
raise AttributeError(
211
f"{submodule._get_name()} has no attribute `{name}`"
215
self, name: str, value: torch.Tensor, allow_missing: bool = False
218
Swap the attribute specified by the given path to value.
220
For example, to swap the attribute mod.layer1.conv1.weight,
221
use accessor.swap_tensor("layer1.conv1.weight", value)
223
prefix, _, attr = name.rpartition(".")
225
self.get_submodule(prefix), attr, value, allow_missing=allow_missing
230
def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]:
232
Get the tensors specified by the given paths.
234
For example, to get the attributes mod.layer1.conv1.weight and
235
mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight",
236
"layer1.conv1.bias"])
238
return [self.get_tensor(name) for name in names]
240
def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None:
242
Set the attributes specified by the given paths to values.
244
For example, to set the attributes mod.layer1.conv1.weight and
245
mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight",
246
"layer1.conv1.bias"], [weight, bias])
248
if not isinstance(names, (list, tuple)):
250
if not isinstance(values, (list, tuple)):
251
values = list(values)
252
assert len(names) == len(values), "names and values must have the same length"
254
for name, value in zip(names, values):
255
self.set_tensor(name, value)
257
def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None:
259
Set the attributes specified by the given paths to values.
261
For example, to set the attributes mod.layer1.conv1.weight and
262
mod.layer1.conv1.bias, use accessor.set_tensors_dict({
263
"layer1.conv1.weight": weight,
264
"layer1.conv1.bias": bias,
267
for name, value in named_tensors.items():
268
self.set_tensor(name, value)
270
def del_tensors(self, names: Iterable[str]) -> None:
272
Delete the attributes specified by the given paths.
274
For example, to delete the attributes mod.layer1.conv1.weight and
275
mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight",
276
"layer1.conv1.bias"])
279
self.del_tensor(name)
283
names: Iterable[str],
284
values: Iterable[torch.Tensor],
285
allow_missing: bool = False,
286
) -> List[torch.Tensor]:
288
Swap the attributes specified by the given paths to values.
290
For example, to swap the attributes mod.layer1.conv1.weight and
291
mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight",
292
"layer1.conv1.bias"], [weight, bias])
294
if not isinstance(names, (list, tuple)):
296
if not isinstance(values, (list, tuple)):
297
values = list(values)
298
assert len(names) == len(values), "names and values must have the same length"
301
self.swap_tensor(name, value, allow_missing=allow_missing)
302
for name, value in zip(names, values)
305
def swap_tensors_dict(
306
self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False
307
) -> Tuple[Dict[str, torch.Tensor], List[str]]:
309
Swap the attributes specified by the given paths to values.
311
For example, to swap the attributes mod.layer1.conv1.weight and
312
mod.layer1.conv1.bias, use accessor.swap_tensors_dict({
313
"layer1.conv1.weight": weight,
314
"layer1.conv1.bias": bias,
317
orig_named_tensors = {}
320
for name, tensor in named_tensors.items():
321
orig_tensor = self.swap_tensor(name, tensor, allow_missing=True)
322
if orig_tensor is _MISSING:
323
missing_keys.append(name)
324
orig_named_tensors[name] = orig_tensor
326
# Swap back if any exception occurs
327
for name, orig_tensor in orig_named_tensors.items():
328
self.swap_tensor(name, orig_tensor, allow_missing=True)
330
if missing_keys and not allow_missing:
331
# Swap back if any key is missing when allow_missing is False
332
for name, orig_tensor in orig_named_tensors.items():
333
self.swap_tensor(name, orig_tensor, allow_missing=True)
334
raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
335
return orig_named_tensors, missing_keys
337
def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]:
338
"""Check that the given keys are valid."""
340
valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)}
341
missing_keys = valid_keys - keys
342
unexpected_keys = keys - valid_keys
343
return sorted(missing_keys), sorted(unexpected_keys)
347
def named_parameters(
349
remove_duplicate: bool = True,
350
) -> Iterable[Tuple[str, torch.Tensor]]:
351
"""Iterate over all the parameters in the module."""
352
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
356
remove_duplicate: bool = True,
357
) -> Iterable[Tuple[str, torch.Tensor]]:
358
"""Iterate over all the buffers in the module."""
359
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
363
remove_duplicate: bool = True,
364
) -> Iterable[Tuple[str, torch.Tensor]]:
365
"""Iterate over all the tensors in the module."""
366
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
367
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
371
remove_duplicate: bool = True,
372
) -> Iterable[Tuple[str, "torch.nn.Module"]]:
373
"""Iterate over all the modules in the module."""
374
yield from self.module.named_modules(remove_duplicate=remove_duplicate)