说明文档
Distil-Whisper: distil-small.en
Distil-Whisper 模型在论文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling 中提出。 它是 Whisper 模型的蒸馏版本,速度快 6 倍,体积小 49%,在分布外评估集上的 WER 表现仅差 1%。
这是 distil-small.en 的仓库,它是 Whisper small.en 的蒸馏变体。 它是最小的 Distil-Whisper 检查点,仅有 166M 参数,是内存受限应用(如设备端部署)的理想选择。
对于大多数其他应用,推荐使用 distil-medium.en 或 distil-large-v2 检查点,因为它们速度更快且 WER 表现更好:
| 模型 | 参数 / M | 相对延迟 ↑ | 短格式 WER ↓ | 长格式 WER ↓ |
|---|---|---|---|---|
| large-v3 | 1550 | 1.0 | 8.4 | 11.0 |
| large-v2 | 1550 | 1.0 | 9.1 | 11.7 |
| distil-large-v3 | 756 | 6.3 | 9.7 | 10.8 |
| distil-large-v2 | 756 | 5.8 | 10.1 | 11.6 |
| distil-medium.en | 394 | 6.8 | 11.1 | 12.4 |
| distil-small.en | 166 | 5.6 | 12.1 | 12.8 |
注意: Distil-Whisper 目前仅支持英语语音识别。我们正在与社区合作蒸馏其他语言的 Whisper。如果您有兴趣蒸馏您自己语言的 Whisper,请查看提供的训练代码。我们将在 Distil-Whisper 仓库 中更新多语言检查点!
为什么 distil-small.en 比 distil-large-v2 慢?
虽然 distil-medium.en 和 distil-large-v2 都使用两个解码器层,但 distil-small.en 使用四个。使用更多解码器层可以提高模型的 WER 表现,但代价是推理速度变慢。我们发现四层是使 distil-small.en 获得合理 WER 表现所需的最少层数,它在 WER 方面与 Whisper large-v2 的差距在 3% 以内,同时快 5.6 倍。当我们尝试只用两层进行蒸馏时,模型比 large-v2 差 5% 以上,尽管速度快了 7.8 倍。我们将两层 small.en 模型的蒸馏工作留待将来完成。
使用方法
Distil-Whisper 从 Hugging Face 🤗 Transformers 4.35 版本起得到支持。要运行模型,首先安装最新版本的 Transformers 库。在本示例中,我们还将安装 🤗 Datasets 以从 Hugging Face Hub 加载示例音频数据集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
短格式转录
可以使用 pipeline 类来转录短格式音频文件(< 30 秒),如下所示:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
要转录本地音频文件,只需在调用 pipeline 时传入音频文件的路径:
- result = pipe(sample)
+ result = pipe("audio.mp3")
长格式转录
Distil-Whisper 使用分块算法来转录长格式音频文件(> 30 秒)。实际上,这种分块长格式算法比 OpenAI 在 Whisper 论文中提出的顺序算法快 9 倍(参见 Distil-Whisper 论文 的表 7)。
要启用分块,请将 chunk_length_s 参数传递给 pipeline。对于 Distil-Whisper,15 秒的分块长度是最佳的。要激活批处理,请传入 batch_size 参数:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "default", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
<!--- Tip: The pipeline can also be used to transcribe an audio file from a remote URL, for example:
result = pipe("https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav")
--->
推测解码
Distil-Whisper 可以作为 Whisper 的辅助模型用于推测解码。 推测解码在数学上确保获得与 Whisper 完全相同的输出,同时速度提高 2 倍。这使其成为现有 Whisper pipeline 的完美即插即用替代品,因为可以保证相同的输出。
在下面的代码片段中,我们将辅助 Distil-Whisper 模型独立加载到主 Whisper pipeline。然后我们将其指定为生成的"辅助模型":
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-small.en"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-medium.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
额外的速度与内存优化
您可以为 Distil-Whisper 应用额外的速度和内存优化,我们将在下面介绍。
Flash Attention
如果您的 GPU 支持,我们建议使用 Flash-Attention 2。 为此,您首先需要安装 Flash Attention:
pip install flash-attn --no-build-isolation
然后只需将 use_flash_attention_2=True 传递给 from_pretrained:
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=True)
Torch Scale-Product-Attention (SDPA)
如果您的 GPU 不支持 Flash Attention,我们建议使用 BetterTransformers。 为此,您首先需要安装 optimum:
pip install --upgrade optimum
然后在使用模型之前将其转换为 "BetterTransformer" 模型:
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = model.to_bettertransformer()
在 openai-whisper 中运行 Distil-Whisper
要以原始 Whisper 格式使用模型,首先确保已安装 openai-whisper 包:
pip install --upgrade openai-whisper
以下代码片段演示如何使用 🤗 Datasets 加载的 LibriSpeech 数据集样本进行转录:
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from whisper import load_model, transcribe
distil_small_en = hf_hub_download(repo_id="distil-whisper/distil-small.en", filename="original-model.bin")
model = load_model(distil_small_en)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sample = torch.from_numpy(sample).float()
pred_out = transcribe(model, audio=sample)
print(pred_out["text"])
请注意,第一次运行此示例时,模型权重将被下载并保存到您的缓存中。之后, 您可以重复使用同一示例,权重将直接从缓存加载,无需再次下载。
要转录本地音频文件,只需将音频文件的路径作为 transcribe 的 audio 参数传入:
pred_out = transcribe(model, audio="audio.mp3")
Whisper.cpp
Distil-Whisper 可以使用原始顺序长格式转录算法从 Whisper.cpp 仓库运行。在 Mac M1 的临时基准测试中,distil-small.en 比 large-v2 快 4 倍以上,同时在长格式音频上的 WER 表现相差 1.4% 以内。
入门步骤:
- 克隆 Whisper.cpp 仓库:
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
- 从 Hugging Face Hub 下载
distil-small.en的 ggml 权重:
python -c "from huggingface_hub import hf_hub_download; hf_hub_download(repo_id='distil-whisper/distil-small.en', filename='ggml-distil-small.en.bin', local_dir='./models')"
请注意,如果您没有安装 huggingface_hub 包,也可以使用 wget 下载权重:
wget https://huggingface.co/distil-whisper/distil-small.en/resolve/main/ggml-distil-small.en.bin -P ./models
- 使用提供的示例音频运行推理:
make -j && ./main -m models/ggml-distil-small.en.bin -f samples/jfk.wav
Transformers.js
Distil-Whisper 甚至可以在您的网页浏览器中与 Transformers.js 一起运行:
- 从 NPM 安装 Transformers.js:
npm i @xenova/transformers
- 导入库并使用 pipeline API 进行推理。
import { pipeline } from '@xenova/transformers';
const transcriber = await pipeline('automatic-speech-recognition', 'distil-whisper/distil-small.en');
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
const output = await transcriber(url);
// { text: " And so my fellow Americans, ask not what your country can do for you. Ask what you can do for your country." }
查看在线 Distil-Whisper Web demo 自己尝试一下。如您所见,它在本地浏览器中运行:无需服务器!
查看文档了解更多信息。
Candle
即将推出!
<!---
Through an integration with Hugging Face Candle 🕯️, Distil-Whisper is now available in the Rust library 🦀
Benefit from:
- Optimised CPU backend with optional MKL support for x86 and Accelerate for Macs
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL
- WASM support: run Distil-Whisper in a browser
Steps for getting started:
- Install
candle-coreas explained here - Clone the
candlerepository locally:
git clone https://github.com/huggingface/candle.git
- Enter the example directory for Whisper:
cd candle/candle-examples/examples/whisper
- Run an example:
cargo run --example whisper --release -- --model distil-small.en
- To specify your own audio file, add the
--inputflag:
cargo run --example whisper --release -- --model distil-small.en --input audio.wav
--->
8位与4位量化
即将推出!
模型详情
Distil-Whisper 继承自 Whisper 的编码器-解码器架构。编码器将语音向量输入序列映射到隐藏状态向量序列。解码器根据所有先前的 token 和编码器的隐藏状态自回归预测文本 token。因此,编码器只需运行一次,而解码器需要运行与生成的 token 数量相同的次数。实际上,这意味着解码器占据了总推理时间的 90% 以上。因此,为了优化延迟,重点是尽量减少解码器的推理时间。
为了蒸馏 Whisper 模型,我们在保持编码器不变的情况下减少解码器层的数量。编码器(图中绿色部分)完全从教师模型复制到学生模型,并在训练期间冻结。学生的解码器由教师解码器层的子集组成,这些层从最大间隔层初始化。然后使用 KL 散度和伪标签损失项的加权和对模型进行训练。
<p align="center"> <img src="https://huggingface.co/datasets/distil-whisper/figures/resolve/main/architecture.png?raw=true" width="600"/> </p>
评估
以下代码片段演示如何使用流式模式在 LibriSpeech validation.clean 数据集上评估 Distil-Whisper 模型,这意味着无需将音频数据下载到本地设备。
首先,我们需要安装所需的包,包括 🤗 Datasets 来流式加载音频数据,以及 🤗 Evaluate 来执行 WER 计算:
pip install --upgrade pip
pip install --upgrade transformers datasets[audio] evaluate jiwer
然后可以使用以下示例进行端到端评估:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
# load the model + processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)
def inference(batch):
# 1. Pre-process the audio data to log-mel spectrogram inputs
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. Auto-regressively generate the predicted token ids
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. Decode the token ids to the final transcription
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# normalize predictions and references
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
输出:
3.4326070294536297
预期用途
Distil-Whisper 旨在作为 Whisper 在英语语音识别方面的即插即用替代品。特别是在分布外测试数据上,它实现了相当的 WER 表现,同时在短格式和长格式音频上都快 6 倍。
数据
Distil-Whisper 在来自 Hugging Face Hub 的 9 个开源、宽松许可的语音数据集上训练了 22,000 小时的音频数据:
| 数据集 | 大小 / h | 说话者 | 领域 | 许可证 |
|---|---|---|---|---|
| People's Speech | 12,000 | unknown | Internet Archive | CC-BY-SA-4.0 |
| Common Voice 13 | 3,000 | unknown | Narrated Wikipedia | CC0-1.0 |
| GigaSpeech | 2,500 | unknown | Audiobook, podcast, YouTube | apache-2.0 |
| Fisher | 1,960 | 11,900 | Telephone conversations | LDC |
| LibriSpeech | 960 | 2,480 | Audiobooks | CC-BY-4.0 |
| VoxPopuli | 540 | 1,310 | European Parliament | CC0 |
| TED-LIUM | 450 | 2,030 | TED talks | CC-BY-NC-ND 3.0 |
| SwitchBoard | 260 | 540 | Telephone conversations | LDC |
| AMI | 100 | unknown | Meetings | CC-BY-4.0 |
| 总计 | 21,770 | 18,260+ |
组合数据集涵盖 10 个不同领域和超过 50,000 名说话者。数据集的多样性对于确保蒸馏模型对音频分布和噪声具有鲁棒性至关重要。
然后使用 Whisper large-v2 模型对音频数据进行伪标签:我们使用 Whisper 为训练集中的所有音频生成预测,并在训练中将这些作为目标标签。使用伪标签确保转录在数据集中格式一致,并在训练期间提供序列级蒸馏信号。
WER 过滤
Whisper 伪标签预测存在转录错误和幻觉。为了确保我们只在准确的伪标签上进行训练,我们在训练期间采用简单的 WER 启发式方法。首先,我们对每个数据集提供的 Whisper 伪标签和真实标签进行标准化。然后计算这些标签之间的 WER。如果 WER 超过指定阈值,我们丢弃该训练样本。否则,我们保留它进行训练。
Distil-Whisper 论文 的第 9.2 节证明了这种过滤对于提高蒸馏模型的下游性能的有效性。我们还将 Distil-Whisper 对幻觉的鲁棒性部分归因于此过滤器。
训练
该模型训练了 50,000 个优化步骤(或 12 个 epoch),批量大小为 2056。Tensorboard 训练日志可在以下地址找到:https://huggingface.co/distil-whisper/distil-small.en/tensorboard?params=scalars#frame
结果
蒸馏模型在分布外(OOD)短格式音频上的 WER 表现与 Whisper 相差在 1% 以内,在 OOD 长格式音频上的表现优于 Whisper 0.1%。这种性能提升归因于更低的幻觉率。
有关评估结果的详细按数据集分类,请参阅 Distil-Whisper 论文 的表 16 和 17。
Distil-Whisper 也在 ESB 基准 数据集上作为 OpenASR 排行榜 的一部分进行了评估,其 WER 表现与 Whisper 相差在 0.2% 以内。
复现 Distil-Whisper
用于复现 Distil-Whisper 的训练和评估代码可在 Distil-Whisper 仓库中获取:https://github.com/huggingface/distil-whisper/tree/main/training
许可证
Distil-Whisper 继承自 OpenAI Whisper 模型的 MIT 许可证。
引用
如果您使用此模型,请考虑引用 Distil-Whisper 论文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
致谢
- OpenAI 提供的 Whisper 模型和原始代码库
- Hugging Face 🤗 Transformers 提供的模型集成
- 谷歌的 TPU Research Cloud (TRC) 计划提供的 Cloud TPU v4s
@rsonavane在 LibriSpeech 数据集上发布的 Distil-Whisper 早期版本
Distil-Whisper: distil-small.en
Distil-Whisper 模型在论文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling 中提出。 它是 Whisper 模型的蒸馏版本,速度快 6 倍,体积小 49%,在分布外评估集上的 WER 表现仅差 1%。
这是 distil-small.en 的仓库,它是 Whisper small.en 的蒸馏变体。 它是最小的 Distil-Whisper 检查点,仅有 166M 参数,是内存受限应用(如设备端部署)的理想选择。
对于大多数其他应用,推荐使用 distil-medium.en 或 distil-large-v2 检查点,因为它们速度更快且 WER 表现更好:
| 模型 | 参数 / M | 相对延迟 ↑ | 短格式 WER ↓ | 长格式 WER ↓ |
|---|---|---|---|---|
| large-v3 | 1550 | 1.0 | 8.4 | 11.0 |
| large-v2 | 1550 | 1.0 | 9.1 | 11.7 |
| distil-large-v3 | 756 | 6.3 | 9.7 | 10.8 |
| distil-large-v2 | 756 | 5.8 | 10.1 | 11.6 |
| distil-medium.en | 394 | 6.8 | 11.1 | 12.4 |
| distil-small.en | 166 | 5.6 | 12.1 | 12.8 |
注意: Distil-Whisper 目前仅支持英语语音识别。我们正在与社区合作蒸馏其他语言的 Whisper。如果您有兴趣蒸馏您自己语言的 Whisper,请查看提供的训练代码。我们将在 Distil-Whisper 仓库 中更新多语言检查点!
为什么 distil-small.en 比 distil-large-v2 慢?
虽然 distil-medium.en 和 distil-large-v2 都使用两个解码器层,但 distil-small.en 使用四个。使用更多解码器层可以提高模型的 WER 表现,但代价是推理速度变慢。我们发现四层是使 distil-small.en 获得合理 WER 表现所需的最少层数,它在 WER 方面与 Whisper large-v2 的差距在 3% 以内,同时快 5.6 倍。当我们尝试只用两层进行蒸馏时,模型比 large-v2 差 5% 以上,尽管速度快了 7.8 倍。我们将两层 small.en 模型的蒸馏工作留待将来完成。
使用方法
Distil-Whisper 从 Hugging Face 🤗 Transformers 4.35 版本起得到支持。要运行模型,首先安装最新版本的 Transformers 库。在本示例中,我们还将安装 🤗 Datasets 以从 Hugging Face Hub 加载示例音频数据集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
短格式转录
可以使用 pipeline 类来转录短格式音频文件(< 30 秒),如下所示:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
要转录本地音频文件,只需在调用 pipeline 时传入音频文件的路径:
- result = pipe(sample)
+ result = pipe("audio.mp3")
长格式转录
Distil-Whisper 使用分块算法来转录长格式音频文件(> 30 秒)。实际上,这种分块长格式算法比 OpenAI 在 Whisper 论文中提出的顺序算法快 9 倍(参见 Distil-Whisper 论文 的表 7)。
要启用分块,请将 chunk_length_s 参数传递给 pipeline。对于 Distil-Whisper,15 秒的分块长度是最佳的。要激活批处理,请传入 batch_size 参数:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "default", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
推测解码
Distil-Whisper 可以作为 Whisper 的辅助模型用于推测解码。 推测解码在数学上确保获得与 Whisper 完全相同的输出,同时速度提高 2 倍。这使其成为现有 Whisper pipeline 的完美即插即用替代品,因为可以保证相同的输出。
在下面的代码片段中,我们将辅助 Distil-Whisper 模型独立加载到主 Whisper pipeline。然后我们将其指定为生成的"辅助模型":
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-small.en"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-medium.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
额外的速度与内存优化
您可以为 Distil-Whisper 应用额外的速度和内存优化,我们将在下面介绍。
Flash Attention
如果您的 GPU 支持,我们建议使用 Flash-Attention 2。 为此,您首先需要安装 Flash Attention:
pip install flash-attn --no-build-isolation
然后只需将 use_flash_attention_2=True 传递给 from_pretrained:
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=True)
Torch Scale-Product-Attention (SDPA)
如果您的 GPU 不支持 Flash Attention,我们建议使用 BetterTransformers。 为此,您首先需要安装 optimum:
pip install --upgrade optimum
然后在使用模型之前将其转换为 "BetterTransformer" 模型:
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = model.to_bettertransformer()
在 openai-whisper 中运行 Distil-Whisper
要以原始 Whisper 格式使用模型,首先确保已安装 openai-whisper 包:
pip install --upgrade openai-whisper
以下代码片段演示如何使用 🤗 Datasets 加载的 LibriSpeech 数据集样本进行转录:
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from whisper import load_model, transcribe
distil_small_en = hf_hub_download(repo_id="distil-whisper/distil-small.en", filename="original-model.bin")
model = load_model(distil_small_en)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sample = torch.from_numpy(sample).float()
pred_out = transcribe(model, audio=sample)
print(pred_out["text"])
请注意,第一次运行此示例时,模型权重将被下载并保存到您的缓存中。之后, 您可以重复使用同一示例,权重将直接从缓存加载,无需再次下载。
要转录本地音频文件,只需将音频文件的路径作为 transcribe 的 audio 参数传入:
pred_out = transcribe(model, audio="audio.mp3")
Whisper.cpp
Distil-Whisper 可以使用原始顺序长格式转录算法从 Whisper.cpp 仓库运行。在 Mac M1 的临时基准测试中,distil-small.en 比 large-v2 快 4 倍以上,同时在长格式音频上的 WER 表现相差 1.4% 以内。
入门步骤:
- 克隆 Whisper.cpp 仓库:
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
- 从 Hugging Face Hub 下载
distil-small.en的 ggml 权重:
python -c "from huggingface_hub import hf_hub_download; hf_hub_download(repo_id='distil-whisper/distil-small.en', filename='ggml-distil-small.en.bin', local_dir='./models')"
请注意,如果您没有安装 huggingface_hub 包,也可以使用 wget 下载权重:
wget https://huggingface.co/distil-whisper/distil-small.en/resolve/main/ggml-distil-small.en.bin -P ./models
- 使用提供的示例音频运行推理:
make -j && ./main -m models/ggml-distil-small.en.bin -f samples/jfk.wav
Transformers.js
Distil-Whisper 甚至可以在您的网页浏览器中与 Transformers.js 一起运行:
- 从 NPM 安装 Transformers.js:
npm i @xenova/transformers
- 导入库并使用 pipeline API 进行推理。
import { pipeline } from '@xenova/transformers';
const transcriber = await pipeline('automatic-speech-recognition', 'distil-whisper/distil-small.en');
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
const output = await transcriber(url);
// { text: " And so my fellow Americans, ask not what your country can do for you. Ask what you can do for your country." }
查看在线 Distil-Whisper Web demo 自己尝试一下。如您所见,它在本地浏览器中运行:无需服务器!
查看文档了解更多信息。
Candle
即将推出!
8位与4位量化
即将推出!
模型详情
Distil-Whisper 继承自 Whisper 的编码器-解码器架构。编码器将语音向量输入序列映射到隐藏状态向量序列。解码器根据所有先前的 token 和编码器的隐藏状态自回归预测文本 token。因此,编码器只需运行一次,而解码器需要运行与生成的 token 数量相同的次数。实际上,这意味着解码器占据了总推理时间的 90% 以上。因此,为了优化延迟,重点是尽量减少解码器的推理时间。
为了蒸馏 Whisper 模型,我们在保持编码器不变的情况下减少解码器层的数量。编码器(图中绿色部分)完全从教师模型复制到学生模型,并在训练期间冻结。学生的解码器由教师解码器层的子集组成,这些层从最大间隔层初始化。然后使用 KL 散度和伪标签损失项的加权和对模型进行训练。
<p align="center"> <img src="https://huggingface.co/datasets/distil-whisper/figures/resolve/main/architecture.png?raw=true" width="600"/> </p>
评估
以下代码片段演示如何使用流式模式在 LibriSpeech validation.clean 数据集上评估 Distil-Whisper 模型,这意味着无需将音频数据下载到本地设备。
首先,我们需要安装所需的包,包括 🤗 Datasets 来流式加载音频数据,以及 🤗 Evaluate 来执行 WER 计算:
pip install --upgrade pip
pip install --upgrade transformers datasets[audio] evaluate jiwer
然后可以使用以下示例进行端到端评估:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
# load the model + processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)
def inference(batch):
# 1. Pre-process the audio data to log-mel spectrogram inputs
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. Auto-regressively generate the predicted token ids
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. Decode the token ids to the final transcription
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# normalize predictions and references
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
输出:
3.4326070294536297
预期用途
Distil-Whisper 旨在作为 Whisper 在英语语音识别方面的即插即用替代品。特别是在分布外测试数据上,它实现了相当的 WER 表现,同时在短格式和长格式音频上都快 6 倍。
数据
Distil-Whisper 在来自 Hugging Face Hub 的 9 个开源、宽松许可的语音数据集上训练了 22,000 小时的音频数据:
| 数据集 | 大小 / h | 说话者 | 领域 | 许可证 |
|---|---|---|---|---|
| People's Speech | 12,000 | unknown | Internet Archive | CC-BY-SA-4.0 |
| Common Voice 13 | 3,000 | unknown | Narrated Wikipedia | CC0-1.0 |
| GigaSpeech | 2,500 | unknown | Audiobook, podcast, YouTube | apache-2.0 |
| Fisher | 1,960 | 11,900 | Telephone conversations | LDC |
| LibriSpeech | 960 | 2,480 | Audiobooks | CC-BY-4.0 |
| VoxPopuli | 540 | 1,310 | European Parliament | CC0 |
| TED-LIUM | 450 | 2,030 | TED talks | CC-BY-NC-ND 3.0 |
| SwitchBoard | 260 | 540 | Telephone conversations | LDC |
| AMI | 100 | unknown | Meetings | CC-BY-4.0 |
| 总计 | 21,770 | 18,260+ |
组合数据集涵盖 10 个不同领域和超过 50,000 名说话者。数据集的多样性对于确保蒸馏模型对音频分布和噪声具有鲁棒性至关重要。
然后使用 Whisper large-v2 模型对音频数据进行伪标签:我们使用 Whisper 为训练集中的所有音频生成预测,并在训练中将这些作为目标标签。使用伪标签确保转录在数据集中格式一致,并在训练期间提供序列级蒸馏信号。
WER 过滤
Whisper 伪标签预测存在转录错误和幻觉。为了确保我们只在准确的伪标签上进行训练,我们在训练期间采用简单的 WER 启发式方法。首先,我们对每个数据集提供的 Whisper 伪标签和真实标签进行标准化。然后计算这些标签之间的 WER。如果 WER 超过指定阈值,我们丢弃该训练样本。否则,我们保留它进行训练。
Distil-Whisper 论文 的第 9.2 节证明了这种过滤对于提高蒸馏模型的下游性能的有效性。我们还将 Distil-Whisper 对幻觉的鲁棒性部分归因于此过滤器。
训练
该模型训练了 50,000 个优化步骤(或 12 个 epoch),批量大小为 2056。Tensorboard 训练日志可在以下地址找到:https://huggingface.co/distil-whisper/distil-small.en/tensorboard?params=scalars#frame
结果
蒸馏模型在分布外(OOD)短格式音频上的 WER 表现与 Whisper 相差在 1% 以内,在 OOD 长格式音频上的表现优于 Whisper 0.1%。这种性能提升归因于更低的幻觉率。
有关评估结果的详细按数据集分类,请参阅 Distil-Whisper 论文 的表 16 和 17。
Distil-Whisper 也在 ESB 基准 数据集上作为 OpenASR 排行榜 的一部分进行了评估,其 WER 表现与 Whisper 相差在 0.2% 以内。
复现 Distil-Whisper
用于复现 Distil-Whisper 的训练和评估代码可在 Distil-Whisper 仓库中获取:https://github.com/huggingface/distil-whisper/tree/main/training
许可证
Distil-Whisper 继承自 OpenAI Whisper 模型的 MIT 许可证。
引用
如果您使用此模型,请考虑引用 Distil-Whisper 论文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
致谢
- OpenAI 提供的 Whisper 模型和原始代码库
- Hugging Face 🤗 Transformers 提供的模型集成
- 谷歌的 TPU Research Cloud (TRC) 计划提供的 Cloud TPU v4s
@rsonavane在 LibriSpeech 数据集上发布的 Distil-Whisper 早期版本
distil-whisper/distil-small.en
作者 distil-whisper
创建时间: 2023-12-06 11:35:48+00:00
更新时间: 2024-03-25 12:09:13+00:00
在 Hugging Face 上查看