2
from dataclasses import dataclass
3
from multiprocessing import Pool
4
from unittest import TestCase
5
from unittest.mock import patch
11
from datasets.utils.py_utils import (
21
from .utils import require_tf, require_torch
24
def np_sum(x): # picklable for multiprocessing
28
def add_one(i): # picklable for multiprocessing
38
class PyUtilsTest(TestCase):
39
def test_map_nested(self):
45
s6 = {"a": [1, 2], "b": [3, 4]}
46
s7 = {"a": {"1": 1}, "b": 2}
47
s8 = {"a": 1, "b": 2, "c": 3, "d": 4}
48
expected_map_nested_s1 = {}
49
expected_map_nested_s2 = []
50
expected_map_nested_s3 = 2
51
expected_map_nested_s4 = [2, 3]
52
expected_map_nested_s5 = {"a": 2, "b": 3}
53
expected_map_nested_s6 = {"a": [2, 3], "b": [4, 5]}
54
expected_map_nested_s7 = {"a": {"1": 2}, "b": 3}
55
expected_map_nested_s8 = {"a": 2, "b": 3, "c": 4, "d": 5}
56
self.assertEqual(map_nested(add_one, s1), expected_map_nested_s1)
57
self.assertEqual(map_nested(add_one, s2), expected_map_nested_s2)
58
self.assertEqual(map_nested(add_one, s3), expected_map_nested_s3)
59
self.assertEqual(map_nested(add_one, s4), expected_map_nested_s4)
60
self.assertEqual(map_nested(add_one, s5), expected_map_nested_s5)
61
self.assertEqual(map_nested(add_one, s6), expected_map_nested_s6)
62
self.assertEqual(map_nested(add_one, s7), expected_map_nested_s7)
63
self.assertEqual(map_nested(add_one, s8), expected_map_nested_s8)
66
self.assertEqual(map_nested(add_one, s1, num_proc=num_proc), expected_map_nested_s1)
67
self.assertEqual(map_nested(add_one, s2, num_proc=num_proc), expected_map_nested_s2)
68
self.assertEqual(map_nested(add_one, s3, num_proc=num_proc), expected_map_nested_s3)
69
self.assertEqual(map_nested(add_one, s4, num_proc=num_proc), expected_map_nested_s4)
70
self.assertEqual(map_nested(add_one, s5, num_proc=num_proc), expected_map_nested_s5)
71
self.assertEqual(map_nested(add_one, s6, num_proc=num_proc), expected_map_nested_s6)
72
self.assertEqual(map_nested(add_one, s7, num_proc=num_proc), expected_map_nested_s7)
73
self.assertEqual(map_nested(add_one, s8, num_proc=num_proc), expected_map_nested_s8)
75
sn1 = {"a": np.eye(2), "b": np.zeros(3), "c": np.ones(2)}
76
expected_map_nested_sn1_sum = {"a": 2, "b": 0, "c": 2}
77
expected_map_nested_sn1_int = {
78
"a": np.eye(2).astype(int),
79
"b": np.zeros(3).astype(int),
80
"c": np.ones(2).astype(int),
82
self.assertEqual(map_nested(np_sum, sn1, map_numpy=False), expected_map_nested_sn1_sum)
84
{k: v.tolist() for k, v in map_nested(int, sn1, map_numpy=True).items()},
85
{k: v.tolist() for k, v in expected_map_nested_sn1_int.items()},
87
self.assertEqual(map_nested(np_sum, sn1, map_numpy=False, num_proc=num_proc), expected_map_nested_sn1_sum)
89
{k: v.tolist() for k, v in map_nested(int, sn1, map_numpy=True, num_proc=num_proc).items()},
90
{k: v.tolist() for k, v in expected_map_nested_sn1_int.items()},
92
with self.assertRaises(AttributeError): # can't pickle a local lambda
93
map_nested(lambda x: x + 1, sn1, num_proc=num_proc)
95
def test_zip_dict(self):
99
expected_zip_dict_result = sorted([("a", (1, 3, 5)), ("b", (2, 4, 6))])
100
self.assertEqual(sorted(zip_dict(d1, d2, d3)), expected_zip_dict_result)
102
def test_temporary_assignment(self):
107
self.assertEqual(foo.my_attr, "bar")
108
with temporary_assignment(foo, "my_attr", "BAR"):
109
self.assertEqual(foo.my_attr, "BAR")
110
self.assertEqual(foo.my_attr, "bar")
113
@pytest.mark.parametrize(
114
"iterable_length, num_proc, expected_num_proc",
128
def test_map_nested_num_proc(iterable_length, num_proc, expected_num_proc):
129
with patch("datasets.utils.py_utils._single_map_nested") as mock_single_map_nested, patch(
130
"datasets.parallel.parallel.Pool"
131
) as mock_multiprocessing_pool:
132
data_struct = {f"{i}": i for i in range(iterable_length)}
133
_ = map_nested(lambda x: x + 10, data_struct, num_proc=num_proc, parallel_min_length=16)
134
if expected_num_proc == 1:
135
assert mock_single_map_nested.called
136
assert not mock_multiprocessing_pool.called
138
assert not mock_single_map_nested.called
139
assert mock_multiprocessing_pool.called
140
assert mock_multiprocessing_pool.call_args[0][0] == expected_num_proc
143
class TempSeedTest(TestCase):
145
def test_tensorflow(self):
146
import tensorflow as tf
147
from tensorflow.keras import layers
149
model = layers.Dense(2)
151
def gen_random_output():
152
x = tf.random.uniform((1, 3))
153
return model(x).numpy()
155
with temp_seed(42, set_tensorflow=True):
156
out1 = gen_random_output()
157
with temp_seed(42, set_tensorflow=True):
158
out2 = gen_random_output()
159
out3 = gen_random_output()
161
np.testing.assert_equal(out1, out2)
162
self.assertGreater(np.abs(out1 - out3).sum(), 0)
165
def test_torch(self):
168
def gen_random_output():
169
model = torch.nn.Linear(3, 2)
171
return model(x).detach().numpy()
173
with temp_seed(42, set_pytorch=True):
174
out1 = gen_random_output()
175
with temp_seed(42, set_pytorch=True):
176
out2 = gen_random_output()
177
out3 = gen_random_output()
179
np.testing.assert_equal(out1, out2)
180
self.assertGreater(np.abs(out1 - out3).sum(), 0)
182
def test_numpy(self):
183
def gen_random_output():
184
return np.random.rand(1, 3)
187
out1 = gen_random_output()
189
out2 = gen_random_output()
190
out3 = gen_random_output()
192
np.testing.assert_equal(out1, out2)
193
self.assertGreater(np.abs(out1 - out3).sum(), 0)
196
@pytest.mark.parametrize("input_data", [{}])
197
def test_nested_data_structure_data(input_data):
198
output_data = NestedDataStructure(input_data).data
199
assert output_data == input_data
202
@pytest.mark.parametrize(
203
"data, expected_output",
208
(["foo", "bar"], ["foo", "bar"]),
209
([["foo", "bar"]], ["foo", "bar"]),
210
([[["foo"], ["bar"]]], ["foo", "bar"]),
211
([[["foo"], "bar"]], ["foo", "bar"]),
212
({"a": 1, "b": 2}, [1, 2]),
213
({"a": [1, 2], "b": [3, 4]}, [1, 2, 3, 4]),
214
({"a": [[1, 2]], "b": [[3, 4]]}, [1, 2, 3, 4]),
215
({"a": [[1, 2]], "b": [3, 4]}, [1, 2, 3, 4]),
216
({"a": [[[1], [2]]], "b": [[[3], [4]]]}, [1, 2, 3, 4]),
217
({"a": [[[1], [2]]], "b": [[3, 4]]}, [1, 2, 3, 4]),
218
({"a": [[[1], [2]]], "b": [3, 4]}, [1, 2, 3, 4]),
219
({"a": [[[1], [2]]], "b": [3, [4]]}, [1, 2, 3, 4]),
220
({"a": {"1": 1}, "b": 2}, [1, 2]),
221
({"a": {"1": [1]}, "b": 2}, [1, 2]),
222
({"a": {"1": [1]}, "b": [2]}, [1, 2]),
225
def test_flatten(data, expected_output):
226
output = NestedDataStructure(data).flatten()
227
assert output == expected_output
231
input = A(x=1, y="foobar")
232
expected_output = {"x": 1, "y": "foobar"}
233
assert asdict(input) == expected_output
235
input = {"a": {"b": A(x=10, y="foo")}, "c": [A(x=20, y="bar")]}
236
expected_output = {"a": {"b": {"x": 10, "y": "foo"}}, "c": [{"x": 20, "y": "bar"}]}
237
assert asdict(input) == expected_output
239
with pytest.raises(TypeError):
240
asdict([1, A(x=10, y="foo")])
243
def _split_text(text: str):
247
def _2seconds_generator_of_2items_with_timing(content):
248
yield (time.time(), content)
250
yield (time.time(), content)
253
def test_iflatmap_unordered():
254
with Pool(2) as pool:
255
out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10))
256
assert out.count("hello") == 10
257
assert out.count("there") == 10
258
assert len(out) == 20
260
# check multiprocess from pathos (uses dill for pickling)
261
with multiprocess.Pool(2) as pool:
262
out = list(iflatmap_unordered(pool, _split_text, kwargs_iterable=[{"text": "hello there"}] * 10))
263
assert out.count("hello") == 10
264
assert out.count("there") == 10
265
assert len(out) == 20
267
# check that we get items as fast as possible
268
with Pool(2) as pool:
270
for yield_time, content in iflatmap_unordered(
271
pool, _2seconds_generator_of_2items_with_timing, kwargs_iterable=[{"content": "a"}, {"content": "b"}]
273
assert yield_time < time.time() + 0.1, "we should each item directly after it was yielded"
275
assert out.count("a") == 2
276
assert out.count("b") == 2