Files
myWiki/raw/papers/goru-one-pass-to-reason-2025.md

44 lines
2.7 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

---
title: "One-Pass to Reason: Token Duplication and Block-Sparse Mask for Efficient Fine-Tuning on Multi-Turn Reasoning"
authors: "Ritesh Goru, Shanay Mehta, Prateek Jain"
venue: "ICML 2025 Workshop: 3rd Workshop on Efficient Systems for Foundational Models"
year: 2025
arxiv: "2504.18246"
code: "https://github.com/devrev/One-Pass-to-Reason"
dataset: "https://huggingface.co/datasets/devrev-research/MathChatSync-reasoning"
type: paper
tags: [efficient-fine-tuning, multi-turn-reasoning, attention-mask, token-duplication, single-pass-training]
---
## Abstract
Fine-tuning Large Language Models (LLMs) on multi-turn reasoning datasets requires N (number of turns) separate forward passes per conversation due to reasoning token visibility constraints, as reasoning tokens for a turn are discarded in subsequent turns. We propose duplicating response tokens along with a custom attention mask to enable single-pass processing of entire conversations. We prove our method produces identical losses to the N-pass approach while reducing time complexity from O(N³) to O(N²) and maintaining the same memory complexity for a transformer based model. Our approach achieves significant training speedup while preserving accuracy. Our implementation is available online.
## Core Problem
Reasoning models (e.g., DeepSeek-R1) generate internal reasoning tokens, produce a response, and then discard the reasoning tokens from context in subsequent turns. This creates:
1. **Visibility Constraints**: Reasoning tokens must be visible during generation but hidden from subsequent turns — static attention masks cannot satisfy this
2. **Position ID Discrepancy**: Response tokens follow reasoning tokens during generation but directly follow human messages in later context
## Method
1. **Token Duplication**: Duplicate response tokens so ri_in (context copy) does not attend to reasoning, while ri_out (generation copy) does
2. **Custom Block-Sparse Attention Mask**: Single mask with visibility rules per token type
3. **Strategic Position ID Assignment**: Maintains correct relative positions equivalent to N-pass
4. **Theorem 2.1**: Proves loss equivalence L_N-Pass(c) = L_1-Pass(c)
## Results
- 1.05×1.22× faster than FlashAttention-2 N-Pass with packing (Qwen-3 4B, 8B, 32B)
- 1.44×1.54× faster than FlexAttention N-Pass with packing
- ~33% more GPU memory
- Speedups grow with conversation depth (O(N²) vs O(N³) theoretical advantage)
- K-Pass variant allows speedmemory trade-off
## Key Contributions
1. Theoretical framework for single-pass multi-turn reasoning training
2. MathChatSync Reasoning dataset (first public multi-turn reasoning dataset with explicit per-turn reasoning)
3. Comprehensive empirical validation on Qwen-3 models using QLoRA