Files
myWiki/papers/goru-one-pass-to-reason-2025.md

107 lines
4.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

---
title: "One-Pass to Reason: 多轮推理的高效单遍微调"
authors: "Ritesh Goru, Shanay Mehta, Prateek Jain (DevRev)"
venue: "ICML 2025 Workshop — Efficient Systems for Foundational Models"
arxiv: "2504.18246"
code: "https://github.com/devrev/One-Pass-to-Reason"
dataset: "https://huggingface.co/datasets/devrev-research/MathChatSync-reasoning"
year: 2025
type: paper
tags: [efficient-fine-tuning, multi-turn-reasoning, attention-mask]
---
# One-Pass to Reason
> **核心思想**:通过 token 复制 + 分块稀疏注意力掩码,将多轮推理对话的 N 遍训练压缩为单遍,时间复杂度从 O(N³) 降至 O(N²)。
## 问题背景
推理模型(如 DeepSeek-R1遵循行业惯例生成推理 token → 输出回复 → 在后续轮次中**丢弃推理 token**。这导致多轮对话微调时,每个对话需要 N 次独立前向传播N = 对话轮数)。
两个核心约束:
1. **[[visibility-constraint|可见性约束]]**:推理 token 在生成时必须可见,但在后续轮次中必须隐藏
2. **[[position-id-discrepancy|位置 ID 偏差]]**:回复 token 在生成时紧跟推理 token但在上下文中的位置紧接着人类消息
## 方法
### Token 复制 ([[token-duplication]])
将每个助手回复的 response token 复制为两份:
- **ri_in**(上下文副本):不关注推理 token作为后续轮次的纯上下文
- **ri_out**(生成副本):关注推理 token参与 loss 计算
### 分块稀疏注意力掩码 ([[block-sparse-attention]])
定义每种 token 类型hi, ti, ri_in, ri_out的可见性规则
- `hi → A(H<i)` — 人类消息只看历史
- `ti → A(H<i, hi)` — 推理 token 看历史+当前人类消息
- `ri_in → A(H<i, hi)` — 上下文副本不看推理
- `ri_out → A(H<i, hi, ti)` — 生成副本看全部包括推理
### 位置 ID 策略
```python
s_ti = s_ri_in = e_hi + 1 # 推理和上下文副本从人类消息后开始
s_ri_out = e_ti + 1 # 生成副本从推理后开始
s_h_{i+1} = e_ri_in + 1 # 下一轮人类消息从上下文副本后开始
```
### 理论保证([[one-pass-fine-tuning|Theorem 2.1]]
**1-Pass 与 N-Pass 的 loss 完全等价**
$$L_{\text{N-Pass}}(c) = L_{\text{1-Pass}}(c)$$
证明分三部分:位置编码等价 → 注意力模式等价 → loss 函数等价。
## 复杂度分析
| 方法 | 时间复杂度 | 空间复杂度 |
|------|-----------|-----------|
| N-Pass | O(N³²d) | O(N²) |
| **1-Pass** | **O(N²²d)** | O(N²) |
响应 token 复制带来约 33% 的额外内存开销(因为 ri 被存了两份),但渐进复杂度相同。
## 实验结果
在 Qwen-3 (4B / 8B / 32B) 上使用 QLoRA + 8×H100
**训练加速**Flex-Pack-1-Pass vs FA2-Pack-N-Pass
- 4B: **1.05×**
- 8B: **1.21×**
- 32B: **1.22×**
vs FlexAttention N-Pass**1.44×1.54×**
**深度扩展**:对话越长加速越明显(验证了 O(N²) vs O(N³) 理论优势)
**K-Pass 中间方案** ([[k-pass-training]])
- K=1最快+33% 内存
- K=21.30×1.37× 加速,+20% 内存
- K>4收益递减
## 数据集
**[[mathchatsync-reasoning|MathChatSync Reasoning]]**:首个公开的多轮推理数据集,基于 MathChatSync用 GPT-4.1-mini 为每个助手回复生成推理 token。
## 实现细节
- 基于 LLaMA-Factory ([[llama-factory]])
- 使用 [[flex-attention|PyTorch FlexAttention]]FlashAttention-2 不支持自定义掩码)
- 掩码生成在 GPU 上向量化执行,用卡诺图化简布尔逻辑
- 支持序列打包 ([[sequence-packing]]) 叠加自定义掩码
## 关键洞察
1. **从 O(N³) 到 O(N²)** 的复杂度降低意味着:对话越长,单遍训练的优势越大
2. Token 复制的本质是**用空间换时间**:多存一份 response 换来一个数量级的加速
3. K-Pass 提供了一个优雅的连续统从完全节省内存N-Pass到完全节省时间1-Pass
## 相关概念
- [[deepseek-r1]] — 典型推理模型
- [[qlora]] — 实验所用的高效微调方法
- [[flash-attention]] — 快速注意力实现
- [[llama-factory]] — 微调框架
- [[multi-turn-reasoning]] — 多轮推理训练问题域