pytorch

Форк
0
/
test_model_dump.py 
242 строки · 7.8 Кб
1
#!/usr/bin/env python3
2
# Owner(s): ["oncall: mobile"]
3

4
import os
5
import io
6
import functools
7
import tempfile
8
import urllib
9
import unittest
10

11
import torch
12
import torch.backends.xnnpack
13
import torch.utils.model_dump
14
import torch.utils.mobile_optimizer
15
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
16
from torch.testing._internal.common_quantized import supported_qengines
17

18

19
class SimpleModel(torch.nn.Module):
20
    def __init__(self):
21
        super().__init__()
22
        self.layer1 = torch.nn.Linear(16, 64)
23
        self.relu1 = torch.nn.ReLU()
24
        self.layer2 = torch.nn.Linear(64, 8)
25
        self.relu2 = torch.nn.ReLU()
26

27
    def forward(self, features):
28
        act = features
29
        act = self.layer1(act)
30
        act = self.relu1(act)
31
        act = self.layer2(act)
32
        act = self.relu2(act)
33
        return act
34

35

36
class QuantModel(torch.nn.Module):
37
    def __init__(self):
38
        super().__init__()
39
        self.quant = torch.ao.quantization.QuantStub()
40
        self.dequant = torch.ao.quantization.DeQuantStub()
41
        self.core = SimpleModel()
42

43
    def forward(self, x):
44
        x = self.quant(x)
45
        x = self.core(x)
46
        x = self.dequant(x)
47
        return x
48

49

50
class ModelWithLists(torch.nn.Module):
51
    def __init__(self):
52
        super().__init__()
53
        self.rt = [torch.zeros(1)]
54
        self.ot = [torch.zeros(1), None]
55

56
    def forward(self, arg):
57
        arg = arg + self.rt[0]
58
        o = self.ot[0]
59
        if o is not None:
60
            arg = arg + o
61
        return arg
62

63

64
def webdriver_test(testfunc):
65
    @functools.wraps(testfunc)
66
    def wrapper(self, *args, **kwds):
67
        self.needs_resources()
68

69
        if os.environ.get("RUN_WEBDRIVER") != "1":
70
            self.skipTest("Webdriver not requested")
71
        from selenium import webdriver
72

73
        for driver in [
74
                "Firefox",
75
                "Chrome",
76
        ]:
77
            with self.subTest(driver=driver):
78
                wd = getattr(webdriver, driver)()
79
                testfunc(self, wd, *args, **kwds)
80
                wd.close()
81

82
    return wrapper
83

84

85
class TestModelDump(TestCase):
86
    def needs_resources(self):
87
        pass
88

89
    def test_inline_skeleton(self):
90
        self.needs_resources()
91
        skel = torch.utils.model_dump.get_inline_skeleton()
92
        assert "unpkg.org" not in skel
93
        assert "src=" not in skel
94

95
    def do_dump_model(self, model, extra_files=None):
96
        # Just check that we're able to run successfully.
97
        buf = io.BytesIO()
98
        torch.jit.save(model, buf, _extra_files=extra_files)
99
        info = torch.utils.model_dump.get_model_info(buf)
100
        assert info is not None
101

102
    def open_html_model(self, wd, model, extra_files=None):
103
        buf = io.BytesIO()
104
        torch.jit.save(model, buf, _extra_files=extra_files)
105
        page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
106
        wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))
107

108
    def open_section_and_get_body(self, wd, name):
109
        container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
110
        caret = container.find_element_by_class_name("caret")
111
        if container.get_attribute("data-shown") != "true":
112
            caret.click()
113
        content = container.find_element_by_tag_name("div")
114
        return content
115

116
    def test_scripted_model(self):
117
        model = torch.jit.script(SimpleModel())
118
        self.do_dump_model(model)
119

120
    def test_traced_model(self):
121
        model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
122
        self.do_dump_model(model)
123

124
    def test_main(self):
125
        self.needs_resources()
126
        if IS_WINDOWS:
127
            # I was getting tempfile errors in CI.  Just skip it.
128
            self.skipTest("Disabled on Windows.")
129

130
        with tempfile.NamedTemporaryFile() as tf:
131
            torch.jit.save(torch.jit.script(SimpleModel()), tf)
132
            # Actually write contents to disk so we can read it below
133
            tf.flush()
134

135
            stdout = io.StringIO()
136
            torch.utils.model_dump.main(
137
                [
138
                    None,
139
                    "--style=json",
140
                    tf.name,
141
                ],
142
                stdout=stdout)
143
            self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')
144

145
            stdout = io.StringIO()
146
            torch.utils.model_dump.main(
147
                [
148
                    None,
149
                    "--style=html",
150
                    tf.name,
151
                ],
152
                stdout=stdout)
153
            self.assertRegex(
154
                stdout.getvalue().replace("\n", " "),
155
                r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')
156

157
    def get_quant_model(self):
158
        fmodel = QuantModel().eval()
159
        fmodel = torch.ao.quantization.fuse_modules(fmodel, [
160
            ["core.layer1", "core.relu1"],
161
            ["core.layer2", "core.relu2"],
162
        ])
163
        fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
164
        prepped = torch.ao.quantization.prepare(fmodel)
165
        prepped(torch.randn(2, 16))
166
        qmodel = torch.ao.quantization.convert(prepped)
167
        return qmodel
168

169
    @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
170
    def test_quantized_model(self):
171
        qmodel = self.get_quant_model()
172
        self.do_dump_model(torch.jit.script(qmodel))
173

174
    @skipIfNoXNNPACK
175
    @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
176
    def test_optimized_quantized_model(self):
177
        qmodel = self.get_quant_model()
178
        smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
179
        omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
180
        self.do_dump_model(omodel)
181

182
    def test_model_with_lists(self):
183
        model = torch.jit.script(ModelWithLists())
184
        self.do_dump_model(model)
185

186
    def test_invalid_json(self):
187
        model = torch.jit.script(SimpleModel())
188
        self.do_dump_model(model, extra_files={"foo.json": "{"})
189

190
    @webdriver_test
191
    def test_memory_computation(self, wd):
192
        def check_memory(model, expected):
193
            self.open_html_model(wd, model)
194
            memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
195
            device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
196
            self.assertEqual("cpu", device)
197
            memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
198
            self.assertEqual(expected, int(memory_usage_str))
199

200
        simple_model_memory = (
201
            # First layer, including bias.
202
            64 * (16 + 1) +
203
            # Second layer, including bias.
204
            8 * (64 + 1)
205
            # 32-bit float
206
        ) * 4
207

208
        check_memory(torch.jit.script(SimpleModel()), simple_model_memory)
209

210
        # The same SimpleModel instance appears twice in this model.
211
        # The tensors will be shared, so ensure no double-counting.
212
        a_simple_model = SimpleModel()
213
        check_memory(
214
            torch.jit.script(
215
                torch.nn.Sequential(a_simple_model, a_simple_model)),
216
            simple_model_memory)
217

218
        # The freezing process will move the weight and bias
219
        # from data to constants.  Ensure they are still counted.
220
        check_memory(
221
            torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
222
            simple_model_memory)
223

224
        # Make sure we can handle a model with both constants and data tensors.
225
        class ComposedModule(torch.nn.Module):
226
            def __init__(self):
227
                super().__init__()
228
                self.w1 = torch.zeros(1, 2)
229
                self.w2 = torch.ones(2, 2)
230

231
            def forward(self, arg):
232
                return arg * self.w2 + self.w1
233

234
        check_memory(
235
            torch.jit.freeze(
236
                torch.jit.script(ComposedModule()).eval(),
237
                preserved_attrs=["w1"]),
238
            4 * (2 + 4))
239

240

241
if __name__ == '__main__':
242
    run_tests()
243

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.