如何采样视频片段¶
在此示例中,我们将学习如何从视频中采样视频片段。片段通常表示帧的序列或批量,并通常作为输入传递给视频模型。
首先是一些样板代码:我们将从网络下载一个视频,并定义一个绘图工具。您可以忽略这部分,直接跳到下面的创建解码器。
from typing import Optional
import torch
import requests
# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
raw_video_bytes = response.content
def plot(frames: torch.Tensor, title : Optional[str] = None):
try:
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
import matplotlib.pyplot as plt
except ImportError:
print("Cannot plot, please run `pip install torchvision matplotlib`")
return
plt.rcParams["savefig.bbox"] = 'tight'
fig, ax = plt.subplots()
ax.imshow(to_pil_image(make_grid(frames)))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if title is not None:
ax.set_title(title)
plt.tight_layout()
创建解码器¶
从视频中采样片段总是始于创建一个VideoDecoder
对象。如果您还不熟悉VideoDecoder
,请快速查看:使用 VideoDecoder 解码视频。
from torchcodec.decoders import VideoDecoder
# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)
采样基础知识¶
我们现在可以使用解码器来采样片段。我们先来看一个简单的示例:所有其他采样器都遵循类似的 API 和原则。我们将使用clips_at_random_indices()
来采样从随机索引开始的片段。
from torchcodec.samplers import clips_at_random_indices
# The samplers RNG is controlled by pytorch's RNG. We set a seed for this
# tutorial to be reproducible across runs, but note that hard-coding a seed for
# a training run is generally not recommended.
torch.manual_seed(0)
clips = clips_at_random_indices(
decoder,
num_clips=5,
num_frames_per_clip=4,
num_indices_between_frames=3,
)
clips
FrameBatch:
data (shape): torch.Size([5, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[10.2000, 10.3200, 10.4400, 10.5600],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 9.6000, 9.7200, 9.8400, 9.9600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
采样器的输出是片段的序列,表示为一个FrameBatch
对象。在此对象中,我们有不同的字段:
data
:一个表示帧数据的 5D uint8 张量。其形状为 (num_clips, num_frames_per_clip, …),其中 … 取决于VideoDecoder
的dimension_order
参数,可以是 (C, H, W) 或 (H, W, C)。这通常是传递给模型的输入。pts_seconds
:一个形状为 (num_clips, num_frames_per_clip) 的 2D float 张量,给出每个片段中每帧的起始时间戳(秒)。duration_seconds
:一个形状为 (num_clips, num_frames_per_clip) 的 2D float 张量,给出每个片段中每帧的持续时间(秒)。
plot(clips[0].data)

片段的索引和操作¶
片段是FrameBatch
对象,它们支持原生 pytorch 索引语义(包括高级索引)。这使得基于给定条件过滤片段变得容易。例如,从上面的片段中,我们可以轻松过滤出在特定时间戳后开始的片段:
clip_starts = clips.pts_seconds[:, 0]
clip_starts
tensor([11.3600, 10.2000, 9.8000, 9.6000, 8.4400], dtype=torch.float64)
clips_starting_after_five_seconds = clips[clip_starts > 5]
clips_starting_after_five_seconds
FrameBatch:
data (shape): torch.Size([5, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[10.2000, 10.3200, 10.4400, 10.5600],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 9.6000, 9.7200, 9.8400, 9.9600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
every_other_clip = clips[::2]
every_other_clip
FrameBatch:
data (shape): torch.Size([3, 4, 3, 360, 640])
pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
[ 9.8000, 9.9200, 10.0400, 10.1600],
[ 8.4400, 8.5600, 8.6800, 8.8000]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
注意
在给定时间戳后获取片段更自然、更高效的方法是依赖于采样范围参数,我们将在后续的高级参数:采样范围中介绍。
基于索引和基于时间的采样器¶
到目前为止,我们使用了clips_at_random_indices()
。Torchcodec 支持额外的采样器,它们主要分为两类:
基于索引的采样器
基于时间的采样器
所有这些采样器都遵循类似的 API,并且基于时间的采样器具有与基于索引的采样器类似的参数。这两种采样器类型通常在速度方面提供可比较的性能。
注意
是使用基于时间的采样器更好还是基于索引的采样器更好?基于索引的采样器可以说 API 稍微简单一些,其行为可能更容易理解和控制,因为索引具有离散性。对于具有恒定帧率 (fps) 的视频,基于索引的采样器与基于时间的采样器行为完全相同。然而,对于具有可变帧率的视频(这种情况经常发生),依赖索引可能会对视频中的某些区域进行欠采样或过采样,这在训练模型时可能导致不良的副作用。使用基于时间的采样器可确保沿时间维度的采样特性均匀。
高级参数:采样范围¶
有时,我们可能不想从整个视频中采样片段。我们可能只对在较小时间间隔内开始的片段感兴趣。在采样器中,sampling_range_start
和 sampling_range_end
参数控制采样范围:它们定义了我们允许片段开始的位置。需要记住两点重要事项:
sampling_range_end
是一个开区间上界:片段只能在 [sampling_range_start, sampling_range_end) 范围内开始。由于这些参数定义了片段可以开始的位置,因此片段可能包含
sampling_range_end
之后帧!
from torchcodec.samplers import clips_at_regular_timestamps
clips = clips_at_regular_timestamps(
decoder,
seconds_between_clip_starts=1,
num_frames_per_clip=4,
seconds_between_frames=0.5,
sampling_range_start=2,
sampling_range_end=5
)
clips
FrameBatch:
data (shape): torch.Size([3, 4, 3, 360, 640])
pts_seconds: tensor([[2.0000, 2.4800, 3.0000, 3.4800],
[3.0000, 3.4800, 4.0000, 4.4800],
[4.0000, 4.4800, 5.0000, 5.4800]], dtype=torch.float64)
duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400],
[0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
高级参数:策略¶
根据视频的长度或持续时间以及采样参数,采样器可能会尝试采样超出视频末尾的帧。`policy` 参数定义了如何用有效帧替换这些无效帧。
from torchcodec.samplers import clips_at_random_timestamps
end_of_video = decoder.metadata.end_stream_seconds
print(f"{end_of_video = }")
end_of_video = 13.8
torch.manual_seed(0)
clips = clips_at_random_timestamps(
decoder,
num_clips=1,
num_frames_per_clip=5,
seconds_between_frames=0.4,
sampling_range_start=end_of_video - 1,
sampling_range_end=end_of_video,
policy="repeat_last",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.6800, 13.6800, 13.6800]], dtype=torch.float64)
我们在上面看到视频在 13.8 秒处结束。采样器尝试在时间戳 [13.28, 13.68, 14.08, …] 处采样帧,但 14.08 是一个无效的时间戳,超出了视频末尾。使用“repeat_last”策略(这是默认策略),采样器简单地重复 13.68 秒处的最后一帧来构建片段。
另一种策略是“wrap”:采样器会环绕片段,并根据需要重复前几个有效帧。
torch.manual_seed(0)
clips = clips_at_random_timestamps(
decoder,
num_clips=1,
num_frames_per_clip=5,
seconds_between_frames=0.4,
sampling_range_start=end_of_video - 1,
sampling_range_end=end_of_video,
policy="wrap",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.2800, 13.6800, 13.2800]], dtype=torch.float64)
默认情况下,sampling_range_end
的值会自动设置,以确保采样器不会尝试采样超出视频末尾的帧:默认值确保片段在视频结束前足够早开始。这意味着默认情况下,`policy` 参数很少起作用,大多数用户可能不需要过多担心它。
脚本总运行时间: (0 分 0.624 秒)