vision

Форк
0
/
test_image.py 
1048 строк · 39.6 Кб
1
import concurrent.futures
2
import glob
3
import io
4
import os
5
import re
6
import sys
7
from pathlib import Path
8

9
import numpy as np
10
import pytest
11
import requests
12
import torch
13
import torchvision.transforms.v2.functional as F
14
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
15
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
16
from torchvision.io.image import (
17
    _decode_avif,
18
    _decode_heic,
19
    decode_gif,
20
    decode_image,
21
    decode_jpeg,
22
    decode_png,
23
    decode_webp,
24
    encode_jpeg,
25
    encode_png,
26
    ImageReadMode,
27
    read_file,
28
    read_image,
29
    write_file,
30
    write_jpeg,
31
    write_png,
32
)
33

34
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
35
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
36
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
37
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
38
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
39
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
40
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
41
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
42
IS_WINDOWS = sys.platform in ("win32", "cygwin")
43
IS_MACOS = sys.platform == "darwin"
44
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
45
WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "")
46

47
# Hacky way of figuring out whether we compiled with libavif/libheif (those are
48
# currenlty disabled by default)
49
try:
50
    _decode_avif(torch.arange(10, dtype=torch.uint8))
51
except Exception as e:
52
    DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e)
53

54
try:
55
    _decode_heic(torch.arange(10, dtype=torch.uint8))
56
except Exception as e:
57
    DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e)
58

59

60
def _get_safe_image_name(name):
61
    # Used when we need to change the pytest "id" for an "image path" parameter.
62
    # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
63
    # and this creates issues when the test is running in a different machine than where it was collected
64
    # (typically, in fb internal infra)
65
    return name.split(os.path.sep)[-1]
66

67

68
def get_images(directory, img_ext):
69
    assert os.path.isdir(directory)
70
    image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
71
    for path in image_paths:
72
        if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
73
            yield path
74

75

76
def pil_read_image(img_path):
77
    with Image.open(img_path) as img:
78
        return torch.from_numpy(np.array(img))
79

80

81
def normalize_dimensions(img_pil):
82
    if len(img_pil.shape) == 3:
83
        img_pil = img_pil.permute(2, 0, 1)
84
    else:
85
        img_pil = img_pil.unsqueeze(0)
86
    return img_pil
87

88

89
@pytest.mark.parametrize(
90
    "img_path",
91
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
92
)
93
@pytest.mark.parametrize(
94
    "pil_mode, mode",
95
    [
96
        (None, ImageReadMode.UNCHANGED),
97
        ("L", ImageReadMode.GRAY),
98
        ("RGB", ImageReadMode.RGB),
99
    ],
100
)
101
@pytest.mark.parametrize("scripted", (False, True))
102
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
103
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):
104

105
    with Image.open(img_path) as img:
106
        is_cmyk = img.mode == "CMYK"
107
        if pil_mode is not None:
108
            img = img.convert(pil_mode)
109
        img_pil = torch.from_numpy(np.array(img))
110
        if is_cmyk and mode == ImageReadMode.UNCHANGED:
111
            # flip the colors to match libjpeg
112
            img_pil = 255 - img_pil
113

114
    img_pil = normalize_dimensions(img_pil)
115
    data = read_file(img_path)
116
    if scripted:
117
        decode_fun = torch.jit.script(decode_fun)
118
    img_ljpeg = decode_fun(data, mode=mode)
119

120
    # Permit a small variation on pixel values to account for implementation
121
    # differences between Pillow and LibJPEG.
122
    abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
123
    assert abs_mean_diff < 2
124

125

126
@pytest.mark.parametrize("codec", ["png", "jpeg"])
127
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
128
def test_decode_with_exif_orientation(tmpdir, codec, orientation):
129
    fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.{codec}")
130
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
131
    im = F.to_pil_image(t)
132
    exif = im.getexif()
133
    exif[0x0112] = orientation  # set exif orientation
134
    im.save(fp, codec.upper(), exif=exif.tobytes())
135

136
    data = read_file(fp)
137
    output = decode_image(data, apply_exif_orientation=True)
138

139
    pimg = Image.open(fp)
140
    pimg = ImageOps.exif_transpose(pimg)
141

142
    expected = F.pil_to_tensor(pimg)
143
    torch.testing.assert_close(expected, output)
144

145

146
@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
147
def test_invalid_exif(tmpdir, size):
148
    # Inspired from a PIL test:
149
    # https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
150
    fp = os.path.join(tmpdir, "invalid_exif.jpg")
151
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
152
    im = F.to_pil_image(t)
153
    im.save(fp, "JPEG", exif=b"1" * size)
154

155
    data = read_file(fp)
156
    output = decode_image(data, apply_exif_orientation=True)
157

158
    pimg = Image.open(fp)
159
    pimg = ImageOps.exif_transpose(pimg)
160

161
    expected = F.pil_to_tensor(pimg)
162
    torch.testing.assert_close(expected, output)
163

164

165
def test_decode_bad_huffman_images():
166
    # sanity check: make sure we can decode the bad Huffman encoding
167
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
168
    decode_jpeg(bad_huff)
169

170

171
@pytest.mark.parametrize(
172
    "img_path",
173
    [
174
        pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
175
        for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
176
    ],
177
)
178
def test_damaged_corrupt_images(img_path):
179
    # Truncated images should raise an exception
180
    data = read_file(img_path)
181
    if "corrupt34" in img_path:
182
        match_message = "Image is incomplete or truncated"
183
    else:
184
        match_message = "Unsupported marker type"
185
    with pytest.raises(RuntimeError, match=match_message):
186
        decode_jpeg(data)
187

188

189
@pytest.mark.parametrize(
190
    "img_path",
191
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
192
)
193
@pytest.mark.parametrize(
194
    "pil_mode, mode",
195
    [
196
        (None, ImageReadMode.UNCHANGED),
197
        ("L", ImageReadMode.GRAY),
198
        ("LA", ImageReadMode.GRAY_ALPHA),
199
        ("RGB", ImageReadMode.RGB),
200
        ("RGBA", ImageReadMode.RGB_ALPHA),
201
    ],
202
)
203
@pytest.mark.parametrize("scripted", (False, True))
204
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
205
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
206

207
    if scripted:
208
        decode_fun = torch.jit.script(decode_fun)
209

210
    with Image.open(img_path) as img:
211
        if pil_mode is not None:
212
            img = img.convert(pil_mode)
213
        img_pil = torch.from_numpy(np.array(img))
214

215
    img_pil = normalize_dimensions(img_pil)
216

217
    if img_path.endswith("16.png"):
218
        data = read_file(img_path)
219
        img_lpng = decode_fun(data, mode=mode)
220
        assert img_lpng.dtype == torch.uint16
221
        # PIL converts 16 bits pngs to uint8
222
        img_lpng = F.to_dtype(img_lpng, torch.uint8, scale=True)
223
    else:
224
        data = read_file(img_path)
225
        img_lpng = decode_fun(data, mode=mode)
226

227
    tol = 0 if pil_mode is None else 1
228

229
    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
230
        # Avoid checking the transparency channel until
231
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
232
        # is fixed.
233
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
234
        img_lpng, img_pil = img_lpng[0], img_pil[0]
235

236
    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
237

238

239
def test_decode_png_errors():
240
    with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
241
        decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
242
    with pytest.raises(RuntimeError, match="Content is too small for png"):
243
        decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
244

245

246
@pytest.mark.parametrize(
247
    "img_path",
248
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
249
)
250
@pytest.mark.parametrize("scripted", (True, False))
251
def test_encode_png(img_path, scripted):
252
    pil_image = Image.open(img_path)
253
    img_pil = torch.from_numpy(np.array(pil_image))
254
    img_pil = img_pil.permute(2, 0, 1)
255
    encode = torch.jit.script(encode_png) if scripted else encode_png
256
    png_buf = encode(img_pil, compression_level=6)
257

258
    rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
259
    rec_img = torch.from_numpy(np.array(rec_img))
260
    rec_img = rec_img.permute(2, 0, 1)
261

262
    assert_equal(img_pil, rec_img)
263

264

265
def test_encode_png_errors():
266
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
267
        encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
268

269
    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
270
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
271

272
    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
273
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
274

275
    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
276
        encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
277

278

279
@pytest.mark.parametrize(
280
    "img_path",
281
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
282
)
283
@pytest.mark.parametrize("scripted", (True, False))
284
def test_write_png(img_path, tmpdir, scripted):
285
    pil_image = Image.open(img_path)
286
    img_pil = torch.from_numpy(np.array(pil_image))
287
    img_pil = img_pil.permute(2, 0, 1)
288

289
    filename, _ = os.path.splitext(os.path.basename(img_path))
290
    torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
291
    write = torch.jit.script(write_png) if scripted else write_png
292
    write(img_pil, torch_png, compression_level=6)
293
    saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
294
    saved_image = saved_image.permute(2, 0, 1)
295

296
    assert_equal(img_pil, saved_image)
297

298

299
def test_read_image():
300
    # Just testing torchcsript, the functionality is somewhat tested already in other tests.
301
    path = next(get_images(IMAGE_ROOT, ".jpg"))
302
    out = read_image(path)
303
    out_scripted = torch.jit.script(read_image)(path)
304
    torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)
305

306

307
@pytest.mark.parametrize("scripted", (True, False))
308
def test_read_file(tmpdir, scripted):
309
    fname, content = "test1.bin", b"TorchVision\211\n"
310
    fpath = os.path.join(tmpdir, fname)
311
    with open(fpath, "wb") as f:
312
        f.write(content)
313

314
    fun = torch.jit.script(read_file) if scripted else read_file
315
    data = fun(fpath)
316
    expected = torch.tensor(list(content), dtype=torch.uint8)
317
    os.unlink(fpath)
318
    assert_equal(data, expected)
319

320
    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
321
        read_file("tst")
322

323

324
def test_read_file_non_ascii(tmpdir):
325
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
326
    fpath = os.path.join(tmpdir, fname)
327
    with open(fpath, "wb") as f:
328
        f.write(content)
329

330
    data = read_file(fpath)
331
    expected = torch.tensor(list(content), dtype=torch.uint8)
332
    os.unlink(fpath)
333
    assert_equal(data, expected)
334

335

336
@pytest.mark.parametrize("scripted", (True, False))
337
def test_write_file(tmpdir, scripted):
338
    fname, content = "test1.bin", b"TorchVision\211\n"
339
    fpath = os.path.join(tmpdir, fname)
340
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
341
    write = torch.jit.script(write_file) if scripted else write_file
342
    write(fpath, content_tensor)
343

344
    with open(fpath, "rb") as f:
345
        saved_content = f.read()
346
    os.unlink(fpath)
347
    assert content == saved_content
348

349

350
def test_write_file_non_ascii(tmpdir):
351
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
352
    fpath = os.path.join(tmpdir, fname)
353
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
354
    write_file(fpath, content_tensor)
355

356
    with open(fpath, "rb") as f:
357
        saved_content = f.read()
358
    os.unlink(fpath)
359
    assert content == saved_content
360

361

362
@pytest.mark.parametrize(
363
    "shape",
364
    [
365
        (27, 27),
366
        (60, 60),
367
        (105, 105),
368
    ],
369
)
370
def test_read_1_bit_png(shape, tmpdir):
371
    np_rng = np.random.RandomState(0)
372
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
373
    pixels = np_rng.rand(*shape) > 0.5
374
    img = Image.fromarray(pixels)
375
    img.save(image_path)
376
    img1 = read_image(image_path)
377
    img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
378
    assert_equal(img1, img2)
379

380

381
@pytest.mark.parametrize(
382
    "shape",
383
    [
384
        (27, 27),
385
        (60, 60),
386
        (105, 105),
387
    ],
388
)
389
@pytest.mark.parametrize(
390
    "mode",
391
    [
392
        ImageReadMode.UNCHANGED,
393
        ImageReadMode.GRAY,
394
    ],
395
)
396
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
397
    np_rng = np.random.RandomState(0)
398
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
399
    pixels = np_rng.rand(*shape) > 0.5
400
    img = Image.fromarray(pixels)
401
    img.save(image_path)
402
    img1 = read_image(image_path, mode)
403
    img2 = read_image(image_path, mode)
404
    assert_equal(img1, img2)
405

406

407
def test_read_interlaced_png():
408
    imgs = list(get_images(INTERLACED_PNG, ".png"))
409
    with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
410
        assert not (im1.info.get("interlace") is im2.info.get("interlace"))
411
    img1 = read_image(imgs[0])
412
    img2 = read_image(imgs[1])
413
    assert_equal(img1, img2)
414

415

416
@needs_cuda
417
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
418
@pytest.mark.parametrize("scripted", (False, True))
419
def test_decode_jpegs_cuda(mode, scripted):
420
    encoded_images = []
421
    for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
422
        if "cmyk" in jpeg_path:
423
            continue
424
        encoded_image = read_file(jpeg_path)
425
        encoded_images.append(encoded_image)
426
    decoded_images_cpu = decode_jpeg(encoded_images, mode=mode)
427
    decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
428

429
    # test multithreaded decoding
430
    # in the current version we prevent this by using a lock but we still want to test it
431
    num_workers = 10
432

433
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
434
        futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)]
435
    decoded_images_threaded = [future.result() for future in futures]
436
    assert len(decoded_images_threaded) == num_workers
437
    for decoded_images in decoded_images_threaded:
438
        assert len(decoded_images) == len(encoded_images)
439
        for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu):
440
            assert decoded_image_cuda.shape == decoded_image_cpu.shape
441
            assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8
442
            assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2
443

444

445
@needs_cuda
446
def test_decode_image_cuda_raises():
447
    data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
448
    with pytest.raises(RuntimeError):
449
        decode_image(data)
450

451

452
@needs_cuda
453
def test_decode_jpeg_cuda_device_param():
454
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
455
    data = read_file(path)
456
    current_device = torch.cuda.current_device()
457
    current_stream = torch.cuda.current_stream()
458
    num_devices = torch.cuda.device_count()
459
    devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
460
    results = []
461
    for device in devices:
462
        results.append(decode_jpeg(data, device=device))
463
    assert len(results) == len(devices)
464
    for result in results:
465
        assert torch.all(result.cpu() == results[0].cpu())
466
    assert current_device == torch.cuda.current_device()
467
    assert current_stream == torch.cuda.current_stream()
468

469

470
@needs_cuda
471
def test_decode_jpeg_cuda_errors():
472
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
473
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
474
        decode_jpeg(data.reshape(-1, 1), device="cuda")
475
    with pytest.raises(ValueError, match="must be tensors"):
476
        decode_jpeg([1, 2, 3])
477
    with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
478
        decode_jpeg(data.to("cuda"), device="cuda")
479
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
480
        decode_jpeg(data.to(torch.float), device="cuda")
481
    with pytest.raises(RuntimeError, match="Expected the device parameter to be a cuda device"):
482
        torch.ops.image.decode_jpegs_cuda([data], ImageReadMode.UNCHANGED.value, "cpu")
483
    with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
484
        decode_jpeg(
485
            torch.empty((100,), dtype=torch.uint8, device="cuda"),
486
        )
487
    with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
488
        decode_jpeg(
489
            [
490
                torch.empty((100,), dtype=torch.uint8, device="cuda"),
491
                torch.empty((100,), dtype=torch.uint8, device="cuda"),
492
            ]
493
        )
494

495
    with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
496
        decode_jpeg(
497
            [
498
                torch.empty((100,), dtype=torch.uint8, device="cuda"),
499
                torch.empty((100,), dtype=torch.uint8, device="cuda"),
500
            ],
501
            device="cuda",
502
        )
503

504
    with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
505
        decode_jpeg(
506
            [
507
                torch.empty((100,), dtype=torch.uint8, device="cpu"),
508
                torch.empty((100,), dtype=torch.uint8, device="cuda"),
509
            ],
510
            device="cuda",
511
        )
512

513
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
514
        decode_jpeg(
515
            [
516
                torch.empty((100,), dtype=torch.uint8),
517
                torch.empty((100,), dtype=torch.float32),
518
            ],
519
            device="cuda",
520
        )
521

522
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
523
        decode_jpeg(
524
            [
525
                torch.empty((100,), dtype=torch.uint8),
526
                torch.empty((1, 100), dtype=torch.uint8),
527
            ],
528
            device="cuda",
529
        )
530

531
    with pytest.raises(RuntimeError, match="Error while decoding JPEG images"):
532
        decode_jpeg(
533
            [
534
                torch.empty((100,), dtype=torch.uint8),
535
                torch.empty((100,), dtype=torch.uint8),
536
            ],
537
            device="cuda",
538
        )
539

540
    with pytest.raises(ValueError, match="Input list must contain at least one element"):
541
        decode_jpeg([], device="cuda")
542

543

544
def test_encode_jpeg_errors():
545

546
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
547
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
548

549
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
550
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
551

552
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
553
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
554

555
    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
556
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
557

558
    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
559
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
560

561
    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
562
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
563

564

565
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
566
@pytest.mark.parametrize(
567
    "img_path",
568
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
569
)
570
@pytest.mark.parametrize("scripted", (True, False))
571
def test_encode_jpeg(img_path, scripted):
572
    img = read_image(img_path)
573

574
    pil_img = F.to_pil_image(img)
575
    buf = io.BytesIO()
576
    pil_img.save(buf, format="JPEG", quality=75)
577

578
    encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
579

580
    encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
581
    for src_img in [img, img.contiguous()]:
582
        encoded_jpeg_torch = encode(src_img, quality=75)
583
        assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
584

585

586
@needs_cuda
587
def test_encode_jpeg_cuda_device_param():
588
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
589

590
    data = read_image(path)
591

592
    current_device = torch.cuda.current_device()
593
    current_stream = torch.cuda.current_stream()
594
    num_devices = torch.cuda.device_count()
595
    devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
596
    results = []
597
    for device in devices:
598
        results.append(encode_jpeg(data.to(device=device)))
599
    assert len(results) == len(devices)
600
    for result in results:
601
        assert torch.all(result.cpu() == results[0].cpu())
602
    assert current_device == torch.cuda.current_device()
603
    assert current_stream == torch.cuda.current_stream()
604

605

606
@needs_cuda
607
@pytest.mark.parametrize(
608
    "img_path",
609
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
610
)
611
@pytest.mark.parametrize("scripted", (False, True))
612
@pytest.mark.parametrize("contiguous", (False, True))
613
def test_encode_jpeg_cuda(img_path, scripted, contiguous):
614
    decoded_image_tv = read_image(img_path)
615
    encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
616

617
    if "cmyk" in img_path:
618
        pytest.xfail("Encoding a CMYK jpeg isn't supported")
619
    if decoded_image_tv.shape[0] == 1:
620
        pytest.xfail("Decoding a grayscale jpeg isn't supported")
621
        # For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
622
    if contiguous:
623
        decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0]
624
    else:
625
        decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0]
626
    encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75)
627
    decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())
628

629
    # the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality
630
    # instead, we re-decode the encoded image and compare to the original
631
    abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
632
    assert abs_mean_diff < 3
633

634

635
@pytest.mark.parametrize("device", cpu_and_cuda())
636
@pytest.mark.parametrize("scripted", (True, False))
637
@pytest.mark.parametrize("contiguous", (True, False))
638
def test_encode_jpegs_batch(scripted, contiguous, device):
639
    if device == "cpu" and IS_MACOS:
640
        pytest.skip("https://github.com/pytorch/vision/issues/8031")
641
    decoded_images_tv = []
642
    for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
643
        if "cmyk" in jpeg_path:
644
            continue
645
        decoded_image = read_image(jpeg_path)
646
        if decoded_image.shape[0] == 1:
647
            continue
648
        if contiguous:
649
            decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0]
650
        else:
651
            decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0]
652
        decoded_images_tv.append(decoded_image)
653

654
    encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
655

656
    decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv]
657
    encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75)
658
    encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device]
659

660
    for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device):
661
        c, h, w = original.shape
662
        abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item()
663
        assert abs_mean_diff < 3
664

665
    # test multithreaded decoding
666
    # in the current version we prevent this by using a lock but we still want to test it
667
    num_workers = 10
668
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
669
        futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)]
670
    encoded_images_threaded = [future.result() for future in futures]
671
    assert len(encoded_images_threaded) == num_workers
672
    for encoded_images in encoded_images_threaded:
673
        assert len(decoded_images_tv_device) == len(encoded_images)
674
        for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)):
675
            # make sure all the threads produce identical outputs
676
            assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i])
677

678
            # make sure the outputs are identical or close enough to baseline
679
            decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu())
680
            assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape
681
            assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype
682
            assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3
683

684

685
@needs_cuda
686
def test_single_encode_jpeg_cuda_errors():
687
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
688
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"))
689

690
    with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
691
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"))
692

693
    with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
694
        encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"))
695

696
    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
697
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"))
698

699
    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
700
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda"))
701

702

703
@needs_cuda
704
def test_batch_encode_jpegs_cuda_errors():
705
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
706
        encode_jpeg(
707
            [
708
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
709
                torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"),
710
            ]
711
        )
712

713
    with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
714
        encode_jpeg(
715
            [
716
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
717
                torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"),
718
            ]
719
        )
720

721
    with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
722
        encode_jpeg(
723
            [
724
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
725
                torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"),
726
            ]
727
        )
728

729
    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
730
        encode_jpeg(
731
            [
732
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
733
                torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"),
734
            ]
735
        )
736

737
    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
738
        encode_jpeg(
739
            [
740
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
741
                torch.empty((100, 100), dtype=torch.uint8, device="cuda"),
742
            ]
743
        )
744

745
    with pytest.raises(RuntimeError, match="Input tensor should be on CPU"):
746
        encode_jpeg(
747
            [
748
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
749
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
750
            ]
751
        )
752

753
    with pytest.raises(
754
        RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
755
    ):
756
        encode_jpeg(
757
            [
758
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
759
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
760
            ]
761
        )
762

763
    if torch.cuda.device_count() >= 2:
764
        with pytest.raises(
765
            RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
766
        ):
767
            encode_jpeg(
768
                [
769
                    torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"),
770
                    torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"),
771
                ]
772
            )
773

774
    with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"):
775
        encode_jpeg([])
776

777

778
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
779
@pytest.mark.parametrize(
780
    "img_path",
781
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
782
)
783
@pytest.mark.parametrize("scripted", (True, False))
784
def test_write_jpeg(img_path, tmpdir, scripted):
785
    tmpdir = Path(tmpdir)
786
    img = read_image(img_path)
787
    pil_img = F.to_pil_image(img)
788

789
    torch_jpeg = str(tmpdir / "torch.jpg")
790
    pil_jpeg = str(tmpdir / "pil.jpg")
791

792
    write = torch.jit.script(write_jpeg) if scripted else write_jpeg
793
    write(img, torch_jpeg, quality=75)
794
    pil_img.save(pil_jpeg, quality=75)
795

796
    with open(torch_jpeg, "rb") as f:
797
        torch_bytes = f.read()
798

799
    with open(pil_jpeg, "rb") as f:
800
        pil_bytes = f.read()
801

802
    assert_equal(torch_bytes, pil_bytes)
803

804

805
def test_pathlib_support(tmpdir):
806
    # Just make sure pathlib.Path is supported where relevant
807

808
    jpeg_path = Path(next(get_images(ENCODE_JPEG, ".jpg")))
809

810
    read_file(jpeg_path)
811
    read_image(jpeg_path)
812

813
    write_path = Path(tmpdir) / "whatever"
814
    img = torch.randint(0, 10, size=(3, 4, 4), dtype=torch.uint8)
815

816
    write_file(write_path, data=img.flatten())
817
    write_jpeg(img, write_path)
818
    write_png(img, write_path)
819

820

821
@pytest.mark.parametrize(
822
    "name", ("gifgrid", "fire", "porsche", "treescap", "treescap-interlaced", "solid2", "x-trans", "earth")
823
)
824
@pytest.mark.parametrize("scripted", (True, False))
825
def test_decode_gif(tmpdir, name, scripted):
826
    # Using test images from GIFLIB
827
    # https://sourceforge.net/p/giflib/code/ci/master/tree/pic/, we assert PIL
828
    # and torchvision decoded outputs are equal.
829
    # We're not testing against "welcome2" because PIL and GIFLIB disagee on what
830
    # the background color should be (likely a difference in the way they handle
831
    # transparency?)
832
    # 'earth' image is from wikipedia, licensed under CC BY-SA 3.0
833
    # https://creativecommons.org/licenses/by-sa/3.0/
834
    # it allows to properly test for transparency, TOP-LEFT offsets, and
835
    # disposal modes.
836

837
    path = tmpdir / f"{name}.gif"
838
    if name == "earth":
839
        if IN_OSS_CI:
840
            # TODO: Fix this... one day.
841
            pytest.skip("Skipping 'earth' test as it's flaky on OSS CI")
842
        url = "https://upload.wikimedia.org/wikipedia/commons/2/2c/Rotating_earth_%28large%29.gif"
843
    else:
844
        url = f"https://sourceforge.net/p/giflib/code/ci/master/tree/pic/{name}.gif?format=raw"
845
    with open(path, "wb") as f:
846
        f.write(requests.get(url).content)
847

848
    encoded_bytes = read_file(path)
849
    f = torch.jit.script(decode_gif) if scripted else decode_gif
850
    tv_out = f(encoded_bytes)
851
    if tv_out.ndim == 3:
852
        tv_out = tv_out[None]
853

854
    assert tv_out.is_contiguous(memory_format=torch.channels_last)
855

856
    # For some reason, not using Image.open() as a CM causes "ResourceWarning: unclosed file"
857
    with Image.open(path) as pil_img:
858
        pil_seq = ImageSequence.Iterator(pil_img)
859

860
        for pil_frame, tv_frame in zip(pil_seq, tv_out):
861
            pil_frame = F.pil_to_tensor(pil_frame.convert("RGB"))
862
            torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)
863

864

865
decode_fun_and_match = [
866
    (decode_png, "Content is not png"),
867
    (decode_jpeg, "Not a JPEG file"),
868
    (decode_gif, re.escape("DGifOpenFileName() failed - 103")),
869
    (decode_webp, "WebPGetFeatures failed."),
870
]
871
if DECODE_AVIF_ENABLED:
872
    decode_fun_and_match.append((_decode_avif, "BMFF parsing failed"))
873
if DECODE_HEIC_ENABLED:
874
    decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box"))
875

876

877
@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match)
878
def test_decode_bad_encoded_data(decode_fun, match):
879
    encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
880
    with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
881
        decode_fun(encoded_data[None])
882
    with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
883
        decode_fun(encoded_data.float())
884
    with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
885
        decode_fun(encoded_data[::2])
886
    with pytest.raises(RuntimeError, match=match):
887
        decode_fun(encoded_data)
888

889

890
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
891
@pytest.mark.parametrize("scripted", (False, True))
892
def test_decode_webp(decode_fun, scripted):
893
    encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp")))
894
    if scripted:
895
        decode_fun = torch.jit.script(decode_fun)
896
    img = decode_fun(encoded_bytes)
897
    assert img.shape == (3, 100, 100)
898
    assert img[None].is_contiguous(memory_format=torch.channels_last)
899
    img += 123  # make sure image buffer wasn't freed by underlying decoding lib
900

901

902
# This test is skipped by default because it requires webp images that we're not
903
# including within the repo. The test images were downloaded manually from the
904
# different pages of https://developers.google.com/speed/webp/gallery
905
@pytest.mark.skipif(not WEBP_TEST_IMAGES_DIR, reason="WEBP_TEST_IMAGES_DIR is not set")
906
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
907
@pytest.mark.parametrize("scripted", (False, True))
908
@pytest.mark.parametrize(
909
    "mode, pil_mode",
910
    (
911
        # Note that converting an RGBA image to RGB leads to bad results because the
912
        # transparent pixels aren't necessarily set to "black" or "white", they can be
913
        # random stuff. This is consistent with PIL results.
914
        (ImageReadMode.RGB, "RGB"),
915
        (ImageReadMode.RGB_ALPHA, "RGBA"),
916
        (ImageReadMode.UNCHANGED, None),
917
    ),
918
)
919
@pytest.mark.parametrize("filename", Path(WEBP_TEST_IMAGES_DIR).glob("*.webp"), ids=lambda p: p.name)
920
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
921
    encoded_bytes = read_file(filename)
922
    if scripted:
923
        decode_fun = torch.jit.script(decode_fun)
924
    img = decode_fun(encoded_bytes, mode=mode)
925
    assert img[None].is_contiguous(memory_format=torch.channels_last)
926

927
    pil_img = Image.open(filename).convert(pil_mode)
928
    from_pil = F.pil_to_tensor(pil_img)
929
    assert_equal(img, from_pil)
930
    img += 123  # make sure image buffer wasn't freed by underlying decoding lib
931

932

933
@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.")
934
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
935
@pytest.mark.parametrize("scripted", (False, True))
936
def test_decode_avif(decode_fun, scripted):
937
    encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif")))
938
    if scripted:
939
        decode_fun = torch.jit.script(decode_fun)
940
    img = decode_fun(encoded_bytes)
941
    assert img.shape == (3, 100, 100)
942
    assert img[None].is_contiguous(memory_format=torch.channels_last)
943
    img += 123  # make sure image buffer wasn't freed by underlying decoding lib
944

945

946
# Note: decode_image fails because some of these files have a (valid) signature
947
# we don't recognize. We should probably use libmagic....
948
decode_funs = []
949
if DECODE_AVIF_ENABLED:
950
    decode_funs.append(_decode_avif)
951
if DECODE_HEIC_ENABLED:
952
    decode_funs.append(_decode_heic)
953

954

955
@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.")
956
@pytest.mark.parametrize("decode_fun", decode_funs)
957
@pytest.mark.parametrize("scripted", (False, True))
958
@pytest.mark.parametrize(
959
    "mode, pil_mode",
960
    (
961
        (ImageReadMode.RGB, "RGB"),
962
        (ImageReadMode.RGB_ALPHA, "RGBA"),
963
        (ImageReadMode.UNCHANGED, None),
964
    ),
965
)
966
@pytest.mark.parametrize(
967
    "filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
968
)
969
def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename):
970
    if "reversed_dimg_order" in str(filename):
971
        # Pillow properly decodes this one, but we don't (order of parts of the
972
        # image is wrong). This is due to a bug that was recently fixed in
973
        # libavif. Hopefully this test will end up passing soon with a new
974
        # libavif version https://github.com/AOMediaCodec/libavif/issues/2311
975
        pytest.xfail()
976
    import pillow_avif  # noqa
977

978
    encoded_bytes = read_file(filename)
979
    if scripted:
980
        decode_fun = torch.jit.script(decode_fun)
981
    try:
982
        img = decode_fun(encoded_bytes, mode=mode)
983
    except RuntimeError as e:
984
        if any(
985
            s in str(e)
986
            for s in (
987
                "BMFF parsing failed",
988
                "avifDecoderParse failed: ",
989
                "file contains more than one image",
990
                "no 'ispe' property",
991
                "'iref' has double references",
992
                "Invalid image grid",
993
            )
994
        ):
995
            pytest.skip(reason="Expected failure, that's OK")
996
        else:
997
            raise e
998
    assert img[None].is_contiguous(memory_format=torch.channels_last)
999
    if mode == ImageReadMode.RGB:
1000
        assert img.shape[0] == 3
1001
    if mode == ImageReadMode.RGB_ALPHA:
1002
        assert img.shape[0] == 4
1003

1004
    if img.dtype == torch.uint16:
1005
        img = F.to_dtype(img, dtype=torch.uint8, scale=True)
1006
    try:
1007
        from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
1008
    except RuntimeError as e:
1009
        if "Invalid image grid" in str(e):
1010
            pytest.skip(reason="PIL failure")
1011
        else:
1012
            raise e
1013

1014
    if True:
1015
        from torchvision.utils import make_grid
1016

1017
        g = make_grid([img, from_pil])
1018
        F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))
1019

1020
    is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
1021
    if mode == ImageReadMode.RGB and not is_decode_heic:
1022
        # We don't compare torchvision's AVIF against PIL for RGB because
1023
        # results look pretty different on RGBA images (other images are fine).
1024
        # The result on torchvision basically just plainly ignores the alpha
1025
        # channel, resuting in transparent pixels looking dark. PIL seems to be
1026
        # using a sort of k-nn thing (Take a look at the resuting images)
1027
        return
1028
    if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic:
1029
        return
1030

1031
    torch.testing.assert_close(img, from_pil, rtol=0, atol=3)
1032

1033

1034
@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.")
1035
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
1036
@pytest.mark.parametrize("scripted", (False, True))
1037
def test_decode_heic(decode_fun, scripted):
1038
    encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic")))
1039
    if scripted:
1040
        decode_fun = torch.jit.script(decode_fun)
1041
    img = decode_fun(encoded_bytes)
1042
    assert img.shape == (3, 100, 100)
1043
    assert img[None].is_contiguous(memory_format=torch.channels_last)
1044
    img += 123  # make sure image buffer wasn't freed by underlying decoding lib
1045

1046

1047
if __name__ == "__main__":
1048
    pytest.main([__file__])
1049

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.