1
import concurrent.futures
7
from pathlib import Path
13
import torchvision.transforms.v2.functional as F
14
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
15
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
16
from torchvision.io.image import (
34
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
35
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
36
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
37
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
38
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
39
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
40
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
41
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
42
IS_WINDOWS = sys.platform in ("win32", "cygwin")
43
IS_MACOS = sys.platform == "darwin"
44
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
45
WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "")
47
# Hacky way of figuring out whether we compiled with libavif/libheif (those are
48
# currenlty disabled by default)
50
_decode_avif(torch.arange(10, dtype=torch.uint8))
52
DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e)
55
_decode_heic(torch.arange(10, dtype=torch.uint8))
57
DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e)
60
def _get_safe_image_name(name):
61
# Used when we need to change the pytest "id" for an "image path" parameter.
62
# If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
63
# and this creates issues when the test is running in a different machine than where it was collected
64
# (typically, in fb internal infra)
65
return name.split(os.path.sep)[-1]
68
def get_images(directory, img_ext):
69
assert os.path.isdir(directory)
70
image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
71
for path in image_paths:
72
if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
76
def pil_read_image(img_path):
77
with Image.open(img_path) as img:
78
return torch.from_numpy(np.array(img))
81
def normalize_dimensions(img_pil):
82
if len(img_pil.shape) == 3:
83
img_pil = img_pil.permute(2, 0, 1)
85
img_pil = img_pil.unsqueeze(0)
89
@pytest.mark.parametrize(
91
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
93
@pytest.mark.parametrize(
96
(None, ImageReadMode.UNCHANGED),
97
("L", ImageReadMode.GRAY),
98
("RGB", ImageReadMode.RGB),
101
@pytest.mark.parametrize("scripted", (False, True))
102
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
103
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):
105
with Image.open(img_path) as img:
106
is_cmyk = img.mode == "CMYK"
107
if pil_mode is not None:
108
img = img.convert(pil_mode)
109
img_pil = torch.from_numpy(np.array(img))
110
if is_cmyk and mode == ImageReadMode.UNCHANGED:
111
# flip the colors to match libjpeg
112
img_pil = 255 - img_pil
114
img_pil = normalize_dimensions(img_pil)
115
data = read_file(img_path)
117
decode_fun = torch.jit.script(decode_fun)
118
img_ljpeg = decode_fun(data, mode=mode)
120
# Permit a small variation on pixel values to account for implementation
121
# differences between Pillow and LibJPEG.
122
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
123
assert abs_mean_diff < 2
126
@pytest.mark.parametrize("codec", ["png", "jpeg"])
127
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
128
def test_decode_with_exif_orientation(tmpdir, codec, orientation):
129
fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.{codec}")
130
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
131
im = F.to_pil_image(t)
133
exif[0x0112] = orientation # set exif orientation
134
im.save(fp, codec.upper(), exif=exif.tobytes())
137
output = decode_image(data, apply_exif_orientation=True)
139
pimg = Image.open(fp)
140
pimg = ImageOps.exif_transpose(pimg)
142
expected = F.pil_to_tensor(pimg)
143
torch.testing.assert_close(expected, output)
146
@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
147
def test_invalid_exif(tmpdir, size):
148
# Inspired from a PIL test:
149
# https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
150
fp = os.path.join(tmpdir, "invalid_exif.jpg")
151
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
152
im = F.to_pil_image(t)
153
im.save(fp, "JPEG", exif=b"1" * size)
156
output = decode_image(data, apply_exif_orientation=True)
158
pimg = Image.open(fp)
159
pimg = ImageOps.exif_transpose(pimg)
161
expected = F.pil_to_tensor(pimg)
162
torch.testing.assert_close(expected, output)
165
def test_decode_bad_huffman_images():
166
# sanity check: make sure we can decode the bad Huffman encoding
167
bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
168
decode_jpeg(bad_huff)
171
@pytest.mark.parametrize(
174
pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
175
for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
178
def test_damaged_corrupt_images(img_path):
179
# Truncated images should raise an exception
180
data = read_file(img_path)
181
if "corrupt34" in img_path:
182
match_message = "Image is incomplete or truncated"
184
match_message = "Unsupported marker type"
185
with pytest.raises(RuntimeError, match=match_message):
189
@pytest.mark.parametrize(
191
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
193
@pytest.mark.parametrize(
196
(None, ImageReadMode.UNCHANGED),
197
("L", ImageReadMode.GRAY),
198
("LA", ImageReadMode.GRAY_ALPHA),
199
("RGB", ImageReadMode.RGB),
200
("RGBA", ImageReadMode.RGB_ALPHA),
203
@pytest.mark.parametrize("scripted", (False, True))
204
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
205
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
208
decode_fun = torch.jit.script(decode_fun)
210
with Image.open(img_path) as img:
211
if pil_mode is not None:
212
img = img.convert(pil_mode)
213
img_pil = torch.from_numpy(np.array(img))
215
img_pil = normalize_dimensions(img_pil)
217
if img_path.endswith("16.png"):
218
data = read_file(img_path)
219
img_lpng = decode_fun(data, mode=mode)
220
assert img_lpng.dtype == torch.uint16
221
# PIL converts 16 bits pngs to uint8
222
img_lpng = F.to_dtype(img_lpng, torch.uint8, scale=True)
224
data = read_file(img_path)
225
img_lpng = decode_fun(data, mode=mode)
227
tol = 0 if pil_mode is None else 1
229
if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
230
# Avoid checking the transparency channel until
231
# https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
233
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
234
img_lpng, img_pil = img_lpng[0], img_pil[0]
236
torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
239
def test_decode_png_errors():
240
with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
241
decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
242
with pytest.raises(RuntimeError, match="Content is too small for png"):
243
decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
246
@pytest.mark.parametrize(
248
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
250
@pytest.mark.parametrize("scripted", (True, False))
251
def test_encode_png(img_path, scripted):
252
pil_image = Image.open(img_path)
253
img_pil = torch.from_numpy(np.array(pil_image))
254
img_pil = img_pil.permute(2, 0, 1)
255
encode = torch.jit.script(encode_png) if scripted else encode_png
256
png_buf = encode(img_pil, compression_level=6)
258
rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
259
rec_img = torch.from_numpy(np.array(rec_img))
260
rec_img = rec_img.permute(2, 0, 1)
262
assert_equal(img_pil, rec_img)
265
def test_encode_png_errors():
266
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
267
encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
269
with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
270
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
272
with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
273
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
275
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
276
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
279
@pytest.mark.parametrize(
281
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
283
@pytest.mark.parametrize("scripted", (True, False))
284
def test_write_png(img_path, tmpdir, scripted):
285
pil_image = Image.open(img_path)
286
img_pil = torch.from_numpy(np.array(pil_image))
287
img_pil = img_pil.permute(2, 0, 1)
289
filename, _ = os.path.splitext(os.path.basename(img_path))
290
torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
291
write = torch.jit.script(write_png) if scripted else write_png
292
write(img_pil, torch_png, compression_level=6)
293
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
294
saved_image = saved_image.permute(2, 0, 1)
296
assert_equal(img_pil, saved_image)
299
def test_read_image():
300
# Just testing torchcsript, the functionality is somewhat tested already in other tests.
301
path = next(get_images(IMAGE_ROOT, ".jpg"))
302
out = read_image(path)
303
out_scripted = torch.jit.script(read_image)(path)
304
torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)
307
@pytest.mark.parametrize("scripted", (True, False))
308
def test_read_file(tmpdir, scripted):
309
fname, content = "test1.bin", b"TorchVision\211\n"
310
fpath = os.path.join(tmpdir, fname)
311
with open(fpath, "wb") as f:
314
fun = torch.jit.script(read_file) if scripted else read_file
316
expected = torch.tensor(list(content), dtype=torch.uint8)
318
assert_equal(data, expected)
320
with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
324
def test_read_file_non_ascii(tmpdir):
325
fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
326
fpath = os.path.join(tmpdir, fname)
327
with open(fpath, "wb") as f:
330
data = read_file(fpath)
331
expected = torch.tensor(list(content), dtype=torch.uint8)
333
assert_equal(data, expected)
336
@pytest.mark.parametrize("scripted", (True, False))
337
def test_write_file(tmpdir, scripted):
338
fname, content = "test1.bin", b"TorchVision\211\n"
339
fpath = os.path.join(tmpdir, fname)
340
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
341
write = torch.jit.script(write_file) if scripted else write_file
342
write(fpath, content_tensor)
344
with open(fpath, "rb") as f:
345
saved_content = f.read()
347
assert content == saved_content
350
def test_write_file_non_ascii(tmpdir):
351
fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
352
fpath = os.path.join(tmpdir, fname)
353
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
354
write_file(fpath, content_tensor)
356
with open(fpath, "rb") as f:
357
saved_content = f.read()
359
assert content == saved_content
362
@pytest.mark.parametrize(
370
def test_read_1_bit_png(shape, tmpdir):
371
np_rng = np.random.RandomState(0)
372
image_path = os.path.join(tmpdir, f"test_{shape}.png")
373
pixels = np_rng.rand(*shape) > 0.5
374
img = Image.fromarray(pixels)
376
img1 = read_image(image_path)
377
img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
378
assert_equal(img1, img2)
381
@pytest.mark.parametrize(
389
@pytest.mark.parametrize(
392
ImageReadMode.UNCHANGED,
396
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
397
np_rng = np.random.RandomState(0)
398
image_path = os.path.join(tmpdir, f"test_{shape}.png")
399
pixels = np_rng.rand(*shape) > 0.5
400
img = Image.fromarray(pixels)
402
img1 = read_image(image_path, mode)
403
img2 = read_image(image_path, mode)
404
assert_equal(img1, img2)
407
def test_read_interlaced_png():
408
imgs = list(get_images(INTERLACED_PNG, ".png"))
409
with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
410
assert not (im1.info.get("interlace") is im2.info.get("interlace"))
411
img1 = read_image(imgs[0])
412
img2 = read_image(imgs[1])
413
assert_equal(img1, img2)
417
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
418
@pytest.mark.parametrize("scripted", (False, True))
419
def test_decode_jpegs_cuda(mode, scripted):
421
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
422
if "cmyk" in jpeg_path:
424
encoded_image = read_file(jpeg_path)
425
encoded_images.append(encoded_image)
426
decoded_images_cpu = decode_jpeg(encoded_images, mode=mode)
427
decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
429
# test multithreaded decoding
430
# in the current version we prevent this by using a lock but we still want to test it
433
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
434
futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)]
435
decoded_images_threaded = [future.result() for future in futures]
436
assert len(decoded_images_threaded) == num_workers
437
for decoded_images in decoded_images_threaded:
438
assert len(decoded_images) == len(encoded_images)
439
for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu):
440
assert decoded_image_cuda.shape == decoded_image_cpu.shape
441
assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8
442
assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2
446
def test_decode_image_cuda_raises():
447
data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
448
with pytest.raises(RuntimeError):
453
def test_decode_jpeg_cuda_device_param():
454
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
455
data = read_file(path)
456
current_device = torch.cuda.current_device()
457
current_stream = torch.cuda.current_stream()
458
num_devices = torch.cuda.device_count()
459
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
461
for device in devices:
462
results.append(decode_jpeg(data, device=device))
463
assert len(results) == len(devices)
464
for result in results:
465
assert torch.all(result.cpu() == results[0].cpu())
466
assert current_device == torch.cuda.current_device()
467
assert current_stream == torch.cuda.current_stream()
471
def test_decode_jpeg_cuda_errors():
472
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
473
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
474
decode_jpeg(data.reshape(-1, 1), device="cuda")
475
with pytest.raises(ValueError, match="must be tensors"):
476
decode_jpeg([1, 2, 3])
477
with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
478
decode_jpeg(data.to("cuda"), device="cuda")
479
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
480
decode_jpeg(data.to(torch.float), device="cuda")
481
with pytest.raises(RuntimeError, match="Expected the device parameter to be a cuda device"):
482
torch.ops.image.decode_jpegs_cuda([data], ImageReadMode.UNCHANGED.value, "cpu")
483
with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
485
torch.empty((100,), dtype=torch.uint8, device="cuda"),
487
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
490
torch.empty((100,), dtype=torch.uint8, device="cuda"),
491
torch.empty((100,), dtype=torch.uint8, device="cuda"),
495
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
498
torch.empty((100,), dtype=torch.uint8, device="cuda"),
499
torch.empty((100,), dtype=torch.uint8, device="cuda"),
504
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
507
torch.empty((100,), dtype=torch.uint8, device="cpu"),
508
torch.empty((100,), dtype=torch.uint8, device="cuda"),
513
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
516
torch.empty((100,), dtype=torch.uint8),
517
torch.empty((100,), dtype=torch.float32),
522
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
525
torch.empty((100,), dtype=torch.uint8),
526
torch.empty((1, 100), dtype=torch.uint8),
531
with pytest.raises(RuntimeError, match="Error while decoding JPEG images"):
534
torch.empty((100,), dtype=torch.uint8),
535
torch.empty((100,), dtype=torch.uint8),
540
with pytest.raises(ValueError, match="Input list must contain at least one element"):
541
decode_jpeg([], device="cuda")
544
def test_encode_jpeg_errors():
546
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
547
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
549
with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
550
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
552
with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
553
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
555
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
556
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
558
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
559
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
561
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
562
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
565
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
566
@pytest.mark.parametrize(
568
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
570
@pytest.mark.parametrize("scripted", (True, False))
571
def test_encode_jpeg(img_path, scripted):
572
img = read_image(img_path)
574
pil_img = F.to_pil_image(img)
576
pil_img.save(buf, format="JPEG", quality=75)
578
encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
580
encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
581
for src_img in [img, img.contiguous()]:
582
encoded_jpeg_torch = encode(src_img, quality=75)
583
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
587
def test_encode_jpeg_cuda_device_param():
588
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
590
data = read_image(path)
592
current_device = torch.cuda.current_device()
593
current_stream = torch.cuda.current_stream()
594
num_devices = torch.cuda.device_count()
595
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
597
for device in devices:
598
results.append(encode_jpeg(data.to(device=device)))
599
assert len(results) == len(devices)
600
for result in results:
601
assert torch.all(result.cpu() == results[0].cpu())
602
assert current_device == torch.cuda.current_device()
603
assert current_stream == torch.cuda.current_stream()
607
@pytest.mark.parametrize(
609
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
611
@pytest.mark.parametrize("scripted", (False, True))
612
@pytest.mark.parametrize("contiguous", (False, True))
613
def test_encode_jpeg_cuda(img_path, scripted, contiguous):
614
decoded_image_tv = read_image(img_path)
615
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
617
if "cmyk" in img_path:
618
pytest.xfail("Encoding a CMYK jpeg isn't supported")
619
if decoded_image_tv.shape[0] == 1:
620
pytest.xfail("Decoding a grayscale jpeg isn't supported")
621
# For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
623
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0]
625
decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0]
626
encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75)
627
decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())
629
# the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality
630
# instead, we re-decode the encoded image and compare to the original
631
abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
632
assert abs_mean_diff < 3
635
@pytest.mark.parametrize("device", cpu_and_cuda())
636
@pytest.mark.parametrize("scripted", (True, False))
637
@pytest.mark.parametrize("contiguous", (True, False))
638
def test_encode_jpegs_batch(scripted, contiguous, device):
639
if device == "cpu" and IS_MACOS:
640
pytest.skip("https://github.com/pytorch/vision/issues/8031")
641
decoded_images_tv = []
642
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
643
if "cmyk" in jpeg_path:
645
decoded_image = read_image(jpeg_path)
646
if decoded_image.shape[0] == 1:
649
decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0]
651
decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0]
652
decoded_images_tv.append(decoded_image)
654
encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
656
decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv]
657
encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75)
658
encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device]
660
for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device):
661
c, h, w = original.shape
662
abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item()
663
assert abs_mean_diff < 3
665
# test multithreaded decoding
666
# in the current version we prevent this by using a lock but we still want to test it
668
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
669
futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)]
670
encoded_images_threaded = [future.result() for future in futures]
671
assert len(encoded_images_threaded) == num_workers
672
for encoded_images in encoded_images_threaded:
673
assert len(decoded_images_tv_device) == len(encoded_images)
674
for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)):
675
# make sure all the threads produce identical outputs
676
assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i])
678
# make sure the outputs are identical or close enough to baseline
679
decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu())
680
assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape
681
assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype
682
assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3
686
def test_single_encode_jpeg_cuda_errors():
687
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
688
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"))
690
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
691
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"))
693
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
694
encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"))
696
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
697
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"))
699
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
700
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda"))
704
def test_batch_encode_jpegs_cuda_errors():
705
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
708
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
709
torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"),
713
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
716
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
717
torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"),
721
with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
724
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
725
torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"),
729
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
732
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
733
torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"),
737
with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
740
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
741
torch.empty((100, 100), dtype=torch.uint8, device="cuda"),
745
with pytest.raises(RuntimeError, match="Input tensor should be on CPU"):
748
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
749
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
754
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
758
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
759
torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
763
if torch.cuda.device_count() >= 2:
765
RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
769
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"),
770
torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"),
774
with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"):
778
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
779
@pytest.mark.parametrize(
781
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
783
@pytest.mark.parametrize("scripted", (True, False))
784
def test_write_jpeg(img_path, tmpdir, scripted):
785
tmpdir = Path(tmpdir)
786
img = read_image(img_path)
787
pil_img = F.to_pil_image(img)
789
torch_jpeg = str(tmpdir / "torch.jpg")
790
pil_jpeg = str(tmpdir / "pil.jpg")
792
write = torch.jit.script(write_jpeg) if scripted else write_jpeg
793
write(img, torch_jpeg, quality=75)
794
pil_img.save(pil_jpeg, quality=75)
796
with open(torch_jpeg, "rb") as f:
797
torch_bytes = f.read()
799
with open(pil_jpeg, "rb") as f:
802
assert_equal(torch_bytes, pil_bytes)
805
def test_pathlib_support(tmpdir):
806
# Just make sure pathlib.Path is supported where relevant
808
jpeg_path = Path(next(get_images(ENCODE_JPEG, ".jpg")))
811
read_image(jpeg_path)
813
write_path = Path(tmpdir) / "whatever"
814
img = torch.randint(0, 10, size=(3, 4, 4), dtype=torch.uint8)
816
write_file(write_path, data=img.flatten())
817
write_jpeg(img, write_path)
818
write_png(img, write_path)
821
@pytest.mark.parametrize(
822
"name", ("gifgrid", "fire", "porsche", "treescap", "treescap-interlaced", "solid2", "x-trans", "earth")
824
@pytest.mark.parametrize("scripted", (True, False))
825
def test_decode_gif(tmpdir, name, scripted):
826
# Using test images from GIFLIB
827
# https://sourceforge.net/p/giflib/code/ci/master/tree/pic/, we assert PIL
828
# and torchvision decoded outputs are equal.
829
# We're not testing against "welcome2" because PIL and GIFLIB disagee on what
830
# the background color should be (likely a difference in the way they handle
832
# 'earth' image is from wikipedia, licensed under CC BY-SA 3.0
833
# https://creativecommons.org/licenses/by-sa/3.0/
834
# it allows to properly test for transparency, TOP-LEFT offsets, and
837
path = tmpdir / f"{name}.gif"
840
# TODO: Fix this... one day.
841
pytest.skip("Skipping 'earth' test as it's flaky on OSS CI")
842
url = "https://upload.wikimedia.org/wikipedia/commons/2/2c/Rotating_earth_%28large%29.gif"
844
url = f"https://sourceforge.net/p/giflib/code/ci/master/tree/pic/{name}.gif?format=raw"
845
with open(path, "wb") as f:
846
f.write(requests.get(url).content)
848
encoded_bytes = read_file(path)
849
f = torch.jit.script(decode_gif) if scripted else decode_gif
850
tv_out = f(encoded_bytes)
852
tv_out = tv_out[None]
854
assert tv_out.is_contiguous(memory_format=torch.channels_last)
856
# For some reason, not using Image.open() as a CM causes "ResourceWarning: unclosed file"
857
with Image.open(path) as pil_img:
858
pil_seq = ImageSequence.Iterator(pil_img)
860
for pil_frame, tv_frame in zip(pil_seq, tv_out):
861
pil_frame = F.pil_to_tensor(pil_frame.convert("RGB"))
862
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)
865
decode_fun_and_match = [
866
(decode_png, "Content is not png"),
867
(decode_jpeg, "Not a JPEG file"),
868
(decode_gif, re.escape("DGifOpenFileName() failed - 103")),
869
(decode_webp, "WebPGetFeatures failed."),
871
if DECODE_AVIF_ENABLED:
872
decode_fun_and_match.append((_decode_avif, "BMFF parsing failed"))
873
if DECODE_HEIC_ENABLED:
874
decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box"))
877
@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match)
878
def test_decode_bad_encoded_data(decode_fun, match):
879
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
880
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
881
decode_fun(encoded_data[None])
882
with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
883
decode_fun(encoded_data.float())
884
with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
885
decode_fun(encoded_data[::2])
886
with pytest.raises(RuntimeError, match=match):
887
decode_fun(encoded_data)
890
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
891
@pytest.mark.parametrize("scripted", (False, True))
892
def test_decode_webp(decode_fun, scripted):
893
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp")))
895
decode_fun = torch.jit.script(decode_fun)
896
img = decode_fun(encoded_bytes)
897
assert img.shape == (3, 100, 100)
898
assert img[None].is_contiguous(memory_format=torch.channels_last)
899
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
902
# This test is skipped by default because it requires webp images that we're not
903
# including within the repo. The test images were downloaded manually from the
904
# different pages of https://developers.google.com/speed/webp/gallery
905
@pytest.mark.skipif(not WEBP_TEST_IMAGES_DIR, reason="WEBP_TEST_IMAGES_DIR is not set")
906
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
907
@pytest.mark.parametrize("scripted", (False, True))
908
@pytest.mark.parametrize(
911
# Note that converting an RGBA image to RGB leads to bad results because the
912
# transparent pixels aren't necessarily set to "black" or "white", they can be
913
# random stuff. This is consistent with PIL results.
914
(ImageReadMode.RGB, "RGB"),
915
(ImageReadMode.RGB_ALPHA, "RGBA"),
916
(ImageReadMode.UNCHANGED, None),
919
@pytest.mark.parametrize("filename", Path(WEBP_TEST_IMAGES_DIR).glob("*.webp"), ids=lambda p: p.name)
920
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
921
encoded_bytes = read_file(filename)
923
decode_fun = torch.jit.script(decode_fun)
924
img = decode_fun(encoded_bytes, mode=mode)
925
assert img[None].is_contiguous(memory_format=torch.channels_last)
927
pil_img = Image.open(filename).convert(pil_mode)
928
from_pil = F.pil_to_tensor(pil_img)
929
assert_equal(img, from_pil)
930
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
933
@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.")
934
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
935
@pytest.mark.parametrize("scripted", (False, True))
936
def test_decode_avif(decode_fun, scripted):
937
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif")))
939
decode_fun = torch.jit.script(decode_fun)
940
img = decode_fun(encoded_bytes)
941
assert img.shape == (3, 100, 100)
942
assert img[None].is_contiguous(memory_format=torch.channels_last)
943
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
946
# Note: decode_image fails because some of these files have a (valid) signature
947
# we don't recognize. We should probably use libmagic....
949
if DECODE_AVIF_ENABLED:
950
decode_funs.append(_decode_avif)
951
if DECODE_HEIC_ENABLED:
952
decode_funs.append(_decode_heic)
955
@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.")
956
@pytest.mark.parametrize("decode_fun", decode_funs)
957
@pytest.mark.parametrize("scripted", (False, True))
958
@pytest.mark.parametrize(
961
(ImageReadMode.RGB, "RGB"),
962
(ImageReadMode.RGB_ALPHA, "RGBA"),
963
(ImageReadMode.UNCHANGED, None),
966
@pytest.mark.parametrize(
967
"filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
969
def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename):
970
if "reversed_dimg_order" in str(filename):
971
# Pillow properly decodes this one, but we don't (order of parts of the
972
# image is wrong). This is due to a bug that was recently fixed in
973
# libavif. Hopefully this test will end up passing soon with a new
974
# libavif version https://github.com/AOMediaCodec/libavif/issues/2311
976
import pillow_avif # noqa
978
encoded_bytes = read_file(filename)
980
decode_fun = torch.jit.script(decode_fun)
982
img = decode_fun(encoded_bytes, mode=mode)
983
except RuntimeError as e:
987
"BMFF parsing failed",
988
"avifDecoderParse failed: ",
989
"file contains more than one image",
990
"no 'ispe' property",
991
"'iref' has double references",
992
"Invalid image grid",
995
pytest.skip(reason="Expected failure, that's OK")
998
assert img[None].is_contiguous(memory_format=torch.channels_last)
999
if mode == ImageReadMode.RGB:
1000
assert img.shape[0] == 3
1001
if mode == ImageReadMode.RGB_ALPHA:
1002
assert img.shape[0] == 4
1004
if img.dtype == torch.uint16:
1005
img = F.to_dtype(img, dtype=torch.uint8, scale=True)
1007
from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
1008
except RuntimeError as e:
1009
if "Invalid image grid" in str(e):
1010
pytest.skip(reason="PIL failure")
1015
from torchvision.utils import make_grid
1017
g = make_grid([img, from_pil])
1018
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))
1020
is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
1021
if mode == ImageReadMode.RGB and not is_decode_heic:
1022
# We don't compare torchvision's AVIF against PIL for RGB because
1023
# results look pretty different on RGBA images (other images are fine).
1024
# The result on torchvision basically just plainly ignores the alpha
1025
# channel, resuting in transparent pixels looking dark. PIL seems to be
1026
# using a sort of k-nn thing (Take a look at the resuting images)
1028
if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic:
1031
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)
1034
@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.")
1035
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
1036
@pytest.mark.parametrize("scripted", (False, True))
1037
def test_decode_heic(decode_fun, scripted):
1038
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic")))
1040
decode_fun = torch.jit.script(decode_fun)
1041
img = decode_fun(encoded_bytes)
1042
assert img.shape == (3, 100, 100)
1043
assert img[None].is_contiguous(memory_format=torch.channels_last)
1044
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
1047
if __name__ == "__main__":
1048
pytest.main([__file__])