107 lines
4.1 KiB
Markdown
107 lines
4.1 KiB
Markdown
---
|
||
title: "One-Pass to Reason: 多轮推理的高效单遍微调"
|
||
authors: "Ritesh Goru, Shanay Mehta, Prateek Jain (DevRev)"
|
||
venue: "ICML 2025 Workshop — Efficient Systems for Foundational Models"
|
||
arxiv: "2504.18246"
|
||
code: "https://github.com/devrev/One-Pass-to-Reason"
|
||
dataset: "https://huggingface.co/datasets/devrev-research/MathChatSync-reasoning"
|
||
year: 2025
|
||
type: paper
|
||
tags: [efficient-fine-tuning, multi-turn-reasoning, attention-mask]
|
||
---
|
||
|
||
# One-Pass to Reason
|
||
|
||
> **核心思想**:通过 token 复制 + 分块稀疏注意力掩码,将多轮推理对话的 N 遍训练压缩为单遍,时间复杂度从 O(N³) 降至 O(N²)。
|
||
|
||
## 问题背景
|
||
|
||
推理模型(如 DeepSeek-R1)遵循行业惯例:生成推理 token → 输出回复 → 在后续轮次中**丢弃推理 token**。这导致多轮对话微调时,每个对话需要 N 次独立前向传播(N = 对话轮数)。
|
||
|
||
两个核心约束:
|
||
1. **[[visibility-constraint|可见性约束]]**:推理 token 在生成时必须可见,但在后续轮次中必须隐藏
|
||
2. **[[position-id-discrepancy|位置 ID 偏差]]**:回复 token 在生成时紧跟推理 token,但在上下文中的位置紧接着人类消息
|
||
|
||
## 方法
|
||
|
||
### Token 复制 ([[token-duplication]])
|
||
|
||
将每个助手回复的 response token 复制为两份:
|
||
- **ri_in**(上下文副本):不关注推理 token,作为后续轮次的纯上下文
|
||
- **ri_out**(生成副本):关注推理 token,参与 loss 计算
|
||
|
||
### 分块稀疏注意力掩码 ([[block-sparse-attention]])
|
||
|
||
定义每种 token 类型(hi, ti, ri_in, ri_out)的可见性规则:
|
||
- `hi → A(H<i)` — 人类消息只看历史
|
||
- `ti → A(H<i, hi)` — 推理 token 看历史+当前人类消息
|
||
- `ri_in → A(H<i, hi)` — 上下文副本不看推理
|
||
- `ri_out → A(H<i, hi, ti)` — 生成副本看全部包括推理
|
||
|
||
### 位置 ID 策略
|
||
|
||
```python
|
||
s_ti = s_ri_in = e_hi + 1 # 推理和上下文副本从人类消息后开始
|
||
s_ri_out = e_ti + 1 # 生成副本从推理后开始
|
||
s_h_{i+1} = e_ri_in + 1 # 下一轮人类消息从上下文副本后开始
|
||
```
|
||
|
||
### 理论保证([[one-pass-fine-tuning|Theorem 2.1]])
|
||
|
||
**1-Pass 与 N-Pass 的 loss 完全等价**:
|
||
$$L_{\text{N-Pass}}(c) = L_{\text{1-Pass}}(c)$$
|
||
|
||
证明分三部分:位置编码等价 → 注意力模式等价 → loss 函数等价。
|
||
|
||
## 复杂度分析
|
||
|
||
| 方法 | 时间复杂度 | 空间复杂度 |
|
||
|------|-----------|-----------|
|
||
| N-Pass | O(N³ℓ²d) | O(Nℓ²) |
|
||
| **1-Pass** | **O(N²ℓ²d)** | O(Nℓ²) |
|
||
|
||
响应 token 复制带来约 33% 的额外内存开销(因为 ri 被存了两份),但渐进复杂度相同。
|
||
|
||
## 实验结果
|
||
|
||
在 Qwen-3 (4B / 8B / 32B) 上使用 QLoRA + 8×H100:
|
||
|
||
**训练加速**(Flex-Pack-1-Pass vs FA2-Pack-N-Pass):
|
||
- 4B: **1.05×**
|
||
- 8B: **1.21×**
|
||
- 32B: **1.22×**
|
||
|
||
vs FlexAttention N-Pass:**1.44×–1.54×**
|
||
|
||
**深度扩展**:对话越长加速越明显(验证了 O(N²) vs O(N³) 理论优势)
|
||
|
||
**K-Pass 中间方案** ([[k-pass-training]]):
|
||
- K=1:最快,+33% 内存
|
||
- K=2:1.30×–1.37× 加速,+20% 内存
|
||
- K>4:收益递减
|
||
|
||
## 数据集
|
||
|
||
**[[mathchatsync-reasoning|MathChatSync Reasoning]]**:首个公开的多轮推理数据集,基于 MathChatSync,用 GPT-4.1-mini 为每个助手回复生成推理 token。
|
||
|
||
## 实现细节
|
||
|
||
- 基于 LLaMA-Factory ([[llama-factory]])
|
||
- 使用 [[flex-attention|PyTorch FlexAttention]](FlashAttention-2 不支持自定义掩码)
|
||
- 掩码生成在 GPU 上向量化执行,用卡诺图化简布尔逻辑
|
||
- 支持序列打包 ([[sequence-packing]]) 叠加自定义掩码
|
||
|
||
## 关键洞察
|
||
|
||
1. **从 O(N³) 到 O(N²)** 的复杂度降低意味着:对话越长,单遍训练的优势越大
|
||
2. Token 复制的本质是**用空间换时间**:多存一份 response 换来一个数量级的加速
|
||
3. K-Pass 提供了一个优雅的连续统:从完全节省内存(N-Pass)到完全节省时间(1-Pass)
|
||
|
||
## 相关概念
|
||
|
||
- [[deepseek-r1]] — 典型推理模型
|
||
- [[qlora]] — 实验所用的高效微调方法
|
||
- [[flash-attention]] — 快速注意力实现
|
||
- [[llama-factory]] — 微调框架
|
||
- [[multi-turn-reasoning]] — 多轮推理训练问题域
|