7
from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader
14
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
17
@pytest.mark.skipif(_HAS_GPU_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder")
18
class TestVideoGPUDecoder:
19
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
20
@pytest.mark.parametrize(
23
"RATRACE_wave_f_nm_np1_fr_goo_37.avi",
24
"TrumanShow_wave_f_nm_np1_fr_med_26.avi",
25
"v_SoccerJuggling_g23_c01.avi",
26
"v_SoccerJuggling_g24_c01.avi",
32
def test_frame_reading(self, video_file):
33
torchvision.set_video_backend("cuda")
34
full_path = os.path.join(VIDEO_DIR, video_file)
35
decoder = VideoReader(full_path)
36
with av.open(full_path) as container:
37
for av_frame in container.decode(container.streams.video[0]):
38
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
39
vision_frames = next(decoder)["data"]
40
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
41
assert mean_delta < 0.75
43
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
44
@pytest.mark.parametrize("keyframes", [True, False])
45
@pytest.mark.parametrize(
46
"full_path, duration",
48
(os.path.join(VIDEO_DIR, x), y)
50
("v_SoccerJuggling_g23_c01.avi", 8.0),
51
("v_SoccerJuggling_g24_c01.avi", 8.0),
52
("R6llTwEh07w.mp4", 10.0),
53
("SOX5yA1l24A.mp4", 11.0),
54
("WUzgd7C1pWA.mp4", 11.0),
58
def test_seek_reading(self, keyframes, full_path, duration):
59
torchvision.set_video_backend("cuda")
60
decoder = VideoReader(full_path)
62
decoder.seek(time, keyframes_only=keyframes)
63
with av.open(full_path) as container:
64
container.seek(int(time * 1000000), any_frame=not keyframes, backward=False)
65
for av_frame in container.decode(container.streams.video[0]):
66
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
67
vision_frames = next(decoder)["data"]
68
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
69
assert mean_delta < 0.75
71
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
72
@pytest.mark.parametrize(
75
"RATRACE_wave_f_nm_np1_fr_goo_37.avi",
76
"TrumanShow_wave_f_nm_np1_fr_med_26.avi",
77
"v_SoccerJuggling_g23_c01.avi",
78
"v_SoccerJuggling_g24_c01.avi",
84
def test_metadata(self, video_file):
85
torchvision.set_video_backend("cuda")
86
full_path = os.path.join(VIDEO_DIR, video_file)
87
decoder = VideoReader(full_path)
88
video_metadata = decoder.get_metadata()["video"]
89
with av.open(full_path) as container:
90
video = container.streams.video[0]
91
av_duration = float(video.duration * video.time_base)
92
assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2)
93
assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2)
96
if __name__ == "__main__":
97
pytest.main([__file__])