datasets
509 строк · 21.7 Кб
1import shutil
2import textwrap
3
4import numpy as np
5import pytest
6
7from datasets import ClassLabel, Features, Image, Value
8from datasets.data_files import DataFilesDict, get_data_patterns
9from datasets.download.streaming_download_manager import StreamingDownloadManager
10from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder
11
12from ..utils import require_pil
13
14
15@pytest.fixture
16def cache_dir(tmp_path):
17return str(tmp_path / "imagefolder_cache_dir")
18
19
20@pytest.fixture
21def data_files_with_labels_no_metadata(tmp_path, image_file):
22data_dir = tmp_path / "data_files_with_labels_no_metadata"
23data_dir.mkdir(parents=True, exist_ok=True)
24subdir_class_0 = data_dir / "cat"
25subdir_class_0.mkdir(parents=True, exist_ok=True)
26subdir_class_1 = data_dir / "dog"
27subdir_class_1.mkdir(parents=True, exist_ok=True)
28
29image_filename = subdir_class_0 / "image_cat.jpg"
30shutil.copyfile(image_file, image_filename)
31image_filename2 = subdir_class_1 / "image_dog.jpg"
32shutil.copyfile(image_file, image_filename2)
33
34data_files_with_labels_no_metadata = DataFilesDict.from_patterns(
35get_data_patterns(str(data_dir)), data_dir.as_posix()
36)
37
38return data_files_with_labels_no_metadata
39
40
41@pytest.fixture
42def image_files_with_labels_and_duplicated_label_key_in_metadata(tmp_path, image_file):
43data_dir = tmp_path / "image_files_with_labels_and_label_key_in_metadata"
44data_dir.mkdir(parents=True, exist_ok=True)
45subdir_class_0 = data_dir / "cat"
46subdir_class_0.mkdir(parents=True, exist_ok=True)
47subdir_class_1 = data_dir / "dog"
48subdir_class_1.mkdir(parents=True, exist_ok=True)
49
50image_filename = subdir_class_0 / "image_cat.jpg"
51shutil.copyfile(image_file, image_filename)
52image_filename2 = subdir_class_1 / "image_dog.jpg"
53shutil.copyfile(image_file, image_filename2)
54
55image_metadata_filename = tmp_path / data_dir / "metadata.jsonl"
56image_metadata = textwrap.dedent(
57"""\
58{"file_name": "cat/image_cat.jpg", "caption": "Nice image of a cat", "label": "Cat"}
59{"file_name": "dog/image_dog.jpg", "caption": "Nice image of a dog", "label": "Dog"}
60"""
61)
62with open(image_metadata_filename, "w", encoding="utf-8") as f:
63f.write(image_metadata)
64
65return str(image_filename), str(image_filename2), str(image_metadata_filename)
66
67
68@pytest.fixture
69def image_file_with_metadata(tmp_path, image_file):
70image_filename = tmp_path / "image_rgb.jpg"
71shutil.copyfile(image_file, image_filename)
72image_metadata_filename = tmp_path / "metadata.jsonl"
73image_metadata = textwrap.dedent(
74"""\
75{"file_name": "image_rgb.jpg", "caption": "Nice image"}
76"""
77)
78with open(image_metadata_filename, "w", encoding="utf-8") as f:
79f.write(image_metadata)
80return str(image_filename), str(image_metadata_filename)
81
82
83@pytest.fixture
84def image_files_with_metadata_that_misses_one_image(tmp_path, image_file):
85image_filename = tmp_path / "image_rgb.jpg"
86shutil.copyfile(image_file, image_filename)
87image_filename2 = tmp_path / "image_rgb2.jpg"
88shutil.copyfile(image_file, image_filename2)
89image_metadata_filename = tmp_path / "metadata.jsonl"
90image_metadata = textwrap.dedent(
91"""\
92{"file_name": "image_rgb.jpg", "caption": "Nice image"}
93"""
94)
95with open(image_metadata_filename, "w", encoding="utf-8") as f:
96f.write(image_metadata)
97return str(image_filename), str(image_filename2), str(image_metadata_filename)
98
99
100@pytest.fixture(params=["jsonl", "csv"])
101def data_files_with_one_split_and_metadata(request, tmp_path, image_file):
102data_dir = tmp_path / "imagefolder_data_dir_with_metadata_one_split"
103data_dir.mkdir(parents=True, exist_ok=True)
104subdir = data_dir / "subdir"
105subdir.mkdir(parents=True, exist_ok=True)
106
107image_filename = data_dir / "image_rgb.jpg"
108shutil.copyfile(image_file, image_filename)
109image_filename2 = data_dir / "image_rgb2.jpg"
110shutil.copyfile(image_file, image_filename2)
111image_filename3 = subdir / "image_rgb3.jpg" # in subdir
112shutil.copyfile(image_file, image_filename3)
113
114image_metadata_filename = data_dir / f"metadata.{request.param}"
115image_metadata = (
116textwrap.dedent(
117"""\
118{"file_name": "image_rgb.jpg", "caption": "Nice image"}
119{"file_name": "image_rgb2.jpg", "caption": "Nice second image"}
120{"file_name": "subdir/image_rgb3.jpg", "caption": "Nice third image"}
121"""
122)
123if request.param == "jsonl"
124else textwrap.dedent(
125"""\
126file_name,caption
127image_rgb.jpg,Nice image
128image_rgb2.jpg,Nice second image
129subdir/image_rgb3.jpg,Nice third image
130"""
131)
132)
133with open(image_metadata_filename, "w", encoding="utf-8") as f:
134f.write(image_metadata)
135data_files_with_one_split_and_metadata = DataFilesDict.from_patterns(
136get_data_patterns(str(data_dir)), data_dir.as_posix()
137)
138assert len(data_files_with_one_split_and_metadata) == 1
139assert len(data_files_with_one_split_and_metadata["train"]) == 4
140return data_files_with_one_split_and_metadata
141
142
143@pytest.fixture(params=["jsonl", "csv"])
144def data_files_with_two_splits_and_metadata(request, tmp_path, image_file):
145data_dir = tmp_path / "imagefolder_data_dir_with_metadata_two_splits"
146data_dir.mkdir(parents=True, exist_ok=True)
147train_dir = data_dir / "train"
148train_dir.mkdir(parents=True, exist_ok=True)
149test_dir = data_dir / "test"
150test_dir.mkdir(parents=True, exist_ok=True)
151
152image_filename = train_dir / "image_rgb.jpg" # train image
153shutil.copyfile(image_file, image_filename)
154image_filename2 = train_dir / "image_rgb2.jpg" # train image
155shutil.copyfile(image_file, image_filename2)
156image_filename3 = test_dir / "image_rgb3.jpg" # test image
157shutil.copyfile(image_file, image_filename3)
158
159train_image_metadata_filename = train_dir / f"metadata.{request.param}"
160image_metadata = (
161textwrap.dedent(
162"""\
163{"file_name": "image_rgb.jpg", "caption": "Nice train image"}
164{"file_name": "image_rgb2.jpg", "caption": "Nice second train image"}
165"""
166)
167if request.param == "jsonl"
168else textwrap.dedent(
169"""\
170file_name,caption
171image_rgb.jpg,Nice train image
172image_rgb2.jpg,Nice second train image
173"""
174)
175)
176with open(train_image_metadata_filename, "w", encoding="utf-8") as f:
177f.write(image_metadata)
178test_image_metadata_filename = test_dir / f"metadata.{request.param}"
179image_metadata = (
180textwrap.dedent(
181"""\
182{"file_name": "image_rgb3.jpg", "caption": "Nice test image"}
183"""
184)
185if request.param == "jsonl"
186else textwrap.dedent(
187"""\
188file_name,caption
189image_rgb3.jpg,Nice test image
190"""
191)
192)
193with open(test_image_metadata_filename, "w", encoding="utf-8") as f:
194f.write(image_metadata)
195data_files_with_two_splits_and_metadata = DataFilesDict.from_patterns(
196get_data_patterns(str(data_dir)), data_dir.as_posix()
197)
198assert len(data_files_with_two_splits_and_metadata) == 2
199assert len(data_files_with_two_splits_and_metadata["train"]) == 3
200assert len(data_files_with_two_splits_and_metadata["test"]) == 2
201return data_files_with_two_splits_and_metadata
202
203
204@pytest.fixture
205def data_files_with_zip_archives(tmp_path, image_file):
206from PIL import Image, ImageOps
207
208data_dir = tmp_path / "imagefolder_data_dir_with_zip_archives"
209data_dir.mkdir(parents=True, exist_ok=True)
210archive_dir = data_dir / "archive"
211archive_dir.mkdir(parents=True, exist_ok=True)
212subdir = archive_dir / "subdir"
213subdir.mkdir(parents=True, exist_ok=True)
214
215image_filename = archive_dir / "image_rgb.jpg"
216shutil.copyfile(image_file, image_filename)
217image_filename2 = subdir / "image_rgb2.jpg" # in subdir
218# make sure they're two different images
219# Indeed we won't be able to compare the image.filename, since the archive is not extracted in streaming mode
220ImageOps.flip(Image.open(image_file)).save(image_filename2)
221
222image_metadata_filename = archive_dir / "metadata.jsonl"
223image_metadata = textwrap.dedent(
224"""\
225{"file_name": "image_rgb.jpg", "caption": "Nice image"}
226{"file_name": "subdir/image_rgb2.jpg", "caption": "Nice second image"}
227"""
228)
229with open(image_metadata_filename, "w", encoding="utf-8") as f:
230f.write(image_metadata)
231
232shutil.make_archive(archive_dir, "zip", archive_dir)
233shutil.rmtree(str(archive_dir))
234
235data_files_with_zip_archives = DataFilesDict.from_patterns(get_data_patterns(str(data_dir)), data_dir.as_posix())
236
237assert len(data_files_with_zip_archives) == 1
238assert len(data_files_with_zip_archives["train"]) == 1
239return data_files_with_zip_archives
240
241
242@require_pil
243# check that labels are inferred correctly from dir names
244def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir):
245# there are no metadata.jsonl files in this test case
246imagefolder = ImageFolder(data_files=data_files_with_labels_no_metadata, cache_dir=cache_dir, drop_labels=False)
247imagefolder.download_and_prepare()
248assert imagefolder.info.features == Features({"image": Image(), "label": ClassLabel(names=["cat", "dog"])})
249dataset = list(imagefolder.as_dataset()["train"])
250label_feature = imagefolder.info.features["label"]
251
252assert dataset[0]["label"] == label_feature._str2int["cat"]
253assert dataset[1]["label"] == label_feature._str2int["dog"]
254
255
256@require_pil
257@pytest.mark.parametrize("drop_metadata", [None, True, False])
258@pytest.mark.parametrize("drop_labels", [None, True, False])
259def test_generate_examples_duplicated_label_key(
260image_files_with_labels_and_duplicated_label_key_in_metadata, drop_metadata, drop_labels, cache_dir, caplog
261):
262cat_image_file, dog_image_file, image_metadata_file = image_files_with_labels_and_duplicated_label_key_in_metadata
263imagefolder = ImageFolder(
264drop_metadata=drop_metadata,
265drop_labels=drop_labels,
266data_files=[cat_image_file, dog_image_file, image_metadata_file],
267cache_dir=cache_dir,
268)
269if drop_labels is False:
270# infer labels from directories even if metadata files are found
271imagefolder.download_and_prepare()
272warning_in_logs = any("ignoring metadata columns" in record.msg.lower() for record in caplog.records)
273assert warning_in_logs if drop_metadata is not True else not warning_in_logs
274dataset = imagefolder.as_dataset()["train"]
275assert imagefolder.info.features["label"] == ClassLabel(names=["cat", "dog"])
276assert all(example["label"] in imagefolder.info.features["label"]._str2int.values() for example in dataset)
277else:
278imagefolder.download_and_prepare()
279dataset = imagefolder.as_dataset()["train"]
280if drop_metadata is not True:
281# labels are from metadata
282assert imagefolder.info.features["label"] == Value("string")
283assert all(example["label"] in ["Cat", "Dog"] for example in dataset)
284else:
285# drop both labels and metadata
286assert imagefolder.info.features == Features({"image": Image()})
287assert all(example.keys() == {"image"} for example in dataset)
288
289
290@require_pil
291@pytest.mark.parametrize("drop_metadata", [None, True, False])
292@pytest.mark.parametrize("drop_labels", [None, True, False])
293def test_generate_examples_drop_labels(data_files_with_labels_no_metadata, drop_metadata, drop_labels):
294imagefolder = ImageFolder(
295drop_metadata=drop_metadata, drop_labels=drop_labels, data_files=data_files_with_labels_no_metadata
296)
297gen_kwargs = imagefolder._split_generators(StreamingDownloadManager())[0].gen_kwargs
298# removing the labels explicitly requires drop_labels=True
299assert gen_kwargs["add_labels"] is not bool(drop_labels)
300assert gen_kwargs["add_metadata"] is False
301generator = imagefolder._generate_examples(**gen_kwargs)
302if not drop_labels:
303assert all(
304example.keys() == {"image", "label"} and all(val is not None for val in example.values())
305for _, example in generator
306)
307else:
308assert all(
309example.keys() == {"image"} and all(val is not None for val in example.values())
310for _, example in generator
311)
312
313
314@require_pil
315@pytest.mark.parametrize("drop_metadata", [None, True, False])
316@pytest.mark.parametrize("drop_labels", [None, True, False])
317def test_generate_examples_drop_metadata(image_file_with_metadata, drop_metadata, drop_labels):
318image_file, image_metadata_file = image_file_with_metadata
319imagefolder = ImageFolder(
320drop_metadata=drop_metadata, drop_labels=drop_labels, data_files={"train": [image_file, image_metadata_file]}
321)
322gen_kwargs = imagefolder._split_generators(StreamingDownloadManager())[0].gen_kwargs
323# since the dataset has metadata, removing the metadata explicitly requires drop_metadata=True
324assert gen_kwargs["add_metadata"] is not bool(drop_metadata)
325# since the dataset has metadata, adding the labels explicitly requires drop_labels=False
326assert gen_kwargs["add_labels"] is (drop_labels is False)
327generator = imagefolder._generate_examples(**gen_kwargs)
328expected_columns = {"image"}
329if gen_kwargs["add_metadata"]:
330expected_columns.add("caption")
331if gen_kwargs["add_labels"]:
332expected_columns.add("label")
333result = [example for _, example in generator]
334assert len(result) == 1
335example = result[0]
336assert example.keys() == expected_columns
337for column in expected_columns:
338assert example[column] is not None
339
340
341@require_pil
342@pytest.mark.parametrize("drop_metadata", [None, True, False])
343def test_generate_examples_with_metadata_in_wrong_location(image_file, image_file_with_metadata, drop_metadata):
344_, image_metadata_file = image_file_with_metadata
345imagefolder = ImageFolder(drop_metadata=drop_metadata, data_files={"train": [image_file, image_metadata_file]})
346gen_kwargs = imagefolder._split_generators(StreamingDownloadManager())[0].gen_kwargs
347generator = imagefolder._generate_examples(**gen_kwargs)
348if not drop_metadata:
349with pytest.raises(ValueError):
350list(generator)
351else:
352assert all(
353example.keys() == {"image"} and all(val is not None for val in example.values())
354for _, example in generator
355)
356
357
358@require_pil
359@pytest.mark.parametrize("drop_metadata", [None, True, False])
360def test_generate_examples_with_metadata_that_misses_one_image(
361image_files_with_metadata_that_misses_one_image, drop_metadata
362):
363image_file, image_file2, image_metadata_file = image_files_with_metadata_that_misses_one_image
364if not drop_metadata:
365features = Features({"image": Image(), "caption": Value("string")})
366else:
367features = Features({"image": Image()})
368imagefolder = ImageFolder(
369drop_metadata=drop_metadata,
370features=features,
371data_files={"train": [image_file, image_file2, image_metadata_file]},
372)
373gen_kwargs = imagefolder._split_generators(StreamingDownloadManager())[0].gen_kwargs
374generator = imagefolder._generate_examples(**gen_kwargs)
375if not drop_metadata:
376with pytest.raises(ValueError):
377list(generator)
378else:
379assert all(
380example.keys() == {"image"} and all(val is not None for val in example.values())
381for _, example in generator
382)
383
384
385@require_pil
386@pytest.mark.parametrize("streaming", [False, True])
387def test_data_files_with_metadata_and_single_split(streaming, cache_dir, data_files_with_one_split_and_metadata):
388data_files = data_files_with_one_split_and_metadata
389imagefolder = ImageFolder(data_files=data_files, cache_dir=cache_dir)
390imagefolder.download_and_prepare()
391datasets = imagefolder.as_streaming_dataset() if streaming else imagefolder.as_dataset()
392for split, data_files in data_files.items():
393expected_num_of_images = len(data_files) - 1 # don't count the metadata file
394assert split in datasets
395dataset = list(datasets[split])
396assert len(dataset) == expected_num_of_images
397# make sure each sample has its own image and metadata
398assert len({example["image"].filename for example in dataset}) == expected_num_of_images
399assert len({example["caption"] for example in dataset}) == expected_num_of_images
400assert all(example["caption"] is not None for example in dataset)
401
402
403@require_pil
404@pytest.mark.parametrize("streaming", [False, True])
405def test_data_files_with_metadata_and_multiple_splits(streaming, cache_dir, data_files_with_two_splits_and_metadata):
406data_files = data_files_with_two_splits_and_metadata
407imagefolder = ImageFolder(data_files=data_files, cache_dir=cache_dir)
408imagefolder.download_and_prepare()
409datasets = imagefolder.as_streaming_dataset() if streaming else imagefolder.as_dataset()
410for split, data_files in data_files.items():
411expected_num_of_images = len(data_files) - 1 # don't count the metadata file
412assert split in datasets
413dataset = list(datasets[split])
414assert len(dataset) == expected_num_of_images
415# make sure each sample has its own image and metadata
416assert len({example["image"].filename for example in dataset}) == expected_num_of_images
417assert len({example["caption"] for example in dataset}) == expected_num_of_images
418assert all(example["caption"] is not None for example in dataset)
419
420
421@require_pil
422@pytest.mark.parametrize("streaming", [False, True])
423def test_data_files_with_metadata_and_archives(streaming, cache_dir, data_files_with_zip_archives):
424imagefolder = ImageFolder(data_files=data_files_with_zip_archives, cache_dir=cache_dir)
425imagefolder.download_and_prepare()
426datasets = imagefolder.as_streaming_dataset() if streaming else imagefolder.as_dataset()
427for split, data_files in data_files_with_zip_archives.items():
428num_of_archives = len(data_files) # the metadata file is inside the archive
429expected_num_of_images = 2 * num_of_archives
430assert split in datasets
431dataset = list(datasets[split])
432assert len(dataset) == expected_num_of_images
433# make sure each sample has its own image and metadata
434assert len({np.array(example["image"])[0, 0, 0] for example in dataset}) == expected_num_of_images
435assert len({example["caption"] for example in dataset}) == expected_num_of_images
436assert all(example["caption"] is not None for example in dataset)
437
438
439@require_pil
440def test_data_files_with_wrong_metadata_file_name(cache_dir, tmp_path, image_file):
441data_dir = tmp_path / "data_dir_with_bad_metadata"
442data_dir.mkdir(parents=True, exist_ok=True)
443shutil.copyfile(image_file, data_dir / "image_rgb.jpg")
444image_metadata_filename = data_dir / "bad_metadata.jsonl" # bad file
445image_metadata = textwrap.dedent(
446"""\
447{"file_name": "image_rgb.jpg", "caption": "Nice image"}
448"""
449)
450with open(image_metadata_filename, "w", encoding="utf-8") as f:
451f.write(image_metadata)
452
453data_files_with_bad_metadata = DataFilesDict.from_patterns(get_data_patterns(str(data_dir)), data_dir.as_posix())
454imagefolder = ImageFolder(data_files=data_files_with_bad_metadata, cache_dir=cache_dir)
455imagefolder.download_and_prepare()
456dataset = imagefolder.as_dataset(split="train")
457# check that there are no metadata, since the metadata file name doesn't have the right name
458assert "caption" not in dataset.column_names
459
460
461@require_pil
462def test_data_files_with_wrong_image_file_name_column_in_metadata_file(cache_dir, tmp_path, image_file):
463data_dir = tmp_path / "data_dir_with_bad_metadata"
464data_dir.mkdir(parents=True, exist_ok=True)
465shutil.copyfile(image_file, data_dir / "image_rgb.jpg")
466image_metadata_filename = data_dir / "metadata.jsonl"
467image_metadata = textwrap.dedent( # with bad column "bad_file_name" instead of "file_name"
468"""\
469{"bad_file_name": "image_rgb.jpg", "caption": "Nice image"}
470"""
471)
472with open(image_metadata_filename, "w", encoding="utf-8") as f:
473f.write(image_metadata)
474
475data_files_with_bad_metadata = DataFilesDict.from_patterns(get_data_patterns(str(data_dir)), data_dir.as_posix())
476imagefolder = ImageFolder(data_files=data_files_with_bad_metadata, cache_dir=cache_dir)
477with pytest.raises(ValueError) as exc_info:
478imagefolder.download_and_prepare()
479assert "`file_name` must be present" in str(exc_info.value)
480
481
482@require_pil
483def test_data_files_with_with_metadata_in_different_formats(cache_dir, tmp_path, image_file):
484data_dir = tmp_path / "data_dir_with_metadata_in_different_format"
485data_dir.mkdir(parents=True, exist_ok=True)
486shutil.copyfile(image_file, data_dir / "image_rgb.jpg")
487image_metadata_filename_jsonl = data_dir / "metadata.jsonl"
488image_metadata_jsonl = textwrap.dedent(
489"""\
490{"file_name": "image_rgb.jpg", "caption": "Nice image"}
491"""
492)
493with open(image_metadata_filename_jsonl, "w", encoding="utf-8") as f:
494f.write(image_metadata_jsonl)
495image_metadata_filename_csv = data_dir / "metadata.csv"
496image_metadata_csv = textwrap.dedent(
497"""\
498file_name,caption
499image_rgb.jpg,Nice image
500"""
501)
502with open(image_metadata_filename_csv, "w", encoding="utf-8") as f:
503f.write(image_metadata_csv)
504
505data_files_with_bad_metadata = DataFilesDict.from_patterns(get_data_patterns(str(data_dir)), data_dir.as_posix())
506imagefolder = ImageFolder(data_files=data_files_with_bad_metadata, cache_dir=cache_dir)
507with pytest.raises(ValueError) as exc_info:
508imagefolder.download_and_prepare()
509assert "metadata files with different extensions" in str(exc_info.value)
510