快捷方式

如何采样视频片段

在此示例中,我们将学习如何从视频中采样视频片段。片段通常表示帧的序列或批量,并通常作为输入传递给视频模型。

首先是一些样板代码:我们将从网络下载一个视频,并定义一个绘图工具。您可以忽略这部分,直接跳到下面的创建解码器

from typing import Optional
import torch
import requests


# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
    raise RuntimeError(f"Failed to download video. {response.status_code = }.")

raw_video_bytes = response.content


def plot(frames: torch.Tensor, title : Optional[str] = None):
    try:
        from torchvision.utils import make_grid
        from torchvision.transforms.v2.functional import to_pil_image
        import matplotlib.pyplot as plt
    except ImportError:
        print("Cannot plot, please run `pip install torchvision matplotlib`")
        return

    plt.rcParams["savefig.bbox"] = 'tight'
    fig, ax = plt.subplots()
    ax.imshow(to_pil_image(make_grid(frames)))
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    if title is not None:
        ax.set_title(title)
    plt.tight_layout()

创建解码器

从视频中采样片段总是始于创建一个VideoDecoder 对象。如果您还不熟悉VideoDecoder,请快速查看:使用 VideoDecoder 解码视频

from torchcodec.decoders import VideoDecoder

# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)

采样基础知识

我们现在可以使用解码器来采样片段。我们先来看一个简单的示例:所有其他采样器都遵循类似的 API 和原则。我们将使用clips_at_random_indices() 来采样从随机索引开始的片段。

from torchcodec.samplers import clips_at_random_indices

# The samplers RNG is controlled by pytorch's RNG. We set a seed for this
# tutorial to be reproducible across runs, but note that hard-coding a seed for
# a training run is generally not recommended.
torch.manual_seed(0)

clips = clips_at_random_indices(
    decoder,
    num_clips=5,
    num_frames_per_clip=4,
    num_indices_between_frames=3,
)
clips
FrameBatch:
  data (shape): torch.Size([5, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [10.2000, 10.3200, 10.4400, 10.5600],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 9.6000,  9.7200,  9.8400,  9.9600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

采样器的输出是片段的序列,表示为一个FrameBatch 对象。在此对象中,我们有不同的字段:

  • data:一个表示帧数据的 5D uint8 张量。其形状为 (num_clips, num_frames_per_clip, …),其中 … 取决于VideoDecoderdimension_order 参数,可以是 (C, H, W) 或 (H, W, C)。这通常是传递给模型的输入。

  • pts_seconds:一个形状为 (num_clips, num_frames_per_clip) 的 2D float 张量,给出每个片段中每帧的起始时间戳(秒)。

  • duration_seconds:一个形状为 (num_clips, num_frames_per_clip) 的 2D float 张量,给出每个片段中每帧的持续时间(秒)。

plot(clips[0].data)
sampling

片段的索引和操作

片段是FrameBatch 对象,它们支持原生 pytorch 索引语义(包括高级索引)。这使得基于给定条件过滤片段变得容易。例如,从上面的片段中,我们可以轻松过滤出在特定时间戳后开始的片段:

tensor([11.3600, 10.2000,  9.8000,  9.6000,  8.4400], dtype=torch.float64)
clips_starting_after_five_seconds = clips[clip_starts > 5]
clips_starting_after_five_seconds
FrameBatch:
  data (shape): torch.Size([5, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [10.2000, 10.3200, 10.4400, 10.5600],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 9.6000,  9.7200,  9.8400,  9.9600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
every_other_clip = clips[::2]
every_other_clip
FrameBatch:
  data (shape): torch.Size([3, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

注意

在给定时间戳后获取片段更自然、更高效的方法是依赖于采样范围参数,我们将在后续的高级参数:采样范围中介绍。

基于索引和基于时间的采样器

到目前为止,我们使用了clips_at_random_indices()。Torchcodec 支持额外的采样器,它们主要分为两类:

基于索引的采样器

基于时间的采样器

所有这些采样器都遵循类似的 API,并且基于时间的采样器具有与基于索引的采样器类似的参数。这两种采样器类型通常在速度方面提供可比较的性能。

注意

是使用基于时间的采样器更好还是基于索引的采样器更好?基于索引的采样器可以说 API 稍微简单一些,其行为可能更容易理解和控制,因为索引具有离散性。对于具有恒定帧率 (fps) 的视频,基于索引的采样器与基于时间的采样器行为完全相同。然而,对于具有可变帧率的视频(这种情况经常发生),依赖索引可能会对视频中的某些区域进行欠采样或过采样,这在训练模型时可能导致不良的副作用。使用基于时间的采样器可确保沿时间维度的采样特性均匀。

高级参数:采样范围

有时,我们可能不想从整个视频中采样片段。我们可能只对在较小时间间隔内开始的片段感兴趣。在采样器中,sampling_range_startsampling_range_end 参数控制采样范围:它们定义了我们允许片段开始的位置。需要记住两点重要事项:

  • sampling_range_end 是一个开区间上界:片段只能在 [sampling_range_start, sampling_range_end) 范围内开始。

  • 由于这些参数定义了片段可以开始的位置,因此片段可能包含 sampling_range_end 之后帧!

from torchcodec.samplers import clips_at_regular_timestamps

clips = clips_at_regular_timestamps(
    decoder,
    seconds_between_clip_starts=1,
    num_frames_per_clip=4,
    seconds_between_frames=0.5,
    sampling_range_start=2,
    sampling_range_end=5
)
clips
FrameBatch:
  data (shape): torch.Size([3, 4, 3, 360, 640])
  pts_seconds: tensor([[2.0000, 2.4800, 3.0000, 3.4800],
        [3.0000, 3.4800, 4.0000, 4.4800],
        [4.0000, 4.4800, 5.0000, 5.4800]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

高级参数:策略

根据视频的长度或持续时间以及采样参数,采样器可能会尝试采样超出视频末尾的帧。`policy` 参数定义了如何用有效帧替换这些无效帧。

from torchcodec.samplers import clips_at_random_timestamps

end_of_video = decoder.metadata.end_stream_seconds
print(f"{end_of_video = }")
end_of_video = 13.8
torch.manual_seed(0)
clips = clips_at_random_timestamps(
    decoder,
    num_clips=1,
    num_frames_per_clip=5,
    seconds_between_frames=0.4,
    sampling_range_start=end_of_video - 1,
    sampling_range_end=end_of_video,
    policy="repeat_last",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.6800, 13.6800, 13.6800]], dtype=torch.float64)

我们在上面看到视频在 13.8 秒处结束。采样器尝试在时间戳 [13.28, 13.68, 14.08, …] 处采样帧,但 14.08 是一个无效的时间戳,超出了视频末尾。使用“repeat_last”策略(这是默认策略),采样器简单地重复 13.68 秒处的最后一帧来构建片段。

另一种策略是“wrap”:采样器会环绕片段,并根据需要重复前几个有效帧。

torch.manual_seed(0)
clips = clips_at_random_timestamps(
    decoder,
    num_clips=1,
    num_frames_per_clip=5,
    seconds_between_frames=0.4,
    sampling_range_start=end_of_video - 1,
    sampling_range_end=end_of_video,
    policy="wrap",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.2800, 13.6800, 13.2800]], dtype=torch.float64)

默认情况下,sampling_range_end 的值会自动设置,以确保采样器不会尝试采样超出视频末尾的帧:默认值确保片段在视频结束前足够早开始。这意味着默认情况下,`policy` 参数很少起作用,大多数用户可能不需要过多担心它。

脚本总运行时间: (0 分 0.624 秒)

画廊由 Sphinx-Gallery 生成


© 版权所有 2023-至今, TorchCodec 贡献者。

使用 Sphinx 并由 Read the Docs 提供的主题构建。

文档

访问 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

寻找开发资源并获得问题解答

查看资源