说明文档
Distil-Whisper: distil-large-v3
Distil-Whisper 模型的提出源于论文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling。
这是 Distil-Whisper 英语系列的第三版也是最后一版。它是基于 OpenAI 最新的 Whisper large-v3(目前性能最强的 Whisper 模型)进行知识蒸馏得到的版本。
与之前的 Distil-Whisper 模型相比,distil-large-v3 的蒸馏过程经过调整,在使用 OpenAI 的顺序长音频算法时能够提供更出色的长音频转录准确率。
结果是一个蒸馏模型,在使用顺序算法和分块算法处理长音频时,其 WER 表现与 large-v3 相差在 1% 以内,而使用顺序算法时比 distil-large-v2 提升了 4.8%。该模型也比之前的 Distil-Whisper 模型更快:比 large-v3 快 6.3 倍,比 distil-large-v2 快 1.1 倍。
| 模型 | 参数 / M | 相对延迟 | 短音频 | 顺序长音频 | 分块长音频 |
|---|---|---|---|---|---|
| large-v3 | 1550 | 1.0 | 8.4 | 10.0 | 11.0 |
| distil-large-v3 | 756 | 6.3 | 9.7 | 10.8 | 10.9 |
| distil-large-v2 | 756 | 5.8 | 10.1 | 15.6 | 11.6 |
由于顺序算法是当前最流行的 Whisper 库(Whisper cpp、Faster-Whisper、OpenAI Whisper)中使用的事实标准转录算法,因此这个蒸馏模型设计与这些库兼容。当您从之前的 Distil-Whisper 检查点切换到 distil-large-v3 时,使用这些库可以获得显著的性能提升。为方便起见,最流行库的权重已经转换完成,下面的说明将帮助您快速上手。
目录
Transformers 使用方法
distil-large-v3 支持 Hugging Face 🤗 Transformers 4.39 及以上版本。要运行该模型,首先请安装最新版本的 Transformers。在本示例中,我们还将安装 🤗 Datasets 以从 Hugging Face Hub 加载一个测试音频数据集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
短音频转录
该模型可以与 pipeline 类结合使用,来转录短音频文件(小于 30 秒),如下所示:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
要转录本地音频文件,只需在调用 pipeline 时传入音频文件的路径即可:
- result = pipe(sample)
+ result = pipe("audio.mp3")
要获取片段级别的时间戳,请传入参数 return_timestamps=True 并返回 "chunks" 输出:
result = pipe(sample, return_timestamps=True)
print(result["chunks"])
<details>
<summary> 如需对生成参数进行更多控制,请直接使用 model + processor API:</summary>
可以将临时生成参数传递给 model.generate,包括用于束搜索的 num_beams、用于片段级时间戳的 return_timestamps,以及用于提示的 prompt_ids。更多详情请参阅 文档字符串。
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import Audio, load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
sample = dataset[0]["audio"]
input_features = processor(
sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
).input_features
input_features = input_features.to(device, dtype=torch_dtype)
gen_kwargs = {
"max_new_tokens": 128,
"num_beams": 1,
"return_timestamps": False,
}
pred_ids = model.generate(input_features, **gen_kwargs)
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True, decode_with_timestamps=gen_kwargs["return_timestamps"])
print(pred_text)
</details>
顺序长音频
与之前的 Distil-Whisper 版本不同,distil-large-v3 专门设计用于与 OpenAI 的顺序长音频转录算法兼容。该算法使用滑动窗口对长音频文件(大于 30 秒)进行缓冲推理,与分块长音频算法相比,返回更准确的转录结果。
在以下任一情况下应使用顺序长音频算法:
- 转录准确性是最重要的因素,而延迟是次要考量
- 您正在转录批量长音频文件,这种情况下顺序方法的延迟与分块方法相当,同时准确性高达 0.5% WER
如果您正在转录单个长音频文件,而延迟是最重要的因素,则应使用下述的分块算法。有关不同算法的详细解释,请参阅 Distil-Whisper 论文 的第 5 节。
可以使用 pipeline 类结合顺序算法来转录长音频,如下所示:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
<details>
<summary> 如需对生成参数进行更多控制,请直接使用 model + processor API:</summary>
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import Audio, load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
sample = dataset[0]["audio"]
inputs = processor(
sample["array"],
sampling_rate=sample["sampling_rate"],
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
)
inputs = inputs.to(device, dtype=torch_dtype)
gen_kwargs = {
"max_new_tokens": 448,
"num_beams": 1,
"condition_on_prev_tokens": False,
"compression_ratio_threshold": 1.35, # zlib 压缩比阈值(token 空间)
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"logprob_threshold": -1.0,
"no_speech_threshold": 0.6,
"return_timestamps": True,
}
pred_ids = model.generate(**inputs, **gen_kwargs)
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True, decode_with_timestamps=False)
print(pred_text)
</details>
分块长音频
distil-large-v3 仍然与 Transformers 分块长音频算法兼容。当转录单个大型音频文件且需要最快的推理速度时,应使用此算法。在这种情况下,分块算法比 OpenAI 的顺序长音频实现快达 9 倍(见 Distil-Whisper 论文 的表 7)。
要启用分块,请将 chunk_length_s 参数传递给 pipeline。对于 distil-large-v3,25 秒的分块长度是最优的。要在长音频上启用批处理,请传入参数 batch_size:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=25,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
推测解码
distil-large-v3 是第一个可以与 Whisper large-v3 配合使用进行推测解码的 Distil-Whisper 模型。推测解码在数学上确保获得与 Whisper 完全相同的输出,同时速度提高 2 倍。这使其成为现有 Whisper pipeline 的完美即插即用替代品,因为保证输出相同。
在以下代码片段中,我们将助手 Distil-Whisper 模型独立加载到主 Whisper pipeline。然后将其指定为生成的"助手模型":
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-large-v3"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
有关推测解码的更多详情,请参阅博客文章 Speculative Decoding for 2x Faster Whisper Inference。
额外的速度和内存优化
您可以对 Distil-Whisper 应用额外的速度和内存优化,以进一步降低推理速度和 VRAM 需求。这些优化主要针对注意力内核,将其从 eager 实现转换为更高效的 flash attention 版本。
Flash Attention 2
如果您的 GPU 支持,我们建议使用 Flash-Attention 2。首先需要安装 Flash Attention:
pip install flash-attn --no-build-isolation
然后将 attn_implementation="flash_attention_2" 传递给 from_pretrained:
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2")
Torch Scale-Product-Attention (SDPA)
如果您的 GPU 不支持 Flash Attention,我们建议使用 PyTorch 缩放点积注意力 (SDPA)。对于 PyTorch 2.1.1 及以上版本,此注意力实现默认启用。要检查您是否有兼容的 PyTorch 版本,请运行以下 Python 代码片段:
from transformers.utils import is_torch_sdpa_available
print(is_torch_sdpa_available())
如果上述返回 True,则您安装了有效版本的 PyTorch,SDPA 默认启用。如果返回 False,则需要根据官方说明升级您的 PyTorch 版本。
安装有效的 PyTorch 版本后,SDPA 默认启用。也可以通过指定 attn_implementation="sdpa" 显式设置:
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
有关如何使用 SDPA 的更多信息,请参阅 Transformers SDPA 文档。
Torch compile
即将推出...
4-bit 和 8-bit 推理
即将推出...
库集成
Whisper.cpp
Distil-Whisper 可以使用原始顺序长音频转录算法通过 Whisper.cpp 包运行。在 Mac M1 的临时基准测试中,distil-large-v3 比 Whisper large-v3 快 5 倍以上,同时在长音频上的 WER 表现相差在 0.8% 以内。
开始步骤:
- 克隆 Whisper.cpp 仓库:
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
- 安装 Hugging Face Hub Python 包:
pip install --upgrade huggingface_hub
然后使用以下 Python 代码片段下载 distil-large-v3 的 GGML 权重:
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id='distil-whisper/distil-large-v3-ggml', filename='ggml-distil-large-v3.bin', local_dir='./models')
请注意,如果您没有设置 Python 环境,也可以直接使用 wget 下载权重:
wget https://huggingface.co/distil-whisper/distil-large-v3-ggml/resolve/main/ggml-distil-large-v3.bin -P ./models
- 使用提供的示例音频运行推理:
make -j && ./main -m models/ggml-distil-large-v3.bin -f samples/jfk.wav
Faster-Whisper
Faster-Whisper 是使用 CTranslate2(用于 Transformer 模型的高效推理引擎)重新实现的 Whisper。
首先,按照官方说明安装 Faster-Whisper 包。在本示例中,我们还将安装 🤗 Datasets 以从 Hugging Face Hub 加载测试音频数据集:
pip install --upgrade pip
pip install --upgrade git+https://github.com/SYSTRAN/faster-whisper datasets[audio]
以下代码片段加载 distil-large-v3 模型并对 LibriSpeech ASR 数据集中的示例文件运行推理:
import torch
from faster_whisper import WhisperModel
from datasets import load_dataset
# 定义我们的 torch 配置
device = "cuda:0" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if torch.cuda.is_available() else "float32"
# 如果有 GPU 则在 GPU 上加载模型,否则使用 CPU
model = WhisperModel("distil-large-v3", device=device, compute_type=compute_type)
# 加载测试数据集作为示例
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[1]["audio"]["path"]
segments, info = model.transcribe(sample, beam_size=1)
for segment in segments:
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
要转录本地音频文件,只需将音频文件的路径作为 audio 参数传递给 transcribe:
segments, info = model.transcribe("audio.mp3", beam_size=1)
OpenAI Whisper
要以原始 Whisper 格式使用该模型,首先确保已安装 openai-whisper 包。在本示例中,我们还将安装 🤗 Datasets 以从 Hugging Face Hub 加载测试音频数据集:
pip install --upgrade pip
pip install --upgrade openai-whisper datasets[audio]
以下代码片段演示如何使用 🤗 Datasets 加载的 LibriSpeech 数据集中的示例文件进行转录:
from huggingface_hub import hf_hub_download
from datasets import load_dataset
from whisper import load_model, transcribe
model_path = hf_hub_download(repo_id="distil-whisper/distil-large-v3-openai", filename="model.bin")
model = load_model(model_path)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]["path"]
pred_out = transcribe(model, audio=sample, language="en")
print(pred_out["text"])
请注意,第一次运行此示例时,模型权重将被下载并保存到缓存中。之后,您可以重复使用同一示例,权重将直接从缓存加载,无需再次下载。
要转录本地音频文件,只需将音频文件的路径作为 audio 参数传递给 transcribe:
pred_out = transcribe(model, audio=sample, language="en")
Distil-Whisper 模型也可以与 OpenAI Whisper CLI 结合使用。详情请参阅以下说明。
Transformers.js
Distil-Whisper 可以使用 Transformers.js 完全在您的网络浏览器中运行:
- 从 NPM 安装 Transformers.js:
npm i @xenova/transformers
- 导入库并使用 pipeline API 进行推理。
import { pipeline } from '@xenova/transformers';
const transcriber = await pipeline('automatic-speech-recognition', 'distil-whisper/distil-large-v3');
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
const output = await transcriber(url);
// { text: " And so, my fellow Americans, ask not what your country can do for you. Ask what you can do for your country." }
查看在线 Distil-Whisper Web Demo 亲自体验。如您所见,它在本地浏览器中运行:无需服务器!
有关更多信息,请参阅 Transformers.js 文档。
Candle
通过与 Hugging Face Candle 🕯️ 的集成,Distil-Whisper 也可以在 Rust 库 🦀 中使用。
优势:
- 针对 Linux x86 的可选 MKL 支持以及针对 Mac 的 Accelerate 优化的 CPU 后端
- 用于高效运行在 Mac 上的 Metal 支持
- 用于高效运行在 GPU 上的 CUDA 后端,通过 NCCL 实现多 GPU 分布
- WASM 支持:在浏览器中运行 Distil-Whisper
开始步骤:
- 按照此处的说明安装
candle-core - 在本地克隆
candle仓库:
git clone https://github.com/huggingface/candle.git
- 进入 Whisper 示例目录:
cd candle/candle-examples/examples/whisper
- 运行示例:
cargo run --example whisper --release --features symphonia -- --model distil-large-v3
- 要指定自己的音频文件,添加
--input标志:
cargo run --example whisper --release --features symphonia -- --model distil-large-v3 --input audio.wav
提示: 要使用 Apple Metal 编译,在运行示例时指定 metal 特性:
cargo run --example whisper --release --features="symphonia,metal" -- --model distil-large-v3
请注意,如果遇到以下错误:
error: target `whisper` in package `candle-examples` requires the features: `symphonia`
Consider enabling them by passing, e.g., `--features="symphonia"`
您应该清理您的 cargo 安装:
cargo clean
然后重新编译:
cargo run --example whisper --release --features symphonia -- --model distil-large-v3
模型详情
Distil-Whisper 继承了 Whisper 的编码器-解码器架构。编码器将语音向量输入序列映射到隐藏状态向量序列。解码器根据所有先前的 token 和编码器的隐藏状态自回归地预测文本 token。因此,编码器只运行一次前向传播,而解码器需要运行与生成的 token 数量相同的次数。在实践中,这意味着解码器占用了总推理时间的 90% 以上。因此,为了优化延迟,重点是尽量减少解码器的推理时间。
为了蒸馏 Whisper 模型,我们减少了解码器层的数量,同时保持编码器不变。编码器(图中绿色部分)完全从教师模型复制到学生模型,并在训练期间冻结。学生的解码器由教师解码器层的子集组成,这些层从间隔最大的层初始化。然后使用 KL 散度和伪标签损失项的加权和来训练模型。
<p align="center"> <img src="https://huggingface.co/datasets/distil-whisper/figures/resolve/main/architecture.png?raw=true" width="600"/> </p>
与 distil-large-v2 的差异
与之前的 Distil-Whisper 版本相比,distil-large-v3 专门针对 OpenAI 顺序长音频转录算法进行设计。与 distil-large-v2 相比没有架构差异,只是模型层从最新的 large-v3 模型初始化,而不是从较早的 large-v2 模型初始化。差异在于模型的训练方式。
之前的 Distil-Whisper 模型在平均输入长度为 7 秒的数据上进行训练,而原始 Whisper 模型在 30 秒输入上进行预训练。在蒸馏期间,我们转移模型权重的分布以匹配训练数据的分布。如果我们的训练数据包含较短的语音(例如平均 7 秒而非 30 秒的音频),则预测分布会转移到这个较短的上下文长度。在推理时,distil-large-v2 的最优上下文窗口是这两个值的插值:15 秒。超过这个时间后,distil-large-v2 的预测在很大程度上是不准确的,特别是时间戳预测。然而,顺序长音频算法使用 30 秒的滑动窗口进行推理,窗口根据最后预测的时间戳进行移动。由于最后的时间戳通常出现在 15 秒之后,它的预测准确率较低,导致长音频转录经常失败。
为了保留 Whisper 转录滑动 30 秒窗口的能力(如同顺序解码中所做的那样),我们需要确保 distil-large-v3 的上下文长度也是 30 秒。这主要通过四种策略实现:
- 将训练数据集中的音频样本打包到 30 秒: 由于模型在打包到 30 秒的音频数据上进行预训练和蒸馏,distil-large-v3 现在与 Whisper 的理想上下文窗口相同运行,预测准确的时间戳最多至 30 秒。
- 冻结解码器输入嵌入: 我们使用与原始模型相同的输入嵌入表示,其设计用于处理比之前 Distil-Whisper 迭代更长的上下文长度。
- 在训练期间使用更长的最大上下文长度: 不是在最大目标长度 128 上训练,而是在最大 256 上训练。这有助于 distil-Whisper 转录 30 秒的片段,其中 token 数量可能超过 128。
- 在 50% 的训练样本中添加提示条件: 使模型能够与
condition_on_prev_tokens参数一起使用,上下文窗口最多 448 个 token。
还有进一步的技巧用于提高 distil-large-v3 在顺序解码算法下的性能,这些将在即将发布的博客文章中详细解释。
评估
以下代码片段演示如何在 LibriSpeech validation-clean 数据集上使用流式模式评估 Distil-Whisper 模型,这意味着无需将音频数据下载到本地设备。
首先,我们需要安装所需的包,包括用于流式加载音频数据的 🤗 Datasets 和用于计算 WER 的 🤗 Evaluate:
pip install --upgrade pip
pip install --upgrade transformers datasets[audio] evaluate jiwer
然后可以使用以下示例进行端到端评估:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# 定义我们的 torch 配置
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
# 加载模型 + 处理器
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# 使用流式模式加载数据集
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# 定义评估指标
wer_metric = load("wer")
def inference(batch):
# 1. 预处理音频数据为 log-mel 频谱图输入
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. 自回归生成预测的 token id
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. 将 token id 解码为最终转录
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
# 批量大小 16 推理
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# 遍历数据集并运行推理
for result in tqdm(dataset, desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# 规范化预测和参考
all_transcriptions = [processor.normalize(transcription) for transcription in all_transcriptions]
all_references = [processor.normalize(reference) for reference in all_references]
# 计算 WER 指标
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
输出:
2.428920763531516
预期用途
Distil-Whisper 旨在作为 Whisper large-v3 在英语语音识别方面的即插即用替代品。特别是,它在分布外(OOD)测试数据上实现了相当的 WER 结果,同时在短音频和长音频上都快 6 倍。
数据
Distil-Whisper 基于 Hugging Face Hub 上 9 个开源、许可授权的语音数据集的 22,000 小时音频数据进行训练:
| 数据集 | 时长 / h | 说话者 | 领域 | 许可证 |
|---|---|---|---|---|
| People's Speech | 12,000 | unknown | Internet Archive | CC-BY-SA-4.0 |
| Common Voice 13 | 3,000 | unknown | Narrated Wikipedia | CC0-1.0 |
| GigaSpeech | 2,500 | unknown | Audiobook, podcast, YouTube | apache-2.0 |
| Fisher | 1,960 | 11,900 | 电话对话 | LDC |
| LibriSpeech | 960 | 2,480 | 有声书 | CC-BY-4.0 |
| VoxPopuli | 540 | 1,310 | 欧洲议会 | CC0 |
| TED-LIUM | 450 | 2,030 | TED 演讲 | CC-BY-NC-ND 3.0 |
| SwitchBoard | 260 | 540 | 电话对话 | LDC |
| AMI | 100 | unknown | 会议 | CC-BY-4.0 |
| 总计 | 21,770 | 18,260+ |
组合数据集涵盖 10 个不同领域和超过 50,000 名说话者。数据集的多样性对于确保蒸馏模型对音频分布和噪声具有鲁棒性至关重要。
然后使用 Whisper large-v3 模型对音频数据进行伪标记:我们使用 Whisper 为训练集中的所有音频生成预测,并在训练中将这些作为目标标签。使用伪标记可确保转录在各数据集之间格式一致,并在训练期间提供序列级蒸馏信号。
WER 过滤器
Whisper 伪标签预测存在转录错误和幻觉。为确保我们仅在准确的伪标记上进行训练,我们在训练期间采用简单的 WER 启发式方法。首先,我们规范化 Whisper 伪标记和各数据集提供的真实标签。然后计算这些标签之间的 WER。如果 WER 超过指定阈值,我们丢弃该训练样本。否则,将其保留用于训练。
Distil-Whisper 论文 的第 9.2 节证明了此过滤器在提高蒸馏模型下游性能方面的有效性。我们还将 Distil-Whisper 对幻觉的鲁棒性部分归因于此过滤器。
训练
模型训练了 80,000 个优化步骤(或 11 个 epoch),批量大小为 256。Tensorboard 训练日志可在以下地址查看:https://huggingface.co/distil-whisper/distil-large-v3/tensorboard?params=scalars#frame
结果
蒸馏模型在分布外(OOD)短音频上与 Whisper large-v3 的 WER 差距在 1.5% 以内,在顺序长音频解码上差距在 1% 以内,在分块长音频上比 large-v3 高 0.1%。这一性能提升归因于更低的幻觉率。
有关评估结果的详细逐数据集分解,请参阅 Distil-Whisper 论文 的表 16 和 17。
Distil-Whisper 也在 ESB 基准 数据集上进行了评估,作为 OpenASR 排行榜 的一部分,其 WER 表现与 Whisper 相差在 0.2% 以内。
复现 Distil-Whisper
用于复现 Distil-Whisper 的训练和评估代码可在 Distil-Whisper 仓库中获取:https://github.com/huggingface/distil-whil-whisper/tree/main/training
此代码将很快更新,以包含 与 distil-large-v2 的差异 部分所述的训练更新。
许可证
Distil-Whisper 继承自 OpenAI Whisper 模型的 MIT 许可证。
引用
如果您使用此模型,请考虑引用 Distil-Whisper 论文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
致谢
- OpenAI 提供的 Whisper 模型,特别感谢 Jong Wook Kim 提供的原始代码库 和训练讨论
- Hugging Face 🤗 Transformers 提供的模型集成
- Georgi Gerganov 提供的 Whisper cpp 集成
- Systran 团队 提供的 Faster-Whisper 集成
- Joshua Lochner 提供的 Transformers.js 集成
- Laurent Mazare 提供的 Candle 集成
- Vaibhav Srivastav 提供的 Distil-Whisper 发行版
- Google 的 TPU Research Cloud (TRC) 计划提供的 Cloud TPU v4 计算资源
- Raghav Sonavane 提供的 LibriSpeech 数据集上的 Distil-Whisper 早期迭代
distil-whisper/distil-large-v3
作者 distil-whisper
创建时间: 2024-03-21 12:10:42+00:00
更新时间: 2025-03-06 17:22:45+00:00
在 Hugging Face 上查看