pytorch

Форк
0
/
string_ops_test.py 
146 строк · 4.0 Кб
1

2

3

4

5

6
from caffe2.python import core
7
from hypothesis import given, settings
8
import caffe2.python.hypothesis_test_util as hu
9
import caffe2.python.serialized_test.serialized_test_util as serial
10
import hypothesis.strategies as st
11
import numpy as np
12

13

14
def _string_lists(alphabet=None):
15
    return st.lists(
16
        elements=st.text(alphabet=alphabet) if alphabet else st.text(),
17
        min_size=0,
18
        max_size=3)
19

20

21
class TestStringOps(serial.SerializedTestCase):
22
    @given(strings=_string_lists())
23
    @settings(deadline=10000)
24
    def test_string_prefix(self, strings):
25
        length = 3
26
        # although we are utf-8 encoding below to avoid python exceptions,
27
        # StringPrefix op deals with byte-length prefixes, which may produce
28
        # an invalid utf-8 string. The goal here is just to avoid python
29
        # complaining about the unicode -> str conversion.
30
        strings = np.array(
31
            [a.encode('utf-8') for a in strings], dtype=object
32
        )
33

34
        def string_prefix_ref(strings):
35
            return (
36
                np.array([a[:length] for a in strings], dtype=object),
37
            )
38

39
        op = core.CreateOperator(
40
            'StringPrefix',
41
            ['strings'],
42
            ['stripped'],
43
            length=length)
44
        self.assertReferenceChecks(
45
            hu.cpu_do,
46
            op,
47
            [strings],
48
            string_prefix_ref)
49

50
    @given(strings=_string_lists())
51
    @settings(deadline=10000)
52
    def test_string_suffix(self, strings):
53
        length = 3
54
        strings = np.array(
55
            [a.encode('utf-8') for a in strings], dtype=object
56
        )
57

58
        def string_suffix_ref(strings):
59
            return (
60
                np.array([a[-length:] for a in strings], dtype=object),
61
            )
62

63
        op = core.CreateOperator(
64
            'StringSuffix',
65
            ['strings'],
66
            ['stripped'],
67
            length=length)
68
        self.assertReferenceChecks(
69
            hu.cpu_do,
70
            op,
71
            [strings],
72
            string_suffix_ref)
73

74
    @given(strings=st.text(alphabet=['a', 'b']))
75
    @settings(deadline=10000)
76
    def test_string_starts_with(self, strings):
77
        prefix = 'a'
78
        strings = np.array(
79
            [str(a) for a in strings], dtype=object
80
        )
81

82
        def string_starts_with_ref(strings):
83
            return (
84
                np.array([a.startswith(prefix) for a in strings], dtype=bool),
85
            )
86

87
        op = core.CreateOperator(
88
            'StringStartsWith',
89
            ['strings'],
90
            ['bools'],
91
            prefix=prefix)
92
        self.assertReferenceChecks(
93
            hu.cpu_do,
94
            op,
95
            [strings],
96
            string_starts_with_ref)
97

98
    @given(strings=st.text(alphabet=['a', 'b']))
99
    @settings(deadline=10000)
100
    def test_string_ends_with(self, strings):
101
        suffix = 'a'
102
        strings = np.array(
103
            [str(a) for a in strings], dtype=object
104
        )
105

106
        def string_ends_with_ref(strings):
107
            return (
108
                np.array([a.endswith(suffix) for a in strings], dtype=bool),
109
            )
110

111
        op = core.CreateOperator(
112
            'StringEndsWith',
113
            ['strings'],
114
            ['bools'],
115
            suffix=suffix)
116
        self.assertReferenceChecks(
117
            hu.cpu_do,
118
            op,
119
            [strings],
120
            string_ends_with_ref)
121

122
    @given(strings=st.text(alphabet=['a', 'b']))
123
    @settings(deadline=10000)
124
    def test_string_equals(self, strings):
125
        text = ""
126
        if strings:
127
            text = strings[0]
128

129
        strings = np.array(
130
            [str(a) for a in strings], dtype=object
131
        )
132

133
        def string_equals_ref(strings):
134
            return (
135
                np.array([a == text for a in strings], dtype=bool),
136
            )
137

138
        op = core.CreateOperator(
139
            'StringEquals',
140
            ['strings'],
141
            ['bools'],
142
            text=text)
143
        self.assertReferenceChecks(
144
            hu.cpu_do,
145
            op,
146
            [strings],
147
            string_equals_ref)
148

149
if __name__ == "__main__":
150
    import unittest
151
    unittest.main()
152

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.