ONNX 模型库
返回模型

说明文档

更多信息请参见此处

如何创建此转换

使用下面的脚本将 Voxtral 转换为 ONNX 格式,并暴露注意力权重和位置信息。

<details>

import torch
from torch import nn
from transformers import VoxtralForConditionalGeneration
from transformers.cache_utils import DynamicCache
import os
import onnx

model_id = "mistralai/Voxtral-Mini-3B-2507"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = VoxtralForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="eager",
)
model.to(device)
model.eval()

class DecoderONNXWrapper(nn.Module):
    def __init__(self, language_model):
        super().__init__()
        self.language_model = language_model

    def forward(self, inputs_embeds, attention_mask, *past_key_value_tensors):
        num_layers = self.language_model.config.num_hidden_layers
        legacy_past = tuple(
            (past_key_value_tensors[i*2], past_key_value_tensors[i*2+1]) for i in range(num_layers)
        )
        past_key_values_cache = DynamicCache.from_legacy_cache(past_key_values=legacy_past)

        outputs = self.language_model(
            input_ids=None,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            past_key_values=past_key_values_cache,
            output_attentions=True,
            use_cache=True,
        )

        flat_outputs = [outputs.logits]
        for k, v in zip(outputs.past_key_values.key_cache, outputs.past_key_values.value_cache):
            flat_outputs.extend([k, v])
        for attn in outputs.attentions:
            flat_outputs.append(attn)
        return tuple(flat_outputs)

batch_size = 1
seq_len = 128
past_seq_len = 100
text_config = model.config.text_config
num_layers = text_config.num_hidden_layers
hidden_size = text_config.hidden_size
head_dim = text_config.head_dim
num_kv_heads = text_config.num_key_value_heads

inputs_embeds = torch.randn((batch_size, seq_len, hidden_size), dtype=torch_dtype, device=device)
attention_mask_4d = torch.ones((batch_size, 1, seq_len, past_seq_len + seq_len), dtype=torch_dtype, device=device)
past_key_value_flat_tuple = tuple(
    torch.randn((batch_size, num_kv_heads, past_seq_len, head_dim), dtype=torch_dtype, device=device)
    for _ in range(num_layers * 2)
)
dummy_inputs = (inputs_embeds, attention_mask_4d) + past_key_value_flat_tuple

output_path = "decoder_model_attentive_unpacked.onnx"
input_names = ["inputs_embeds", "attention_mask"] + [f"past_key_values.{i}.{kv}" for i in range(num_layers) for kv in ["key", "value"]]
output_names = ["logits"] + [f"present.{i}.{kv}" for i in range(num_layers) for kv in ["key", "value"]] + [f"attention.{i}" for i in range(num_layers)]

dynamic_axes = {
    "inputs_embeds": {1: "sequence_length"},
    "attention_mask": {2: "sequence_length", 3: "total_sequence_length"},
}
for name in input_names + output_names:
    if "key" in name or "value" in name:
        dynamic_axes[name] = {2: "past_sequence_length"} if "past" in name else {2: "total_sequence_length"}
    elif "attention" in name:
        dynamic_axes[name] = {2: "sequence_length", 3: "total_sequence_length"}

wrapped_model = DecoderONNXWrapper(model.language_model)
wrapped_model.eval()

with torch.no_grad():
    torch.onnx.export(
        wrapped_model,
        dummy_inputs,
        output_path,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        opset_version=17,
    )

import onnx

onnx_model = onnx.load(output_path, load_external_data=True)

data_file_location = "decoder_model_attentive.onnx_data"

onnx.save_model(
    onnx_model,
    "decoder_model_attentive.onnx",
    save_as_external_data=True,
    all_tensors_to_one_file=True,
    location=data_file_location,
)

清理环境:

for fname in os.listdir("."):
    if fname.startswith("language_") or fname.startswith("onnx_"):
        os.remove(os.path.join(my_dir, fname))

</details>

下面的脚本将对给定的 ONNX 文件进行量化(需要同时提供数据文件)。

<details>

import os
from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.utils.constant import FP32
import onnx
import logging
logging.basicConfig(level=logging.INFO)

model_dir = "."
model_fp32 = 'decoder_model_attentive.onnx'
model_quantized = 'decoder_model_attentive_q4_weight_only_inc.onnx'

input_model_path = os.path.join(model_dir, model_fp32)
output_model_path = os.path.join(model_dir, model_quantized)

try:
    if not onnx.checker.check_model(input_model_path):
        print(f"Error: Original model '{input_model_path}' is not a valid ONNX model.")
        exit()
    print(f"Original model '{input_model_path}' is valid.")
except Exception as e:
    print(f"Failed to load or check original model '{input_model_path}': {e}")
    print("Please ensure the original model file exists and is not corrupted.")
    exit()

config = PostTrainingQuantConfig(
    approach="weight_only",
    op_type_dict={
        ".*": {
            "weight": {
                "bits": 4,
                "algorithm": ["RTN"],
                "scheme": ["asym"],
                "group_size": 32,
            }
        }
    },
)

print(f"\nAttempting to quantize '{model_fp32}' to 4-bit weight-only using Neural Compressor...")

try:
    q_model = quantization.fit(
        input_model_path,
        config,
    )

    q_model.save(output_model_path)
    print(f"Model successfully quantized and saved to {output_model_path}")

except Exception as e:
    print(f"Error during Neural Compressor weight-only quantization: {e}")
    print("Please ensure Neural Compressor is installed (`pip install neural_compressor`)")
    print("and that your ONNX Runtime version is compatible.")

</details>

推理并提取词级时间戳

  1. 从本仓库下载两个解码器文件(*.onnx*.onnx_data
  2. 安装依赖,例如 ipython
  3. 下载 audio.wav(任意你想转录的音频文件)
  4. 然后运行以下 Python 代码:
import os
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer
import soundfile as sf
import librosa
import logging
import matplotlib.pyplot as plt
from IPython.display import display, Audio
import torch

def _create_4d_causal_attention_mask(input_shape, past_sequence_length, dtype=np.float32):
    batch_size, sequence_length = input_shape
    total_sequence_length = past_sequence_length + sequence_length

    mask = np.tril(np.ones((total_sequence_length, total_sequence_length), dtype=np.bool_))
    mask = mask[past_sequence_length:, :]

    causal_mask = np.zeros((batch_size, 1, sequence_length, total_sequence_length), dtype=dtype)
    causal_mask[:, :, :, :] = np.where(
        mask[None, None, :, :], 0.0, np.finfo(dtype).min
    )
    return causal_mask

repo_id = "onnx-community/Voxtral-Mini-3B-2507-ONNX"
audio_file_path = "audio.wav"
custom_decoder_path = "decoder_model_attentive_q4_weight_only_inc.onnx"
max_generation_tokens = 999
eos_token_id = 2

print(f"Downloading base model files from {repo_id}...")
local_dir = snapshot_download(
    repo_id=repo_id,
    repo_type="model",
    allow_patterns=["onnx/audio_encoder_q4.*", "onnx/embed_tokens_q4.*", "onnx/decoder_model_merged_q4.*", "tokenizer.json"],
)
onnx_dir = os.path.join(local_dir, "onnx")
tok = Tokenizer.from_file(os.path.join(local_dir, "tokenizer.json"))
bos_id, inst_id, baud_id, aud_id, einst_id = 1, 3, 25, 24, 4
ae_path = os.path.join(onnx_dir, "audio_encoder_q4.onnx")
embed_path = os.path.join(onnx_dir, "embed_tokens_q4.onnx")
if not os.path.exists(custom_decoder_path):
    raise FileNotFoundError(f"Custom ONNX decoder not found at '{custom_decoder_path}'.")
sess_opts = ort.SessionOptions()
session_providers = ["CPUExecutionProvider"]
ae_sess = ort.InferenceSession(ae_path, sess_options=sess_opts, providers=session_providers)
embed_sess = ort.InferenceSession(embed_path, sess_options=sess_opts, providers=session_providers)
dec_sess = ort.InferenceSession(custom_decoder_path, sess_options=sess_opts, providers=session_providers)
num_decoder_layers = sum(1 for i in dec_sess.get_inputs() if i.name.endswith(".key"))
print(f"Detected {num_decoder_layers} decoder layers.")

def extract_mel_features_for_chunk(audio_chunk, sampling_rate=16000, n_fft=400, hop_length=160, n_mels=128, target_length=3000):
    target_samples = sampling_rate * 30
    audio_chunk = librosa.util.fix_length(audio_chunk, size=target_samples)
    mel_spec = librosa.feature.melspectrogram(y=audio_chunk, sr=sampling_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
    log_spec = np.log10(np.maximum(mel_spec, 1e-10))
    log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec[:, :target_length].astype(np.float32)

def process_long_audio(audio_path, session, sampling_rate=16000):
    y, sr = sf.read(audio_path)
    if y.ndim > 1: y = y.mean(axis=1)
    if sr != sampling_rate: y = librosa.resample(y, orig_sr=sr, target_sr=sampling_rate)
    chunk_samples = chunk_duration = 30 * sampling_rate
    num_chunks = int(np.ceil(len(y) / chunk_samples))
    all_audio_embeds = []
    print(f"Processing in {num_chunks} chunk(s)...")
    for i in range(num_chunks):
        chunk = y[i * chunk_samples:(i + 1) * chunk_samples]
        mel_features = extract_mel_features_for_chunk(chunk, sampling_rate)
        all_audio_embeds.append(session.run(None, {session.get_inputs()[0].name: mel_features[None, :]})[0])
    return np.concatenate(all_audio_embeds, axis=0)

if not os.path.exists(audio_file_path): raise FileNotFoundError(f"Audio file '{audio_file_path}' not found.")

audio_embeds_raw = process_long_audio(audio_file_path, ae_sess)
batch_size = 1
audio_output_frames = audio_embeds_raw.shape[0] // batch_size
audio_embeds = audio_embeds_raw.reshape(batch_size, audio_output_frames, -1)
text_instruction_ids = tok.encode("Transcribe.", add_special_tokens=False).ids
prompt_tokens = ([bos_id, inst_id, baud_id] + [aud_id] * audio_output_frames + text_instruction_ids + [einst_id])
initial_sequence_length = len(prompt_tokens)

prompt_ids = np.array([prompt_tokens], dtype=np.int64)
inputs_embeds = embed_sess.run(None, {"input_ids": prompt_ids})[0]
inputs_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
inputs_embeds = inputs_embeds.astype(np.float32)

generated_ids = []
past_key_values = None
current_past_len = 0
for i in range(max_generation_tokens):
    dec_inputs = {}
    if i == 0:
        dec_inputs["inputs_embeds"] = inputs_embeds
        attention_mask = _create_4d_causal_attention_mask((batch_size, initial_sequence_length), 0)
    else:
        last_token_id = np.array([[generated_ids[-1]]], dtype=np.int64)
        dec_inputs["inputs_embeds"] = embed_sess.run(None, {"input_ids": last_token_id})[0].astype(np.float32)
        attention_mask = _create_4d_causal_attention_mask((batch_size, 1), current_past_len)

    dec_inputs["attention_mask"] = attention_mask
    if past_key_values:
        for l in range(num_decoder_layers):
            dec_inputs[f"past_key_values.{l}.key"] = past_key_values[l*2].astype(np.float32)
            dec_inputs[f"past_key_values.{l}.value"] = past_key_values[l*2+1].astype(np.float32)
    else:
        for l in range(num_decoder_layers):
            dec_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)
            dec_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)

    outputs = dec_sess.run(None, dec_inputs)
    logits, past_key_values = outputs[0], outputs[1:1+num_decoder_layers*2]

    next_token_id = np.argmax(logits[0, -1, :])
    if next_token_id == eos_token_id: break
    generated_ids.append(next_token_id)
    print(tok.decode([next_token_id]), end="", flush=True)

full_sequence_ids = np.array([prompt_tokens + generated_ids], dtype=np.int64)
full_embeds = embed_sess.run(None, {"input_ids": full_sequence_ids})[0]
full_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
full_embeds = full_embeds.astype(np.float32)

alignment_inputs = {
    "inputs_embeds": full_embeds,
    "attention_mask": _create_4d_causal_attention_mask(full_embeds.shape[:2], 0)
}
for l in range(num_decoder_layers):
    alignment_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)
    alignment_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)

alignment_outputs = dec_sess.run(None, alignment_inputs)
attentions = [torch.from_numpy(attn) for attn in alignment_outputs[1+num_decoder_layers*2:]]

text_start_idx = len(prompt_tokens)
audio_end_idx = 3 + audio_output_frames
start_layer, end_layer = 10, 20
layer_attentions = []
for i in range(start_layer, end_layer):
    layer_attn = attentions[i][0]
    layer_attn_avg_heads = layer_attn.mean(dim=0)
    relevant_attns = layer_attn_avg_heads[text_start_idx:, 3:audio_end_idx]
    if relevant_attns.numel() > 0:
        layer_attentions.append(relevant_attns)

if not layer_attentions:
    raise ValueError("Could not extract any valid attention weights. The generated text might be empty.")

avg_attentions = torch.stack(layer_attentions).mean(dim=0)
temperature = 0.1
weights = torch.nn.functional.softmax(avg_attentions / temperature, dim=1).cpu().numpy()

plt.figure(figsize=(10, 10))
plt.imshow(weights, aspect="auto", origin="lower", cmap="viridis")
plt.xlabel("Audio Frames")
plt.ylabel("Generated Text Tokens")
plt.title("Audio-to-Text Alignment Matrix (Sharpened)")
plt.colorbar()
plt.savefig("alignment_matrix.png")

cost_matrix = -weights.T
D, wp = librosa.sequence.dtw(C=cost_matrix.astype(np.float32), backtrack=True)
wp = np.flip(wp, axis=0)

token_to_frame_map = {}
for frame_idx, token_idx in wp:
    if token_idx not in token_to_frame_map:
        token_to_frame_map[token_idx] = frame_idx

word_groups = []
current_word_tokens = []
if generated_ids:
    for token_id in generated_ids:
        if tok.decode([token_id]).startswith(" ") and current_word_tokens:
            word_groups.append(current_word_tokens)
            current_word_tokens = []
        current_word_tokens.append(token_id)
    if current_word_tokens: word_groups.append(current_word_tokens)

EFFECTIVE_AUDIO_DURATION = 30.0
AUDIO_TIME_PER_FRAME = EFFECTIVE_AUDIO_DURATION / audio_output_frames
results = []
token_idx_counter = 0
previous_word_end_time = token_to_frame_map.get(0, 0) * AUDIO_TIME_PER_FRAME

for word_group in word_groups:
    word_text = tok.decode(word_group).strip()
    if not word_text: continue

    start_time = previous_word_end_time
    last_token_in_word_idx = token_idx_counter + len(word_group) - 1
    end_frame = token_to_frame_map.get(last_token_in_word_idx, 0)
    end_time = max(start_time, end_frame * AUDIO_TIME_PER_FRAME)
    results.append({"word": word_text, "start": start_time, "end": end_time})
    previous_word_end_time = end_time
    token_idx_counter += len(word_group)

for res in results: print(f"[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")

SAMPLING_RATE = 16000
y, sr = librosa.load(audio_file_path, sr=SAMPLING_RATE)
if not results: print("No words were transcribed to verify.")
else:
    for res in results:
        start_sample = int(res['start'] * SAMPLING_RATE)
        end_sample = int(res['end'] * SAMPLING_RATE)
        audio_snippet = y[start_sample:end_sample]
        print(f"\n[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")
        if len(audio_snippet) > 0: display(Audio(audio_snippet, rate=SAMPLING_RATE))
        else: print("   (No audio for this segment)")

urroxyz/Voxtral-Mini-3B-2507_timestamped

作者 urroxyz

audio-text-to-text transformers.js
↓ 0 ♥ 3

创建时间: 2025-07-25 01:03:23+00:00

更新时间: 2025-07-27 00:47:21+00:00

在 Hugging Face 上查看

文件 (6)

.gitattributes
README.md
decoder_model_attentive.onnx ONNX
decoder_model_attentive.onnx_data
decoder_model_attentive_q4_weight_only_inc.onnx ONNX
decoder_model_attentive_q4_weight_only_inc.onnx_data