ONNX 模型库
返回模型

说明文档

TinyStories 词汇级 LSTM (ONNX)

一个紧凑的 10.9 MB 词汇级 LSTM 语言模型,在 TinyStories 数据集上训练。能够以极低的计算量生成简短、连贯的儿童故事。

  • 词汇量:5,004(词汇级,包含 <PAD><UNK><SOS><EOS>
  • 架构:2层 LSTM,256 个隐藏单元,128 维嵌入
  • 最大序列长度:50 个 token
  • 格式:ONNX(兼容 CPU/GPU 上的 ONNX Runtime)
  • 模型大小:约 11 MB

在 T4 GPU (x2) 上,使用 50 万条 TinyStories 样本进行训练,耗时不到 5 分钟。

使用方法

1. 安装依赖

pip install onnxruntime numpy

2. 从本仓库下载文件

3. 运行推理

import numpy as np
import onnxruntime as ort

# --- 加载词汇表 ---
with open("vocab.txt", "r") as f:
    vocab = [line.strip() for line in f]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

# 特殊 token
SOS_IDX = word2idx["<SOS>"]
EOS_IDX = word2idx["<EOS>"]
PAD_IDX = word2idx["<PAD>"]
UNK_IDX = word2idx["<UNK>"]

# --- 分词器(简单的词汇级)---
def tokenize(text):
    import re
    # 转小写并分割标点符号
    text = re.sub(r'([.,!?])', r' \1 ', text.lower())
    return text.split()

# --- 加载 ONNX 模型 ---
ort_session = ort.InferenceSession("tinystories_lstm.onnx")

# --- 文本生成函数 ---
def generate_text(prompt, max_new_tokens=30, temperature=0.8):
    # 对提示词进行分词
    tokens = tokenize(prompt)
    input_ids = [SOS_IDX] + [
        word2idx.get(t, UNK_IDX) for t in tokens
    ]
    
    # 填充至长度 1(我们将以自回归方式生成)
    current_seq = input_ids.copy()
    
    for _ in range(max_new_tokens):
        # 将当前序列填充至长度 50(模型需要固定长度)
        padded = current_seq + [PAD_IDX] * (50 - len(current_seq))
        if len(padded) > 50:
            padded = padded[-50:]  # 如果太长则截断
        
        input_tensor = np.array([padded], dtype=np.int64)
        
        # 运行模型
        outputs = ort_session.run(None, {"input": input_tensor})
        logits = outputs[0]  # 形状: (1, 50, 5004)
        
        # 获取最后一个非填充 token 的 logits
        last_pos = min(len(current_seq) - 1, 49)
        next_token_logits = logits[0, last_pos, :] / temperature
        
        # Softmax + 采样
        probs = np.exp(next_token_logits - np.max(next_token_logits))
        probs = probs / np.sum(probs)
        next_token = np.random.choice(len(probs), p=probs)
        
        if next_token == EOS_IDX:
            break
        current_seq.append(next_token)
    
    # 解码
    words = [idx2word[idx] for idx in current_seq[1:] if idx != PAD_IDX]
    return " ".join(words).replace(" .", ".").replace(" ,", ",")

# --- 示例用法 ---
if __name__ == "____main__":
    prompt = "once upon a time"
    story = generate_text(prompt, max_new_tokens=40, temperature=0.7)
    print(f"Prompt: {prompt}")
    print(f"Story:  {story}")

示例输出

Prompt: once upon a time
Story:  once upon a time there was a little girl named lily. she loved to play in the garden. one day she found a magic flower that could talk!

训练细节

  • 数据集:roneneldan/TinyStories(50 万条训练样本)
  • 优化器:Adam(学习率=0.002)
  • 批量大小:128
  • 训练轮数:2
  • 硬件:NVIDIA T4 (x2)
  • 训练时间:约 5 分钟

局限性

  • 词汇级建模 → 无法很好地处理词汇表之外的词
  • 固定上下文窗口(50 个 token)
  • 无集束搜索(使用基本采样)

phmd/TinyStories-LSTM-5.5M

作者 phmd

text-generation
↓ 0 ♥ 0

创建时间: 2025-10-11 20:18:33+00:00

更新时间: 2025-10-12 11:09:51+00:00

在 Hugging Face 上查看

文件 (5)

.gitattributes
README.md
config.json
tinystories_lstm.onnx ONNX
vocab.txt