vision

Форк
0
/
test_video_reader.py 
1254 строки · 43.4 Кб
1
import collections
2
import math
3
import os
4
from fractions import Fraction
5

6
import numpy as np
7
import pytest
8
import torch
9
import torchvision.io as io
10
from common_utils import assert_equal
11
from numpy.random import randint
12
from pytest import approx
13
from torchvision import set_video_backend
14
from torchvision.io import _HAS_CPU_VIDEO_DECODER
15

16

17
try:
18
    import av
19

20
    # Do a version test too
21
    io.video._check_av_available()
22
except ImportError:
23
    av = None
24

25

26
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
27

28
CheckerConfig = [
29
    "duration",
30
    "video_fps",
31
    "audio_sample_rate",
32
    # We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
33
    # slightly different between TorchVision decoder and PyAv decoder. So omit it during check
34
    "check_aframes",
35
    "check_aframe_pts",
36
]
37
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
38

39
all_check_config = GroundTruth(
40
    duration=0,
41
    video_fps=0,
42
    audio_sample_rate=0,
43
    check_aframes=True,
44
    check_aframe_pts=True,
45
)
46

47
test_videos = {
48
    "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
49
        duration=2.0,
50
        video_fps=30.0,
51
        audio_sample_rate=None,
52
        check_aframes=True,
53
        check_aframe_pts=True,
54
    ),
55
    "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
56
        duration=2.0,
57
        video_fps=30.0,
58
        audio_sample_rate=None,
59
        check_aframes=True,
60
        check_aframe_pts=True,
61
    ),
62
    "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
63
        duration=2.0,
64
        video_fps=30.0,
65
        audio_sample_rate=None,
66
        check_aframes=True,
67
        check_aframe_pts=True,
68
    ),
69
    "v_SoccerJuggling_g23_c01.avi": GroundTruth(
70
        duration=8.0,
71
        video_fps=29.97,
72
        audio_sample_rate=None,
73
        check_aframes=True,
74
        check_aframe_pts=True,
75
    ),
76
    "v_SoccerJuggling_g24_c01.avi": GroundTruth(
77
        duration=8.0,
78
        video_fps=29.97,
79
        audio_sample_rate=None,
80
        check_aframes=True,
81
        check_aframe_pts=True,
82
    ),
83
    "R6llTwEh07w.mp4": GroundTruth(
84
        duration=10.0,
85
        video_fps=30.0,
86
        audio_sample_rate=44100,
87
        # PyAv miss one audio frame at the beginning (pts=0)
88
        check_aframes=False,
89
        check_aframe_pts=False,
90
    ),
91
    "SOX5yA1l24A.mp4": GroundTruth(
92
        duration=11.0,
93
        video_fps=29.97,
94
        audio_sample_rate=48000,
95
        # PyAv miss one audio frame at the beginning (pts=0)
96
        check_aframes=False,
97
        check_aframe_pts=False,
98
    ),
99
    "WUzgd7C1pWA.mp4": GroundTruth(
100
        duration=11.0,
101
        video_fps=29.97,
102
        audio_sample_rate=48000,
103
        # PyAv miss one audio frame at the beginning (pts=0)
104
        check_aframes=False,
105
        check_aframe_pts=False,
106
    ),
107
}
108

109

110
DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase")
111

112
# av_seek_frame is imprecise so seek to a timestamp earlier by a margin
113
# The unit of margin is second
114
SEEK_FRAME_MARGIN = 0.25
115

116

117
def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4):
118
    """
119
    Args:
120
        container: pyav container
121
        start_pts/end_pts: the starting/ending Presentation TimeStamp where
122
            frames are read
123
        stream: pyav stream
124
        stream_name: a dictionary of streams. For example, {"video": 0} means
125
            video stream at stream index 0
126
        buffer_size: pts of frames decoded by PyAv is not guaranteed to be in
127
            ascending order. We need to decode more frames even when we meet end
128
            pts
129
    """
130
    # seeking in the stream is imprecise. Thus, seek to an earlier PTS by a margin
131
    margin = 1
132
    seek_offset = max(start_pts - margin, 0)
133

134
    container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
135
    frames = {}
136
    buffer_count = 0
137
    for frame in container.decode(**stream_name):
138
        if frame.pts < start_pts:
139
            continue
140
        if frame.pts <= end_pts:
141
            frames[frame.pts] = frame
142
        else:
143
            buffer_count += 1
144
            if buffer_count >= buffer_size:
145
                break
146
    result = [frames[pts] for pts in sorted(frames)]
147

148
    return result
149

150

151
def _get_timebase_by_av_module(full_path):
152
    container = av.open(full_path)
153
    video_time_base = container.streams.video[0].time_base
154
    if container.streams.audio:
155
        audio_time_base = container.streams.audio[0].time_base
156
    else:
157
        audio_time_base = None
158
    return video_time_base, audio_time_base
159

160

161
def _fraction_to_tensor(fraction):
162
    ret = torch.zeros([2], dtype=torch.int32)
163
    ret[0] = fraction.numerator
164
    ret[1] = fraction.denominator
165
    return ret
166

167

168
def _decode_frames_by_av_module(
169
    full_path,
170
    video_start_pts=0,
171
    video_end_pts=None,
172
    audio_start_pts=0,
173
    audio_end_pts=None,
174
):
175
    """
176
    Use PyAv to decode video frames. This provides a reference for our decoder
177
    to compare the decoding results.
178
    Input arguments:
179
        full_path: video file path
180
        video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where
181
            frames are read
182
    """
183
    if video_end_pts is None:
184
        video_end_pts = float("inf")
185
    if audio_end_pts is None:
186
        audio_end_pts = float("inf")
187
    container = av.open(full_path)
188

189
    video_frames = []
190
    vtimebase = torch.zeros([0], dtype=torch.int32)
191
    if container.streams.video:
192
        video_frames = _read_from_stream(
193
            container,
194
            video_start_pts,
195
            video_end_pts,
196
            container.streams.video[0],
197
            {"video": 0},
198
        )
199
        # container.streams.video[0].average_rate is not a reliable estimator of
200
        # frame rate. It can be wrong for certain codec, such as VP80
201
        # So we do not return video fps here
202
        vtimebase = _fraction_to_tensor(container.streams.video[0].time_base)
203

204
    audio_frames = []
205
    atimebase = torch.zeros([0], dtype=torch.int32)
206
    if container.streams.audio:
207
        audio_frames = _read_from_stream(
208
            container,
209
            audio_start_pts,
210
            audio_end_pts,
211
            container.streams.audio[0],
212
            {"audio": 0},
213
        )
214
        atimebase = _fraction_to_tensor(container.streams.audio[0].time_base)
215

216
    container.close()
217
    vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
218
    vframes = torch.as_tensor(np.stack(vframes))
219

220
    vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64)
221

222
    aframes = [frame.to_ndarray() for frame in audio_frames]
223
    if aframes:
224
        aframes = np.transpose(np.concatenate(aframes, axis=1))
225
        aframes = torch.as_tensor(aframes)
226
    else:
227
        aframes = torch.empty((1, 0), dtype=torch.float32)
228

229
    aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64)
230

231
    return DecoderResult(
232
        vframes=vframes,
233
        vframe_pts=vframe_pts,
234
        vtimebase=vtimebase,
235
        aframes=aframes,
236
        aframe_pts=aframe_pts,
237
        atimebase=atimebase,
238
    )
239

240

241
def _pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
242
    """convert pts between different time bases
243
    Args:
244
        pts: presentation timestamp, float
245
        timebase_from: original timebase. Fraction
246
        timebase_to: new timebase. Fraction
247
        round_func: rounding function.
248
    """
249
    new_pts = Fraction(pts, 1) * timebase_from / timebase_to
250
    return int(round_func(new_pts))
251

252

253
def _get_video_tensor(video_dir, video_file):
254
    """open a video file, and represent the video data by a PT tensor"""
255
    full_path = os.path.join(video_dir, video_file)
256

257
    assert os.path.exists(full_path), "File not found: %s" % full_path
258

259
    with open(full_path, "rb") as fp:
260
        video_tensor = torch.frombuffer(fp.read(), dtype=torch.uint8)
261

262
    return full_path, video_tensor
263

264

265
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
266
@pytest.mark.skipif(_HAS_CPU_VIDEO_DECODER is False, reason="Didn't compile with ffmpeg")
267
class TestVideoReader:
268
    def check_separate_decoding_result(self, tv_result, config):
269
        """check the decoding results from TorchVision decoder"""
270
        (
271
            vframes,
272
            vframe_pts,
273
            vtimebase,
274
            vfps,
275
            vduration,
276
            aframes,
277
            aframe_pts,
278
            atimebase,
279
            asample_rate,
280
            aduration,
281
        ) = tv_result
282

283
        video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
284
        assert video_duration == approx(config.duration, abs=0.5)
285

286
        assert vfps.item() == approx(config.video_fps, abs=0.5)
287

288
        if asample_rate.numel() > 0:
289
            assert asample_rate.item() == config.audio_sample_rate
290
            audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
291
            assert audio_duration == approx(config.duration, abs=0.5)
292

293
        # check if pts of video frames are sorted in ascending order
294
        for i in range(len(vframe_pts) - 1):
295
            assert vframe_pts[i] < vframe_pts[i + 1]
296

297
        if len(aframe_pts) > 1:
298
            # check if pts of audio frames are sorted in ascending order
299
            for i in range(len(aframe_pts) - 1):
300
                assert aframe_pts[i] < aframe_pts[i + 1]
301

302
    def check_probe_result(self, result, config):
303
        vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
304
        video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
305
        assert video_duration == approx(config.duration, abs=0.5)
306
        assert vfps.item() == approx(config.video_fps, abs=0.5)
307
        if asample_rate.numel() > 0:
308
            assert asample_rate.item() == config.audio_sample_rate
309
            audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
310
            assert audio_duration == approx(config.duration, abs=0.5)
311

312
    def check_meta_result(self, result, config):
313
        assert result.video_duration == approx(config.duration, abs=0.5)
314
        assert result.video_fps == approx(config.video_fps, abs=0.5)
315
        if result.has_audio > 0:
316
            assert result.audio_sample_rate == config.audio_sample_rate
317
            assert result.audio_duration == approx(config.duration, abs=0.5)
318

319
    def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
320
        """
321
        Compare decoding results from two sources.
322
        Args:
323
            tv_result: decoding results from TorchVision decoder
324
            ref_result: reference decoding results which can be from either PyAv
325
                        decoder or TorchVision decoder with getPtsOnly = 1
326
            config: config of decoding results checker
327
        """
328
        (
329
            vframes,
330
            vframe_pts,
331
            vtimebase,
332
            _vfps,
333
            _vduration,
334
            aframes,
335
            aframe_pts,
336
            atimebase,
337
            _asample_rate,
338
            _aduration,
339
        ) = tv_result
340
        if isinstance(ref_result, list):
341
            # the ref_result is from new video_reader decoder
342
            ref_result = DecoderResult(
343
                vframes=ref_result[0],
344
                vframe_pts=ref_result[1],
345
                vtimebase=ref_result[2],
346
                aframes=ref_result[5],
347
                aframe_pts=ref_result[6],
348
                atimebase=ref_result[7],
349
            )
350

351
        if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
352
            mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float()))
353
            assert mean_delta == approx(0.0, abs=8.0)
354

355
        mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()))
356
        assert mean_delta == approx(0.0, abs=1.0)
357

358
        assert_equal(vtimebase, ref_result.vtimebase)
359

360
        if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0:
361
            """Audio stream is available and audio frame is required to return
362
            from decoder"""
363
            assert_equal(aframes, ref_result.aframes)
364

365
        if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0:
366
            """Audio stream is available"""
367
            assert_equal(aframe_pts, ref_result.aframe_pts)
368

369
            assert_equal(atimebase, ref_result.atimebase)
370

371
    @pytest.mark.parametrize("test_video", test_videos.keys())
372
    def test_stress_test_read_video_from_file(self, test_video):
373
        pytest.skip(
374
            "This stress test will iteratively decode the same set of videos."
375
            "It helps to detect memory leak but it takes lots of time to run."
376
            "By default, it is disabled"
377
        )
378
        num_iter = 10000
379
        # video related
380
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
381
        video_start_pts, video_end_pts = 0, -1
382
        video_timebase_num, video_timebase_den = 0, 1
383
        # audio related
384
        samples, channels = 0, 0
385
        audio_start_pts, audio_end_pts = 0, -1
386
        audio_timebase_num, audio_timebase_den = 0, 1
387

388
        for _i in range(num_iter):
389
            full_path = os.path.join(VIDEO_DIR, test_video)
390

391
            # pass 1: decode all frames using new decoder
392
            torch.ops.video_reader.read_video_from_file(
393
                full_path,
394
                SEEK_FRAME_MARGIN,
395
                0,  # getPtsOnly
396
                1,  # readVideoStream
397
                width,
398
                height,
399
                min_dimension,
400
                max_dimension,
401
                video_start_pts,
402
                video_end_pts,
403
                video_timebase_num,
404
                video_timebase_den,
405
                1,  # readAudioStream
406
                samples,
407
                channels,
408
                audio_start_pts,
409
                audio_end_pts,
410
                audio_timebase_num,
411
                audio_timebase_den,
412
            )
413

414
    @pytest.mark.parametrize("test_video,config", test_videos.items())
415
    def test_read_video_from_file(self, test_video, config):
416
        """
417
        Test the case when decoder starts with a video file to decode frames.
418
        """
419
        # video related
420
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
421
        video_start_pts, video_end_pts = 0, -1
422
        video_timebase_num, video_timebase_den = 0, 1
423
        # audio related
424
        samples, channels = 0, 0
425
        audio_start_pts, audio_end_pts = 0, -1
426
        audio_timebase_num, audio_timebase_den = 0, 1
427

428
        full_path = os.path.join(VIDEO_DIR, test_video)
429

430
        # pass 1: decode all frames using new decoder
431
        tv_result = torch.ops.video_reader.read_video_from_file(
432
            full_path,
433
            SEEK_FRAME_MARGIN,
434
            0,  # getPtsOnly
435
            1,  # readVideoStream
436
            width,
437
            height,
438
            min_dimension,
439
            max_dimension,
440
            video_start_pts,
441
            video_end_pts,
442
            video_timebase_num,
443
            video_timebase_den,
444
            1,  # readAudioStream
445
            samples,
446
            channels,
447
            audio_start_pts,
448
            audio_end_pts,
449
            audio_timebase_num,
450
            audio_timebase_den,
451
        )
452
        # pass 2: decode all frames using av
453
        pyav_result = _decode_frames_by_av_module(full_path)
454
        # check results from TorchVision decoder
455
        self.check_separate_decoding_result(tv_result, config)
456
        # compare decoding results
457
        self.compare_decoding_result(tv_result, pyav_result, config)
458

459
    @pytest.mark.parametrize("test_video,config", test_videos.items())
460
    @pytest.mark.parametrize("read_video_stream,read_audio_stream", [(1, 0), (0, 1)])
461
    def test_read_video_from_file_read_single_stream_only(
462
        self, test_video, config, read_video_stream, read_audio_stream
463
    ):
464
        """
465
        Test the case when decoder starts with a video file to decode frames, and
466
        only reads video stream and ignores audio stream
467
        """
468
        # video related
469
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
470
        video_start_pts, video_end_pts = 0, -1
471
        video_timebase_num, video_timebase_den = 0, 1
472
        # audio related
473
        samples, channels = 0, 0
474
        audio_start_pts, audio_end_pts = 0, -1
475
        audio_timebase_num, audio_timebase_den = 0, 1
476

477
        full_path = os.path.join(VIDEO_DIR, test_video)
478
        # decode all frames using new decoder
479
        tv_result = torch.ops.video_reader.read_video_from_file(
480
            full_path,
481
            SEEK_FRAME_MARGIN,
482
            0,  # getPtsOnly
483
            read_video_stream,
484
            width,
485
            height,
486
            min_dimension,
487
            max_dimension,
488
            video_start_pts,
489
            video_end_pts,
490
            video_timebase_num,
491
            video_timebase_den,
492
            read_audio_stream,
493
            samples,
494
            channels,
495
            audio_start_pts,
496
            audio_end_pts,
497
            audio_timebase_num,
498
            audio_timebase_den,
499
        )
500

501
        (
502
            vframes,
503
            vframe_pts,
504
            vtimebase,
505
            vfps,
506
            vduration,
507
            aframes,
508
            aframe_pts,
509
            atimebase,
510
            asample_rate,
511
            aduration,
512
        ) = tv_result
513

514
        assert (vframes.numel() > 0) is bool(read_video_stream)
515
        assert (vframe_pts.numel() > 0) is bool(read_video_stream)
516
        assert (vtimebase.numel() > 0) is bool(read_video_stream)
517
        assert (vfps.numel() > 0) is bool(read_video_stream)
518

519
        expect_audio_data = read_audio_stream == 1 and config.audio_sample_rate is not None
520
        assert (aframes.numel() > 0) is bool(expect_audio_data)
521
        assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
522
        assert (atimebase.numel() > 0) is bool(expect_audio_data)
523
        assert (asample_rate.numel() > 0) is bool(expect_audio_data)
524

525
    @pytest.mark.parametrize("test_video", test_videos.keys())
526
    def test_read_video_from_file_rescale_min_dimension(self, test_video):
527
        """
528
        Test the case when decoder starts with a video file to decode frames, and
529
        video min dimension between height and width is set.
530
        """
531
        # video related
532
        width, height, min_dimension, max_dimension = 0, 0, 128, 0
533
        video_start_pts, video_end_pts = 0, -1
534
        video_timebase_num, video_timebase_den = 0, 1
535
        # audio related
536
        samples, channels = 0, 0
537
        audio_start_pts, audio_end_pts = 0, -1
538
        audio_timebase_num, audio_timebase_den = 0, 1
539

540
        full_path = os.path.join(VIDEO_DIR, test_video)
541

542
        tv_result = torch.ops.video_reader.read_video_from_file(
543
            full_path,
544
            SEEK_FRAME_MARGIN,
545
            0,  # getPtsOnly
546
            1,  # readVideoStream
547
            width,
548
            height,
549
            min_dimension,
550
            max_dimension,
551
            video_start_pts,
552
            video_end_pts,
553
            video_timebase_num,
554
            video_timebase_den,
555
            1,  # readAudioStream
556
            samples,
557
            channels,
558
            audio_start_pts,
559
            audio_end_pts,
560
            audio_timebase_num,
561
            audio_timebase_den,
562
        )
563
        assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
564

565
    @pytest.mark.parametrize("test_video", test_videos.keys())
566
    def test_read_video_from_file_rescale_max_dimension(self, test_video):
567
        """
568
        Test the case when decoder starts with a video file to decode frames, and
569
        video min dimension between height and width is set.
570
        """
571
        # video related
572
        width, height, min_dimension, max_dimension = 0, 0, 0, 85
573
        video_start_pts, video_end_pts = 0, -1
574
        video_timebase_num, video_timebase_den = 0, 1
575
        # audio related
576
        samples, channels = 0, 0
577
        audio_start_pts, audio_end_pts = 0, -1
578
        audio_timebase_num, audio_timebase_den = 0, 1
579

580
        full_path = os.path.join(VIDEO_DIR, test_video)
581

582
        tv_result = torch.ops.video_reader.read_video_from_file(
583
            full_path,
584
            SEEK_FRAME_MARGIN,
585
            0,  # getPtsOnly
586
            1,  # readVideoStream
587
            width,
588
            height,
589
            min_dimension,
590
            max_dimension,
591
            video_start_pts,
592
            video_end_pts,
593
            video_timebase_num,
594
            video_timebase_den,
595
            1,  # readAudioStream
596
            samples,
597
            channels,
598
            audio_start_pts,
599
            audio_end_pts,
600
            audio_timebase_num,
601
            audio_timebase_den,
602
        )
603
        assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
604

605
    @pytest.mark.parametrize("test_video", test_videos.keys())
606
    def test_read_video_from_file_rescale_both_min_max_dimension(self, test_video):
607
        """
608
        Test the case when decoder starts with a video file to decode frames, and
609
        video min dimension between height and width is set.
610
        """
611
        # video related
612
        width, height, min_dimension, max_dimension = 0, 0, 64, 85
613
        video_start_pts, video_end_pts = 0, -1
614
        video_timebase_num, video_timebase_den = 0, 1
615
        # audio related
616
        samples, channels = 0, 0
617
        audio_start_pts, audio_end_pts = 0, -1
618
        audio_timebase_num, audio_timebase_den = 0, 1
619

620
        full_path = os.path.join(VIDEO_DIR, test_video)
621

622
        tv_result = torch.ops.video_reader.read_video_from_file(
623
            full_path,
624
            SEEK_FRAME_MARGIN,
625
            0,  # getPtsOnly
626
            1,  # readVideoStream
627
            width,
628
            height,
629
            min_dimension,
630
            max_dimension,
631
            video_start_pts,
632
            video_end_pts,
633
            video_timebase_num,
634
            video_timebase_den,
635
            1,  # readAudioStream
636
            samples,
637
            channels,
638
            audio_start_pts,
639
            audio_end_pts,
640
            audio_timebase_num,
641
            audio_timebase_den,
642
        )
643
        assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
644
        assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
645

646
    @pytest.mark.parametrize("test_video", test_videos.keys())
647
    def test_read_video_from_file_rescale_width(self, test_video):
648
        """
649
        Test the case when decoder starts with a video file to decode frames, and
650
        video width is set.
651
        """
652
        # video related
653
        width, height, min_dimension, max_dimension = 256, 0, 0, 0
654
        video_start_pts, video_end_pts = 0, -1
655
        video_timebase_num, video_timebase_den = 0, 1
656
        # audio related
657
        samples, channels = 0, 0
658
        audio_start_pts, audio_end_pts = 0, -1
659
        audio_timebase_num, audio_timebase_den = 0, 1
660

661
        full_path = os.path.join(VIDEO_DIR, test_video)
662

663
        tv_result = torch.ops.video_reader.read_video_from_file(
664
            full_path,
665
            SEEK_FRAME_MARGIN,
666
            0,  # getPtsOnly
667
            1,  # readVideoStream
668
            width,
669
            height,
670
            min_dimension,
671
            max_dimension,
672
            video_start_pts,
673
            video_end_pts,
674
            video_timebase_num,
675
            video_timebase_den,
676
            1,  # readAudioStream
677
            samples,
678
            channels,
679
            audio_start_pts,
680
            audio_end_pts,
681
            audio_timebase_num,
682
            audio_timebase_den,
683
        )
684
        assert tv_result[0].size(2) == width
685

686
    @pytest.mark.parametrize("test_video", test_videos.keys())
687
    def test_read_video_from_file_rescale_height(self, test_video):
688
        """
689
        Test the case when decoder starts with a video file to decode frames, and
690
        video height is set.
691
        """
692
        # video related
693
        width, height, min_dimension, max_dimension = 0, 224, 0, 0
694
        video_start_pts, video_end_pts = 0, -1
695
        video_timebase_num, video_timebase_den = 0, 1
696
        # audio related
697
        samples, channels = 0, 0
698
        audio_start_pts, audio_end_pts = 0, -1
699
        audio_timebase_num, audio_timebase_den = 0, 1
700

701
        full_path = os.path.join(VIDEO_DIR, test_video)
702

703
        tv_result = torch.ops.video_reader.read_video_from_file(
704
            full_path,
705
            SEEK_FRAME_MARGIN,
706
            0,  # getPtsOnly
707
            1,  # readVideoStream
708
            width,
709
            height,
710
            min_dimension,
711
            max_dimension,
712
            video_start_pts,
713
            video_end_pts,
714
            video_timebase_num,
715
            video_timebase_den,
716
            1,  # readAudioStream
717
            samples,
718
            channels,
719
            audio_start_pts,
720
            audio_end_pts,
721
            audio_timebase_num,
722
            audio_timebase_den,
723
        )
724
        assert tv_result[0].size(1) == height
725

726
    @pytest.mark.parametrize("test_video", test_videos.keys())
727
    def test_read_video_from_file_rescale_width_and_height(self, test_video):
728
        """
729
        Test the case when decoder starts with a video file to decode frames, and
730
        both video height and width are set.
731
        """
732
        # video related
733
        width, height, min_dimension, max_dimension = 320, 240, 0, 0
734
        video_start_pts, video_end_pts = 0, -1
735
        video_timebase_num, video_timebase_den = 0, 1
736
        # audio related
737
        samples, channels = 0, 0
738
        audio_start_pts, audio_end_pts = 0, -1
739
        audio_timebase_num, audio_timebase_den = 0, 1
740

741
        full_path = os.path.join(VIDEO_DIR, test_video)
742

743
        tv_result = torch.ops.video_reader.read_video_from_file(
744
            full_path,
745
            SEEK_FRAME_MARGIN,
746
            0,  # getPtsOnly
747
            1,  # readVideoStream
748
            width,
749
            height,
750
            min_dimension,
751
            max_dimension,
752
            video_start_pts,
753
            video_end_pts,
754
            video_timebase_num,
755
            video_timebase_den,
756
            1,  # readAudioStream
757
            samples,
758
            channels,
759
            audio_start_pts,
760
            audio_end_pts,
761
            audio_timebase_num,
762
            audio_timebase_den,
763
        )
764
        assert tv_result[0].size(1) == height
765
        assert tv_result[0].size(2) == width
766

767
    @pytest.mark.parametrize("test_video", test_videos.keys())
768
    @pytest.mark.parametrize("samples", [9600, 96000])
769
    def test_read_video_from_file_audio_resampling(self, test_video, samples):
770
        """
771
        Test the case when decoder starts with a video file to decode frames, and
772
        audio waveform are resampled
773
        """
774
        # video related
775
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
776
        video_start_pts, video_end_pts = 0, -1
777
        video_timebase_num, video_timebase_den = 0, 1
778
        # audio related
779
        channels = 0
780
        audio_start_pts, audio_end_pts = 0, -1
781
        audio_timebase_num, audio_timebase_den = 0, 1
782

783
        full_path = os.path.join(VIDEO_DIR, test_video)
784

785
        tv_result = torch.ops.video_reader.read_video_from_file(
786
            full_path,
787
            SEEK_FRAME_MARGIN,
788
            0,  # getPtsOnly
789
            1,  # readVideoStream
790
            width,
791
            height,
792
            min_dimension,
793
            max_dimension,
794
            video_start_pts,
795
            video_end_pts,
796
            video_timebase_num,
797
            video_timebase_den,
798
            1,  # readAudioStream
799
            samples,
800
            channels,
801
            audio_start_pts,
802
            audio_end_pts,
803
            audio_timebase_num,
804
            audio_timebase_den,
805
        )
806
        (
807
            vframes,
808
            vframe_pts,
809
            vtimebase,
810
            vfps,
811
            vduration,
812
            aframes,
813
            aframe_pts,
814
            atimebase,
815
            asample_rate,
816
            aduration,
817
        ) = tv_result
818
        if aframes.numel() > 0:
819
            assert samples == asample_rate.item()
820
            assert 1 == aframes.size(1)
821
            # when audio stream is found
822
            duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
823
            assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
824

825
    @pytest.mark.parametrize("test_video,config", test_videos.items())
826
    def test_compare_read_video_from_memory_and_file(self, test_video, config):
827
        """
828
        Test the case when video is already in memory, and decoder reads data in memory
829
        """
830
        # video related
831
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
832
        video_start_pts, video_end_pts = 0, -1
833
        video_timebase_num, video_timebase_den = 0, 1
834
        # audio related
835
        samples, channels = 0, 0
836
        audio_start_pts, audio_end_pts = 0, -1
837
        audio_timebase_num, audio_timebase_den = 0, 1
838

839
        full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
840

841
        # pass 1: decode all frames using cpp decoder
842
        tv_result_memory = torch.ops.video_reader.read_video_from_memory(
843
            video_tensor,
844
            SEEK_FRAME_MARGIN,
845
            0,  # getPtsOnly
846
            1,  # readVideoStream
847
            width,
848
            height,
849
            min_dimension,
850
            max_dimension,
851
            video_start_pts,
852
            video_end_pts,
853
            video_timebase_num,
854
            video_timebase_den,
855
            1,  # readAudioStream
856
            samples,
857
            channels,
858
            audio_start_pts,
859
            audio_end_pts,
860
            audio_timebase_num,
861
            audio_timebase_den,
862
        )
863
        self.check_separate_decoding_result(tv_result_memory, config)
864
        # pass 2: decode all frames from file
865
        tv_result_file = torch.ops.video_reader.read_video_from_file(
866
            full_path,
867
            SEEK_FRAME_MARGIN,
868
            0,  # getPtsOnly
869
            1,  # readVideoStream
870
            width,
871
            height,
872
            min_dimension,
873
            max_dimension,
874
            video_start_pts,
875
            video_end_pts,
876
            video_timebase_num,
877
            video_timebase_den,
878
            1,  # readAudioStream
879
            samples,
880
            channels,
881
            audio_start_pts,
882
            audio_end_pts,
883
            audio_timebase_num,
884
            audio_timebase_den,
885
        )
886

887
        self.check_separate_decoding_result(tv_result_file, config)
888
        # finally, compare results decoded from memory and file
889
        self.compare_decoding_result(tv_result_memory, tv_result_file)
890

891
    @pytest.mark.parametrize("test_video,config", test_videos.items())
892
    def test_read_video_from_memory(self, test_video, config):
893
        """
894
        Test the case when video is already in memory, and decoder reads data in memory
895
        """
896
        # video related
897
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
898
        video_start_pts, video_end_pts = 0, -1
899
        video_timebase_num, video_timebase_den = 0, 1
900
        # audio related
901
        samples, channels = 0, 0
902
        audio_start_pts, audio_end_pts = 0, -1
903
        audio_timebase_num, audio_timebase_den = 0, 1
904

905
        full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
906

907
        # pass 1: decode all frames using cpp decoder
908
        tv_result = torch.ops.video_reader.read_video_from_memory(
909
            video_tensor,
910
            SEEK_FRAME_MARGIN,
911
            0,  # getPtsOnly
912
            1,  # readVideoStream
913
            width,
914
            height,
915
            min_dimension,
916
            max_dimension,
917
            video_start_pts,
918
            video_end_pts,
919
            video_timebase_num,
920
            video_timebase_den,
921
            1,  # readAudioStream
922
            samples,
923
            channels,
924
            audio_start_pts,
925
            audio_end_pts,
926
            audio_timebase_num,
927
            audio_timebase_den,
928
        )
929
        # pass 2: decode all frames using av
930
        pyav_result = _decode_frames_by_av_module(full_path)
931

932
        self.check_separate_decoding_result(tv_result, config)
933
        self.compare_decoding_result(tv_result, pyav_result, config)
934

935
    @pytest.mark.parametrize("test_video,config", test_videos.items())
936
    def test_read_video_from_memory_get_pts_only(self, test_video, config):
937
        """
938
        Test the case when video is already in memory, and decoder reads data in memory.
939
        Compare frame pts between decoding for pts only and full decoding
940
        for both pts and frame data
941
        """
942
        # video related
943
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
944
        video_start_pts, video_end_pts = 0, -1
945
        video_timebase_num, video_timebase_den = 0, 1
946
        # audio related
947
        samples, channels = 0, 0
948
        audio_start_pts, audio_end_pts = 0, -1
949
        audio_timebase_num, audio_timebase_den = 0, 1
950

951
        _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
952

953
        # pass 1: decode all frames using cpp decoder
954
        tv_result = torch.ops.video_reader.read_video_from_memory(
955
            video_tensor,
956
            SEEK_FRAME_MARGIN,
957
            0,  # getPtsOnly
958
            1,  # readVideoStream
959
            width,
960
            height,
961
            min_dimension,
962
            max_dimension,
963
            video_start_pts,
964
            video_end_pts,
965
            video_timebase_num,
966
            video_timebase_den,
967
            1,  # readAudioStream
968
            samples,
969
            channels,
970
            audio_start_pts,
971
            audio_end_pts,
972
            audio_timebase_num,
973
            audio_timebase_den,
974
        )
975
        assert abs(config.video_fps - tv_result[3].item()) < 0.01
976

977
        # pass 2: decode all frames to get PTS only using cpp decoder
978
        tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
979
            video_tensor,
980
            SEEK_FRAME_MARGIN,
981
            1,  # getPtsOnly
982
            1,  # readVideoStream
983
            width,
984
            height,
985
            min_dimension,
986
            max_dimension,
987
            video_start_pts,
988
            video_end_pts,
989
            video_timebase_num,
990
            video_timebase_den,
991
            1,  # readAudioStream
992
            samples,
993
            channels,
994
            audio_start_pts,
995
            audio_end_pts,
996
            audio_timebase_num,
997
            audio_timebase_den,
998
        )
999

1000
        assert not tv_result_pts_only[0].numel()
1001
        assert not tv_result_pts_only[5].numel()
1002
        self.compare_decoding_result(tv_result, tv_result_pts_only)
1003

1004
    @pytest.mark.parametrize("test_video,config", test_videos.items())
1005
    @pytest.mark.parametrize("num_frames", [4, 8, 16, 32, 64, 128])
1006
    def test_read_video_in_range_from_memory(self, test_video, config, num_frames):
1007
        """
1008
        Test the case when video is already in memory, and decoder reads data in memory.
1009
        In addition, decoder takes meaningful start- and end PTS as input, and decode
1010
        frames within that interval
1011
        """
1012
        full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1013
        # video related
1014
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
1015
        video_start_pts, video_end_pts = 0, -1
1016
        video_timebase_num, video_timebase_den = 0, 1
1017
        # audio related
1018
        samples, channels = 0, 0
1019
        audio_start_pts, audio_end_pts = 0, -1
1020
        audio_timebase_num, audio_timebase_den = 0, 1
1021
        # pass 1: decode all frames using new decoder
1022
        tv_result = torch.ops.video_reader.read_video_from_memory(
1023
            video_tensor,
1024
            SEEK_FRAME_MARGIN,
1025
            0,  # getPtsOnly
1026
            1,  # readVideoStream
1027
            width,
1028
            height,
1029
            min_dimension,
1030
            max_dimension,
1031
            video_start_pts,
1032
            video_end_pts,
1033
            video_timebase_num,
1034
            video_timebase_den,
1035
            1,  # readAudioStream
1036
            samples,
1037
            channels,
1038
            audio_start_pts,
1039
            audio_end_pts,
1040
            audio_timebase_num,
1041
            audio_timebase_den,
1042
        )
1043
        (
1044
            vframes,
1045
            vframe_pts,
1046
            vtimebase,
1047
            vfps,
1048
            vduration,
1049
            aframes,
1050
            aframe_pts,
1051
            atimebase,
1052
            asample_rate,
1053
            aduration,
1054
        ) = tv_result
1055
        assert abs(config.video_fps - vfps.item()) < 0.01
1056

1057
        start_pts_ind_max = vframe_pts.size(0) - num_frames
1058
        if start_pts_ind_max <= 0:
1059
            return
1060
        # randomly pick start pts
1061
        start_pts_ind = randint(0, start_pts_ind_max)
1062
        end_pts_ind = start_pts_ind + num_frames - 1
1063
        video_start_pts = vframe_pts[start_pts_ind]
1064
        video_end_pts = vframe_pts[end_pts_ind]
1065

1066
        video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1]
1067
        if len(atimebase) > 0:
1068
            # when audio stream is available
1069
            audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
1070
            audio_start_pts = _pts_convert(
1071
                video_start_pts.item(),
1072
                Fraction(video_timebase_num.item(), video_timebase_den.item()),
1073
                Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
1074
                math.floor,
1075
            )
1076
            audio_end_pts = _pts_convert(
1077
                video_end_pts.item(),
1078
                Fraction(video_timebase_num.item(), video_timebase_den.item()),
1079
                Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
1080
                math.ceil,
1081
            )
1082

1083
        # pass 2: decode frames in the randomly generated range
1084
        tv_result = torch.ops.video_reader.read_video_from_memory(
1085
            video_tensor,
1086
            SEEK_FRAME_MARGIN,
1087
            0,  # getPtsOnly
1088
            1,  # readVideoStream
1089
            width,
1090
            height,
1091
            min_dimension,
1092
            max_dimension,
1093
            video_start_pts,
1094
            video_end_pts,
1095
            video_timebase_num,
1096
            video_timebase_den,
1097
            1,  # readAudioStream
1098
            samples,
1099
            channels,
1100
            audio_start_pts,
1101
            audio_end_pts,
1102
            audio_timebase_num,
1103
            audio_timebase_den,
1104
        )
1105

1106
        # pass 3: decode frames in range using PyAv
1107
        video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path)
1108

1109
        video_start_pts_av = _pts_convert(
1110
            video_start_pts.item(),
1111
            Fraction(video_timebase_num.item(), video_timebase_den.item()),
1112
            Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
1113
            math.floor,
1114
        )
1115
        video_end_pts_av = _pts_convert(
1116
            video_end_pts.item(),
1117
            Fraction(video_timebase_num.item(), video_timebase_den.item()),
1118
            Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
1119
            math.ceil,
1120
        )
1121
        if audio_timebase_av:
1122
            audio_start_pts = _pts_convert(
1123
                video_start_pts.item(),
1124
                Fraction(video_timebase_num.item(), video_timebase_den.item()),
1125
                Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
1126
                math.floor,
1127
            )
1128
            audio_end_pts = _pts_convert(
1129
                video_end_pts.item(),
1130
                Fraction(video_timebase_num.item(), video_timebase_den.item()),
1131
                Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
1132
                math.ceil,
1133
            )
1134

1135
        pyav_result = _decode_frames_by_av_module(
1136
            full_path,
1137
            video_start_pts_av,
1138
            video_end_pts_av,
1139
            audio_start_pts,
1140
            audio_end_pts,
1141
        )
1142

1143
        assert tv_result[0].size(0) == num_frames
1144
        if pyav_result.vframes.size(0) == num_frames:
1145
            # if PyAv decodes a different number of video frames, skip
1146
            # comparing the decoding results between Torchvision video reader
1147
            # and PyAv
1148
            self.compare_decoding_result(tv_result, pyav_result, config)
1149

1150
    @pytest.mark.parametrize("test_video,config", test_videos.items())
1151
    def test_probe_video_from_file(self, test_video, config):
1152
        """
1153
        Test the case when decoder probes a video file
1154
        """
1155
        full_path = os.path.join(VIDEO_DIR, test_video)
1156
        probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
1157
        self.check_probe_result(probe_result, config)
1158

1159
    @pytest.mark.parametrize("test_video,config", test_videos.items())
1160
    def test_probe_video_from_memory(self, test_video, config):
1161
        """
1162
        Test the case when decoder probes a video in memory
1163
        """
1164
        _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1165
        probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
1166
        self.check_probe_result(probe_result, config)
1167

1168
    @pytest.mark.parametrize("test_video,config", test_videos.items())
1169
    def test_probe_video_from_memory_script(self, test_video, config):
1170
        scripted_fun = torch.jit.script(io._probe_video_from_memory)
1171
        assert scripted_fun is not None
1172

1173
        _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1174
        probe_result = scripted_fun(video_tensor)
1175
        self.check_meta_result(probe_result, config)
1176

1177
    @pytest.mark.parametrize("test_video", test_videos.keys())
1178
    def test_read_video_from_memory_scripted(self, test_video):
1179
        """
1180
        Test the case when video is already in memory, and decoder reads data in memory
1181
        """
1182
        # video related
1183
        width, height, min_dimension, max_dimension = 0, 0, 0, 0
1184
        video_start_pts, video_end_pts = 0, -1
1185
        video_timebase_num, video_timebase_den = 0, 1
1186
        # audio related
1187
        samples, channels = 0, 0
1188
        audio_start_pts, audio_end_pts = 0, -1
1189
        audio_timebase_num, audio_timebase_den = 0, 1
1190

1191
        scripted_fun = torch.jit.script(io._read_video_from_memory)
1192
        assert scripted_fun is not None
1193

1194
        _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
1195

1196
        # decode all frames using cpp decoder
1197
        scripted_fun(
1198
            video_tensor,
1199
            SEEK_FRAME_MARGIN,
1200
            1,  # readVideoStream
1201
            width,
1202
            height,
1203
            min_dimension,
1204
            max_dimension,
1205
            [video_start_pts, video_end_pts],
1206
            video_timebase_num,
1207
            video_timebase_den,
1208
            1,  # readAudioStream
1209
            samples,
1210
            channels,
1211
            [audio_start_pts, audio_end_pts],
1212
            audio_timebase_num,
1213
            audio_timebase_den,
1214
        )
1215
        # FUTURE: check value of video / audio frames
1216

1217
    def test_invalid_file(self):
1218
        set_video_backend("video_reader")
1219
        with pytest.raises(RuntimeError):
1220
            io.read_video("foo.mp4")
1221

1222
        set_video_backend("pyav")
1223
        with pytest.raises(RuntimeError):
1224
            io.read_video("foo.mp4")
1225

1226
    @pytest.mark.parametrize("test_video", test_videos.keys())
1227
    @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
1228
    @pytest.mark.parametrize("start_offset", [0, 500])
1229
    @pytest.mark.parametrize("end_offset", [3000, None])
1230
    def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
1231
        """Test if audio frames are returned with pts unit."""
1232
        full_path = os.path.join(VIDEO_DIR, test_video)
1233
        container = av.open(full_path)
1234
        if container.streams.audio:
1235
            set_video_backend(backend)
1236
            _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts")
1237
            assert all([dimension > 0 for dimension in audio.shape[:2]])
1238

1239
    @pytest.mark.parametrize("test_video", test_videos.keys())
1240
    @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
1241
    @pytest.mark.parametrize("start_offset", [0, 0.1])
1242
    @pytest.mark.parametrize("end_offset", [0.3, None])
1243
    def test_audio_present_sec(self, test_video, backend, start_offset, end_offset):
1244
        """Test if audio frames are returned with sec unit."""
1245
        full_path = os.path.join(VIDEO_DIR, test_video)
1246
        container = av.open(full_path)
1247
        if container.streams.audio:
1248
            set_video_backend(backend)
1249
            _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec")
1250
            assert all([dimension > 0 for dimension in audio.shape[:2]])
1251

1252

1253
if __name__ == "__main__":
1254
    pytest.main([__file__])
1255

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.