2
# Owner(s): ["oncall: mobile"]
6
from typing import List, Optional, Dict
9
import torch.utils.bundled_inputs
10
from torch.testing._internal.common_utils import TestCase, run_tests
15
torch.jit.save(sm, buffer)
16
return len(buffer.getvalue())
21
torch.jit.save(sm, buffer)
23
return torch.jit.load(buffer)
26
class TestBundledInputs(TestCase):
28
def test_single_tensors(self):
29
class SingleTensorModel(torch.nn.Module):
30
def forward(self, arg):
33
sm = torch.jit.script(SingleTensorModel())
34
original_size = model_size(sm)
35
get_expr : List[str] = []
37
# Tensor with small numel and small storage.
39
# Tensor with large numel and small storage.
40
(torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
41
# Tensor with small numel and large storage.
42
(torch.tensor(range(1 << 16))[-8:],),
44
(torch.zeros(1 << 16),),
45
# Large channels-last ones tensor.
46
(torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
47
# Special encoding of random tensor.
48
(torch.utils.bundled_inputs.bundle_randn(1 << 16),),
49
# Quantized uniform tensor.
50
(torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),),
52
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
53
sm, samples, get_expr)
55
# print(sm._generate_bundled_inputs.code)
57
# Make sure the model only grew a little bit,
58
# despite having nominally large bundled inputs.
59
augmented_size = model_size(sm)
60
self.assertLess(augmented_size, original_size + (1 << 12))
62
loaded = save_and_load(sm)
63
inflated = loaded.get_all_bundled_inputs()
64
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
65
self.assertEqual(len(inflated), len(samples))
66
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
68
for idx, inp in enumerate(inflated):
69
self.assertIsInstance(inp, tuple)
70
self.assertEqual(len(inp), 1)
71
self.assertIsInstance(inp[0], torch.Tensor)
73
# Strides might be important for benchmarking.
74
self.assertEqual(inp[0].stride(), samples[idx][0].stride())
75
self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)
77
# This tensor is random, but with 100,000 trials,
78
# mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
79
self.assertEqual(inflated[5][0].shape, (1 << 16,))
80
self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
81
self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
84
def test_large_tensor_with_inflation(self):
85
class SingleTensorModel(torch.nn.Module):
86
def forward(self, arg):
88
sm = torch.jit.script(SingleTensorModel())
89
sample_tensor = torch.randn(1 << 16)
90
# We can store tensors with custom inflation functions regardless
91
# of size, even if inflation is just the identity.
92
sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
93
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
96
loaded = save_and_load(sm)
97
inflated = loaded.get_all_bundled_inputs()
98
self.assertEqual(len(inflated), 1)
100
self.assertEqual(inflated[0][0], sample_tensor)
103
def test_rejected_tensors(self):
104
def check_tensor(sample):
105
# Need to define the class in this scope to get a fresh type for each run.
106
class SingleTensorModel(torch.nn.Module):
107
def forward(self, arg):
109
sm = torch.jit.script(SingleTensorModel())
110
with self.assertRaisesRegex(Exception, "Bundled input argument"):
111
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
114
# Plain old big tensor.
115
check_tensor(torch.randn(1 << 16))
116
# This tensor has two elements, but they're far apart in memory.
117
# We currently cannot represent this compactly while preserving
119
small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
120
self.assertEqual(small_sparse.numel(), 2)
121
check_tensor(small_sparse)
124
def test_non_tensors(self):
125
class StringAndIntModel(torch.nn.Module):
126
def forward(self, fmt: str, num: int):
127
return fmt.format(num)
129
sm = torch.jit.script(StringAndIntModel())
134
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
137
loaded = save_and_load(sm)
138
inflated = loaded.get_all_bundled_inputs()
139
self.assertEqual(inflated, samples)
140
self.assertTrue(loaded(*inflated[0]) == "first 1")
142
def test_multiple_methods_with_inputs(self):
143
class MultipleMethodModel(torch.nn.Module):
144
def forward(self, arg):
151
mm = torch.jit.script(MultipleMethodModel())
153
# Tensor with small numel and small storage.
154
(torch.tensor([1]),),
155
# Tensor with large numel and small storage.
156
(torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
157
# Tensor with small numel and large storage.
158
(torch.tensor(range(1 << 16))[-8:],),
160
(torch.zeros(1 << 16),),
161
# Large channels-last ones tensor.
162
(torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
165
'Tensor with small numel and small storage.',
166
'Tensor with large numel and small storage.',
167
'Tensor with small numel and large storage.',
168
'Large zero tensor.',
169
'Large channels-last ones tensor.',
170
'Special encoding of random tensor.',
172
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
175
mm.forward : samples,
183
loaded = save_and_load(mm)
184
inflated = loaded.get_all_bundled_inputs()
186
# Make sure these functions are all consistent.
187
self.assertEqual(inflated, samples)
188
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_forward())
189
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
191
# Check running and size helpers
192
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
193
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
195
# Check helper that work on all functions
196
all_info = loaded.get_bundled_inputs_functions_and_info()
197
self.assertEqual(set(all_info.keys()), {'forward', 'foo'})
198
self.assertEqual(all_info['forward']['get_inputs_function_name'], ['get_all_bundled_inputs_for_forward'])
199
self.assertEqual(all_info['foo']['get_inputs_function_name'], ['get_all_bundled_inputs_for_foo'])
200
self.assertEqual(all_info['forward']['info'], info)
201
self.assertEqual(all_info['foo']['info'], info)
203
# example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs
204
for func_name in all_info.keys():
205
input_func_name = all_info[func_name]['get_inputs_function_name'][0]
206
func_to_run = getattr(loaded, input_func_name)
207
self.assertEqual(func_to_run(), samples)
209
def test_multiple_methods_with_inputs_both_defined_failure(self):
210
class MultipleMethodModel(torch.nn.Module):
211
def forward(self, arg):
218
samples = [(torch.tensor([1]),)]
220
# inputs defined 2 ways so should fail
221
with self.assertRaises(Exception):
222
mm = torch.jit.script(MultipleMethodModel())
223
definition = textwrap.dedent("""
224
def _generate_bundled_inputs_for_forward(self):
227
mm.define(definition)
228
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
231
mm.forward : samples,
236
def test_multiple_methods_with_inputs_neither_defined_failure(self):
237
class MultipleMethodModel(torch.nn.Module):
238
def forward(self, arg):
245
samples = [(torch.tensor([1]),)]
247
# inputs not defined so should fail
248
with self.assertRaises(Exception):
249
mm = torch.jit.script(MultipleMethodModel())
250
mm._generate_bundled_inputs_for_forward()
251
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
259
def test_bad_inputs(self):
260
class SingleTensorModel(torch.nn.Module):
261
def forward(self, arg):
264
# Non list for input list
265
with self.assertRaises(TypeError):
266
m = torch.jit.script(SingleTensorModel())
267
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
269
inputs="foo" # type: ignore[arg-type]
272
# List of non tuples. Most common error using the api.
273
with self.assertRaises(TypeError):
274
m = torch.jit.script(SingleTensorModel())
275
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
277
inputs=[torch.ones(1, 2), ] # type: ignore[list-item]
280
def test_double_augment_fail(self):
281
class SingleTensorModel(torch.nn.Module):
282
def forward(self, arg):
285
m = torch.jit.script(SingleTensorModel())
286
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
288
inputs=[(torch.ones(1),)]
290
with self.assertRaisesRegex(Exception, "Models can only be augmented with bundled inputs once."):
291
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
293
inputs=[(torch.ones(1),)]
296
def test_double_augment_non_mutator(self):
297
class SingleTensorModel(torch.nn.Module):
298
def forward(self, arg):
301
m = torch.jit.script(SingleTensorModel())
302
bundled_model = torch.utils.bundled_inputs.bundle_inputs(
304
inputs=[(torch.ones(1),)]
306
with self.assertRaises(AttributeError):
307
m.get_all_bundled_inputs()
308
self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
309
self.assertEqual(bundled_model.forward(torch.ones(1)), torch.ones(1))
311
def test_double_augment_success(self):
312
class SingleTensorModel(torch.nn.Module):
313
def forward(self, arg):
316
m = torch.jit.script(SingleTensorModel())
317
bundled_model = torch.utils.bundled_inputs.bundle_inputs(
319
inputs={m.forward : [(torch.ones(1),)]}
321
self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
323
bundled_model2 = torch.utils.bundled_inputs.bundle_inputs(
325
inputs=[(torch.ones(2),)]
327
self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)])
330
def test_dict_args(self):
331
class MyModel(torch.nn.Module):
334
arg1: Optional[Dict[str, torch.Tensor]],
335
arg2: Optional[List[torch.Tensor]],
341
return arg1["a"] + arg1["b"]
343
return arg1["a"] + arg1["b"] + arg2[0]
346
a=torch.zeros([10, 20]),
347
b=torch.zeros([1, 1]),
348
c=torch.zeros([10, 20]),
350
small_list = [torch.zeros([10, 20])]
353
a=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
354
b=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
355
c=torch.zeros([1 << 5, 1 << 8, 1 << 10]),
357
big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])]
360
ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape)
361
assert ret.storage().size() == 1
362
# ret.storage()[0] = 0
365
def bundle_optional_dict_of_randn(template):
366
return torch.utils.bundled_inputs.InflatableArg(
370
else {k: condensed(v) for (k, v) in template.items()}
374
def {}(self, value: Optional[Dict[str, Tensor]]):
378
for k, v in value.items():
379
output[k] = torch.randn_like(v)
384
def bundle_optional_list_of_randn(template):
385
return torch.utils.bundled_inputs.InflatableArg(
386
value=(None if template is None else [condensed(v) for v in template]),
389
def {}(self, value: Optional[List[Tensor]]):
394
output.append(torch.randn_like(v))
400
sm = torch.jit.script(MyModel())
401
original_size = model_size(sm)
403
bundle_optional_dict_of_randn(small_sample),
404
bundle_optional_list_of_randn(small_list),
408
bundle_optional_dict_of_randn(big_sample),
409
bundle_optional_list_of_randn(big_list),
410
torch.zeros([1 << 5, 1 << 8, 1 << 10]),
413
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
419
_receive_inflate_expr=out,
421
augmented_size = model_size(sm)
422
# assert the size has not increased more than 8KB
423
self.assertLess(augmented_size, original_size + (1 << 13))
425
loaded = save_and_load(sm)
426
inflated = loaded.get_all_bundled_inputs()
427
self.assertEqual(len(inflated[0]), len(small_inputs))
429
methods, _ = torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods(
433
# One Function (forward)
434
# two bundled inputs (big_inputs and small_inputs)
435
# two args which have InflatableArg with fmt_fn
438
sum([method.startswith("_inflate_helper") for method in methods]), 4
442
if __name__ == '__main__':