视频 API¶
此示例说明了 torchvision 为视频提供的部分 API,以及有关如何构建数据集等的示例。
1. 简介:构建新的视频对象并检查属性¶
首先,我们选择一个视频来测试该对象。为了论证,我们使用 kinetics400 数据集中的一个视频。为了创建它,我们需要定义路径和要使用的流。
所选视频的统计信息
- WUzgd7C1pWA.mp4
- 来源
kinetics-400
- 视频
H-264
MPEG-4 AVC (第 10 部分) (avc1)
fps:29.97
- 音频
MPEG AAC 音频 (mp4a)
采样率:48K Hz
import torch
import torchvision
from torchvision.datasets.utils import download_url
torchvision.set_video_backend("video_reader")
# Download the sample video
download_url(
"https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
".",
"WUzgd7C1pWA.mp4"
)
video_path = "./WUzgd7C1pWA.mp4"
Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/WUzgd7C1pWA.mp4 to ./WUzgd7C1pWA.mp4
3.7%
7.4%
11.1%
14.7%
18.4%
22.1%
25.8%
29.5%
33.2%
36.8%
40.5%
44.2%
47.9%
51.6%
55.3%
58.9%
62.6%
66.3%
70.0%
73.7%
77.4%
81.0%
84.7%
88.4%
92.1%
95.8%
99.5%
100.0%
流的定义方式类似于 torch 设备。我们将它们编码为字符串,形式为 stream_type:stream_id
,其中 stream_type
是一个字符串,而 stream_id
是一个长整型。构造函数接受仅传递 stream_type
,在这种情况下,流会自动发现。首先,让我们获取特定视频的元数据
stream = "video"
video = torchvision.io.VideoReader(video_path, stream)
video.get_metadata()
{'video': {'duration': [10.9109], 'fps': [29.97002997002997]}, 'audio': {'duration': [10.9], 'framerate': [48000.0]}, 'subtitles': {'duration': []}, 'cc': {'duration': []}}
这里我们可以看到视频有两个流 - 视频流和音频流。当前可用的流类型包括 [‘video’, ‘audio’]。每个描述符包含两部分:流类型(例如 ‘video’)和唯一的流 ID(由视频编码确定)。这样,如果视频容器包含多种相同类型的流,用户可以访问他们想要的流。如果只传递了流类型,解码器会自动检测到该类型的第一个流并将其返回。
让我们从视频流中读取所有帧。默认情况下,next(video_reader)
的返回值是一个包含以下字段的字典。
返回值字段为
data
:包含一个 torch.tensorpts
:包含此特定帧的浮点时间戳
metadata = video.get_metadata()
video.set_current_stream("audio")
frames = [] # we are going to save the frames here.
ptss = [] # pts is a presentation timestamp in seconds (float) of each frame
for frame in video:
frames.append(frame['data'])
ptss.append(frame['pts'])
print("PTS for first five frames ", ptss[:5])
print("Total number of frames: ", len(frames))
approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]
print("Approx total number of datapoints we can expect: ", approx_nf)
print("Read data size: ", frames[0].size(0) * len(frames))
PTS for first five frames [0.0, 0.021332999999999998, 0.042667, 0.064, 0.08533299999999999]
Total number of frames: 511
Approx total number of datapoints we can expect: 523200.0
Read data size: 523264
但如果我们只想读取视频的特定时间段呢?这可以通过组合我们的 seek
函数以及每次调用 next 时以秒为单位返回的返回帧的表示时间戳来轻松完成。
鉴于我们的实现依赖于 python 迭代器,我们可以利用 itertools 来简化流程并使其更具 python 风格。
例如,如果我们想从第 2 秒读取 10 帧
import itertools
video.set_current_stream("video")
frames = [] # we are going to save the frames here.
# We seek into a second second of the video and use islice to get 10 frames since
for frame, pts in itertools.islice(video.seek(2), 10):
frames.append(frame)
print("Total number of frames: ", len(frames))
Total number of frames: 10
或者如果我们想从第 2 秒到第 5 秒读取,我们会定位到视频的第 2 秒,然后利用 itertools takewhile 获取正确数量的帧
video.set_current_stream("video")
frames = [] # we are going to save the frames here.
video = video.seek(2)
for frame in itertools.takewhile(lambda x: x['pts'] <= 5, video):
frames.append(frame['data'])
print("Total number of frames: ", len(frames))
approx_nf = (5 - 2) * video.get_metadata()['video']['fps'][0]
print("We can expect approx: ", approx_nf)
print("Tensor size: ", frames[0].size())
Total number of frames: 90
We can expect approx: 89.91008991008991
Tensor size: torch.Size([3, 256, 340])
2. 构建示例 read_video 函数¶
我们可以利用以上方法来构建 read_video 函数,该函数遵循与现有 read_video
函数相同的 API。
def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True):
if end is None:
end = float("inf")
if end < start:
raise ValueError(
"end time should be larger than start time, got "
f"start time={start} and end time={end}"
)
video_frames = torch.empty(0)
video_pts = []
if read_video:
video_object.set_current_stream("video")
frames = []
for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
frames.append(frame['data'])
video_pts.append(frame['pts'])
if len(frames) > 0:
video_frames = torch.stack(frames, 0)
audio_frames = torch.empty(0)
audio_pts = []
if read_audio:
video_object.set_current_stream("audio")
frames = []
for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
frames.append(frame['data'])
audio_pts.append(frame['pts'])
if len(frames) > 0:
audio_frames = torch.cat(frames, 0)
return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()
# Total number of frames should be 327 for video and 523264 datapoints for audio
vf, af, info, meta = example_read_video(video)
print(vf.size(), af.size())
torch.Size([327, 3, 256, 340]) torch.Size([523264, 1])
3. 构建示例随机采样数据集(可应用于 kinetics400 的训练数据集)¶
很酷,所以现在我们可以使用相同的原则来制作示例数据集。我们建议为此目的尝试使用可迭代数据集。这里,我们将构建一个示例数据集,该数据集读取随机选择的 10 帧视频。
制作示例数据集
import os
os.makedirs("./dataset", exist_ok=True)
os.makedirs("./dataset/1", exist_ok=True)
os.makedirs("./dataset/2", exist_ok=True)
下载视频
from torchvision.datasets.utils import download_url
download_url(
"https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
"./dataset/1", "WUzgd7C1pWA.mp4"
)
download_url(
"https://github.com/pytorch/vision/blob/main/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true",
"./dataset/1",
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
)
download_url(
"https://github.com/pytorch/vision/blob/main/test/assets/videos/SOX5yA1l24A.mp4?raw=true",
"./dataset/2",
"SOX5yA1l24A.mp4"
)
download_url(
"https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true",
"./dataset/2",
"v_SoccerJuggling_g23_c01.avi"
)
download_url(
"https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true",
"./dataset/2",
"v_SoccerJuggling_g24_c01.avi"
)
Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/WUzgd7C1pWA.mp4 to ./dataset/1/WUzgd7C1pWA.mp4
3.7%
7.4%
11.1%
14.7%
18.4%
22.1%
25.8%
29.5%
33.2%
36.8%
40.5%
44.2%
47.9%
51.6%
55.3%
58.9%
62.6%
66.3%
70.0%
73.7%
77.4%
81.0%
84.7%
88.4%
92.1%
95.8%
99.5%
100.0%
Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi to ./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi
12.4%
24.9%
37.3%
49.7%
62.1%
74.6%
87.0%
99.4%
100.0%
Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/SOX5yA1l24A.mp4 to ./dataset/2/SOX5yA1l24A.mp4
5.8%
11.7%
17.5%
23.4%
29.2%
35.1%
40.9%
46.8%
52.6%
58.5%
64.3%
70.2%
76.0%
81.9%
87.7%
93.6%
99.4%
100.0%
Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/v_SoccerJuggling_g23_c01.avi to ./dataset/2/v_SoccerJuggling_g23_c01.avi
6.4%
12.9%
19.3%
25.8%
32.2%
38.7%
45.1%
51.6%
58.0%
64.5%
70.9%
77.3%
83.8%
90.2%
96.7%
100.0%
Downloading https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/test/assets/videos/v_SoccerJuggling_g24_c01.avi to ./dataset/2/v_SoccerJuggling_g24_c01.avi
5.3%
10.5%
15.8%
21.0%
26.3%
31.6%
36.8%
42.1%
47.3%
52.6%
57.9%
63.1%
68.4%
73.6%
78.9%
84.2%
89.4%
94.7%
99.9%
100.0%
家政和实用程序
import os
import random
from torchvision.datasets.folder import make_dataset
from torchvision import transforms as t
def _find_classes(dir):
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def get_samples(root, extensions=(".mp4", ".avi")):
_, class_to_idx = _find_classes(root)
return make_dataset(root, class_to_idx, extensions=extensions)
我们将定义数据集和一些基本参数。我们假设 FolderDataset 的结构,并添加以下参数
clip_len
:以帧为单位的剪辑长度frame_transform
:对每帧单独进行的转换video_transform
:对视频序列进行的转换
注意
实际上,我们将时期大小添加为使用 IterableDataset()
类允许我们根据需要自然地对每个视频的剪辑或图像进行过采样。
class RandomDataset(torch.utils.data.IterableDataset):
def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
super(RandomDataset).__init__()
self.samples = get_samples(root)
# Allow for temporal jittering
if epoch_size is None:
epoch_size = len(self.samples)
self.epoch_size = epoch_size
self.clip_len = clip_len
self.frame_transform = frame_transform
self.video_transform = video_transform
def __iter__(self):
for i in range(self.epoch_size):
# Get random sample
path, target = random.choice(self.samples)
# Get video object
vid = torchvision.io.VideoReader(path, "video")
metadata = vid.get_metadata()
video_frames = [] # video frame buffer
# Seek and return frames
max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
start = random.uniform(0., max_seek)
for frame in itertools.islice(vid.seek(start), self.clip_len):
video_frames.append(self.frame_transform(frame['data']))
current_pts = frame['pts']
# Stack it into a tensor
video = torch.stack(video_frames, 0)
if self.video_transform:
video = self.video_transform(video)
output = {
'path': path,
'video': video,
'target': target,
'start': start,
'end': current_pts}
yield output
给定文件夹结构中视频的路径,即
- 数据集
- 类别 1
文件 0
文件 1
…
- 类别 2
文件 0
文件 1
…
…
我们可以生成一个 dataloader 并测试数据集。
transforms = [t.Resize((112, 112))]
frame_transform = t.Compose(transforms)
dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform)
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=12)
data = {"video": [], 'start': [], 'end': [], 'tensorsize': []}
for batch in loader:
for i in range(len(batch['path'])):
data['video'].append(batch['path'][i])
data['start'].append(batch['start'][i].item())
data['end'].append(batch['end'][i].item())
data['tensorsize'].append(batch['video'][i].size())
print(data)
{'video': ['./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi', './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi', './dataset/1/WUzgd7C1pWA.mp4', './dataset/2/SOX5yA1l24A.mp4', './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi'], 'start': [1.203051270746008, 0.5760754748483161, 10.23800984898201, 9.709010060342672, 1.256045985643026], 'end': [1.733333, 1.0999999999999999, 10.744067, 10.2102, 1.766667], 'tensorsize': [torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112])]}
4. 数据可视化¶
可视化视频的示例
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 12))
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0))
plt.axis("off")
清理视频和数据集
import os
import shutil
os.remove("./WUzgd7C1pWA.mp4")
shutil.rmtree("./dataset")
脚本的总运行时间:(0 分钟 4.939 秒)