1
from unittest import TestCase
3
from datasets import Sequence, Value
4
from datasets.arrow_dataset import Dataset
7
class DatasetListTest(TestCase):
8
def _create_example_records(self):
10
{"col_1": 3, "col_2": "a"},
11
{"col_1": 2, "col_2": "b"},
12
{"col_1": 1, "col_2": "c"},
13
{"col_1": 0, "col_2": "d"},
16
def _create_example_dict(self):
17
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
18
return Dataset.from_dict(data)
20
def test_create(self):
21
example_records = self._create_example_records()
22
dset = Dataset.from_list(example_records)
23
self.assertListEqual(dset.column_names, ["col_1", "col_2"])
24
for i, r in enumerate(dset):
25
self.assertDictEqual(r, example_records[i])
27
def test_list_dict_equivalent(self):
28
example_records = self._create_example_records()
29
dset = Dataset.from_list(example_records)
30
dset_from_dict = Dataset.from_dict({k: [r[k] for r in example_records] for k in example_records[0]})
31
self.assertEqual(dset.info, dset_from_dict.info)
33
def test_uneven_records(self): # checks what happens with missing columns
34
uneven_records = [{"col_1": 1}, {"col_2": "x"}]
35
dset = Dataset.from_list(uneven_records)
36
self.assertDictEqual(dset[0], {"col_1": 1})
37
self.assertDictEqual(dset[1], {"col_1": None}) # NB: first record is used for columns
39
def test_variable_list_records(self): # checks if the type can be inferred from the second record
40
list_records = [{"col_1": []}, {"col_1": [1, 2]}]
41
dset = Dataset.from_list(list_records)
42
self.assertEqual(dset.info.features["col_1"], Sequence(Value("int64")))
44
def test_create_empty(self):
45
dset = Dataset.from_list([])
46
self.assertEqual(len(dset), 0)
47
self.assertListEqual(dset.column_names, [])