4.1 KiB
4.1 KiB
title, authors, venue, arxiv, code, dataset, year, type, tags
| title | authors | venue | arxiv | code | dataset | year | type | tags | |||
|---|---|---|---|---|---|---|---|---|---|---|---|
| One-Pass to Reason: 多轮推理的高效单遍微调 | Ritesh Goru, Shanay Mehta, Prateek Jain (DevRev) | ICML 2025 Workshop — Efficient Systems for Foundational Models | 2504.18246 | https://github.com/devrev/One-Pass-to-Reason | https://huggingface.co/datasets/devrev-research/MathChatSync-reasoning | 2025 | paper |
|
One-Pass to Reason
核心思想:通过 token 复制 + 分块稀疏注意力掩码,将多轮推理对话的 N 遍训练压缩为单遍,时间复杂度从 O(N³) 降至 O(N²)。
问题背景
推理模型(如 DeepSeek-R1)遵循行业惯例:生成推理 token → 输出回复 → 在后续轮次中丢弃推理 token。这导致多轮对话微调时,每个对话需要 N 次独立前向传播(N = 对话轮数)。
两个核心约束:
- visibility-constraint:推理 token 在生成时必须可见,但在后续轮次中必须隐藏
- position-id-discrepancy:回复 token 在生成时紧跟推理 token,但在上下文中的位置紧接着人类消息
方法
Token 复制 (token-duplication)
将每个助手回复的 response token 复制为两份:
- ri_in(上下文副本):不关注推理 token,作为后续轮次的纯上下文
- ri_out(生成副本):关注推理 token,参与 loss 计算
分块稀疏注意力掩码 (block-sparse-attention)
定义每种 token 类型(hi, ti, ri_in, ri_out)的可见性规则:
hi → A(H<i)— 人类消息只看历史ti → A(H<i, hi)— 推理 token 看历史+当前人类消息ri_in → A(H<i, hi)— 上下文副本不看推理ri_out → A(H<i, hi, ti)— 生成副本看全部包括推理
位置 ID 策略
s_ti = s_ri_in = e_hi + 1 # 推理和上下文副本从人类消息后开始
s_ri_out = e_ti + 1 # 生成副本从推理后开始
s_h_{i+1} = e_ri_in + 1 # 下一轮人类消息从上下文副本后开始
理论保证(one-pass-fine-tuning)
1-Pass 与 N-Pass 的 loss 完全等价:
L_{\text{N-Pass}}(c) = L_{\text{1-Pass}}(c)
证明分三部分:位置编码等价 → 注意力模式等价 → loss 函数等价。
复杂度分析
| 方法 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| N-Pass | O(N³ℓ²d) | O(Nℓ²) |
| 1-Pass | O(N²ℓ²d) | O(Nℓ²) |
响应 token 复制带来约 33% 的额外内存开销(因为 ri 被存了两份),但渐进复杂度相同。
实验结果
在 Qwen-3 (4B / 8B / 32B) 上使用 QLoRA + 8×H100:
训练加速(Flex-Pack-1-Pass vs FA2-Pack-N-Pass):
- 4B: 1.05×
- 8B: 1.21×
- 32B: 1.22×
vs FlexAttention N-Pass:1.44×–1.54×
深度扩展:对话越长加速越明显(验证了 O(N²) vs O(N³) 理论优势)
K-Pass 中间方案 (k-pass-training):
- K=1:最快,+33% 内存
- K=2:1.30×–1.37× 加速,+20% 内存
- K>4:收益递减
数据集
mathchatsync-reasoning:首个公开的多轮推理数据集,基于 MathChatSync,用 GPT-4.1-mini 为每个助手回复生成推理 token。
实现细节
- 基于 LLaMA-Factory (llama-factory)
- 使用 flex-attention(FlashAttention-2 不支持自定义掩码)
- 掩码生成在 GPU 上向量化执行,用卡诺图化简布尔逻辑
- 支持序列打包 (sequence-packing) 叠加自定义掩码
关键洞察
- 从 O(N³) 到 O(N²) 的复杂度降低意味着:对话越长,单遍训练的优势越大
- Token 复制的本质是用空间换时间:多存一份 response 换来一个数量级的加速
- K-Pass 提供了一个优雅的连续统:从完全节省内存(N-Pass)到完全节省时间(1-Pass)
相关概念
- deepseek-r1 — 典型推理模型
- qlora — 实验所用的高效微调方法
- flash-attention — 快速注意力实现
- llama-factory — 微调框架
- multi-turn-reasoning — 多轮推理训练问题域