pytorch

Форк
0
/
_named_member_accessor.py 
374 строки · 13.9 Кб
1
# This source code is licensed under the BSD-style license found in the
2
# LICENSE file in the root directory of this source tree.
3

4
from typing import Dict, Iterable, List, Tuple
5

6
import torch
7

8

9
_MISSING: torch.Tensor = object()  # type: ignore[assignment]
10

11

12
def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None:
13
    if not isinstance(module, torch.nn.Module):
14
        raise TypeError(f"{module} is not an instance of torch.nn.Module")
15
    if not isinstance(tensor, torch.Tensor) and tensor is not None:
16
        raise TypeError(f"{tensor} is not an instance of torch.Tensor")
17
    if "." in name:
18
        raise KeyError('tensor name can\'t contain "."')
19
    if name == "":
20
        raise KeyError('tensor name can\'t be empty string ""')
21
    if name in module._parameters:
22
        module._parameters[name] = tensor  # type: ignore[assignment]
23
    elif name in module._buffers:
24
        module._buffers[name] = tensor
25
    else:
26
        setattr(module, name, tensor)
27

28

29
def swap_tensor(
30
    module: "torch.nn.Module",
31
    name: str,
32
    tensor: torch.Tensor,
33
    allow_missing: bool = False,
34
) -> torch.Tensor:
35
    if not isinstance(module, torch.nn.Module):
36
        raise TypeError(f"{module} is not an instance of torch.nn.Module")
37
    if (
38
        tensor is not _MISSING
39
        and not isinstance(tensor, torch.Tensor)
40
        and tensor is not None
41
    ):
42
        raise TypeError(f"{tensor} is not an instance of torch.Tensor")
43
    if "." in name:
44
        raise KeyError('tensor name can\'t contain "."')
45
    if name == "":
46
        raise KeyError('tensor name can\'t be empty string ""')
47

48
    orig_tensor: torch.Tensor
49
    if name in module._parameters:
50
        orig_tensor = module._parameters[name]  # type: ignore[assignment]
51
        if tensor is not _MISSING:
52
            module._parameters[name] = tensor  # type: ignore[assignment]
53
        else:
54
            del module._parameters[name]
55
    elif name in module._buffers:
56
        orig_tensor = module._buffers[name]  # type: ignore[assignment]
57
        if tensor is not _MISSING:
58
            module._buffers[name] = tensor
59
        else:
60
            del module._buffers[name]
61
    else:
62
        try:
63
            orig_tensor = getattr(module, name)
64
        except AttributeError as ex:
65
            if not allow_missing:
66
                raise AttributeError(
67
                    f"{module._get_name()} has no attribute `{name}`"
68
                ) from ex
69
            orig_tensor = _MISSING
70
        if (
71
            orig_tensor is not _MISSING
72
            and not isinstance(orig_tensor, torch.Tensor)
73
            and orig_tensor is not None
74
        ):
75
            raise TypeError(
76
                f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor"
77
            )
78
        if tensor is not _MISSING:
79
            setattr(module, name, tensor)
80
        elif hasattr(module, name):
81
            delattr(module, name)
82
    return orig_tensor
83

84

85
def swap_submodule(
86
    module: "torch.nn.Module",
87
    name: str,
88
    submodule: "torch.nn.Module",
89
) -> "torch.nn.Module":
90
    if not isinstance(module, torch.nn.Module):
91
        raise TypeError(f"{module} is not an instance of torch.nn.Module")
92
    if not isinstance(submodule, torch.nn.Module):
93
        raise TypeError(f"{submodule} is not an instance of torch.nn.Module")
94
    if "." in name:
95
        raise KeyError('submodule name can\'t contain "."')
96
    if name == "":
97
        raise KeyError('submodule name can\'t be empty string ""')
98
    if name not in module._modules:
99
        raise KeyError(f"submodule {name} does not exist")
100

101
    orig_submodule = module._modules[name]
102
    if not isinstance(orig_submodule, torch.nn.Module):
103
        raise TypeError(f"{name} attribute is not an instance of torch.nn.Module")
104
    module._modules[name] = submodule
105
    return orig_submodule
106

107

108
class NamedMemberAccessor:
109
    """
110
    A class that provides a way to access the submodules and parameters/buffers of a module.
111

112
    It provides caching mechanism to speed up submodule lookups.
113
    This is useful for functional programming to manipulate the module state.
114
    """
115

116
    def __init__(self, module: "torch.nn.Module") -> None:
117
        self.module = module
118
        self.memo: Dict[str, torch.nn.Module] = {}
119

120
    # Nested attribute access
121

122
    def get_submodule(self, name: str) -> "torch.nn.Module":
123
        """
124
        Return the submodule specified by the given path.
125

126
        For example, to get the submodule mod.layer1.conv1,
127
        use accessor.get_submodule("layer1.conv1")
128

129
        Compare to mod.get_submodule("layer1.conv1"), this method will cache the
130
        intermediate submodule access to speed up future lookups.
131
        """
132
        if not name:
133
            return self.module
134

135
        try:
136
            return self.memo[name]
137
        except KeyError:
138
            prefix, dot, attr = name.rpartition(".")
139
            if dot:
140
                module = self.get_submodule(prefix)
141
            else:
142
                module = self.module
143
            try:
144
                submodule = getattr(module, attr)
145
            except AttributeError as ex:
146
                raise AttributeError(
147
                    f"{module._get_name()} has no attribute `{attr}`"
148
                ) from ex
149
            if not isinstance(submodule, torch.nn.Module):
150
                raise TypeError(  # noqa: TRY200
151
                    f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module"
152
                )
153
            self.memo[name] = submodule
154
            return submodule
155

156
    def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module":
157
        """
158
        Swap the submodule specified by the given ``path`` to ``value``.
159

160
        For example, to swap the attribute mod.layer1.conv1 use
161
        ``accessor.swap_submodule("layer1.conv1", conv2)``.
162
        """
163
        prefix, _, attr = path.rpartition(".")
164
        return swap_submodule(self.get_submodule(prefix), attr, value)
165

166
    def get_tensor(self, name: str) -> torch.Tensor:
167
        """
168
        Get the tensor specified by the given path to value.
169

170
        For example, to get the attribute mod.layer1.conv1.weight,
171
        use accessor.get_tensor('layer1.conv1.weight')
172

173
        Compare to mod.get_parameter("layer1.conv1.weight"), this method will
174
        cache the intermediate submodule access to speed up future lookups.
175
        """
176
        prefix, _, attr = name.rpartition(".")
177
        submodule = self.get_submodule(prefix)
178
        try:
179
            tensor = getattr(submodule, attr)
180
        except AttributeError as ex:
181
            raise AttributeError(
182
                f"{submodule._get_name()} has no attribute `{name}`"
183
            ) from ex
184
        if not isinstance(tensor, torch.Tensor) and tensor is not None:
185
            raise TypeError(f"{tensor} is not an instance of torch.Tensor")
186
        return tensor  # type: ignore[return-value]
187

188
    def set_tensor(self, name: str, value: torch.Tensor) -> None:
189
        """
190
        Set the attribute specified by the given path to value.
191

192
        For example, to set the attribute mod.layer1.conv1.weight,
193
        use accessor.set_tensor("layer1.conv1.weight", value)
194
        """
195
        prefix, _, attr = name.rpartition(".")
196
        set_tensor(self.get_submodule(prefix), attr, value)
197

198
    def del_tensor(self, name: str) -> None:
199
        """
200
        Delete the attribute specified by the given path.
201

202
        For example, to delete the attribute mod.layer1.conv1.weight,
203
        use accessor.del_tensor("layer1.conv1.weight")
204
        """
205
        prefix, _, attr = name.rpartition(".")
206
        submodule = self.get_submodule(prefix)
207
        try:
208
            delattr(submodule, attr)
209
        except AttributeError as ex:
210
            raise AttributeError(
211
                f"{submodule._get_name()} has no attribute `{name}`"
212
            ) from ex
213

214
    def swap_tensor(
215
        self, name: str, value: torch.Tensor, allow_missing: bool = False
216
    ) -> torch.Tensor:
217
        """
218
        Swap the attribute specified by the given path to value.
219

220
        For example, to swap the attribute mod.layer1.conv1.weight,
221
        use accessor.swap_tensor("layer1.conv1.weight", value)
222
        """
223
        prefix, _, attr = name.rpartition(".")
224
        return swap_tensor(
225
            self.get_submodule(prefix), attr, value, allow_missing=allow_missing
226
        )
227

228
    # Batched operations
229

230
    def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]:
231
        """
232
        Get the tensors specified by the given paths.
233

234
        For example, to get the attributes mod.layer1.conv1.weight and
235
        mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight",
236
        "layer1.conv1.bias"])
237
        """
238
        return [self.get_tensor(name) for name in names]
239

240
    def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None:
241
        """
242
        Set the attributes specified by the given paths to values.
243

244
        For example, to set the attributes mod.layer1.conv1.weight and
245
        mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight",
246
        "layer1.conv1.bias"], [weight, bias])
247
        """
248
        if not isinstance(names, (list, tuple)):
249
            names = list(names)
250
        if not isinstance(values, (list, tuple)):
251
            values = list(values)
252
        assert len(names) == len(values), "names and values must have the same length"
253

254
        for name, value in zip(names, values):
255
            self.set_tensor(name, value)
256

257
    def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None:
258
        """
259
        Set the attributes specified by the given paths to values.
260

261
        For example, to set the attributes mod.layer1.conv1.weight and
262
        mod.layer1.conv1.bias, use accessor.set_tensors_dict({
263
            "layer1.conv1.weight": weight,
264
            "layer1.conv1.bias": bias,
265
        })
266
        """
267
        for name, value in named_tensors.items():
268
            self.set_tensor(name, value)
269

270
    def del_tensors(self, names: Iterable[str]) -> None:
271
        """
272
        Delete the attributes specified by the given paths.
273

274
        For example, to delete the attributes mod.layer1.conv1.weight and
275
        mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight",
276
        "layer1.conv1.bias"])
277
        """
278
        for name in names:
279
            self.del_tensor(name)
280

281
    def swap_tensors(
282
        self,
283
        names: Iterable[str],
284
        values: Iterable[torch.Tensor],
285
        allow_missing: bool = False,
286
    ) -> List[torch.Tensor]:
287
        """
288
        Swap the attributes specified by the given paths to values.
289

290
        For example, to swap the attributes mod.layer1.conv1.weight and
291
        mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight",
292
        "layer1.conv1.bias"], [weight, bias])
293
        """
294
        if not isinstance(names, (list, tuple)):
295
            names = list(names)
296
        if not isinstance(values, (list, tuple)):
297
            values = list(values)
298
        assert len(names) == len(values), "names and values must have the same length"
299

300
        return [
301
            self.swap_tensor(name, value, allow_missing=allow_missing)
302
            for name, value in zip(names, values)
303
        ]
304

305
    def swap_tensors_dict(
306
        self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False
307
    ) -> Tuple[Dict[str, torch.Tensor], List[str]]:
308
        """
309
        Swap the attributes specified by the given paths to values.
310

311
        For example, to swap the attributes mod.layer1.conv1.weight and
312
        mod.layer1.conv1.bias, use accessor.swap_tensors_dict({
313
            "layer1.conv1.weight": weight,
314
            "layer1.conv1.bias": bias,
315
        })
316
        """
317
        orig_named_tensors = {}
318
        missing_keys = []
319
        try:
320
            for name, tensor in named_tensors.items():
321
                orig_tensor = self.swap_tensor(name, tensor, allow_missing=True)
322
                if orig_tensor is _MISSING:
323
                    missing_keys.append(name)
324
                orig_named_tensors[name] = orig_tensor
325
        except Exception:
326
            # Swap back if any exception occurs
327
            for name, orig_tensor in orig_named_tensors.items():
328
                self.swap_tensor(name, orig_tensor, allow_missing=True)
329
            raise
330
        if missing_keys and not allow_missing:
331
            # Swap back if any key is missing when allow_missing is False
332
            for name, orig_tensor in orig_named_tensors.items():
333
                self.swap_tensor(name, orig_tensor, allow_missing=True)
334
            raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
335
        return orig_named_tensors, missing_keys
336

337
    def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]:
338
        """Check that the given keys are valid."""
339
        keys = set(keys)
340
        valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)}
341
        missing_keys = valid_keys - keys
342
        unexpected_keys = keys - valid_keys
343
        return sorted(missing_keys), sorted(unexpected_keys)
344

345
    # Shortcut methods
346

347
    def named_parameters(
348
        self,
349
        remove_duplicate: bool = True,
350
    ) -> Iterable[Tuple[str, torch.Tensor]]:
351
        """Iterate over all the parameters in the module."""
352
        yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
353

354
    def named_buffers(
355
        self,
356
        remove_duplicate: bool = True,
357
    ) -> Iterable[Tuple[str, torch.Tensor]]:
358
        """Iterate over all the buffers in the module."""
359
        yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
360

361
    def named_tensors(
362
        self,
363
        remove_duplicate: bool = True,
364
    ) -> Iterable[Tuple[str, torch.Tensor]]:
365
        """Iterate over all the tensors in the module."""
366
        yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
367
        yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
368

369
    def named_modules(
370
        self,
371
        remove_duplicate: bool = True,
372
    ) -> Iterable[Tuple[str, "torch.nn.Module"]]:
373
        """Iterate over all the modules in the module."""
374
        yield from self.module.named_modules(remove_duplicate=remove_duplicate)
375

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.