返回模型
说明文档
更多信息请参见此处。
如何创建此转换
使用下面的脚本将 Voxtral 转换为 ONNX 格式,并暴露注意力权重和位置信息。
<details>
import torch
from torch import nn
from transformers import VoxtralForConditionalGeneration
from transformers.cache_utils import DynamicCache
import os
import onnx
model_id = "mistralai/Voxtral-Mini-3B-2507"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = VoxtralForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="eager",
)
model.to(device)
model.eval()
class DecoderONNXWrapper(nn.Module):
def __init__(self, language_model):
super().__init__()
self.language_model = language_model
def forward(self, inputs_embeds, attention_mask, *past_key_value_tensors):
num_layers = self.language_model.config.num_hidden_layers
legacy_past = tuple(
(past_key_value_tensors[i*2], past_key_value_tensors[i*2+1]) for i in range(num_layers)
)
past_key_values_cache = DynamicCache.from_legacy_cache(past_key_values=legacy_past)
outputs = self.language_model(
input_ids=None,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values_cache,
output_attentions=True,
use_cache=True,
)
flat_outputs = [outputs.logits]
for k, v in zip(outputs.past_key_values.key_cache, outputs.past_key_values.value_cache):
flat_outputs.extend([k, v])
for attn in outputs.attentions:
flat_outputs.append(attn)
return tuple(flat_outputs)
batch_size = 1
seq_len = 128
past_seq_len = 100
text_config = model.config.text_config
num_layers = text_config.num_hidden_layers
hidden_size = text_config.hidden_size
head_dim = text_config.head_dim
num_kv_heads = text_config.num_key_value_heads
inputs_embeds = torch.randn((batch_size, seq_len, hidden_size), dtype=torch_dtype, device=device)
attention_mask_4d = torch.ones((batch_size, 1, seq_len, past_seq_len + seq_len), dtype=torch_dtype, device=device)
past_key_value_flat_tuple = tuple(
torch.randn((batch_size, num_kv_heads, past_seq_len, head_dim), dtype=torch_dtype, device=device)
for _ in range(num_layers * 2)
)
dummy_inputs = (inputs_embeds, attention_mask_4d) + past_key_value_flat_tuple
output_path = "decoder_model_attentive_unpacked.onnx"
input_names = ["inputs_embeds", "attention_mask"] + [f"past_key_values.{i}.{kv}" for i in range(num_layers) for kv in ["key", "value"]]
output_names = ["logits"] + [f"present.{i}.{kv}" for i in range(num_layers) for kv in ["key", "value"]] + [f"attention.{i}" for i in range(num_layers)]
dynamic_axes = {
"inputs_embeds": {1: "sequence_length"},
"attention_mask": {2: "sequence_length", 3: "total_sequence_length"},
}
for name in input_names + output_names:
if "key" in name or "value" in name:
dynamic_axes[name] = {2: "past_sequence_length"} if "past" in name else {2: "total_sequence_length"}
elif "attention" in name:
dynamic_axes[name] = {2: "sequence_length", 3: "total_sequence_length"}
wrapped_model = DecoderONNXWrapper(model.language_model)
wrapped_model.eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model,
dummy_inputs,
output_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
)
import onnx
onnx_model = onnx.load(output_path, load_external_data=True)
data_file_location = "decoder_model_attentive.onnx_data"
onnx.save_model(
onnx_model,
"decoder_model_attentive.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location=data_file_location,
)
清理环境:
for fname in os.listdir("."):
if fname.startswith("language_") or fname.startswith("onnx_"):
os.remove(os.path.join(my_dir, fname))
</details>
下面的脚本将对给定的 ONNX 文件进行量化(需要同时提供数据文件)。
<details>
import os
from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.utils.constant import FP32
import onnx
import logging
logging.basicConfig(level=logging.INFO)
model_dir = "."
model_fp32 = 'decoder_model_attentive.onnx'
model_quantized = 'decoder_model_attentive_q4_weight_only_inc.onnx'
input_model_path = os.path.join(model_dir, model_fp32)
output_model_path = os.path.join(model_dir, model_quantized)
try:
if not onnx.checker.check_model(input_model_path):
print(f"Error: Original model '{input_model_path}' is not a valid ONNX model.")
exit()
print(f"Original model '{input_model_path}' is valid.")
except Exception as e:
print(f"Failed to load or check original model '{input_model_path}': {e}")
print("Please ensure the original model file exists and is not corrupted.")
exit()
config = PostTrainingQuantConfig(
approach="weight_only",
op_type_dict={
".*": {
"weight": {
"bits": 4,
"algorithm": ["RTN"],
"scheme": ["asym"],
"group_size": 32,
}
}
},
)
print(f"\nAttempting to quantize '{model_fp32}' to 4-bit weight-only using Neural Compressor...")
try:
q_model = quantization.fit(
input_model_path,
config,
)
q_model.save(output_model_path)
print(f"Model successfully quantized and saved to {output_model_path}")
except Exception as e:
print(f"Error during Neural Compressor weight-only quantization: {e}")
print("Please ensure Neural Compressor is installed (`pip install neural_compressor`)")
print("and that your ONNX Runtime version is compatible.")
</details>
推理并提取词级时间戳
- 从本仓库下载两个解码器文件(
*.onnx和*.onnx_data) - 安装依赖,例如
ipython - 下载
audio.wav(任意你想转录的音频文件) - 然后运行以下 Python 代码:
import os
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer
import soundfile as sf
import librosa
import logging
import matplotlib.pyplot as plt
from IPython.display import display, Audio
import torch
def _create_4d_causal_attention_mask(input_shape, past_sequence_length, dtype=np.float32):
batch_size, sequence_length = input_shape
total_sequence_length = past_sequence_length + sequence_length
mask = np.tril(np.ones((total_sequence_length, total_sequence_length), dtype=np.bool_))
mask = mask[past_sequence_length:, :]
causal_mask = np.zeros((batch_size, 1, sequence_length, total_sequence_length), dtype=dtype)
causal_mask[:, :, :, :] = np.where(
mask[None, None, :, :], 0.0, np.finfo(dtype).min
)
return causal_mask
repo_id = "onnx-community/Voxtral-Mini-3B-2507-ONNX"
audio_file_path = "audio.wav"
custom_decoder_path = "decoder_model_attentive_q4_weight_only_inc.onnx"
max_generation_tokens = 999
eos_token_id = 2
print(f"Downloading base model files from {repo_id}...")
local_dir = snapshot_download(
repo_id=repo_id,
repo_type="model",
allow_patterns=["onnx/audio_encoder_q4.*", "onnx/embed_tokens_q4.*", "onnx/decoder_model_merged_q4.*", "tokenizer.json"],
)
onnx_dir = os.path.join(local_dir, "onnx")
tok = Tokenizer.from_file(os.path.join(local_dir, "tokenizer.json"))
bos_id, inst_id, baud_id, aud_id, einst_id = 1, 3, 25, 24, 4
ae_path = os.path.join(onnx_dir, "audio_encoder_q4.onnx")
embed_path = os.path.join(onnx_dir, "embed_tokens_q4.onnx")
if not os.path.exists(custom_decoder_path):
raise FileNotFoundError(f"Custom ONNX decoder not found at '{custom_decoder_path}'.")
sess_opts = ort.SessionOptions()
session_providers = ["CPUExecutionProvider"]
ae_sess = ort.InferenceSession(ae_path, sess_options=sess_opts, providers=session_providers)
embed_sess = ort.InferenceSession(embed_path, sess_options=sess_opts, providers=session_providers)
dec_sess = ort.InferenceSession(custom_decoder_path, sess_options=sess_opts, providers=session_providers)
num_decoder_layers = sum(1 for i in dec_sess.get_inputs() if i.name.endswith(".key"))
print(f"Detected {num_decoder_layers} decoder layers.")
def extract_mel_features_for_chunk(audio_chunk, sampling_rate=16000, n_fft=400, hop_length=160, n_mels=128, target_length=3000):
target_samples = sampling_rate * 30
audio_chunk = librosa.util.fix_length(audio_chunk, size=target_samples)
mel_spec = librosa.feature.melspectrogram(y=audio_chunk, sr=sampling_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
log_spec = np.log10(np.maximum(mel_spec, 1e-10))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec[:, :target_length].astype(np.float32)
def process_long_audio(audio_path, session, sampling_rate=16000):
y, sr = sf.read(audio_path)
if y.ndim > 1: y = y.mean(axis=1)
if sr != sampling_rate: y = librosa.resample(y, orig_sr=sr, target_sr=sampling_rate)
chunk_samples = chunk_duration = 30 * sampling_rate
num_chunks = int(np.ceil(len(y) / chunk_samples))
all_audio_embeds = []
print(f"Processing in {num_chunks} chunk(s)...")
for i in range(num_chunks):
chunk = y[i * chunk_samples:(i + 1) * chunk_samples]
mel_features = extract_mel_features_for_chunk(chunk, sampling_rate)
all_audio_embeds.append(session.run(None, {session.get_inputs()[0].name: mel_features[None, :]})[0])
return np.concatenate(all_audio_embeds, axis=0)
if not os.path.exists(audio_file_path): raise FileNotFoundError(f"Audio file '{audio_file_path}' not found.")
audio_embeds_raw = process_long_audio(audio_file_path, ae_sess)
batch_size = 1
audio_output_frames = audio_embeds_raw.shape[0] // batch_size
audio_embeds = audio_embeds_raw.reshape(batch_size, audio_output_frames, -1)
text_instruction_ids = tok.encode("Transcribe.", add_special_tokens=False).ids
prompt_tokens = ([bos_id, inst_id, baud_id] + [aud_id] * audio_output_frames + text_instruction_ids + [einst_id])
initial_sequence_length = len(prompt_tokens)
prompt_ids = np.array([prompt_tokens], dtype=np.int64)
inputs_embeds = embed_sess.run(None, {"input_ids": prompt_ids})[0]
inputs_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
inputs_embeds = inputs_embeds.astype(np.float32)
generated_ids = []
past_key_values = None
current_past_len = 0
for i in range(max_generation_tokens):
dec_inputs = {}
if i == 0:
dec_inputs["inputs_embeds"] = inputs_embeds
attention_mask = _create_4d_causal_attention_mask((batch_size, initial_sequence_length), 0)
else:
last_token_id = np.array([[generated_ids[-1]]], dtype=np.int64)
dec_inputs["inputs_embeds"] = embed_sess.run(None, {"input_ids": last_token_id})[0].astype(np.float32)
attention_mask = _create_4d_causal_attention_mask((batch_size, 1), current_past_len)
dec_inputs["attention_mask"] = attention_mask
if past_key_values:
for l in range(num_decoder_layers):
dec_inputs[f"past_key_values.{l}.key"] = past_key_values[l*2].astype(np.float32)
dec_inputs[f"past_key_values.{l}.value"] = past_key_values[l*2+1].astype(np.float32)
else:
for l in range(num_decoder_layers):
dec_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)
dec_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)
outputs = dec_sess.run(None, dec_inputs)
logits, past_key_values = outputs[0], outputs[1:1+num_decoder_layers*2]
next_token_id = np.argmax(logits[0, -1, :])
if next_token_id == eos_token_id: break
generated_ids.append(next_token_id)
print(tok.decode([next_token_id]), end="", flush=True)
full_sequence_ids = np.array([prompt_tokens + generated_ids], dtype=np.int64)
full_embeds = embed_sess.run(None, {"input_ids": full_sequence_ids})[0]
full_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
full_embeds = full_embeds.astype(np.float32)
alignment_inputs = {
"inputs_embeds": full_embeds,
"attention_mask": _create_4d_causal_attention_mask(full_embeds.shape[:2], 0)
}
for l in range(num_decoder_layers):
alignment_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)
alignment_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)
alignment_outputs = dec_sess.run(None, alignment_inputs)
attentions = [torch.from_numpy(attn) for attn in alignment_outputs[1+num_decoder_layers*2:]]
text_start_idx = len(prompt_tokens)
audio_end_idx = 3 + audio_output_frames
start_layer, end_layer = 10, 20
layer_attentions = []
for i in range(start_layer, end_layer):
layer_attn = attentions[i][0]
layer_attn_avg_heads = layer_attn.mean(dim=0)
relevant_attns = layer_attn_avg_heads[text_start_idx:, 3:audio_end_idx]
if relevant_attns.numel() > 0:
layer_attentions.append(relevant_attns)
if not layer_attentions:
raise ValueError("Could not extract any valid attention weights. The generated text might be empty.")
avg_attentions = torch.stack(layer_attentions).mean(dim=0)
temperature = 0.1
weights = torch.nn.functional.softmax(avg_attentions / temperature, dim=1).cpu().numpy()
plt.figure(figsize=(10, 10))
plt.imshow(weights, aspect="auto", origin="lower", cmap="viridis")
plt.xlabel("Audio Frames")
plt.ylabel("Generated Text Tokens")
plt.title("Audio-to-Text Alignment Matrix (Sharpened)")
plt.colorbar()
plt.savefig("alignment_matrix.png")
cost_matrix = -weights.T
D, wp = librosa.sequence.dtw(C=cost_matrix.astype(np.float32), backtrack=True)
wp = np.flip(wp, axis=0)
token_to_frame_map = {}
for frame_idx, token_idx in wp:
if token_idx not in token_to_frame_map:
token_to_frame_map[token_idx] = frame_idx
word_groups = []
current_word_tokens = []
if generated_ids:
for token_id in generated_ids:
if tok.decode([token_id]).startswith(" ") and current_word_tokens:
word_groups.append(current_word_tokens)
current_word_tokens = []
current_word_tokens.append(token_id)
if current_word_tokens: word_groups.append(current_word_tokens)
EFFECTIVE_AUDIO_DURATION = 30.0
AUDIO_TIME_PER_FRAME = EFFECTIVE_AUDIO_DURATION / audio_output_frames
results = []
token_idx_counter = 0
previous_word_end_time = token_to_frame_map.get(0, 0) * AUDIO_TIME_PER_FRAME
for word_group in word_groups:
word_text = tok.decode(word_group).strip()
if not word_text: continue
start_time = previous_word_end_time
last_token_in_word_idx = token_idx_counter + len(word_group) - 1
end_frame = token_to_frame_map.get(last_token_in_word_idx, 0)
end_time = max(start_time, end_frame * AUDIO_TIME_PER_FRAME)
results.append({"word": word_text, "start": start_time, "end": end_time})
previous_word_end_time = end_time
token_idx_counter += len(word_group)
for res in results: print(f"[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")
SAMPLING_RATE = 16000
y, sr = librosa.load(audio_file_path, sr=SAMPLING_RATE)
if not results: print("No words were transcribed to verify.")
else:
for res in results:
start_sample = int(res['start'] * SAMPLING_RATE)
end_sample = int(res['end'] * SAMPLING_RATE)
audio_snippet = y[start_sample:end_sample]
print(f"\n[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")
if len(audio_snippet) > 0: display(Audio(audio_snippet, rate=SAMPLING_RATE))
else: print(" (No audio for this segment)")
urroxyz/Voxtral-Mini-3B-2507_timestamped
作者 urroxyz
audio-text-to-text
transformers.js
↓ 0
♥ 3
创建时间: 2025-07-25 01:03:23+00:00
更新时间: 2025-07-27 00:47:21+00:00
在 Hugging Face 上查看文件 (6)
.gitattributes
README.md
decoder_model_attentive.onnx
ONNX
decoder_model_attentive.onnx_data
decoder_model_attentive_q4_weight_only_inc.onnx
ONNX
decoder_model_attentive_q4_weight_only_inc.onnx_data