4
from fractions import Fraction
9
import torchvision.io as io
10
from common_utils import assert_equal
11
from numpy.random import randint
12
from pytest import approx
13
from torchvision import set_video_backend
14
from torchvision.io import _HAS_CPU_VIDEO_DECODER
21
io.video._check_av_available()
26
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
37
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
39
all_check_config = GroundTruth(
44
check_aframe_pts=True,
48
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
51
audio_sample_rate=None,
53
check_aframe_pts=True,
55
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
58
audio_sample_rate=None,
60
check_aframe_pts=True,
62
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
65
audio_sample_rate=None,
67
check_aframe_pts=True,
69
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
72
audio_sample_rate=None,
74
check_aframe_pts=True,
76
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
79
audio_sample_rate=None,
81
check_aframe_pts=True,
83
"R6llTwEh07w.mp4": GroundTruth(
86
audio_sample_rate=44100,
89
check_aframe_pts=False,
91
"SOX5yA1l24A.mp4": GroundTruth(
94
audio_sample_rate=48000,
97
check_aframe_pts=False,
99
"WUzgd7C1pWA.mp4": GroundTruth(
102
audio_sample_rate=48000,
105
check_aframe_pts=False,
110
DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase")
114
SEEK_FRAME_MARGIN = 0.25
117
def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4):
120
container: pyav container
121
start_pts/end_pts: the starting/ending Presentation TimeStamp where
124
stream_name: a dictionary of streams. For example, {"video": 0} means
125
video stream at stream index 0
126
buffer_size: pts of frames decoded by PyAv is not guaranteed to be in
127
ascending order. We need to decode more frames even when we meet end
132
seek_offset = max(start_pts - margin, 0)
134
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
137
for frame in container.decode(**stream_name):
138
if frame.pts < start_pts:
140
if frame.pts <= end_pts:
141
frames[frame.pts] = frame
144
if buffer_count >= buffer_size:
146
result = [frames[pts] for pts in sorted(frames)]
151
def _get_timebase_by_av_module(full_path):
152
container = av.open(full_path)
153
video_time_base = container.streams.video[0].time_base
154
if container.streams.audio:
155
audio_time_base = container.streams.audio[0].time_base
157
audio_time_base = None
158
return video_time_base, audio_time_base
161
def _fraction_to_tensor(fraction):
162
ret = torch.zeros([2], dtype=torch.int32)
163
ret[0] = fraction.numerator
164
ret[1] = fraction.denominator
168
def _decode_frames_by_av_module(
176
Use PyAv to decode video frames. This provides a reference for our decoder
177
to compare the decoding results.
179
full_path: video file path
180
video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where
183
if video_end_pts is None:
184
video_end_pts = float("inf")
185
if audio_end_pts is None:
186
audio_end_pts = float("inf")
187
container = av.open(full_path)
190
vtimebase = torch.zeros([0], dtype=torch.int32)
191
if container.streams.video:
192
video_frames = _read_from_stream(
196
container.streams.video[0],
202
vtimebase = _fraction_to_tensor(container.streams.video[0].time_base)
205
atimebase = torch.zeros([0], dtype=torch.int32)
206
if container.streams.audio:
207
audio_frames = _read_from_stream(
211
container.streams.audio[0],
214
atimebase = _fraction_to_tensor(container.streams.audio[0].time_base)
217
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
218
vframes = torch.as_tensor(np.stack(vframes))
220
vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64)
222
aframes = [frame.to_ndarray() for frame in audio_frames]
224
aframes = np.transpose(np.concatenate(aframes, axis=1))
225
aframes = torch.as_tensor(aframes)
227
aframes = torch.empty((1, 0), dtype=torch.float32)
229
aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64)
231
return DecoderResult(
233
vframe_pts=vframe_pts,
236
aframe_pts=aframe_pts,
241
def _pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
242
"""convert pts between different time bases
244
pts: presentation timestamp, float
245
timebase_from: original timebase. Fraction
246
timebase_to: new timebase. Fraction
247
round_func: rounding function.
249
new_pts = Fraction(pts, 1) * timebase_from / timebase_to
250
return int(round_func(new_pts))
253
def _get_video_tensor(video_dir, video_file):
254
"""open a video file, and represent the video data by a PT tensor"""
255
full_path = os.path.join(video_dir, video_file)
257
assert os.path.exists(full_path), "File not found: %s" % full_path
259
with open(full_path, "rb") as fp:
260
video_tensor = torch.frombuffer(fp.read(), dtype=torch.uint8)
262
return full_path, video_tensor
265
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
266
@pytest.mark.skipif(_HAS_CPU_VIDEO_DECODER is False, reason="Didn't compile with ffmpeg")
267
class TestVideoReader:
268
def check_separate_decoding_result(self, tv_result, config):
269
"""check the decoding results from TorchVision decoder"""
283
video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
284
assert video_duration == approx(config.duration, abs=0.5)
286
assert vfps.item() == approx(config.video_fps, abs=0.5)
288
if asample_rate.numel() > 0:
289
assert asample_rate.item() == config.audio_sample_rate
290
audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
291
assert audio_duration == approx(config.duration, abs=0.5)
294
for i in range(len(vframe_pts) - 1):
295
assert vframe_pts[i] < vframe_pts[i + 1]
297
if len(aframe_pts) > 1:
299
for i in range(len(aframe_pts) - 1):
300
assert aframe_pts[i] < aframe_pts[i + 1]
302
def check_probe_result(self, result, config):
303
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
304
video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
305
assert video_duration == approx(config.duration, abs=0.5)
306
assert vfps.item() == approx(config.video_fps, abs=0.5)
307
if asample_rate.numel() > 0:
308
assert asample_rate.item() == config.audio_sample_rate
309
audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
310
assert audio_duration == approx(config.duration, abs=0.5)
312
def check_meta_result(self, result, config):
313
assert result.video_duration == approx(config.duration, abs=0.5)
314
assert result.video_fps == approx(config.video_fps, abs=0.5)
315
if result.has_audio > 0:
316
assert result.audio_sample_rate == config.audio_sample_rate
317
assert result.audio_duration == approx(config.duration, abs=0.5)
319
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
321
Compare decoding results from two sources.
323
tv_result: decoding results from TorchVision decoder
324
ref_result: reference decoding results which can be from either PyAv
325
decoder or TorchVision decoder with getPtsOnly = 1
326
config: config of decoding results checker
340
if isinstance(ref_result, list):
342
ref_result = DecoderResult(
343
vframes=ref_result[0],
344
vframe_pts=ref_result[1],
345
vtimebase=ref_result[2],
346
aframes=ref_result[5],
347
aframe_pts=ref_result[6],
348
atimebase=ref_result[7],
351
if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
352
mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float()))
353
assert mean_delta == approx(0.0, abs=8.0)
355
mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()))
356
assert mean_delta == approx(0.0, abs=1.0)
358
assert_equal(vtimebase, ref_result.vtimebase)
360
if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0:
361
"""Audio stream is available and audio frame is required to return
363
assert_equal(aframes, ref_result.aframes)
365
if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0:
366
"""Audio stream is available"""
367
assert_equal(aframe_pts, ref_result.aframe_pts)
369
assert_equal(atimebase, ref_result.atimebase)
371
@pytest.mark.parametrize("test_video", test_videos.keys())
372
def test_stress_test_read_video_from_file(self, test_video):
374
"This stress test will iteratively decode the same set of videos."
375
"It helps to detect memory leak but it takes lots of time to run."
376
"By default, it is disabled"
380
width, height, min_dimension, max_dimension = 0, 0, 0, 0
381
video_start_pts, video_end_pts = 0, -1
382
video_timebase_num, video_timebase_den = 0, 1
384
samples, channels = 0, 0
385
audio_start_pts, audio_end_pts = 0, -1
386
audio_timebase_num, audio_timebase_den = 0, 1
388
for _i in range(num_iter):
389
full_path = os.path.join(VIDEO_DIR, test_video)
392
torch.ops.video_reader.read_video_from_file(
414
@pytest.mark.parametrize("test_video,config", test_videos.items())
415
def test_read_video_from_file(self, test_video, config):
417
Test the case when decoder starts with a video file to decode frames.
420
width, height, min_dimension, max_dimension = 0, 0, 0, 0
421
video_start_pts, video_end_pts = 0, -1
422
video_timebase_num, video_timebase_den = 0, 1
424
samples, channels = 0, 0
425
audio_start_pts, audio_end_pts = 0, -1
426
audio_timebase_num, audio_timebase_den = 0, 1
428
full_path = os.path.join(VIDEO_DIR, test_video)
431
tv_result = torch.ops.video_reader.read_video_from_file(
453
pyav_result = _decode_frames_by_av_module(full_path)
455
self.check_separate_decoding_result(tv_result, config)
457
self.compare_decoding_result(tv_result, pyav_result, config)
459
@pytest.mark.parametrize("test_video,config", test_videos.items())
460
@pytest.mark.parametrize("read_video_stream,read_audio_stream", [(1, 0), (0, 1)])
461
def test_read_video_from_file_read_single_stream_only(
462
self, test_video, config, read_video_stream, read_audio_stream
465
Test the case when decoder starts with a video file to decode frames, and
466
only reads video stream and ignores audio stream
469
width, height, min_dimension, max_dimension = 0, 0, 0, 0
470
video_start_pts, video_end_pts = 0, -1
471
video_timebase_num, video_timebase_den = 0, 1
473
samples, channels = 0, 0
474
audio_start_pts, audio_end_pts = 0, -1
475
audio_timebase_num, audio_timebase_den = 0, 1
477
full_path = os.path.join(VIDEO_DIR, test_video)
479
tv_result = torch.ops.video_reader.read_video_from_file(
514
assert (vframes.numel() > 0) is bool(read_video_stream)
515
assert (vframe_pts.numel() > 0) is bool(read_video_stream)
516
assert (vtimebase.numel() > 0) is bool(read_video_stream)
517
assert (vfps.numel() > 0) is bool(read_video_stream)
519
expect_audio_data = read_audio_stream == 1 and config.audio_sample_rate is not None
520
assert (aframes.numel() > 0) is bool(expect_audio_data)
521
assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
522
assert (atimebase.numel() > 0) is bool(expect_audio_data)
523
assert (asample_rate.numel() > 0) is bool(expect_audio_data)
525
@pytest.mark.parametrize("test_video", test_videos.keys())
526
def test_read_video_from_file_rescale_min_dimension(self, test_video):
528
Test the case when decoder starts with a video file to decode frames, and
529
video min dimension between height and width is set.
532
width, height, min_dimension, max_dimension = 0, 0, 128, 0
533
video_start_pts, video_end_pts = 0, -1
534
video_timebase_num, video_timebase_den = 0, 1
536
samples, channels = 0, 0
537
audio_start_pts, audio_end_pts = 0, -1
538
audio_timebase_num, audio_timebase_den = 0, 1
540
full_path = os.path.join(VIDEO_DIR, test_video)
542
tv_result = torch.ops.video_reader.read_video_from_file(
563
assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
565
@pytest.mark.parametrize("test_video", test_videos.keys())
566
def test_read_video_from_file_rescale_max_dimension(self, test_video):
568
Test the case when decoder starts with a video file to decode frames, and
569
video min dimension between height and width is set.
572
width, height, min_dimension, max_dimension = 0, 0, 0, 85
573
video_start_pts, video_end_pts = 0, -1
574
video_timebase_num, video_timebase_den = 0, 1
576
samples, channels = 0, 0
577
audio_start_pts, audio_end_pts = 0, -1
578
audio_timebase_num, audio_timebase_den = 0, 1
580
full_path = os.path.join(VIDEO_DIR, test_video)
582
tv_result = torch.ops.video_reader.read_video_from_file(
603
assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
605
@pytest.mark.parametrize("test_video", test_videos.keys())
606
def test_read_video_from_file_rescale_both_min_max_dimension(self, test_video):
608
Test the case when decoder starts with a video file to decode frames, and
609
video min dimension between height and width is set.
612
width, height, min_dimension, max_dimension = 0, 0, 64, 85
613
video_start_pts, video_end_pts = 0, -1
614
video_timebase_num, video_timebase_den = 0, 1
616
samples, channels = 0, 0
617
audio_start_pts, audio_end_pts = 0, -1
618
audio_timebase_num, audio_timebase_den = 0, 1
620
full_path = os.path.join(VIDEO_DIR, test_video)
622
tv_result = torch.ops.video_reader.read_video_from_file(
643
assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
644
assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
646
@pytest.mark.parametrize("test_video", test_videos.keys())
647
def test_read_video_from_file_rescale_width(self, test_video):
649
Test the case when decoder starts with a video file to decode frames, and
653
width, height, min_dimension, max_dimension = 256, 0, 0, 0
654
video_start_pts, video_end_pts = 0, -1
655
video_timebase_num, video_timebase_den = 0, 1
657
samples, channels = 0, 0
658
audio_start_pts, audio_end_pts = 0, -1
659
audio_timebase_num, audio_timebase_den = 0, 1
661
full_path = os.path.join(VIDEO_DIR, test_video)
663
tv_result = torch.ops.video_reader.read_video_from_file(
684
assert tv_result[0].size(2) == width
686
@pytest.mark.parametrize("test_video", test_videos.keys())
687
def test_read_video_from_file_rescale_height(self, test_video):
689
Test the case when decoder starts with a video file to decode frames, and
693
width, height, min_dimension, max_dimension = 0, 224, 0, 0
694
video_start_pts, video_end_pts = 0, -1
695
video_timebase_num, video_timebase_den = 0, 1
697
samples, channels = 0, 0
698
audio_start_pts, audio_end_pts = 0, -1
699
audio_timebase_num, audio_timebase_den = 0, 1
701
full_path = os.path.join(VIDEO_DIR, test_video)
703
tv_result = torch.ops.video_reader.read_video_from_file(
724
assert tv_result[0].size(1) == height
726
@pytest.mark.parametrize("test_video", test_videos.keys())
727
def test_read_video_from_file_rescale_width_and_height(self, test_video):
729
Test the case when decoder starts with a video file to decode frames, and
730
both video height and width are set.
733
width, height, min_dimension, max_dimension = 320, 240, 0, 0
734
video_start_pts, video_end_pts = 0, -1
735
video_timebase_num, video_timebase_den = 0, 1
737
samples, channels = 0, 0
738
audio_start_pts, audio_end_pts = 0, -1
739
audio_timebase_num, audio_timebase_den = 0, 1
741
full_path = os.path.join(VIDEO_DIR, test_video)
743
tv_result = torch.ops.video_reader.read_video_from_file(
764
assert tv_result[0].size(1) == height
765
assert tv_result[0].size(2) == width
767
@pytest.mark.parametrize("test_video", test_videos.keys())
768
@pytest.mark.parametrize("samples", [9600, 96000])
769
def test_read_video_from_file_audio_resampling(self, test_video, samples):
771
Test the case when decoder starts with a video file to decode frames, and
772
audio waveform are resampled
775
width, height, min_dimension, max_dimension = 0, 0, 0, 0
776
video_start_pts, video_end_pts = 0, -1
777
video_timebase_num, video_timebase_den = 0, 1
780
audio_start_pts, audio_end_pts = 0, -1
781
audio_timebase_num, audio_timebase_den = 0, 1
783
full_path = os.path.join(VIDEO_DIR, test_video)
785
tv_result = torch.ops.video_reader.read_video_from_file(
818
if aframes.numel() > 0:
819
assert samples == asample_rate.item()
820
assert 1 == aframes.size(1)
822
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
823
assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
825
@pytest.mark.parametrize("test_video,config", test_videos.items())
826
def test_compare_read_video_from_memory_and_file(self, test_video, config):
828
Test the case when video is already in memory, and decoder reads data in memory
831
width, height, min_dimension, max_dimension = 0, 0, 0, 0
832
video_start_pts, video_end_pts = 0, -1
833
video_timebase_num, video_timebase_den = 0, 1
835
samples, channels = 0, 0
836
audio_start_pts, audio_end_pts = 0, -1
837
audio_timebase_num, audio_timebase_den = 0, 1
839
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
842
tv_result_memory = torch.ops.video_reader.read_video_from_memory(
863
self.check_separate_decoding_result(tv_result_memory, config)
865
tv_result_file = torch.ops.video_reader.read_video_from_file(
887
self.check_separate_decoding_result(tv_result_file, config)
889
self.compare_decoding_result(tv_result_memory, tv_result_file)
891
@pytest.mark.parametrize("test_video,config", test_videos.items())
892
def test_read_video_from_memory(self, test_video, config):
894
Test the case when video is already in memory, and decoder reads data in memory
897
width, height, min_dimension, max_dimension = 0, 0, 0, 0
898
video_start_pts, video_end_pts = 0, -1
899
video_timebase_num, video_timebase_den = 0, 1
901
samples, channels = 0, 0
902
audio_start_pts, audio_end_pts = 0, -1
903
audio_timebase_num, audio_timebase_den = 0, 1
905
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
908
tv_result = torch.ops.video_reader.read_video_from_memory(
930
pyav_result = _decode_frames_by_av_module(full_path)
932
self.check_separate_decoding_result(tv_result, config)
933
self.compare_decoding_result(tv_result, pyav_result, config)
935
@pytest.mark.parametrize("test_video,config", test_videos.items())
936
def test_read_video_from_memory_get_pts_only(self, test_video, config):
938
Test the case when video is already in memory, and decoder reads data in memory.
939
Compare frame pts between decoding for pts only and full decoding
940
for both pts and frame data
943
width, height, min_dimension, max_dimension = 0, 0, 0, 0
944
video_start_pts, video_end_pts = 0, -1
945
video_timebase_num, video_timebase_den = 0, 1
947
samples, channels = 0, 0
948
audio_start_pts, audio_end_pts = 0, -1
949
audio_timebase_num, audio_timebase_den = 0, 1
951
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
954
tv_result = torch.ops.video_reader.read_video_from_memory(
975
assert abs(config.video_fps - tv_result[3].item()) < 0.01
978
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
1000
assert not tv_result_pts_only[0].numel()
1001
assert not tv_result_pts_only[5].numel()
1002
self.compare_decoding_result(tv_result, tv_result_pts_only)
1004
@pytest.mark.parametrize("test_video,config", test_videos.items())
1005
@pytest.mark.parametrize("num_frames", [4, 8, 16, 32, 64, 128])
1006
def test_read_video_in_range_from_memory(self, test_video, config, num_frames):
1008
Test the case when video is already in memory, and decoder reads data in memory.
1009
In addition, decoder takes meaningful start- and end PTS as input, and decode
1010
frames within that interval
1012
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1014
width, height, min_dimension, max_dimension = 0, 0, 0, 0
1015
video_start_pts, video_end_pts = 0, -1
1016
video_timebase_num, video_timebase_den = 0, 1
1018
samples, channels = 0, 0
1019
audio_start_pts, audio_end_pts = 0, -1
1020
audio_timebase_num, audio_timebase_den = 0, 1
1022
tv_result = torch.ops.video_reader.read_video_from_memory(
1055
assert abs(config.video_fps - vfps.item()) < 0.01
1057
start_pts_ind_max = vframe_pts.size(0) - num_frames
1058
if start_pts_ind_max <= 0:
1061
start_pts_ind = randint(0, start_pts_ind_max)
1062
end_pts_ind = start_pts_ind + num_frames - 1
1063
video_start_pts = vframe_pts[start_pts_ind]
1064
video_end_pts = vframe_pts[end_pts_ind]
1066
video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1]
1067
if len(atimebase) > 0:
1069
audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
1070
audio_start_pts = _pts_convert(
1071
video_start_pts.item(),
1072
Fraction(video_timebase_num.item(), video_timebase_den.item()),
1073
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
1076
audio_end_pts = _pts_convert(
1077
video_end_pts.item(),
1078
Fraction(video_timebase_num.item(), video_timebase_den.item()),
1079
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
1084
tv_result = torch.ops.video_reader.read_video_from_memory(
1107
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path)
1109
video_start_pts_av = _pts_convert(
1110
video_start_pts.item(),
1111
Fraction(video_timebase_num.item(), video_timebase_den.item()),
1112
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
1115
video_end_pts_av = _pts_convert(
1116
video_end_pts.item(),
1117
Fraction(video_timebase_num.item(), video_timebase_den.item()),
1118
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
1121
if audio_timebase_av:
1122
audio_start_pts = _pts_convert(
1123
video_start_pts.item(),
1124
Fraction(video_timebase_num.item(), video_timebase_den.item()),
1125
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
1128
audio_end_pts = _pts_convert(
1129
video_end_pts.item(),
1130
Fraction(video_timebase_num.item(), video_timebase_den.item()),
1131
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
1135
pyav_result = _decode_frames_by_av_module(
1143
assert tv_result[0].size(0) == num_frames
1144
if pyav_result.vframes.size(0) == num_frames:
1148
self.compare_decoding_result(tv_result, pyav_result, config)
1150
@pytest.mark.parametrize("test_video,config", test_videos.items())
1151
def test_probe_video_from_file(self, test_video, config):
1153
Test the case when decoder probes a video file
1155
full_path = os.path.join(VIDEO_DIR, test_video)
1156
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
1157
self.check_probe_result(probe_result, config)
1159
@pytest.mark.parametrize("test_video,config", test_videos.items())
1160
def test_probe_video_from_memory(self, test_video, config):
1162
Test the case when decoder probes a video in memory
1164
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1165
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
1166
self.check_probe_result(probe_result, config)
1168
@pytest.mark.parametrize("test_video,config", test_videos.items())
1169
def test_probe_video_from_memory_script(self, test_video, config):
1170
scripted_fun = torch.jit.script(io._probe_video_from_memory)
1171
assert scripted_fun is not None
1173
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1174
probe_result = scripted_fun(video_tensor)
1175
self.check_meta_result(probe_result, config)
1177
@pytest.mark.parametrize("test_video", test_videos.keys())
1178
def test_read_video_from_memory_scripted(self, test_video):
1180
Test the case when video is already in memory, and decoder reads data in memory
1183
width, height, min_dimension, max_dimension = 0, 0, 0, 0
1184
video_start_pts, video_end_pts = 0, -1
1185
video_timebase_num, video_timebase_den = 0, 1
1187
samples, channels = 0, 0
1188
audio_start_pts, audio_end_pts = 0, -1
1189
audio_timebase_num, audio_timebase_den = 0, 1
1191
scripted_fun = torch.jit.script(io._read_video_from_memory)
1192
assert scripted_fun is not None
1194
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1205
[video_start_pts, video_end_pts],
1211
[audio_start_pts, audio_end_pts],
1217
def test_invalid_file(self):
1218
set_video_backend("video_reader")
1219
with pytest.raises(RuntimeError):
1220
io.read_video("foo.mp4")
1222
set_video_backend("pyav")
1223
with pytest.raises(RuntimeError):
1224
io.read_video("foo.mp4")
1226
@pytest.mark.parametrize("test_video", test_videos.keys())
1227
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
1228
@pytest.mark.parametrize("start_offset", [0, 500])
1229
@pytest.mark.parametrize("end_offset", [3000, None])
1230
def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
1231
"""Test if audio frames are returned with pts unit."""
1232
full_path = os.path.join(VIDEO_DIR, test_video)
1233
container = av.open(full_path)
1234
if container.streams.audio:
1235
set_video_backend(backend)
1236
_, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts")
1237
assert all([dimension > 0 for dimension in audio.shape[:2]])
1239
@pytest.mark.parametrize("test_video", test_videos.keys())
1240
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
1241
@pytest.mark.parametrize("start_offset", [0, 0.1])
1242
@pytest.mark.parametrize("end_offset", [0.3, None])
1243
def test_audio_present_sec(self, test_video, backend, start_offset, end_offset):
1244
"""Test if audio frames are returned with sec unit."""
1245
full_path = os.path.join(VIDEO_DIR, test_video)
1246
container = av.open(full_path)
1247
if container.streams.audio:
1248
set_video_backend(backend)
1249
_, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec")
1250
assert all([dimension > 0 for dimension in audio.shape[:2]])
1253
if __name__ == "__main__":
1254
pytest.main([__file__])