ONNX 模型库
返回模型

说明文档

基于 Wav2Vec 2.0 的语音结束检测

语音结束检测模型基于 Meta AI 的开源 Wav2Vec 2.0 模型。它使用卷积特征编码器,将原始音频输入块转换为潜在的语音表示,并使用 transformer 来捕获整个表示序列中的信息。这有助于模型区分不同的音高下降,以及语调中的最终延长(及其后的停顿),从而区分何时发生语音结束事件——就像我们人类一样。

训练数据

训练数据由 Mozilla Firefox 基金会的 Common Voice 16.0 英语音频数据集构建而成。该数据集采用宽松的 CC0 1.0 许可证。

为了训练 wav2vec 2.0 模型进行语音结束检测,我们需要一个足够大的数据集,其中包含语音结束和非语音结束的样本。由于没有任何包含此类现成样本的开源数据集,我们需要自行构建一个。Common Voice 数据集由仅包含一个口语句子的音频样本组成。

不幸的是,音频样本的开头和结尾有额外的噪声/空白音频。为了去除这些并仅捕获对应于口语句子的音频,我们需要句子的时间戳,或者更好的是,词级别的时间戳。这是借助 whisperX 实现的。这样我们就可以捕获句子的开始和结束时间,并删除其前后的任何内容。

清理样本后,我们通过随机抽样验证了该程序的正确性。之后,我们将音频样本的最后 700/704 毫秒标记为语音结束事件,其之前的所有内容标记为非语音结束。

最后,我们还通过向两个方向移动 700/704 毫秒窗口,向数据集添加了重叠片段。

输入

该模型使用 700 和 704 毫秒(11x64ms)的原始音频输入进行训练。采样率为 16kHz。在实验中,我们测试了不同的长度(300ms、500ms 和 1 秒),700/704ms 被证明是性能足够好和最短块之间的平衡点。

输出

该模型将每个音频输入分类为 2 个类别 - eos(id: 0)和 not_eos(id: 1)。

使用方法

from transformers import Wav2Vec2Processor, AutoConfig
import onnxruntime as rt
import torch
import torch.nn.functional as F
import numpy as np
import os
import torchaudio


class EndOfSpeechDetection:
    processor: Wav2Vec2Processor
    config: AutoConfig
    session: rt.InferenceSession

    def load_model(self, path, use_gpu=False):
        processor = Wav2Vec2Processor.from_pretrained(path)
        config = AutoConfig.from_pretrained(path)

        sess_options = rt.SessionOptions()
        sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL

        providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
        session = rt.InferenceSession(
            os.path.join(path, "model.onnx"), sess_options, providers=providers
        )
        return processor, config, session

    def predict(self, segment, file_type="pcm"):
        if file_type == "pcm":
            # pcm files
            speech_array = np.memmap(segment, dtype="float32", mode="r").astype(
                np.float32
            )
        else:
            # wave files
            speech_array, _ = torchaudio.load(segment)
            speech_array = speech_array[0].numpy()

        features = self.processor(
            speech_array, sampling_rate=16000, return_tensors="pt", padding=True
        )
        input_values = features.input_values
        outputs = self.session.run(
            [self.session.get_outputs()[-1].name],
            {self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()},
        )[0]
        softmax_output = F.softmax(torch.tensor(outputs), dim=1)

        both_classes_with_prob = {
            self.config.id2label[i]: softmax_output[0][i].item()
            for i in range(len(softmax_output[0]))
        }

        return both_classes_with_prob


if __name__ == "__main__":
    eos = EndOfSpeechDetection()
    eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx")
    print(eos.predict("some.pcm", file_type="pcm"))

延迟(和内存)优化

  • 知识蒸馏
  • Onnx 格式权重
    • 权重已转换为 Onnx 格式(以优化 CPU 和 GPU 性能)
    • 在 AMD Instinct MI100 GPU 上测试 - 每 704ms 音频块的推理时间低于 10ms

评估

在 8120 个测试样本上的准确率为 0.95。

类别 精确率 召回率 f1分数 样本数
eos 0.94 0.95 0.95 4060
not_eos 0.95 0.94 0.95 4060

telnyx/wav2vec2-end-of-speech-detection

作者 telnyx

audio-classification
↓ 0 ♥ 6

创建时间: 2024-09-13 00:00:52+00:00

更新时间: 2024-09-13 07:39:06+00:00

在 Hugging Face 上查看

文件 (17)

.gitattributes
5sec_audio.wav
README.md
eos-model-onnx/config.json
eos-model-onnx/model.onnx ONNX
eos-model-onnx/preprocessor_config.json
eos-model-onnx/special_tokens_map.json
eos-model-onnx/tokenizer_config.json
eos-model-onnx/vocab.json
inference.py
segments/segment_0.wav
segments/segment_1.wav
segments/segment_2.wav
segments/segment_3.wav
segments/segment_4.wav
segments/segment_5.wav
segments/segment_6.wav