pytorch

Форк
0
/
cpp.py 
88 строк · 2.9 Кб
1
"""Functionality for Python <-> C++ frontend inter-op."""
2

3
from torch import nn
4

5

6
class OrderedDictWrapper:
7
    """A wrapper around a C++ OrderedDict.
8

9
    It dynamically evaluates the OrderedDict getter on a bound C++ module, such
10
    that new changes on the C++ side are picked up. Otherwise accessing e.g.
11
    ``cpp_module._parameters`` just once would get a frozen copy of the parameters
12
    at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
13
    so using properties does not work.
14
    """
15

16
    def __init__(self, cpp_module, attr):
17
        self.cpp_module = cpp_module
18
        self.attr = attr
19

20
    @property
21
    def cpp_dict(self):
22
        return getattr(self.cpp_module, self.attr)
23

24
    # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
25
    # must manually override them.
26

27
    def items(self):
28
        return self.cpp_dict.items()
29

30
    def keys(self):
31
        return self.cpp_dict.keys()
32

33
    def values(self):
34
        return self.cpp_dict.values()
35

36
    def __iter__(self):
37
        return self.cpp_dict.__iter__()
38

39
    def __len__(self):
40
        return self.cpp_dict.__len__()
41

42
    def __contains__(self, key):
43
        return self.cpp_dict.__contains__(key)
44

45
    def __getitem__(self, key):
46
        return self.cpp_dict.__getitem__(key)
47

48

49
class ModuleWrapper(nn.Module):
50
    """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
51

52
    def __init__(self, cpp_module):
53
        # Assign before the super class constructor so ``self.training`` can be
54
        # assigned to in the super class constructor.
55
        self.cpp_module = cpp_module
56
        super().__init__()
57
        self._parameters = OrderedDictWrapper(cpp_module, "_parameters")  # type: ignore[assignment]
58
        self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers")  # type: ignore[assignment]
59
        self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules")  # type: ignore[assignment]
60
        for attr in dir(cpp_module):
61
            # Skip magic methods and the three attributes above.
62
            if not attr.startswith("_"):
63
                setattr(self, attr, getattr(self.cpp_module, attr))
64

65
    def _apply(self, fn, recurse=True):
66
        for param in self.parameters():
67
            # Tensors stored in modules are graph leaves, and we don't
68
            # want to create copy nodes, so we have to unpack the data.
69
            param.data = fn(param.data)
70
            if param._grad is not None:
71
                param._grad.data = fn(param._grad.data)
72

73
        for buf in self.buffers():
74
            buf.data = fn(buf.data)
75

76
        return self
77

78
    # nn.Module defines training as a boolean
79
    @property  # type: ignore[override]
80
    def training(self):
81
        return self.cpp_module.training
82

83
    @training.setter
84
    def training(self, mode):
85
        self.cpp_module.train(mode)
86

87
    def __repr__(self):
88
        return self.cpp_module.__repr__()
89

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.