datasets

Форк
0
/
test_py_utils.py 
277 строк · 9.5 Кб
1
import time
2
from dataclasses import dataclass
3
from multiprocessing import Pool
4
from unittest import TestCase
5
from unittest.mock import patch
6

7
import multiprocess
8
import numpy as np
9
import pytest
10

11
from datasets.utils.py_utils import (
12
    NestedDataStructure,
13
    asdict,
14
    iflatmap_unordered,
15
    map_nested,
16
    temp_seed,
17
    temporary_assignment,
18
    zip_dict,
19
)
20

21
from .utils import require_tf, require_torch
22

23

24
def np_sum(x):  # picklable for multiprocessing
25
    return x.sum()
26

27

28
def add_one(i):  # picklable for multiprocessing
29
    return i + 1
30

31

32
@dataclass
33
class A:
34
    x: int
35
    y: str
36

37

38
class PyUtilsTest(TestCase):
39
    def test_map_nested(self):
40
        s1 = {}
41
        s2 = []
42
        s3 = 1
43
        s4 = [1, 2]
44
        s5 = {"a": 1, "b": 2}
45
        s6 = {"a": [1, 2], "b": [3, 4]}
46
        s7 = {"a": {"1": 1}, "b": 2}
47
        s8 = {"a": 1, "b": 2, "c": 3, "d": 4}
48
        expected_map_nested_s1 = {}
49
        expected_map_nested_s2 = []
50
        expected_map_nested_s3 = 2
51
        expected_map_nested_s4 = [2, 3]
52
        expected_map_nested_s5 = {"a": 2, "b": 3}
53
        expected_map_nested_s6 = {"a": [2, 3], "b": [4, 5]}
54
        expected_map_nested_s7 = {"a": {"1": 2}, "b": 3}
55
        expected_map_nested_s8 = {"a": 2, "b": 3, "c": 4, "d": 5}
56
        self.assertEqual(map_nested(add_one, s1), expected_map_nested_s1)
57
        self.assertEqual(map_nested(add_one, s2), expected_map_nested_s2)
58
        self.assertEqual(map_nested(add_one, s3), expected_map_nested_s3)
59
        self.assertEqual(map_nested(add_one, s4), expected_map_nested_s4)
60
        self.assertEqual(map_nested(add_one, s5), expected_map_nested_s5)
61
        self.assertEqual(map_nested(add_one, s6), expected_map_nested_s6)
62
        self.assertEqual(map_nested(add_one, s7), expected_map_nested_s7)
63
        self.assertEqual(map_nested(add_one, s8), expected_map_nested_s8)
64

65
        num_proc = 2
66
        self.assertEqual(map_nested(add_one, s1, num_proc=num_proc), expected_map_nested_s1)
67
        self.assertEqual(map_nested(add_one, s2, num_proc=num_proc), expected_map_nested_s2)
68
        self.assertEqual(map_nested(add_one, s3, num_proc=num_proc), expected_map_nested_s3)
69
        self.assertEqual(map_nested(add_one, s4, num_proc=num_proc), expected_map_nested_s4)
70
        self.assertEqual(map_nested(add_one, s5, num_proc=num_proc), expected_map_nested_s5)
71
        self.assertEqual(map_nested(add_one, s6, num_proc=num_proc), expected_map_nested_s6)
72
        self.assertEqual(map_nested(add_one, s7, num_proc=num_proc), expected_map_nested_s7)
73
        self.assertEqual(map_nested(add_one, s8, num_proc=num_proc), expected_map_nested_s8)
74

75
        sn1 = {"a": np.eye(2), "b": np.zeros(3), "c": np.ones(2)}
76
        expected_map_nested_sn1_sum = {"a": 2, "b": 0, "c": 2}
77
        expected_map_nested_sn1_int = {
78
            "a": np.eye(2).astype(int),
79
            "b": np.zeros(3).astype(int),
80
            "c": np.ones(2).astype(int),
81
        }
82
        self.assertEqual(map_nested(np_sum, sn1, map_numpy=False), expected_map_nested_sn1_sum)
83
        self.assertEqual(
84
            {k: v.tolist() for k, v in map_nested(int, sn1, map_numpy=True).items()},
85
            {k: v.tolist() for k, v in expected_map_nested_sn1_int.items()},
86
        )
87
        self.assertEqual(map_nested(np_sum, sn1, map_numpy=False, num_proc=num_proc), expected_map_nested_sn1_sum)
88
        self.assertEqual(
89
            {k: v.tolist() for k, v in map_nested(int, sn1, map_numpy=True, num_proc=num_proc).items()},
90
            {k: v.tolist() for k, v in expected_map_nested_sn1_int.items()},
91
        )
92
        with self.assertRaises(AttributeError):  # can't pickle a local lambda
93
            map_nested(lambda x: x + 1, sn1, num_proc=num_proc)
94

95
    def test_zip_dict(self):
96
        d1 = {"a": 1, "b": 2}
97
        d2 = {"a": 3, "b": 4}
98
        d3 = {"a": 5, "b": 6}
99
        expected_zip_dict_result = sorted([("a", (1, 3, 5)), ("b", (2, 4, 6))])
100
        self.assertEqual(sorted(zip_dict(d1, d2, d3)), expected_zip_dict_result)
101

102
    def test_temporary_assignment(self):
103
        class Foo:
104
            my_attr = "bar"
105

106
        foo = Foo()
107
        self.assertEqual(foo.my_attr, "bar")
108
        with temporary_assignment(foo, "my_attr", "BAR"):
109
            self.assertEqual(foo.my_attr, "BAR")
110
        self.assertEqual(foo.my_attr, "bar")
111

112

113
@pytest.mark.parametrize(
114
    "iterable_length, num_proc, expected_num_proc",
115
    [
116
        (1, None, 1),
117
        (1, 1, 1),
118
        (2, None, 1),
119
        (2, 1, 1),
120
        (2, 2, 1),
121
        (2, 3, 1),
122
        (3, 2, 1),
123
        (16, 16, 16),
124
        (16, 17, 16),
125
        (17, 16, 16),
126
    ],
127
)
128
def test_map_nested_num_proc(iterable_length, num_proc, expected_num_proc):
129
    with patch("datasets.utils.py_utils._single_map_nested") as mock_single_map_nested, patch(
130
        "datasets.parallel.parallel.Pool"
131
    ) as mock_multiprocessing_pool:
132
        data_struct = {f"{i}": i for i in range(iterable_length)}
133
        _ = map_nested(lambda x: x + 10, data_struct, num_proc=num_proc, parallel_min_length=16)
134
        if expected_num_proc == 1:
135
            assert mock_single_map_nested.called
136
            assert not mock_multiprocessing_pool.called
137
        else:
138
            assert not mock_single_map_nested.called
139
            assert mock_multiprocessing_pool.called
140
            assert mock_multiprocessing_pool.call_args[0][0] == expected_num_proc
141

142

143
class TempSeedTest(TestCase):
144
    @require_tf
145
    def test_tensorflow(self):
146
        import tensorflow as tf
147
        from tensorflow.keras import layers
148

149
        model = layers.Dense(2)
150

151
        def gen_random_output():
152
            x = tf.random.uniform((1, 3))
153
            return model(x).numpy()
154

155
        with temp_seed(42, set_tensorflow=True):
156
            out1 = gen_random_output()
157
        with temp_seed(42, set_tensorflow=True):
158
            out2 = gen_random_output()
159
        out3 = gen_random_output()
160

161
        np.testing.assert_equal(out1, out2)
162
        self.assertGreater(np.abs(out1 - out3).sum(), 0)
163

164
    @require_torch
165
    def test_torch(self):
166
        import torch
167

168
        def gen_random_output():
169
            model = torch.nn.Linear(3, 2)
170
            x = torch.rand(1, 3)
171
            return model(x).detach().numpy()
172

173
        with temp_seed(42, set_pytorch=True):
174
            out1 = gen_random_output()
175
        with temp_seed(42, set_pytorch=True):
176
            out2 = gen_random_output()
177
        out3 = gen_random_output()
178

179
        np.testing.assert_equal(out1, out2)
180
        self.assertGreater(np.abs(out1 - out3).sum(), 0)
181

182
    def test_numpy(self):
183
        def gen_random_output():
184
            return np.random.rand(1, 3)
185

186
        with temp_seed(42):
187
            out1 = gen_random_output()
188
        with temp_seed(42):
189
            out2 = gen_random_output()
190
        out3 = gen_random_output()
191

192
        np.testing.assert_equal(out1, out2)
193
        self.assertGreater(np.abs(out1 - out3).sum(), 0)
194

195

196
@pytest.mark.parametrize("input_data", [{}])
197
def test_nested_data_structure_data(input_data):
198
    output_data = NestedDataStructure(input_data).data
199
    assert output_data == input_data
200

201

202
@pytest.mark.parametrize(
203
    "data, expected_output",
204
    [
205
        ({}, []),
206
        ([], []),
207
        ("foo", ["foo"]),
208
        (["foo", "bar"], ["foo", "bar"]),
209
        ([["foo", "bar"]], ["foo", "bar"]),
210
        ([[["foo"], ["bar"]]], ["foo", "bar"]),
211
        ([[["foo"], "bar"]], ["foo", "bar"]),
212
        ({"a": 1, "b": 2}, [1, 2]),
213
        ({"a": [1, 2], "b": [3, 4]}, [1, 2, 3, 4]),
214
        ({"a": [[1, 2]], "b": [[3, 4]]}, [1, 2, 3, 4]),
215
        ({"a": [[1, 2]], "b": [3, 4]}, [1, 2, 3, 4]),
216
        ({"a": [[[1], [2]]], "b": [[[3], [4]]]}, [1, 2, 3, 4]),
217
        ({"a": [[[1], [2]]], "b": [[3, 4]]}, [1, 2, 3, 4]),
218
        ({"a": [[[1], [2]]], "b": [3, 4]}, [1, 2, 3, 4]),
219
        ({"a": [[[1], [2]]], "b": [3, [4]]}, [1, 2, 3, 4]),
220
        ({"a": {"1": 1}, "b": 2}, [1, 2]),
221
        ({"a": {"1": [1]}, "b": 2}, [1, 2]),
222
        ({"a": {"1": [1]}, "b": [2]}, [1, 2]),
223
    ],
224
)
225
def test_flatten(data, expected_output):
226
    output = NestedDataStructure(data).flatten()
227
    assert output == expected_output
228

229

230
def test_asdict():
231
    input = A(x=1, y="foobar")
232
    expected_output = {"x": 1, "y": "foobar"}
233
    assert asdict(input) == expected_output
234

235
    input = {"a": {"b": A(x=10, y="foo")}, "c": [A(x=20, y="bar")]}
236
    expected_output = {"a": {"b": {"x": 10, "y": "foo"}}, "c": [{"x": 20, "y": "bar"}]}
237
    assert asdict(input) == expected_output
238

239
    with pytest.raises(TypeError):
240
        asdict([1, A(x=10, y="foo")])
241

242

243
def _split_text(text: str):
244
    return text.split()
245

246

247
def _2seconds_generator_of_2items_with_timing(content):
248
    yield (time.time(), content)
249
    time.sleep(2)
250
    yield (time.time(), content)
251

252

253
def test_iflatmap_unordered():
254
    with Pool(2) as pool:
255
        out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10))
256
        assert out.count("hello") == 10
257
        assert out.count("there") == 10
258
        assert len(out) == 20
259

260
    # check multiprocess from pathos (uses dill for pickling)
261
    with multiprocess.Pool(2) as pool:
262
        out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10))
263
        assert out.count("hello") == 10
264
        assert out.count("there") == 10
265
        assert len(out) == 20
266

267
    # check that we get items as fast as possible
268
    with Pool(2) as pool:
269
        out = []
270
        for yield_time, content in iflatmap_unordered(
271
            pool, _2seconds_generator_of_2items_with_timing, kwargs_iterable=[{"content": "a"}, {"content": "b"}]
272
        ):
273
            assert yield_time < time.time() + 0.1, "we should each item directly after it was yielded"
274
            out.append(content)
275
        assert out.count("a") == 2
276
        assert out.count("b") == 2
277
        assert len(out) == 4
278

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.