Files
myWiki/papers/goru-one-pass-to-reason-2025.md

4.1 KiB
Raw Blame History

title, authors, venue, arxiv, code, dataset, year, type, tags
title authors venue arxiv code dataset year type tags
One-Pass to Reason: 多轮推理的高效单遍微调 Ritesh Goru, Shanay Mehta, Prateek Jain (DevRev) ICML 2025 Workshop — Efficient Systems for Foundational Models 2504.18246 https://github.com/devrev/One-Pass-to-Reason https://huggingface.co/datasets/devrev-research/MathChatSync-reasoning 2025 paper
efficient-fine-tuning
multi-turn-reasoning
attention-mask

One-Pass to Reason

核心思想:通过 token 复制 + 分块稀疏注意力掩码,将多轮推理对话的 N 遍训练压缩为单遍,时间复杂度从 O(N³) 降至 O(N²)。

问题背景

推理模型(如 DeepSeek-R1遵循行业惯例生成推理 token → 输出回复 → 在后续轮次中丢弃推理 token。这导致多轮对话微调时,每个对话需要 N 次独立前向传播N = 对话轮数)。

两个核心约束:

  1. visibility-constraint:推理 token 在生成时必须可见,但在后续轮次中必须隐藏
  2. position-id-discrepancy:回复 token 在生成时紧跟推理 token但在上下文中的位置紧接着人类消息

方法

Token 复制 (token-duplication)

将每个助手回复的 response token 复制为两份:

  • ri_in(上下文副本):不关注推理 token作为后续轮次的纯上下文
  • ri_out(生成副本):关注推理 token参与 loss 计算

分块稀疏注意力掩码 (block-sparse-attention)

定义每种 token 类型hi, ti, ri_in, ri_out的可见性规则

  • hi → A(H<i) — 人类消息只看历史
  • ti → A(H<i, hi) — 推理 token 看历史+当前人类消息
  • ri_in → A(H<i, hi) — 上下文副本不看推理
  • ri_out → A(H<i, hi, ti) — 生成副本看全部包括推理

位置 ID 策略

s_ti = s_ri_in = e_hi + 1   # 推理和上下文副本从人类消息后开始
s_ri_out = e_ti + 1          # 生成副本从推理后开始
s_h_{i+1} = e_ri_in + 1     # 下一轮人类消息从上下文副本后开始

理论保证(one-pass-fine-tuning

1-Pass 与 N-Pass 的 loss 完全等价

L_{\text{N-Pass}}(c) = L_{\text{1-Pass}}(c)

证明分三部分:位置编码等价 → 注意力模式等价 → loss 函数等价。

复杂度分析

方法 时间复杂度 空间复杂度
N-Pass O(N³²d) O(N²)
1-Pass O(N²²d) O(N²)

响应 token 复制带来约 33% 的额外内存开销(因为 ri 被存了两份),但渐进复杂度相同。

实验结果

在 Qwen-3 (4B / 8B / 32B) 上使用 QLoRA + 8×H100

训练加速Flex-Pack-1-Pass vs FA2-Pack-N-Pass

  • 4B: 1.05×
  • 8B: 1.21×
  • 32B: 1.22×

vs FlexAttention N-Pass1.44×1.54×

深度扩展:对话越长加速越明显(验证了 O(N²) vs O(N³) 理论优势)

K-Pass 中间方案 (k-pass-training)

  • K=1最快+33% 内存
  • K=21.30×1.37× 加速,+20% 内存
  • K>4收益递减

数据集

mathchatsync-reasoning:首个公开的多轮推理数据集,基于 MathChatSync用 GPT-4.1-mini 为每个助手回复生成推理 token。

实现细节

  • 基于 LLaMA-Factory (llama-factory)
  • 使用 flex-attentionFlashAttention-2 不支持自定义掩码)
  • 掩码生成在 GPU 上向量化执行,用卡诺图化简布尔逻辑
  • 支持序列打包 (sequence-packing) 叠加自定义掩码

关键洞察

  1. 从 O(N³) 到 O(N²) 的复杂度降低意味着:对话越长,单遍训练的优势越大
  2. Token 复制的本质是用空间换时间:多存一份 response 换来一个数量级的加速
  3. K-Pass 提供了一个优雅的连续统从完全节省内存N-Pass到完全节省时间1-Pass

相关概念