pytorch

Форк
0
/
test_jit_string.py 
333 строки · 12.8 Кб
1
# Owner(s): ["oncall: jit"]
2

3
from test_jit import JitTestCase
4
from torch.testing._internal.common_utils import run_tests
5

6
from typing import List, Tuple
7

8
class TestScript(JitTestCase):
9
    def test_str_ops(self):
10
        def test_str_is(s: str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
11
            return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
12
                s.isalnum(), s.isalpha(), s.isdecimal(), s.isnumeric(), \
13
                s.isidentifier(), s.istitle(), s.isprintable()
14

15
        def test_str_to(s: str) -> Tuple[str, str, str, str, str]:
16
            return s.upper(), s.lower(), s.capitalize(), s.title(), s.swapcase()
17

18
        def test_str_strip(s: str) -> Tuple[str, str, str]:
19
            return (
20
                s.lstrip(),
21
                s.rstrip(),
22
                s.strip(),
23
            )
24

25
        def test_str_strip_char_set(s: str, char_set: str) -> Tuple[str, str, str]:
26
            return (
27
                s.lstrip(char_set),
28
                s.rstrip(char_set),
29
                s.strip(char_set),
30
            )
31

32
        inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ",
33
                  "  \t", "  \n", "\na", "abc", "123.3", "s a", "b12a ",
34
                  "more strings with spaces", "Titular Strings", "\x0acan'tprintthis",
35
                  "spaces at the end ", " begin"]
36

37
        def test_str_center(i: int, s: str) -> str:
38
            return s.center(i)
39

40
        def test_str_center_fc(i: int, s: str) -> str:
41
            return s.center(i, '*')
42

43
        def test_str_center_error(s: str) -> str:
44
            return s.center(10, '**')
45

46
        def test_ljust(s: str, i: int) -> str:
47
            return s.ljust(i)
48

49
        def test_ljust_fc(s: str, i: int, fc: str) -> str:
50
            return s.ljust(i, fc)
51

52
        def test_ljust_fc_err(s: str) -> str:
53
            return s.ljust(10, '**')
54

55
        def test_rjust(s: str, i: int) -> str:
56
            return s.rjust(i)
57

58
        def test_rjust_fc(s: str, i: int, fc: str) -> str:
59
            return s.rjust(i, fc)
60

61
        def test_rjust_fc_err(s: str) -> str:
62
            return s.rjust(10, '**')
63

64
        def test_zfill(s: str, i: int) -> str:
65
            return s.zfill(i)
66

67
        for input in inputs:
68
            self.checkScript(test_str_is, (input,))
69
            self.checkScript(test_str_to, (input,))
70
            self.checkScript(test_str_strip, (input,))
71
            for char_set in ["abc", "123", " ", "\t"]:
72
                self.checkScript(test_str_strip_char_set, (input, char_set))
73
            for i in range(7):
74
                self.checkScript(test_str_center, (i, input,))
75
                self.checkScript(test_str_center_fc, (i, input,))
76
                self.checkScript(test_ljust, (input, i))
77
                self.checkScript(test_ljust_fc, (input, i, '*'))
78
                self.checkScript(test_rjust, (input, i))
79
                self.checkScript(test_rjust_fc, (input, i, '*'))
80
                self.checkScript(test_zfill, (input, i))
81

82
        with self.assertRaises(Exception):
83
            test_str_center_error("error")
84
            test_ljust("error")
85

86
        def test_count() -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]:
87
            return (
88
                "hello".count("h"),
89
                "hello".count("h", 0, 1),
90
                "hello".count("h", -3),
91
                "hello".count("h", -10, 1),
92
                "hello".count("h", 0, -10),
93
                "hello".count("h", 0, 10),
94
                "hello".count("ell"),
95
                "hello".count("ell", 0, 1),
96
                "hello".count("ell", -3),
97
                "hello".count("ell", -10, 1),
98
                "hello".count("ell", 0, -10),
99
                "hello".count("ell", 0, 10)
100
            )
101
        self.checkScript(test_count, ())
102

103
        def test_endswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
104
            return (
105
                "hello".endswith("lo"),
106
                "hello".endswith("lo", 0),
107
                "hello".endswith("lo", -2),
108
                "hello".endswith("lo", -8),
109
                "hello".endswith("lo", 0, -5),
110
                "hello".endswith("lo", -2, 3),
111
                "hello".endswith("lo", -8, 4),
112
                "hello".endswith("l"),
113
                "hello".endswith("l", 0),
114
                "hello".endswith("l", -2),
115
                "hello".endswith("l", -8),
116
                "hello".endswith("l", 0, -5),
117
                "hello".endswith("l", -2, 3),
118
                "hello".endswith("l", -8, 4)
119
            )
120
        self.checkScript(test_endswith, ())
121

122
        def test_startswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
123
            return (
124
                "hello".startswith("lo"),
125
                "hello".startswith("lo", 0),
126
                "hello".startswith("lo", -2),
127
                "hello".startswith("lo", -8),
128
                "hello".startswith("lo", 0, -5),
129
                "hello".startswith("lo", -2, 3),
130
                "hello".startswith("lo", -8, 4),
131
                "hello".startswith("l"),
132
                "hello".startswith("l", 0),
133
                "hello".startswith("l", -2),
134
                "hello".startswith("l", -8),
135
                "hello".startswith("l", 0, -5),
136
                "hello".startswith("l", -2, 3),
137
                "hello".startswith("l", -8, 4)
138
            )
139
        self.checkScript(test_startswith, ())
140

141
        def test_expandtabs() -> Tuple[str, str, str, str, str, str]:
142
            return (
143
                'xyz\t82345\tabc'.expandtabs(),
144
                'xyz\t32345\tabc'.expandtabs(3),
145
                'xyz\t52345\tabc'.expandtabs(5),
146
                'xyz\t62345\tabc'.expandtabs(6),
147
                'xyz\t72345\tabc'.expandtabs(7),
148
                'xyz\t62345\tabc'.expandtabs(-5),
149
            )
150
        self.checkScript(test_expandtabs, ())
151

152
        def test_rfind() -> Tuple[int, int, int, int, int, int, int, int, int]:
153
            return (
154
                "hello123abc".rfind("llo"),
155
                "hello123abc".rfind("12"),
156
                "hello123abc".rfind("ab"),
157
                "hello123abc".rfind("ll", -1),
158
                "hello123abc".rfind("12", 4),
159
                "hello123abc".rfind("ab", -7),
160
                "hello123abc".rfind("ll", -1, 8),
161
                "hello123abc".rfind("12", 4, -4),
162
                "hello123abc".rfind("ab", -7, -20),
163
            )
164
        self.checkScript(test_rfind, ())
165

166
        def test_find() -> Tuple[int, int, int, int, int, int, int, int, int]:
167
            return (
168
                "hello123abc".find("llo"),
169
                "hello123abc".find("12"),
170
                "hello123abc".find("ab"),
171
                "hello123abc".find("ll", -1),
172
                "hello123abc".find("12", 4),
173
                "hello123abc".find("ab", -7),
174
                "hello123abc".find("ll", -1, 8),
175
                "hello123abc".find("12", 4, -4),
176
                "hello123abc".find("ab", -7, -20),
177
            )
178
        self.checkScript(test_find, ())
179

180
        def test_index() -> Tuple[int, int, int, int, int, int]:
181
            return (
182
                "hello123abc".index("llo"),
183
                "hello123abc".index("12"),
184
                "hello123abc".index("ab"),
185
                "hello123abc".index("12", 4),
186
                "hello123abc".index("ab", -7),
187
                "hello123abc".index("12", 4, -4),
188
            )
189
        self.checkScript(test_index, ())
190

191
        def test_rindex() -> Tuple[int, int, int, int, int, int]:
192
            return (
193
                "hello123abc".rindex("llo"),
194
                "hello123abc".rindex("12"),
195
                "hello123abc".rindex("ab"),
196
                "hello123abc".rindex("12", 4),
197
                "hello123abc".rindex("ab", -7),
198
                "hello123abc".rindex("12", 4, -4),
199
            )
200
        self.checkScript(test_rindex, ())
201

202
        def test_replace() -> Tuple[str, str, str, str, str, str, str]:
203
            return (
204
                "hello123abc".replace("llo", "sdf"),
205
                "ff".replace("f", "ff"),
206
                "abc123".replace("a", "testing"),
207
                "aaaaaa".replace("a", "testing", 3),
208
                "bbb".replace("a", "testing", 3),
209
                "ccc".replace("c", "ccc", 3),
210
                "cc".replace("c", "ccc", -3),
211
            )
212
        self.checkScript(test_replace, ())
213

214
        def test_partition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
215
                                      Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
216
                                      Tuple[str, str, str]]:
217
            return (
218
                "hello123abc".partition("llo"),
219
                "ff".partition("f"),
220
                "abc123".partition("a"),
221
                "aaaaaa".partition("testing"),
222
                "bbb".partition("a"),
223
                "ccc".partition("ccc"),
224
                "cc".partition("ccc"),
225
            )
226
        self.checkScript(test_partition, ())
227

228
        def test_rpartition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
229
                                       Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
230
                                       Tuple[str, str, str]]:
231
            return (
232
                "hello123abc".rpartition("llo"),
233
                "ff".rpartition("f"),
234
                "abc123".rpartition("a"),
235
                "aaaaaa".rpartition("testing"),
236
                "bbb".rpartition("a"),
237
                "ccc".rpartition("ccc"),
238
                "cc".rpartition("ccc"),
239
            )
240
        self.checkScript(test_rpartition, ())
241

242
        def test_split() -> Tuple[List[str], List[str], List[str], List[str], List[str],
243
                                  List[str], List[str], List[str], List[str], List[str], List[str]]:
244
            return (
245
                "a a a a a".split(),
246
                "a  a a   a a".split(),
247
                "   a a\ta \v a \v\f\n a \t   ".split(),
248
                " a a a a a ".split(" "),
249
                "a a a a a ".split(" ", 10),
250
                "a a a a a ".split(" ", -1),
251
                "a a a a a ".split(" ", 3),
252
                " a a a a a ".split("*"),
253
                " a*a a*a a".split("*"),
254
                " a*a a*a a ".split("*", -1),
255
                " a*a a*a a ".split("a*", 10),
256
            )
257
        self.checkScript(test_split, ())
258

259
        # test raising error for empty separator
260
        def test_split_empty_separator():
261
            s = "test"
262
            return s.split("")
263

264
        self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception,
265
                                    "empty separator")
266

267
        def test_rsplit() -> Tuple[List[str], List[str], List[str], List[str], List[str],
268
                                   List[str], List[str], List[str], List[str]]:
269
            return (
270
                "a a a a a".rsplit(),
271
                " a a a a a ".rsplit(" "),
272
                "a a a a a ".rsplit(" ", 10),
273
                "a a a a a ".rsplit(" ", -1),
274
                "a a a a a ".rsplit(" ", 3),
275
                " a a a a a ".rsplit("*"),
276
                " a*a a*a a ".rsplit("*"),
277
                " a*a a*a a ".rsplit("*", -1),
278
                " a*a a*a a".rsplit("a*", 10),
279
            )
280
        self.checkScript(test_rsplit, ())
281

282
        def test_splitlines() -> Tuple[List[str], List[str], List[str], List[str],
283
                                       List[str], List[str]]:
284
            return (
285
                "hello\ntest".splitlines(),
286
                "hello\n\ntest\n".splitlines(),
287
                "hello\ntest\n\n".splitlines(),
288
                "hello\vtest".splitlines(),
289
                "hello\v\f\ntest".splitlines(),
290
                "hello\ftest".splitlines(),
291
            )
292
        self.checkScript(test_splitlines, ())
293

294
        def test_str_cmp(a: str, b: str) -> Tuple[bool, bool, bool, bool, bool, bool]:
295
            return a != b, a == b, a < b, a > b, a <= b, a >= b
296

297
        for i in range(len(inputs) - 1):
298
            self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
299

300
        def test_str_join():
301
            return (
302
                ",".join(["a"]),
303
                ",".join(["a", "b", "c"]),
304
                ",".join(["aa", "bb", "cc"]),
305
                ",".join(["a,a", "bb", "c,c"]),
306
                "**a**".join(["b", "c", "d", "e"]),
307
                "".join(["a", "b", "c"]),
308
            )
309
        self.checkScript(test_str_join, ())
310

311
        def test_bool_conversion(a: str):
312
            if a:
313
                return a
314
            else:
315
                return "default"
316

317
        self.checkScript(test_bool_conversion, ("nonempty",))
318
        self.checkScript(test_bool_conversion, ("",))
319

320
    def test_string_slice(self):
321
        def test_slice(a: str) -> Tuple[str, str, str, str, str]:
322
            return (
323
                a[0:1:2],
324
                a[0:6:1],
325
                a[4:1:2],
326
                a[0:3:2],
327
                a[-1:1:3],
328
            )
329

330
        self.checkScript(test_slice, ("hellotest",))
331

332
if __name__ == '__main__':
333
    run_tests()
334

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.