返回模型
说明文档
TinyStories 词汇级 LSTM (ONNX)
一个紧凑的 10.9 MB 词汇级 LSTM 语言模型,在 TinyStories 数据集上训练。能够以极低的计算量生成简短、连贯的儿童故事。
- 词汇量:5,004(词汇级,包含
<PAD>、<UNK>、<SOS>、<EOS>) - 架构:2层 LSTM,256 个隐藏单元,128 维嵌入
- 最大序列长度:50 个 token
- 格式:ONNX(兼容 CPU/GPU 上的 ONNX Runtime)
- 模型大小:约 11 MB
在 T4 GPU (x2) 上,使用 50 万条 TinyStories 样本进行训练,耗时不到 5 分钟。
使用方法
1. 安装依赖
pip install onnxruntime numpy
2. 从本仓库下载文件
3. 运行推理
import numpy as np
import onnxruntime as ort
# --- 加载词汇表 ---
with open("vocab.txt", "r") as f:
vocab = [line.strip() for line in f]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}
# 特殊 token
SOS_IDX = word2idx["<SOS>"]
EOS_IDX = word2idx["<EOS>"]
PAD_IDX = word2idx["<PAD>"]
UNK_IDX = word2idx["<UNK>"]
# --- 分词器(简单的词汇级)---
def tokenize(text):
import re
# 转小写并分割标点符号
text = re.sub(r'([.,!?])', r' \1 ', text.lower())
return text.split()
# --- 加载 ONNX 模型 ---
ort_session = ort.InferenceSession("tinystories_lstm.onnx")
# --- 文本生成函数 ---
def generate_text(prompt, max_new_tokens=30, temperature=0.8):
# 对提示词进行分词
tokens = tokenize(prompt)
input_ids = [SOS_IDX] + [
word2idx.get(t, UNK_IDX) for t in tokens
]
# 填充至长度 1(我们将以自回归方式生成)
current_seq = input_ids.copy()
for _ in range(max_new_tokens):
# 将当前序列填充至长度 50(模型需要固定长度)
padded = current_seq + [PAD_IDX] * (50 - len(current_seq))
if len(padded) > 50:
padded = padded[-50:] # 如果太长则截断
input_tensor = np.array([padded], dtype=np.int64)
# 运行模型
outputs = ort_session.run(None, {"input": input_tensor})
logits = outputs[0] # 形状: (1, 50, 5004)
# 获取最后一个非填充 token 的 logits
last_pos = min(len(current_seq) - 1, 49)
next_token_logits = logits[0, last_pos, :] / temperature
# Softmax + 采样
probs = np.exp(next_token_logits - np.max(next_token_logits))
probs = probs / np.sum(probs)
next_token = np.random.choice(len(probs), p=probs)
if next_token == EOS_IDX:
break
current_seq.append(next_token)
# 解码
words = [idx2word[idx] for idx in current_seq[1:] if idx != PAD_IDX]
return " ".join(words).replace(" .", ".").replace(" ,", ",")
# --- 示例用法 ---
if __name__ == "____main__":
prompt = "once upon a time"
story = generate_text(prompt, max_new_tokens=40, temperature=0.7)
print(f"Prompt: {prompt}")
print(f"Story: {story}")
示例输出
Prompt: once upon a time
Story: once upon a time there was a little girl named lily. she loved to play in the garden. one day she found a magic flower that could talk!
训练细节
- 数据集:
roneneldan/TinyStories(50 万条训练样本) - 优化器:Adam(学习率=0.002)
- 批量大小:128
- 训练轮数:2
- 硬件:NVIDIA T4 (x2)
- 训练时间:约 5 分钟
局限性
- 词汇级建模 → 无法很好地处理词汇表之外的词
- 固定上下文窗口(50 个 token)
- 无集束搜索(使用基本采样)
phmd/TinyStories-LSTM-5.5M
作者 phmd
text-generation
↓ 0
♥ 0
创建时间: 2025-10-11 20:18:33+00:00
更新时间: 2025-10-12 11:09:51+00:00
在 Hugging Face 上查看文件 (5)
.gitattributes
README.md
config.json
tinystories_lstm.onnx
ONNX
vocab.txt