2
# Owner(s): ["oncall: mobile"]
12
import torch.backends.xnnpack
13
import torch.utils.model_dump
14
import torch.utils.mobile_optimizer
15
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
16
from torch.testing._internal.common_quantized import supported_qengines
19
class SimpleModel(torch.nn.Module):
22
self.layer1 = torch.nn.Linear(16, 64)
23
self.relu1 = torch.nn.ReLU()
24
self.layer2 = torch.nn.Linear(64, 8)
25
self.relu2 = torch.nn.ReLU()
27
def forward(self, features):
29
act = self.layer1(act)
31
act = self.layer2(act)
36
class QuantModel(torch.nn.Module):
39
self.quant = torch.ao.quantization.QuantStub()
40
self.dequant = torch.ao.quantization.DeQuantStub()
41
self.core = SimpleModel()
50
class ModelWithLists(torch.nn.Module):
53
self.rt = [torch.zeros(1)]
54
self.ot = [torch.zeros(1), None]
56
def forward(self, arg):
57
arg = arg + self.rt[0]
64
def webdriver_test(testfunc):
65
@functools.wraps(testfunc)
66
def wrapper(self, *args, **kwds):
67
self.needs_resources()
69
if os.environ.get("RUN_WEBDRIVER") != "1":
70
self.skipTest("Webdriver not requested")
71
from selenium import webdriver
77
with self.subTest(driver=driver):
78
wd = getattr(webdriver, driver)()
79
testfunc(self, wd, *args, **kwds)
85
class TestModelDump(TestCase):
86
def needs_resources(self):
89
def test_inline_skeleton(self):
90
self.needs_resources()
91
skel = torch.utils.model_dump.get_inline_skeleton()
92
assert "unpkg.org" not in skel
93
assert "src=" not in skel
95
def do_dump_model(self, model, extra_files=None):
96
# Just check that we're able to run successfully.
98
torch.jit.save(model, buf, _extra_files=extra_files)
99
info = torch.utils.model_dump.get_model_info(buf)
100
assert info is not None
102
def open_html_model(self, wd, model, extra_files=None):
104
torch.jit.save(model, buf, _extra_files=extra_files)
105
page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
106
wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))
108
def open_section_and_get_body(self, wd, name):
109
container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
110
caret = container.find_element_by_class_name("caret")
111
if container.get_attribute("data-shown") != "true":
113
content = container.find_element_by_tag_name("div")
116
def test_scripted_model(self):
117
model = torch.jit.script(SimpleModel())
118
self.do_dump_model(model)
120
def test_traced_model(self):
121
model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
122
self.do_dump_model(model)
125
self.needs_resources()
127
# I was getting tempfile errors in CI. Just skip it.
128
self.skipTest("Disabled on Windows.")
130
with tempfile.NamedTemporaryFile() as tf:
131
torch.jit.save(torch.jit.script(SimpleModel()), tf)
132
# Actually write contents to disk so we can read it below
135
stdout = io.StringIO()
136
torch.utils.model_dump.main(
143
self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')
145
stdout = io.StringIO()
146
torch.utils.model_dump.main(
154
stdout.getvalue().replace("\n", " "),
155
r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')
157
def get_quant_model(self):
158
fmodel = QuantModel().eval()
159
fmodel = torch.ao.quantization.fuse_modules(fmodel, [
160
["core.layer1", "core.relu1"],
161
["core.layer2", "core.relu2"],
163
fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
164
prepped = torch.ao.quantization.prepare(fmodel)
165
prepped(torch.randn(2, 16))
166
qmodel = torch.ao.quantization.convert(prepped)
169
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
170
def test_quantized_model(self):
171
qmodel = self.get_quant_model()
172
self.do_dump_model(torch.jit.script(qmodel))
175
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
176
def test_optimized_quantized_model(self):
177
qmodel = self.get_quant_model()
178
smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
179
omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
180
self.do_dump_model(omodel)
182
def test_model_with_lists(self):
183
model = torch.jit.script(ModelWithLists())
184
self.do_dump_model(model)
186
def test_invalid_json(self):
187
model = torch.jit.script(SimpleModel())
188
self.do_dump_model(model, extra_files={"foo.json": "{"})
191
def test_memory_computation(self, wd):
192
def check_memory(model, expected):
193
self.open_html_model(wd, model)
194
memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
195
device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
196
self.assertEqual("cpu", device)
197
memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
198
self.assertEqual(expected, int(memory_usage_str))
200
simple_model_memory = (
201
# First layer, including bias.
203
# Second layer, including bias.
208
check_memory(torch.jit.script(SimpleModel()), simple_model_memory)
210
# The same SimpleModel instance appears twice in this model.
211
# The tensors will be shared, so ensure no double-counting.
212
a_simple_model = SimpleModel()
215
torch.nn.Sequential(a_simple_model, a_simple_model)),
218
# The freezing process will move the weight and bias
219
# from data to constants. Ensure they are still counted.
221
torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
224
# Make sure we can handle a model with both constants and data tensors.
225
class ComposedModule(torch.nn.Module):
228
self.w1 = torch.zeros(1, 2)
229
self.w2 = torch.ones(2, 2)
231
def forward(self, arg):
232
return arg * self.w2 + self.w1
236
torch.jit.script(ComposedModule()).eval(),
237
preserved_attrs=["w1"]),
241
if __name__ == '__main__':