pytorch
71 строка · 1.8 Кб
1import torch
2
3
4def check_error(desc, fn, *required_substrings):
5try:
6fn()
7except Exception as e:
8error_message = e.args[0]
9print('=' * 80)
10print(desc)
11print('-' * 80)
12print(error_message)
13print('')
14for sub in required_substrings:
15assert sub in error_message
16return
17raise AssertionError(f"given function ({desc}) didn't raise an error")
18
19check_error(
20'Wrong argument types',
21lambda: torch.FloatStorage(object()),
22'object')
23
24check_error('Unknown keyword argument',
25lambda: torch.FloatStorage(content=1234.),
26'keyword')
27
28check_error('Invalid types inside a sequence',
29lambda: torch.FloatStorage(['a', 'b']),
30'list', 'str')
31
32check_error('Invalid size type',
33lambda: torch.FloatStorage(1.5),
34'float')
35
36check_error('Invalid offset',
37lambda: torch.FloatStorage(torch.FloatStorage(2), 4),
38'2', '4')
39
40check_error('Negative offset',
41lambda: torch.FloatStorage(torch.FloatStorage(2), -1),
42'2', '-1')
43
44check_error('Invalid size',
45lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
46'2', '1', '5')
47
48check_error('Negative size',
49lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
50'2', '1', '-5')
51
52check_error('Invalid index type',
53lambda: torch.FloatStorage(10)['first item'],
54'str')
55
56
57def assign():
58torch.FloatStorage(10)[1:-1] = '1'
59check_error('Invalid value type',
60assign,
61'str')
62
63check_error('resize_ with invalid type',
64lambda: torch.FloatStorage(10).resize_(1.5),
65'float')
66
67check_error('fill_ with invalid type',
68lambda: torch.IntStorage(10).fill_('asdf'),
69'str')
70
71# TODO: frombuffer
72