• 文档 >
  • StreamWriter 基本用法 >
  • 旧版本(稳定版)
快捷方式

StreamWriter 基本用法

作者: Moto Hira

本教程介绍如何使用 torchaudio.io.StreamWriter 将音频/视频数据编码并保存到各种格式/目的地。

注意

本教程需要 FFmpeg 库。有关详细信息,请参阅 FFmpeg 依赖项

警告

TorchAudio 会动态加载系统上安装的兼容 FFmpeg 库。支持的格式类型(媒体格式、编码器、编码器选项等)取决于这些库。

要检查可用的复用器和编码器,可以使用以下命令

ffmpeg -muxers
ffmpeg -encoders

准备

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

from torchaudio.io import StreamWriter

print("FFmpeg library versions")
for k, v in torchaudio.utils.ffmpeg_utils.get_versions().items():
    print(f"  {k}: {v}")
2.5.0
2.5.0
FFmpeg library versions
  libavcodec: (60, 3, 100)
  libavdevice: (60, 1, 100)
  libavfilter: (9, 3, 100)
  libavformat: (60, 3, 100)
  libavutil: (58, 2, 100)
import io
import os
import tempfile

from IPython.display import Audio, Video

from torchaudio.utils import download_asset

SAMPLE_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
WAVEFORM, SAMPLE_RATE = torchaudio.load(SAMPLE_PATH, channels_first=False)
NUM_FRAMES, NUM_CHANNELS = WAVEFORM.shape

_BASE_DIR = tempfile.TemporaryDirectory()


def get_path(filename):
    return os.path.join(_BASE_DIR.name, filename)

基本用法

要使用 StreamWriter 将 Tensor 数据保存到媒体格式,需要三个步骤

  1. 指定输出

  2. 配置流

  3. 写入数据

以下代码演示了如何将音频数据保存为 WAV 文件。

# 1. Define the destination. (local file in this case)
path = get_path("test.wav")
s = StreamWriter(path)
# 2. Configure the stream. (8kHz, Stereo WAV)
s.add_audio_stream(
    sample_rate=SAMPLE_RATE,
    num_channels=NUM_CHANNELS,
)
# 3. Write the data
with s.open():
    s.write_audio_chunk(0, WAVEFORM)
Audio(path)


现在,我们更详细地了解每个步骤。

写入目的地

StreamWriter 支持不同类型的写入目的地

  1. 本地文件

  2. 类文件对象

  3. 流协议(例如 RTMP 和 UDP)

  4. 媒体设备(扬声器和视频播放器)†

† 对于媒体设备,请参阅 StreamWriter 高级用法

本地文件

StreamWriter 支持将媒体保存到本地文件。

StreamWriter(dst="audio.wav")

StreamWriter(dst="audio.mp3")

这也适用于静止图像和视频。

StreamWriter(dst="image.jpeg")

StreamWriter(dst="video.mpeg")

类文件对象

您也可以传递一个类文件对象。类文件对象必须实现符合 io.RawIOBase.writewrite 方法。

# Open the local file as fileobj
with open("audio.wav", "wb") as dst:
    StreamWriter(dst=dst)
# In-memory encoding
buffer = io.BytesIO()
StreamWriter(dst=buffer)

流协议

您可以使用流协议来流式传输媒体

# Real-Time Messaging Protocol
StreamWriter(dst="rtmp://localhost:1234/live/app", format="flv")

# UDP
StreamWriter(dst="udp://localhost:48550", format="mpegts")

配置输出流

指定目的地后,下一步是配置流。对于典型的音频和静止图像情况,只需要一个流,但对于带有音频的视频,至少需要配置两个流(一个用于音频,另一个用于视频)。

音频流

可以使用 add_audio_stream() 方法添加音频流。

对于写入常规音频文件,至少需要 sample_ratenum_channels

s = StreamWriter("audio.wav")
s.add_audio_stream(sample_rate=8000, num_channels=2)

默认情况下,音频流期望输入波形张量为 torch.float32 类型。在这种情况下,数据将被编码为 WAV 格式的默认编码格式,即 16 位有符号整数线性 PCM。StreamWriter 会在内部转换采样格式。

如果编码器支持多种采样格式,并且您想更改编码器采样格式,可以使用 encoder_format 选项。

在以下示例中,StreamWriter 期望输入波形张量的数据类型为 torch.float32,但它会在编码时将其转换为 16 位有符号整数。

s = StreamWriter("audio.mp3")
s.add_audio_stream(
    ...,
    encoder="libmp3lame",   # "libmp3lame" is often the default encoder for mp3,
                            # but specifying it manually, for the sake of illustration.

    encoder_format="s16p",  # "libmp3lame" encoder supports the following sample format.
                            #  - "s16p" (16-bit signed integer)
                            #  - "s32p" (32-bit signed integer)
                            #  - "fltp" (32-bit floating point)
)

如果您的波形张量的数据类型不是 torch.float32,您可以提供 format 选项来更改期望的数据类型。

以下示例配置 StreamWriter 以期望 torch.int16 类型的张量。

# Audio data passed to StreamWriter must be torch.int16
s.add_audio_stream(..., format="s16")

下图说明了 formatencoder_format 选项在音频流中是如何工作的。

https://download.pytorch.org/torchaudio/tutorial-assets/streamwriter-format-audio.png

视频流

要添加静止图像或视频流,可以使用 add_video_stream() 方法。

至少需要 frame_rateheightwidth

s = StreamWriter("video.mp4")
s.add_video_stream(frame_rate=10, height=96, width=128)

对于静止图像,请使用 frame_rate=1

s = StreamWriter("image.png")
s.add_video_stream(frame_rate=1, ...)

与音频流类似,您可以提供 formatencoder_format 选项来控制输入数据和编码的格式。

以下示例以 YUV422 格式编码视频数据。

s = StreamWriter("video.mov")
s.add_video_stream(
    ...,
    encoder="libx264",  # libx264 supports different YUV formats, such as
                        # yuv420p yuvj420p yuv422p yuvj422p yuv444p yuvj444p nv12 nv16 nv21

    encoder_format="yuv422p",  # StreamWriter will convert the input data to YUV422 internally
)

YUV 格式通常用于视频编码。许多 YUV 格式由色度通道组成,色度通道的平面大小与亮度通道的平面大小不同。这使得难以直接用 torch.Tensor 类型来表达。因此,StreamWriter 会自动将输入视频张量转换为目标格式。

StreamWriter 期望输入图像张量为 4 维 (time, channel, height, width),并且数据类型为 torch.uint8

默认的颜色通道是 RGB。也就是说,三个颜色通道分别对应红色、绿色和蓝色。如果您的输入具有不同的颜色通道,例如 BGR 和 YUV,您可以使用 format 选项来指定。

以下示例指定了 BGR 格式。

s.add_video_stream(..., format="bgr24")
                   # Image data passed to StreamWriter must have
                   # three color channels representing Blue Green Red.
                   #
                   # The shape of the input tensor has to be
                   # (time, channel==3, height, width)

下图说明了 formatencoder_format 选项在视频流中是如何工作的。

https://download.pytorch.org/torchaudio/tutorial-assets/streamwriter-format-video.png

写入数据

配置完流后,下一步是打开输出位置并开始写入数据。

使用 open() 方法打开目的地,然后使用 write_audio_chunk() 和/或 write_video_chunk() 写入数据。

音频张量应具有 (time, channels) 的形状,视频/图像张量应具有 (time, channels, height, width) 的形状。

通道、高度和宽度必须与使用 "format" 选项指定的相应流的配置匹配。

表示静止图像的张量在时间维度上只能有一帧,但音频和视频张量在时间维度上可以有任意数量的帧。

以下代码片段说明了这一点;

例如) 音频

# Configure stream
s = StreamWriter(dst=get_path("audio.wav"))
s.add_audio_stream(sample_rate=SAMPLE_RATE, num_channels=NUM_CHANNELS)

# Write data
with s.open():
    s.write_audio_chunk(0, WAVEFORM)

例如) 图像

# Image config
height = 96
width = 128

# Configure stream
s = StreamWriter(dst=get_path("image.png"))
s.add_video_stream(frame_rate=1, height=height, width=width, format="rgb24")

# Generate image
chunk = torch.randint(256, (1, 3, height, width), dtype=torch.uint8)

# Write data
with s.open():
    s.write_video_chunk(0, chunk)

例如) 无音频的视频

# Video config
frame_rate = 30
height = 96
width = 128

# Configure stream
s = StreamWriter(dst=get_path("video.mp4"))
s.add_video_stream(frame_rate=frame_rate, height=height, width=width, format="rgb24")

# Generate video chunk (3 seconds)
time = int(frame_rate * 3)
chunk = torch.randint(256, (time, 3, height, width), dtype=torch.uint8)

# Write data
with s.open():
    s.write_video_chunk(0, chunk)

例如) 带音频的视频

要写入带有音频的视频,必须配置单独的流。

# Configure stream
s = StreamWriter(dst=get_path("video.mp4"))
s.add_audio_stream(sample_rate=SAMPLE_RATE, num_channels=NUM_CHANNELS)
s.add_video_stream(frame_rate=frame_rate, height=height, width=width, format="rgb24")

# Generate audio/video chunk (3 seconds)
time = int(SAMPLE_RATE * 3)
audio_chunk = torch.randn((time, NUM_CHANNELS))
time = int(frame_rate * 3)
video_chunk = torch.randint(256, (time, 3, height, width), dtype=torch.uint8)

# Write data
with s.open():
    s.write_audio_chunk(0, audio_chunk)
    s.write_video_chunk(1, video_chunk)

逐块写入数据

写入数据时,可以沿时间维度拆分数据并以较小的块写入。

# Write data in one-go
dst1 = io.BytesIO()
s = StreamWriter(dst=dst1, format="mp3")
s.add_audio_stream(SAMPLE_RATE, NUM_CHANNELS)
with s.open():
    s.write_audio_chunk(0, WAVEFORM)
# Write data in smaller chunks
dst2 = io.BytesIO()
s = StreamWriter(dst=dst2, format="mp3")
s.add_audio_stream(SAMPLE_RATE, NUM_CHANNELS)
with s.open():
    for start in range(0, NUM_FRAMES, SAMPLE_RATE):
        end = start + SAMPLE_RATE
        s.write_audio_chunk(0, WAVEFORM[start:end, ...])
# Check that the contents are same
dst1.seek(0)
bytes1 = dst1.read()

print(f"bytes1: {len(bytes1)}")
print(f"{bytes1[:10]}...{bytes1[-10:]}\n")

dst2.seek(0)
bytes2 = dst2.read()

print(f"bytes2: {len(bytes2)}")
print(f"{bytes2[:10]}...{bytes2[-10:]}\n")

assert bytes1 == bytes2

import matplotlib.pyplot as plt
bytes1: 10700
b'ID3\x04\x00\x00\x00\x00\x00"'...b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'

bytes2: 10700
b'ID3\x04\x00\x00\x00\x00\x00"'...b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'

示例 - 频谱可视化器

在本节中,我们使用 StreamWriter 创建音频的频谱可视化并将其保存为视频文件。

要创建频谱可视化,我们使用 torchaudio.transforms.Spectrogram 获取音频的频谱表示,使用 matplotplib 生成其可视化的光栅图像,然后使用 StreamWriter 将它们转换为带有原始音频的视频。

import torchaudio.transforms as T

准备数据

首先,我们准备频谱数据。我们使用 Spectrogram

我们调整 hop_length 使得频谱的一帧对应于一帧视频。

frame_rate = 20
n_fft = 4000

trans = T.Spectrogram(
    n_fft=n_fft,
    hop_length=SAMPLE_RATE // frame_rate,  # One FFT per one video frame
    normalized=True,
    power=1,
)
specs = trans(WAVEFORM.T)[0].T

生成的频谱如下所示。

spec_db = T.AmplitudeToDB(stype="magnitude", top_db=80)(specs.T)
_ = plt.imshow(spec_db, aspect="auto", origin="lower")
streamwriter basic tutorial

准备画布

我们使用 matplotlib 可视化每帧的频谱。我们创建了一个辅助函数,它绘制频谱数据并生成图形的光栅图像。

fig, ax = plt.subplots(figsize=[3.2, 2.4])
ax.set_position([0, 0, 1, 1])
ax.set_facecolor("black")
ncols, nrows = fig.canvas.get_width_height()


def _plot(data):
    ax.clear()
    x = list(range(len(data)))
    R, G, B = 238 / 255, 76 / 255, 44 / 255
    for coeff, alpha in [(0.8, 0.7), (1, 1)]:
        d = data**coeff
        ax.fill_between(x, d, -d, color=[R, G, B, alpha])
    xlim = n_fft // 2 + 1
    ax.set_xlim([-1, n_fft // 2 + 1])
    ax.set_ylim([-1, 1])
    ax.text(
        xlim,
        0.95,
        f"Created with TorchAudio\n{torchaudio.__version__}",
        color="white",
        ha="right",
        va="top",
        backgroundcolor="black",
    )
    fig.canvas.draw()
    frame = torch.frombuffer(fig.canvas.tostring_rgb(), dtype=torch.uint8)
    return frame.reshape(nrows, ncols, 3).permute(2, 0, 1)


# sphinx_gallery_defer_figures

写入视频

最后,我们使用 StreamWriter 并写入视频。我们一次处理一秒钟的音频和视频帧。

s = StreamWriter(get_path("example.mp4"))
s.add_audio_stream(sample_rate=SAMPLE_RATE, num_channels=NUM_CHANNELS)
s.add_video_stream(frame_rate=frame_rate, height=nrows, width=ncols)

with s.open():
    i = 0
    # Process by second
    for t in range(0, NUM_FRAMES, SAMPLE_RATE):
        # Write audio chunk
        s.write_audio_chunk(0, WAVEFORM[t : t + SAMPLE_RATE, :])

        # write 1 second of video chunk
        frames = [_plot(spec) for spec in specs[i : i + frame_rate]]
        if frames:
            s.write_video_chunk(1, torch.stack(frames))
        i += frame_rate

plt.close(fig)
/pytorch/audio/examples/tutorials/streamwriter_basic_tutorial.py:566: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.
  frame = torch.frombuffer(fig.canvas.tostring_rgb(), dtype=torch.uint8)
/pytorch/audio/examples/tutorials/streamwriter_basic_tutorial.py:566: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /opt/conda/conda-bld/pytorch_1727971112454/work/torch/csrc/utils/tensor_new.cpp:1560.)
  frame = torch.frombuffer(fig.canvas.tostring_rgb(), dtype=torch.uint8)

结果

结果如下所示。

Video(get_path("example.mp4"), embed=True)


仔细观看视频,可以观察到“s”的声音 (curiosity, besides, this) 在更高频率侧(视频右侧)分配了更多能量。

标签: torchaudio.io

脚本的总运行时间: ( 0 分钟 7.371 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源