pytorch

Форк
0
/
test_bundled_inputs.py 
443 строки · 16.1 Кб
1
#!/usr/bin/env python3
2
# Owner(s): ["oncall: mobile"]
3

4
import io
5
import textwrap
6
from typing import List, Optional, Dict
7

8
import torch
9
import torch.utils.bundled_inputs
10
from torch.testing._internal.common_utils import TestCase, run_tests
11

12

13
def model_size(sm):
14
    buffer = io.BytesIO()
15
    torch.jit.save(sm, buffer)
16
    return len(buffer.getvalue())
17

18

19
def save_and_load(sm):
20
    buffer = io.BytesIO()
21
    torch.jit.save(sm, buffer)
22
    buffer.seek(0)
23
    return torch.jit.load(buffer)
24

25

26
class TestBundledInputs(TestCase):
27

28
    def test_single_tensors(self):
29
        class SingleTensorModel(torch.nn.Module):
30
            def forward(self, arg):
31
                return arg
32

33
        sm = torch.jit.script(SingleTensorModel())
34
        original_size = model_size(sm)
35
        get_expr : List[str] = []
36
        samples = [
37
            # Tensor with small numel and small storage.
38
            (torch.tensor([1]),),
39
            # Tensor with large numel and small storage.
40
            (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
41
            # Tensor with small numel and large storage.
42
            (torch.tensor(range(1 << 16))[-8:],),
43
            # Large zero tensor.
44
            (torch.zeros(1 << 16),),
45
            # Large channels-last ones tensor.
46
            (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
47
            # Special encoding of random tensor.
48
            (torch.utils.bundled_inputs.bundle_randn(1 << 16),),
49
            # Quantized uniform tensor.
50
            (torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),),
51
        ]
52
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
53
            sm, samples, get_expr)
54
        # print(get_expr[0])
55
        # print(sm._generate_bundled_inputs.code)
56

57
        # Make sure the model only grew a little bit,
58
        # despite having nominally large bundled inputs.
59
        augmented_size = model_size(sm)
60
        self.assertLess(augmented_size, original_size + (1 << 12))
61

62
        loaded = save_and_load(sm)
63
        inflated = loaded.get_all_bundled_inputs()
64
        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
65
        self.assertEqual(len(inflated), len(samples))
66
        self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
67

68
        for idx, inp in enumerate(inflated):
69
            self.assertIsInstance(inp, tuple)
70
            self.assertEqual(len(inp), 1)
71
            self.assertIsInstance(inp[0], torch.Tensor)
72
            if idx != 5:
73
                # Strides might be important for benchmarking.
74
                self.assertEqual(inp[0].stride(), samples[idx][0].stride())
75
                self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)
76

77
        # This tensor is random, but with 100,000 trials,
78
        # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
79
        self.assertEqual(inflated[5][0].shape, (1 << 16,))
80
        self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
81
        self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
82

83

84
    def test_large_tensor_with_inflation(self):
85
        class SingleTensorModel(torch.nn.Module):
86
            def forward(self, arg):
87
                return arg
88
        sm = torch.jit.script(SingleTensorModel())
89
        sample_tensor = torch.randn(1 << 16)
90
        # We can store tensors with custom inflation functions regardless
91
        # of size, even if inflation is just the identity.
92
        sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
93
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
94
            sm, [(sample,)])
95

96
        loaded = save_and_load(sm)
97
        inflated = loaded.get_all_bundled_inputs()
98
        self.assertEqual(len(inflated), 1)
99

100
        self.assertEqual(inflated[0][0], sample_tensor)
101

102

103
    def test_rejected_tensors(self):
104
        def check_tensor(sample):
105
            # Need to define the class in this scope to get a fresh type for each run.
106
            class SingleTensorModel(torch.nn.Module):
107
                def forward(self, arg):
108
                    return arg
109
            sm = torch.jit.script(SingleTensorModel())
110
            with self.assertRaisesRegex(Exception, "Bundled input argument"):
111
                torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
112
                    sm, [(sample,)])
113

114
        # Plain old big tensor.
115
        check_tensor(torch.randn(1 << 16))
116
        # This tensor has two elements, but they're far apart in memory.
117
        # We currently cannot represent this compactly while preserving
118
        # the strides.
119
        small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
120
        self.assertEqual(small_sparse.numel(), 2)
121
        check_tensor(small_sparse)
122

123

124
    def test_non_tensors(self):
125
        class StringAndIntModel(torch.nn.Module):
126
            def forward(self, fmt: str, num: int):
127
                return fmt.format(num)
128

129
        sm = torch.jit.script(StringAndIntModel())
130
        samples = [
131
            ("first {}", 1),
132
            ("second {}", 2),
133
        ]
134
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
135
            sm, samples)
136

137
        loaded = save_and_load(sm)
138
        inflated = loaded.get_all_bundled_inputs()
139
        self.assertEqual(inflated, samples)
140
        self.assertTrue(loaded(*inflated[0]) == "first 1")
141

142
    def test_multiple_methods_with_inputs(self):
143
        class MultipleMethodModel(torch.nn.Module):
144
            def forward(self, arg):
145
                return arg
146

147
            @torch.jit.export
148
            def foo(self, arg):
149
                return arg
150

151
        mm = torch.jit.script(MultipleMethodModel())
152
        samples = [
153
            # Tensor with small numel and small storage.
154
            (torch.tensor([1]),),
155
            # Tensor with large numel and small storage.
156
            (torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
157
            # Tensor with small numel and large storage.
158
            (torch.tensor(range(1 << 16))[-8:],),
159
            # Large zero tensor.
160
            (torch.zeros(1 << 16),),
161
            # Large channels-last ones tensor.
162
            (torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
163
        ]
164
        info = [
165
            'Tensor with small numel and small storage.',
166
            'Tensor with large numel and small storage.',
167
            'Tensor with small numel and large storage.',
168
            'Large zero tensor.',
169
            'Large channels-last ones tensor.',
170
            'Special encoding of random tensor.',
171
        ]
172
        torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
173
            mm,
174
            inputs={
175
                mm.forward : samples,
176
                mm.foo : samples
177
            },
178
            info={
179
                mm.forward : info,
180
                mm.foo : info
181
            }
182
        )
183
        loaded = save_and_load(mm)
184
        inflated = loaded.get_all_bundled_inputs()
185

186
        # Make sure these functions are all consistent.
187
        self.assertEqual(inflated, samples)
188
        self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward())
189
        self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
190

191
        # Check running and size helpers
192
        self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
193
        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
194

195
        # Check helper that work on all functions
196
        all_info = loaded.get_bundled_inputs_functions_and_info()
197
        self.assertEqual(set(all_info.keys()), {'forward', 'foo'})
198
        self.assertEqual(all_info['forward']['get_inputs_function_name'], ['get_all_bundled_inputs_for_forward'])
199
        self.assertEqual(all_info['foo']['get_inputs_function_name'], ['get_all_bundled_inputs_for_foo'])
200
        self.assertEqual(all_info['forward']['info'], info)
201
        self.assertEqual(all_info['foo']['info'], info)
202

203
        # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs
204
        for func_name in all_info.keys():
205
            input_func_name = all_info[func_name]['get_inputs_function_name'][0]
206
            func_to_run = getattr(loaded, input_func_name)
207
            self.assertEqual(func_to_run(), samples)
208

209
    def test_multiple_methods_with_inputs_both_defined_failure(self):
210
        class MultipleMethodModel(torch.nn.Module):
211
            def forward(self, arg):
212
                return arg
213

214
            @torch.jit.export
215
            def foo(self, arg):
216
                return arg
217

218
        samples = [(torch.tensor([1]),)]
219

220
        # inputs defined 2 ways so should fail
221
        with self.assertRaises(Exception):
222
            mm = torch.jit.script(MultipleMethodModel())
223
            definition = textwrap.dedent("""
224
                def _generate_bundled_inputs_for_forward(self):
225
                    return []
226
                """)
227
            mm.define(definition)
228
            torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
229
                mm,
230
                inputs={
231
                    mm.forward : samples,
232
                    mm.foo : samples,
233
                },
234
            )
235

236
    def test_multiple_methods_with_inputs_neither_defined_failure(self):
237
        class MultipleMethodModel(torch.nn.Module):
238
            def forward(self, arg):
239
                return arg
240

241
            @torch.jit.export
242
            def foo(self, arg):
243
                return arg
244

245
        samples = [(torch.tensor([1]),)]
246

247
        # inputs not defined so should fail
248
        with self.assertRaises(Exception):
249
            mm = torch.jit.script(MultipleMethodModel())
250
            mm._generate_bundled_inputs_for_forward()
251
            torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
252
                mm,
253
                inputs={
254
                    mm.forward : None,
255
                    mm.foo : samples,
256
                },
257
            )
258

259
    def test_bad_inputs(self):
260
        class SingleTensorModel(torch.nn.Module):
261
            def forward(self, arg):
262
                return arg
263

264
        # Non list for input list
265
        with self.assertRaises(TypeError):
266
            m = torch.jit.script(SingleTensorModel())
267
            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
268
                m,
269
                inputs="foo"  # type: ignore[arg-type]
270
            )
271

272
        # List of non tuples. Most common error using the api.
273
        with self.assertRaises(TypeError):
274
            m = torch.jit.script(SingleTensorModel())
275
            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
276
                m,
277
                inputs=[torch.ones(1, 2), ]  # type: ignore[list-item]
278
            )
279

280
    def test_double_augment_fail(self):
281
        class SingleTensorModel(torch.nn.Module):
282
            def forward(self, arg):
283
                return arg
284

285
        m = torch.jit.script(SingleTensorModel())
286
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
287
            m,
288
            inputs=[(torch.ones(1),)]
289
        )
290
        with self.assertRaisesRegex(Exception, "Models can only be augmented with bundled inputs once."):
291
            torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
292
                m,
293
                inputs=[(torch.ones(1),)]
294
            )
295

296
    def test_double_augment_non_mutator(self):
297
        class SingleTensorModel(torch.nn.Module):
298
            def forward(self, arg):
299
                return arg
300

301
        m = torch.jit.script(SingleTensorModel())
302
        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
303
            m,
304
            inputs=[(torch.ones(1),)]
305
        )
306
        with self.assertRaises(AttributeError):
307
            m.get_all_bundled_inputs()
308
        self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
309
        self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1))
310

311
    def test_double_augment_success(self):
312
        class SingleTensorModel(torch.nn.Module):
313
            def forward(self, arg):
314
                return arg
315

316
        m = torch.jit.script(SingleTensorModel())
317
        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
318
            m,
319
            inputs={m.forward : [(torch.ones(1),)]}
320
        )
321
        self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
322

323
        bundled_model2 = torch.utils.bundled_inputs.bundle_inputs(
324
            bundled_model,
325
            inputs=[(torch.ones(2),)]
326
        )
327
        self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)])
328

329

330
    def test_dict_args(self):
331
        class MyModel(torch.nn.Module):
332
            def forward(
333
                self,
334
                arg1: Optional[Dict[str, torch.Tensor]],
335
                arg2: Optional[List[torch.Tensor]],
336
                arg3: torch.Tensor,
337
            ):
338
                if arg1 is None:
339
                    return arg3
340
                elif arg2 is None:
341
                    return arg1["a"] + arg1["b"]
342
                else:
343
                    return arg1["a"] + arg1["b"] + arg2[0]
344

345
        small_sample = dict(
346
            a=torch.zeros([10, 20]),
347
            b=torch.zeros([1, 1]),
348
            c=torch.zeros([10, 20]),
349
        )
350
        small_list = [torch.zeros([10, 20])]
351

352
        big_sample = dict(
353
            a=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
354
            b=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
355
            c=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
356
        )
357
        big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])]
358

359
        def condensed(t):
360
            ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape)
361
            assert ret.storage().size() == 1
362
            # ret.storage()[0] = 0
363
            return ret
364

365
        def bundle_optional_dict_of_randn(template):
366
            return torch.utils.bundled_inputs.InflatableArg(
367
                value=(
368
                    None
369
                    if template is None
370
                    else {k: condensed(v) for (k, v) in template.items()}
371
                ),
372
                fmt="{}",
373
                fmt_fn="""
374
                def {}(self, value: Optional[Dict[str, Tensor]]):
375
                    if value is None:
376
                        return None
377
                    output = {{}}
378
                    for k, v in value.items():
379
                        output[k] = torch.randn_like(v)
380
                    return output
381
                """,
382
            )
383

384
        def bundle_optional_list_of_randn(template):
385
            return torch.utils.bundled_inputs.InflatableArg(
386
                value=(None if template is None else [condensed(v) for v in template]),
387
                fmt="{}",
388
                fmt_fn="""
389
                def {}(self, value: Optional[List[Tensor]]):
390
                    if value is None:
391
                        return None
392
                    output = []
393
                    for v in value:
394
                        output.append(torch.randn_like(v))
395
                    return output
396
                """,
397
            )
398

399
        out : List[str] = []
400
        sm = torch.jit.script(MyModel())
401
        original_size = model_size(sm)
402
        small_inputs = (
403
            bundle_optional_dict_of_randn(small_sample),
404
            bundle_optional_list_of_randn(small_list),
405
            torch.zeros([3, 4]),
406
        )
407
        big_inputs = (
408
            bundle_optional_dict_of_randn(big_sample),
409
            bundle_optional_list_of_randn(big_list),
410
            torch.zeros([1 << 5, 1 << 8, 1 << 10]),
411
        )
412

413
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
414
            sm,
415
            [
416
                big_inputs,
417
                small_inputs,
418
            ],
419
            _receive_inflate_expr=out,
420
        )
421
        augmented_size = model_size(sm)
422
        # assert the size has not increased more than 8KB
423
        self.assertLess(augmented_size, original_size + (1 << 13))
424

425
        loaded = save_and_load(sm)
426
        inflated = loaded.get_all_bundled_inputs()
427
        self.assertEqual(len(inflated[0]), len(small_inputs))
428

429
        methods, _ = torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods(
430
            loaded
431
        )
432

433
        # One Function (forward)
434
        # two bundled inputs (big_inputs and small_inputs)
435
        # two args which have InflatableArg with fmt_fn
436
        # 1 * 2 * 2 = 4
437
        self.assertEqual(
438
            sum([method.startswith("_inflate_helper") for method in methods]), 4
439
        )
440

441

442
if __name__ == '__main__':
443
    run_tests()
444

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.