vision

Форк
0
/
test_video_gpu_decoder.py 
97 строк · 3.7 Кб
1
import math
2
import os
3

4
import pytest
5
import torch
6
import torchvision
7
from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader
8

9
try:
10
    import av
11
except ImportError:
12
    av = None
13

14
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
15

16

17
@pytest.mark.skipif(_HAS_GPU_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder")
18
class TestVideoGPUDecoder:
19
    @pytest.mark.skipif(av is None, reason="PyAV unavailable")
20
    @pytest.mark.parametrize(
21
        "video_file",
22
        [
23
            "RATRACE_wave_f_nm_np1_fr_goo_37.avi",
24
            "TrumanShow_wave_f_nm_np1_fr_med_26.avi",
25
            "v_SoccerJuggling_g23_c01.avi",
26
            "v_SoccerJuggling_g24_c01.avi",
27
            "R6llTwEh07w.mp4",
28
            "SOX5yA1l24A.mp4",
29
            "WUzgd7C1pWA.mp4",
30
        ],
31
    )
32
    def test_frame_reading(self, video_file):
33
        torchvision.set_video_backend("cuda")
34
        full_path = os.path.join(VIDEO_DIR, video_file)
35
        decoder = VideoReader(full_path)
36
        with av.open(full_path) as container:
37
            for av_frame in container.decode(container.streams.video[0]):
38
                av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
39
                vision_frames = next(decoder)["data"]
40
                mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
41
                assert mean_delta < 0.75
42

43
    @pytest.mark.skipif(av is None, reason="PyAV unavailable")
44
    @pytest.mark.parametrize("keyframes", [True, False])
45
    @pytest.mark.parametrize(
46
        "full_path, duration",
47
        [
48
            (os.path.join(VIDEO_DIR, x), y)
49
            for x, y in [
50
                ("v_SoccerJuggling_g23_c01.avi", 8.0),
51
                ("v_SoccerJuggling_g24_c01.avi", 8.0),
52
                ("R6llTwEh07w.mp4", 10.0),
53
                ("SOX5yA1l24A.mp4", 11.0),
54
                ("WUzgd7C1pWA.mp4", 11.0),
55
            ]
56
        ],
57
    )
58
    def test_seek_reading(self, keyframes, full_path, duration):
59
        torchvision.set_video_backend("cuda")
60
        decoder = VideoReader(full_path)
61
        time = duration / 2
62
        decoder.seek(time, keyframes_only=keyframes)
63
        with av.open(full_path) as container:
64
            container.seek(int(time * 1000000), any_frame=not keyframes, backward=False)
65
            for av_frame in container.decode(container.streams.video[0]):
66
                av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
67
                vision_frames = next(decoder)["data"]
68
                mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
69
                assert mean_delta < 0.75
70

71
    @pytest.mark.skipif(av is None, reason="PyAV unavailable")
72
    @pytest.mark.parametrize(
73
        "video_file",
74
        [
75
            "RATRACE_wave_f_nm_np1_fr_goo_37.avi",
76
            "TrumanShow_wave_f_nm_np1_fr_med_26.avi",
77
            "v_SoccerJuggling_g23_c01.avi",
78
            "v_SoccerJuggling_g24_c01.avi",
79
            "R6llTwEh07w.mp4",
80
            "SOX5yA1l24A.mp4",
81
            "WUzgd7C1pWA.mp4",
82
        ],
83
    )
84
    def test_metadata(self, video_file):
85
        torchvision.set_video_backend("cuda")
86
        full_path = os.path.join(VIDEO_DIR, video_file)
87
        decoder = VideoReader(full_path)
88
        video_metadata = decoder.get_metadata()["video"]
89
        with av.open(full_path) as container:
90
            video = container.streams.video[0]
91
            av_duration = float(video.duration * video.time_base)
92
            assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2)
93
            assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2)
94

95

96
if __name__ == "__main__":
97
    pytest.main([__file__])
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.