ONNX 模型库
返回模型

说明文档

这是 Distil-Whisper-Large 医疗语音识别微调的工作空间。该模型会频繁更改,如果您觉得它对您的需求有用,请复制此空间,因为它会经常更新。

Distil-Whisper: distil-large-v3

Distil-Whisper 在论文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling 中提出。

这是 Distil-Whisper 英文系列的第三个也是最后一个版本。它是 OpenAI Whisper large-v3 的知识蒸馏版本,是迄今为止最新、性能最强的 Whisper 模型。

与之前的 Distil-Whisper 模型相比,distil-large-v3 的蒸馏过程经过优化,在使用 OpenAI 顺序长音频算法时具有更优越的长音频转录准确性

结果是一个蒸馏模型,在使用顺序算法和分块算法处理长音频时,其 WER(词错误率)与 large-v3 相差不到 1%,并且在使用顺序算法时比 distil-large-v2 优越 4.8%。该模型也比之前的 Distil-Whisper 模型更快:比 large-v3 快 6.3 倍,比 distil-large-v2 快 1.1 倍。

模型 参数量 / M 相对延迟 短音频 顺序长音频 分块长音频
large-v3 1550 1.0 8.4 10.0 11.0
distil-large-v3 756 6.3 9.7 10.8 10.9
distil-large-v2 756 5.8 10.1 15.6 11.6

由于顺序算法是最流行的 Whisper 库(Whisper cpp、Faster-Whisper、OpenAI Whisper)中"事实上的"转录算法,因此该蒸馏模型设计为与这些库兼容。在使用这些库时,从之前的 Distil-Whisper 检查点切换到 distil-large-v3,您可以期待显著的性能提升。为方便起见,最流行库的权重已经转换完成,入门说明如下。

目录

  1. Transformers 使用方法
  2. 库集成
  3. 模型详情
  4. 许可证

Transformers 使用方法

distil-large-v3 从 4.39 版本开始在 Hugging Face 🤗 Transformers 库中得到支持。要运行该模型,首先安装最新版本的 Transformers。在此示例中,我们还将安装 🤗 Datasets 以从 Hugging Face Hub 加载一个示例音频数据集:

pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]

短音频转录

该模型可以与 pipeline 类一起使用来转录短音频文件(< 30秒),如下所示:

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])

要转录本地音频文件,只需在调用 pipeline 时传递音频文件的路径:

- result = pipe(sample)
+ result = pipe("audio.mp3")

要获取片段级时间戳,请传递参数 return_timestamps=True 并返回 "chunks" 输出:

result = pipe(sample, return_timestamps=True)
print(result["chunks"])

<details>

<summary> 要对生成参数进行更多控制,请直接使用 model + processor API: </summary>

可以将临时生成参数传递给 model.generate,包括用于束搜索的 num_beams、用于片段级时间戳的 return_timestamps,以及用于提示的 prompt_ids。有关更多详细信息,请参阅 文档字符串

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import Audio, load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
sample = dataset[0]["audio"]

input_features = processor(
  sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
).input_features

input_features = input_features.to(device, dtype=torch_dtype)

gen_kwargs = {
  "max_new_tokens": 128,
  "num_beams": 1,
  "return_timestamps": False,
}

pred_ids = model.generate(input_features, **gen_kwargs)
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True, decode_with_timestamps=gen_kwargs["return_timestamps"])

print(pred_text)

</details>

顺序长音频

与之前的 Distil-Whisper 版本不同,distil-large-v3 专门设计为与 OpenAI 的顺序长音频转录算法兼容。该算法使用滑动窗口对长音频文件(> 30秒)进行缓冲推理,与分块长音频算法相比,返回更准确的转录结果。

在以下情况下应使用顺序长音频算法:

  1. 转录准确性是最重要的因素,而延迟不太重要
  2. 您正在转录批量长音频文件,在这种情况下,顺序算法的延迟与分块算法相当,同时准确性可提高 0.5% WER

如果您正在转录单个长音频文件且延迟是最重要的因素,则应使用下面描述的分块算法。有关不同算法的详细说明,请参阅 Distil-Whisper 论文的第 5 节。

pipeline 类可以使用顺序算法转录长音频文件,如下所示:

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])

<details>

<summary> 要对生成参数进行更多控制,请直接使用 model + processor API: </summary>

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import Audio, load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
sample = dataset[0]["audio"]

inputs = processor(
    sample["array"],
    sampling_rate=sample["sampling_rate"],
    return_tensors="pt",
    truncation=False,
    padding="longest",
    return_attention_mask=True,
)
inputs = inputs.to(device, dtype=torch_dtype)

gen_kwargs = {
    "max_new_tokens": 448,
    "num_beams": 1,
    "condition_on_prev_tokens": False,
    "compression_ratio_threshold": 1.35,  # zlib 压缩比阈值(在 token 空间中)
    "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
    "logprob_threshold": -1.0,
    "no_speech_threshold": 0.6,
    "return_timestamps": True,
}

pred_ids = model.generate(**inputs, **gen_kwargs)
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True, decode_with_timestamps=False)

print(pred_text)

</details>

分块长音频

distil-large-v3 仍与 Transformers 分块长音频算法兼容。当转录单个大型音频文件且需要尽可能快的推理速度时,应使用此算法。在这种情况下,分块算法比 OpenAI 的顺序长音频实现快 9 倍(参见 Distil-Whisper 论文的表 7)。

要启用分块,请将 chunk_length_s 参数传递给 pipeline。对于 distil-large-v3,25 秒的分块长度是最佳的。要对长音频文件进行批处理,请传递参数 batch_size

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=25,
    batch_size=16,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])

推测解码

distil-large-v3 是第一个可以用作 Whisper large-v3 助手模型进行推测解码的 Distil-Whisper 模型。推测解码在数学上确保获得与 Whisper 完全相同的输出,同时速度提高 2 倍。这使其成为现有 Whisper 流水线的完美替代品,因为可以保证相同的输出。

在以下代码片段中,我们将助手 Distil-Whisper 模型独立加载到主 Whisper 流水线中。然后我们将其指定为生成的"助手模型":

from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

assistant_model_id = "distil-whisper/distil-large-v3"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    generate_kwargs={"assistant_model": assistant_model},
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]

result = pipe(sample)
print(result["text"])

有关推测解码的更多详细信息,请参阅博客文章 Speculative Decoding for 2x Faster Whisper Inference

额外的速度和内存优化

您可以对 Distil-Whisper 应用额外的速度和内存优化,以进一步减少推理速度和 VRAM 需求。这些优化主要针对注意力内核,将其从 eager 实现切换到更高效的 flash attention 版本。

Flash Attention 2

如果您的 GPU 支持,我们建议使用 Flash-Attention 2。为此,您首先需要安装 Flash Attention

pip install flash-attn --no-build-isolation

然后将 attn_implementation="flash_attention_2" 传递给 from_pretrained

- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2")

Torch Scale-Product-Attention (SDPA)

如果您的 GPU 不支持 Flash Attention,我们建议使用 PyTorch scaled dot-product attention (SDPA)。此注意力实现对于 PyTorch 2.1.1 或更高版本默认激活。要检查您是否有兼容的 PyTorch 版本,请运行以下 Python 代码片段:

from transformers.utils import is_torch_sdpa_available

print(is_torch_sdpa_available())

如果上述代码返回 True,则表示您安装了有效版本的 PyTorch,SDPA 默认激活。如果返回 False,则需要按照官方说明升级 PyTorch 版本。

安装有效的 PyTorch 版本后,SDPA 默认激活。也可以通过指定 attn_implementation="sdpa" 来显式设置:

- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")

Torch compile

即将推出...

4-bit 和 8-bit 推理

即将推出...

库集成

Whisper.cpp

Distil-Whisper 可以通过 Whisper.cpp 包使用原始的顺序长音频转录算法运行。在 Mac M1 上的初步基准测试中,distil-large-v3 比 Whisper large-v3 快 5 倍以上,同时在长音频上的 WER 差距在 0.8% 以内。

入门步骤:

  1. 克隆 Whisper.cpp 仓库:
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
  1. 安装 Hugging Face Hub Python 包:
pip install --upgrade huggingface_hub

并使用以下 Python 代码片段下载 distil-large-v3 的 GGML 权重:

from huggingface_hub import hf_hub_download

hf_hub_download(repo_id='distil-whisper/distil-large-v3-ggml', filename='ggml-distil-large-v3.bin', local_dir='./models')

请注意,如果您没有设置 Python 环境,也可以直接使用 wget 下载权重:

wget https://huggingface.co/distil-whisper/distil-large-v3-ggml/resolve/main/ggml-distil-large-v3.bin -P ./models
  1. 使用提供的示例音频运行推理:
make -j && ./main -m models/ggml-distil-large-v3.bin -f samples/jfk.wav

Faster-Whisper

Faster-Whisper 是使用 CTranslate2 对 Whisper 的重新实现,CTranslate2 是一个快速的 Transformer 模型推理引擎。

首先,按照官方说明安装 Faster-Whisper 包。在此示例中,我们还将安装 🤗 Datasets 以从 Hugging Face Hub 加载一个示例音频数据集:

pip install --upgrade pip
pip install --upgrade git+https://github.com/SYSTRAN/faster-whisper datasets[audio]

以下代码片段加载 distil-large-v3 模型并对 LibriSpeech ASR 数据集中的一个示例文件运行推理:

import torch
from faster_whisper import WhisperModel
from datasets import load_dataset

# 定义我们的 torch 配置
device = "cuda:0" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if torch.cuda.is_available() else "float32"

# 如果可用则在 GPU 上加载模型,否则在 cpu 上
model = WhisperModel("distil-large-v3", device=device, compute_type=compute_type)

# 加载示例数据集
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[1]["audio"]["path"]

segments, info = model.transcribe(sample, beam_size=1)

for segment in segments:
    print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))

要转录本地音频文件,只需将音频文件路径作为 audio 参数传递给 transcribe:

segments, info = model.transcribe("audio.mp3", beam_size=1)

OpenAI Whisper

要以原始 Whisper 格式使用该模型,首先确保已安装 openai-whisper 包。 在此示例中,我们还将安装 🤗 Datasets 以从 Hugging Face Hub 加载一个示例音频数据集:

pip install --upgrade pip
pip install --upgrade openai-whisper datasets[audio]

以下代码片段演示了如何转录使用 🤗 Datasets 加载的 LibriSpeech 数据集中的示例文件:

from huggingface_hub import hf_hub_download
from datasets import load_dataset
from whisper import load_model, transcribe

model_path = hf_hub_download(repo_id="distil-whisper/distil-large-v3-openai", filename="model.bin")
model = load_model(model_path)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]["path"]

pred_out = transcribe(model, audio=sample, language="en")
print(pred_out["text"])

请注意,第一次运行示例时,模型权重将被下载并保存到您的缓存中。之后,您可以重复使用相同的示例,权重将直接从缓存加载,无需再次下载。

要转录本地音频文件,只需将音频文件路径作为 audio 参数传递给 transcribe:

pred_out = transcribe(model, audio=sample, language="en")

Distil-Whisper 模型也可以与 OpenAI Whisper CLI 一起使用。有关详细信息,请参阅以下说明

Transformers.js

Distil-Whisper 可以通过 Transformers.js 完全在您的网络浏览器中运行:

  1. NPM 安装 Transformers.js:
npm i @xenova/transformers
  1. 导入库并使用 pipeline API 进行推理。
import { pipeline } from '@xenova/transformers';

const transcriber = await pipeline('automatic-speech-recognition', 'distil-whisper/distil-large-v3');

const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
const output = await transcriber(url);
// { text: " And so, my fellow Americans, ask not what your country can do for you. Ask what you can do for your country." }

查看在线 Distil-Whisper Web 演示 亲自体验一下。 正如您将看到的,它在您的浏览器中本地运行:无需服务器!

有关更多信息,请参阅 Transformers.js 文档

Candle

通过与 Hugging Face Candle 🕯️ 的集成,Distil-Whisper 可在 Rust 库 🦀 中使用

优势包括:

  • 优化的 CPU 后端,Linux x86 可选 MKL 支持,Mac 可选 Accelerate 支持
  • Metal 支持,可在 Mac 上高效运行
  • CUDA 后端,可在 GPU 上高效运行,通过 NCCL 进行多 GPU 分布
  • WASM 支持:在浏览器中运行 Distil-Whisper

入门步骤:

  1. 按照此处说明安装 candle-core:https://huggingface.github.io/candle/guide/installation.html
  2. 在本地克隆 candle 仓库:
git clone https://github.com/huggingface/candle.git
  1. 进入 Whisper 示例目录:
cd candle/candle-examples/examples/whisper
  1. 运行示例:
cargo run --example whisper --release --features symphonia -- --model distil-large-v3
  1. 要指定自己的音频文件,添加 --input 标志:
cargo run --example whisper --release --features symphonia -- --model distil-large-v3 --input audio.wav

提示: 要使用 Apple Metal 编译,请在运行示例时指定 metal 功能:

cargo run --example whisper --release --features="symphonia,metal" -- --model distil-large-v3

请注意,如果您遇到以下错误:

error: target `whisper` in package `candle-examples` requires the features: `symphonia`
Consider enabling them by passing, e.g., `--features="symphonia"`

您应该清理您的 cargo 安装:

cargo clean

然后重新编译:

cargo run --example whisper --release --features symphonia -- --model distil-large-v3

模型详情

Distil-Whisper 继承了 Whisper 的编码器-解码器架构。编码器将语音向量输入序列映射到隐藏状态向量序列。解码器根据所有先前的 token 和编码器隐藏状态自回归地预测文本 token。因此,编码器只运行一次前向传播,而解码器运行的次数与生成的 token 数量相同。在实践中,这意味着解码器占总推理时间的 90% 以上。因此,为了优化延迟,重点是最小化解码器的推理时间。

为了蒸馏 Whisper 模型,我们减少了解码器层的数量,同时保持编码器不变。编码器(绿色显示)完全从教师模型复制到学生模型,并在训练期间冻结。学生模型的解码器由教师模型解码器层的一个子集组成,这些层从最大间隔的层初始化。然后在 KL 散度和伪标签损失项的加权和上训练模型。

<p align="center"> <img src="https://huggingface.co/datasets/distil-whisper/figures/resolve/main/architecture.png?raw=true" width="600"/> </p>

与 distil-large-v2 的区别

与之前版本的 Distil-Whisper 相比,distil-large-v3 专门针对 OpenAI 顺序长音频转录算法进行了优化。与 distil-large-v2 相比没有架构差异,除了模型层是从最新的 large-v3 模型而不是旧的 large-v2 模型初始化的。差异在于模型的训练方式。

之前的 Distil-Whisper 模型是在平均输入长度为 7 秒的数据上训练的,而原始 Whisper 模型是在 30 秒输入上预训练的。在蒸馏过程中,我们将模型权重的分布转移到训练数据的分布。如果我们的训练数据包含较短的语句(例如平均 7 秒的音频而不是 30 秒),那么预测的分布就会转移到这个较短的上下文长度。在推理时,distil-large-v2 的最佳上下文窗口是这两个值的插值:15 秒。超过这个时间,distil-large-v2 模型的预测基本不准确,特别是时间戳预测。然而,顺序长音频算法使用 30 秒滑动窗口进行推理,窗口根据最后预测的时间戳移动。由于最后一个时间戳通常出现在 15 秒标记之后,因此预测精度低,导致长音频转录经常失败。

为了保持 Whisper 转录滑动 30 秒窗口的能力(顺序解码就是这样做的),我们需要确保 distil-large-v3 的上下文长度也是 30 秒。这主要通过四种策略实现:

  1. 将训练数据集中的音频样本打包到 30 秒: 由于模型在打包到 30 秒的音频数据上进行预训练和蒸馏,distil-large-v3 现在与 Whisper 在相同的理想上下文窗口上运行,预测准确的时间戳长达 30 秒。
  2. 冻结解码器输入嵌入: 我们使用与原始模型相同的输入嵌入表示,该表示设计用于处理比之前 Distil-Whisper 迭代更长的上下文长度。
  3. 在训练期间使用更长的最大上下文长度: 不是在 128 的最大目标长度上训练,而是在 256 的最大长度上训练。这有助于 distil-large-v3 转录 30 秒片段,其中 token 数量可能超过 128。
  4. 将提示条件附加到 50% 的训练样本: 使模型能够与 condition_on_prev_tokens 参数一起使用,以及长达 448 个 token 的上下文窗口。

还有更多的技巧被用来提高 distil-large-v3 在顺序解码算法下的性能,这些将在即将发布的博客文章中详细解释。

评估

以下代码片段演示了如何使用流式模式在 LibriSpeech validation-clean 数据集上评估 Distil-Whisper 模型,这意味着无需将音频数据下载到本地设备。

首先,我们需要安装所需的包,包括 🤗 Datasets 用于流式传输和加载音频数据,以及 🤗 Evaluate 用于执行 WER 计算:

pip install --upgrade pip
pip install --upgrade transformers datasets[audio] evaluate jiwer

然后可以使用以下示例端到端运行评估:

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm

# 定义我们的 torch 配置
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-large-v3"

# 加载模型 + 处理器
model =  AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)

# 使用流式模式加载数据集
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)

# 定义评估指标
wer_metric = load("wer")

def inference(batch):
    # 1. 预处理音频数据为 log-mel 频谱图输入
    audio = [sample["array"] for sample in batch["audio"]]
    input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
    input_features = input_features.to(device, dtype=torch_dtype)
    
    # 2. 自回归生成预测的 token id
    pred_ids = model.generate(input_features, max_new_tokens=128)
    
    # 3. 将 token id 解码为最终转录
    batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
    batch["reference"] = batch["text"]
    return batch

# 批量大小 16 推理
dataset = dataset.map(function=inference, batched=True, batch_size=16)

all_transcriptions = []
all_references = []

# 遍历数据集并运行推理
for result in tqdm(dataset, desc="Evaluating..."):
    all_transcriptions.append(result["transcription"])
    all_references.append(result["reference"])

# 规范化预测和参考
all_transcriptions = [processor.normalize(transcription) for transcription in all_transcriptions]
all_references = [processor.normalize(reference) for reference in all_references]

# 计算 WER 指标
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)

打印输出:

2.428920763531516

预期用途

Distil-Whisper 旨在作为 Whisper large-v3 在英语语音识别中的直接替代品。特别是,它在分布外(OOD)测试数据上实现了可比的 WER 结果,同时在短音频和长音频上都快 6 倍。

数据

Distil-Whisper 在来自 Hugging Face Hub 上九个开源、许可宽松的语音数据集的 22,000 小时音频数据上进行训练:

数据集 大小 / h 说话者 领域 许可证
People's Speech 12,000 未知 Internet Archive CC-BY-SA-4.0
Common Voice 13 3,000 未知 旁白维基百科 CC0-1.0
GigaSpeech 2,500 未知 有声读物, 播客, YouTube apache-2.0
Fisher 1,960 11,900 电话对话 LDC
LibriSpeech 960 2,480 有声读物 CC-BY-4.0
VoxPopuli 540 1,310 欧洲议会 CC0
TED-LIUM 450 2,030 TED 演讲 CC-BY-NC-ND 3.0
SwitchBoard 260 540 电话对话 LDC
AMI 100 未知 会议 CC-BY-4.0
总计 21,770 18,260+

合并的数据集涵盖 10 个不同的领域和超过 50k 说话者。该数据集的多样性对于确保蒸馏模型对音频分布和噪声具有鲁棒性至关重要。

然后使用 Whisper large-v3 模型对音频数据进行伪标注:我们使用 Whisper 为训练集中的所有音频生成预测,并将这些作为训练期间的目标标签。使用伪标签确保转录在各数据集之间格式一致,并在训练期间提供序列级蒸馏信号。

WER 过滤器

Whisper 伪标签预测可能会出现错误转录和幻觉。为了确保我们只在准确的伪标签上进行训练,我们在训练期间采用简单的 WER 启发式方法。首先,我们规范化 Whisper 伪标签和每个数据集提供的真实标签。然后计算这些标签之间的 WER。如果 WER 超过指定阈值,我们丢弃该训练样本。否则,我们保留它用于训练。

Distil-Whisper 论文的第 9.2 节展示了此过滤器对提高蒸馏模型下游性能的有效性。我们还将 Distil-Whisper 对幻觉的鲁棒性部分归因于此过滤器。

训练

该模型训练了 80,000 个优化步骤(或 11 个 epoch),批量大小为 256。Tensorboard 训练日志可在以下位置找到:https://huggingface.co/distil-whisper/distil-large-v3/tensorboard?params=scalars#frame

结果

蒸馏模型在分布外(OOD)短音频上的 WER 与 Whisper large-v3 相差在 1.5% 以内,在顺序长音频解码上相差在 1% 以内,在分块长音频上比 large-v3 优越 0.1%。这种性能提升归因于较低的幻觉。

有关评估结果的详细数据集细分,请参阅 Distil-Whisper 论文的表 16 和 17

Distil-Whisper 还作为 OpenASR 排行榜的一部分在 ESB 基准数据集上进行了评估,其 WER 与 Whisper 相差在 0.2% 以内。

复现 Distil-Whisper

复现 Distil-Whisper 的训练和评估代码可在 Distil-Whisper 仓库中找到:https://github.com/huggingface/distil-whisper/tree/main/training

此代码将很快更新,以包含与 distil-large-v2 的区别部分中描述的训练更新。

许可证

Distil-Whisper 继承了 OpenAI Whisper 模型的 MIT 许可证

引用

如果您使用此模型,请考虑引用 Distil-Whisper 论文

@misc{gandhi2023distilwhisper,
      title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling}, 
      author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
      year={2023},
      eprint={2311.00430},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

致谢

IsGarrido/Whisper-Medicalv1

作者 IsGarrido

automatic-speech-recognition transformers
↓ 1 ♥ 0

创建时间: 2026-02-17 14:14:15+00:00

更新时间: 2026-02-17 14:14:15+00:00

在 Hugging Face 上查看

文件 (26)

.gitattributes
README.md
added_tokens.json
config.json
flax_model.msgpack
generation_config.json
merges.txt
model.fp32.safetensors
model.safetensors
normalizer.json
onnx/decoder_model.onnx ONNX
onnx/decoder_model_merged.onnx ONNX
onnx/decoder_model_merged_quantized.onnx ONNX
onnx/decoder_model_quantized.onnx ONNX
onnx/decoder_with_past_model.onnx ONNX
onnx/decoder_with_past_model_quantized.onnx ONNX
onnx/encoder_model.onnx ONNX
onnx/encoder_model.onnx_data
onnx/encoder_model_quantized.onnx ONNX
preprocessor_config.json
runs/first_50k_steps/events.out.tfevents.1707818488.t1v-n-d928564b-w-0.763446.0.v2
runs/last_30k_steps/events.out.tfevents.1708076214.t1v-n-d928564b-w-0.5978.0.v2
special_tokens_map.json
tokenizer.json
tokenizer_config.json
vocab.json