1
"""Functionality for Python <-> C++ frontend inter-op."""
6
class OrderedDictWrapper:
7
"""A wrapper around a C++ OrderedDict.
9
It dynamically evaluates the OrderedDict getter on a bound C++ module, such
10
that new changes on the C++ side are picked up. Otherwise accessing e.g.
11
``cpp_module._parameters`` just once would get a frozen copy of the parameters
12
at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
13
so using properties does not work.
16
def __init__(self, cpp_module, attr):
17
self.cpp_module = cpp_module
22
return getattr(self.cpp_module, self.attr)
28
return self.cpp_dict.items()
31
return self.cpp_dict.keys()
34
return self.cpp_dict.values()
37
return self.cpp_dict.__iter__()
40
return self.cpp_dict.__len__()
42
def __contains__(self, key):
43
return self.cpp_dict.__contains__(key)
45
def __getitem__(self, key):
46
return self.cpp_dict.__getitem__(key)
49
class ModuleWrapper(nn.Module):
50
"""A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
52
def __init__(self, cpp_module):
55
self.cpp_module = cpp_module
57
self._parameters = OrderedDictWrapper(cpp_module, "_parameters")
58
self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers")
59
self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules")
60
for attr in dir(cpp_module):
62
if not attr.startswith("_"):
63
setattr(self, attr, getattr(self.cpp_module, attr))
65
def _apply(self, fn, recurse=True):
66
for param in self.parameters():
69
param.data = fn(param.data)
70
if param._grad is not None:
71
param._grad.data = fn(param._grad.data)
73
for buf in self.buffers():
74
buf.data = fn(buf.data)
81
return self.cpp_module.training
84
def training(self, mode):
85
self.cpp_module.train(mode)
88
return self.cpp_module.__repr__()